carlfeynman commited on
Commit
9c06ef3
β€’
1 Parent(s): ea785e1

batchnorm2d replaced with layernorm2d

Browse files
Files changed (2) hide show
  1. classifier.pth +0 -0
  2. mnist_classifier.ipynb +71 -54
classifier.pth CHANGED
Binary files a/classifier.pth and b/classifier.pth differ
 
mnist_classifier.ipynb CHANGED
@@ -46,7 +46,7 @@
46
  "output_type": "stream",
47
  "text": [
48
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
49
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 112.23it/s]\n"
50
  ]
51
  }
52
  ],
@@ -139,7 +139,7 @@
139
  },
140
  {
141
  "cell_type": "code",
142
- "execution_count": 147,
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
@@ -154,7 +154,7 @@
154
  },
155
  {
156
  "cell_type": "code",
157
- "execution_count": 148,
158
  "metadata": {},
159
  "outputs": [],
160
  "source": [
@@ -168,7 +168,7 @@
168
  "\n",
169
  "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
170
  " return nn.Sequential(\n",
171
- " conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n",
172
  " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
173
  " )\n",
174
  "\n",
@@ -186,7 +186,7 @@
186
  },
187
  {
188
  "cell_type": "code",
189
- "execution_count": 149,
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
@@ -203,19 +203,30 @@
203
  "\n",
204
  "def cnn_classifier():\n",
205
  " return nn.Sequential(\n",
206
- " ResBlock(1, 8,),\n",
207
- " ResBlock(8, 16, ),\n",
208
- " ResBlock(16, 32,),\n",
209
- " ResBlock(32, 64, ),\n",
210
- " ResBlock(64, 64,),\n",
211
  " conv(64, 10, act=False),\n",
212
  " nn.Flatten(),\n",
213
- " )"
 
 
 
 
 
 
 
 
 
 
 
214
  ]
215
  },
216
  {
217
  "cell_type": "code",
218
- "execution_count": 150,
219
  "metadata": {},
220
  "outputs": [],
221
  "source": [
@@ -226,7 +237,7 @@
226
  },
227
  {
228
  "cell_type": "code",
229
- "execution_count": 151,
230
  "metadata": {
231
  "tags": [
232
  "exclude"
@@ -237,16 +248,16 @@
237
  "name": "stdout",
238
  "output_type": "stream",
239
  "text": [
240
- "train, epoch:1, loss: 1.3684, accuracy: 0.5153\n",
241
- "eval, epoch:1, loss: 0.4238, accuracy: 0.8648\n",
242
- "train, epoch:2, loss: 0.2660, accuracy: 0.9162\n",
243
- "eval, epoch:2, loss: 0.1468, accuracy: 0.9552\n",
244
- "train, epoch:3, loss: 0.1479, accuracy: 0.9545\n",
245
- "eval, epoch:3, loss: 0.1101, accuracy: 0.9647\n",
246
- "train, epoch:4, loss: 0.1149, accuracy: 0.9650\n",
247
- "eval, epoch:4, loss: 0.0997, accuracy: 0.9705\n",
248
- "train, epoch:5, loss: 0.2118, accuracy: 0.9399\n",
249
- "eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
250
  ]
251
  }
252
  ],
@@ -282,7 +293,7 @@
282
  },
283
  {
284
  "cell_type": "code",
285
- "execution_count": 152,
286
  "metadata": {
287
  "tags": [
288
  "exclude"
@@ -293,11 +304,11 @@
293
  "name": "stdout",
294
  "output_type": "stream",
295
  "text": [
296
- "eval, epoch:1, loss: 0.1625, accuracy: 0.9478\n",
297
- "eval, epoch:2, loss: 0.1625, accuracy: 0.9478\n",
298
- "eval, epoch:3, loss: 0.1625, accuracy: 0.9478\n",
299
- "eval, epoch:4, loss: 0.1625, accuracy: 0.9478\n",
300
- "eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
301
  ]
302
  }
303
  ],
@@ -320,7 +331,7 @@
320
  },
321
  {
322
  "cell_type": "code",
323
- "execution_count": 153,
324
  "metadata": {
325
  "tags": [
326
  "exclude"
@@ -329,7 +340,7 @@
329
  "outputs": [
330
  {
331
  "data": {
332
- "image/png": "",
333
  "text/plain": [
334
  "<Figure size 720x720 with 5 Axes>"
335
  ]
@@ -355,7 +366,7 @@
355
  },
356
  {
357
  "cell_type": "code",
358
- "execution_count": 158,
359
  "metadata": {
360
  "tags": [
361
  "exclude"
@@ -368,7 +379,7 @@
368
  },
369
  {
370
  "cell_type": "code",
371
- "execution_count": 159,
372
  "metadata": {},
373
  "outputs": [],
374
  "source": [
@@ -379,7 +390,7 @@
379
  },
380
  {
381
  "cell_type": "code",
382
- "execution_count": 160,
383
  "metadata": {
384
  "tags": [
385
  "exclude"
@@ -388,7 +399,7 @@
388
  "outputs": [
389
  {
390
  "data": {
391
- "image/png": "",
392
  "text/plain": [
393
  "<Figure size 720x720 with 5 Axes>"
394
  ]
@@ -415,7 +426,7 @@
415
  },
416
  {
417
  "cell_type": "code",
418
- "execution_count": 161,
419
  "metadata": {
420
  "tags": [
421
  "exclude"
@@ -424,7 +435,7 @@
424
  "outputs": [
425
  {
426
  "data": {
427
- "image/png": "",
428
  "text/plain": [
429
  "<Figure size 720x720 with 5 Axes>"
430
  ]
@@ -449,7 +460,7 @@
449
  },
450
  {
451
  "cell_type": "code",
452
- "execution_count": 164,
453
  "metadata": {},
454
  "outputs": [],
455
  "source": [
@@ -465,7 +476,7 @@
465
  },
466
  {
467
  "cell_type": "code",
468
- "execution_count": 167,
469
  "metadata": {
470
  "tags": [
471
  "exclude"
@@ -476,25 +487,32 @@
476
  "name": "stdout",
477
  "output_type": "stream",
478
  "text": [
479
- "tensor(4)\n"
 
 
 
 
 
 
 
480
  ]
481
  },
482
  {
483
  "data": {
484
  "text/plain": [
485
- "[{'digit': 0, 'prob': '0.12%', 'logits': tensor(-1.1319)},\n",
486
- " {'digit': 1, 'prob': '0.00%', 'logits': tensor(-4.7852)},\n",
487
- " {'digit': 2, 'prob': '2.15%', 'logits': tensor(1.7912)},\n",
488
- " {'digit': 3, 'prob': '0.07%', 'logits': tensor(-1.6584)},\n",
489
- " {'digit': 4, 'prob': '97.03%', 'logits': tensor(5.5990)},\n",
490
- " {'digit': 5, 'prob': '0.01%', 'logits': tensor(-3.5289)},\n",
491
- " {'digit': 6, 'prob': '0.00%', 'logits': tensor(-4.4016)},\n",
492
- " {'digit': 7, 'prob': '0.09%', 'logits': tensor(-1.3343)},\n",
493
- " {'digit': 8, 'prob': '0.07%', 'logits': tensor(-1.6577)},\n",
494
- " {'digit': 9, 'prob': '0.45%', 'logits': tensor(0.2194)}]"
495
  ]
496
  },
497
- "execution_count": 167,
498
  "metadata": {},
499
  "output_type": "execute_result"
500
  }
@@ -518,7 +536,7 @@
518
  },
519
  {
520
  "cell_type": "code",
521
- "execution_count": 168,
522
  "metadata": {
523
  "tags": [
524
  "exclude"
@@ -529,8 +547,7 @@
529
  "name": "stdout",
530
  "output_type": "stream",
531
  "text": [
532
- "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
533
- "[NbConvertApp] Writing 3187 bytes to mnist_classifier.py\n"
534
  ]
535
  }
536
  ],
 
46
  "output_type": "stream",
47
  "text": [
48
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
49
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 75.69it/s]\n"
50
  ]
51
  }
52
  ],
 
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": 8,
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
 
154
  },
155
  {
156
  "cell_type": "code",
157
+ "execution_count": 47,
158
  "metadata": {},
159
  "outputs": [],
160
  "source": [
 
168
  "\n",
169
  "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
170
  " return nn.Sequential(\n",
171
+ " conv(ni, nf, ks=ks, s=1, norm=None, act=act),\n",
172
  " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
173
  " )\n",
174
  "\n",
 
186
  },
187
  {
188
  "cell_type": "code",
189
+ "execution_count": 48,
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
 
203
  "\n",
204
  "def cnn_classifier():\n",
205
  " return nn.Sequential(\n",
206
+ " ResBlock(1, 8, norm=nn.LayerNorm([8, 14, 14])),\n",
207
+ " ResBlock(8, 16, norm=nn.LayerNorm([16, 7, 7])),\n",
208
+ " ResBlock(16, 32, norm=nn.LayerNorm([32, 4, 4])),\n",
209
+ " ResBlock(32, 64, norm=nn.LayerNorm([64, 2, 2])),\n",
210
+ " ResBlock(64, 64, norm=nn.LayerNorm([64, 1, 1])),\n",
211
  " conv(64, 10, act=False),\n",
212
  " nn.Flatten(),\n",
213
+ " )\n",
214
+ "\n",
215
+ "# def cnn_classifier():\n",
216
+ "# return nn.Sequential(\n",
217
+ "# ResBlock(1, 8,),\n",
218
+ "# ResBlock(8, 16, ),\n",
219
+ "# ResBlock(16, 32,),\n",
220
+ "# ResBlock(32, 64, ),\n",
221
+ "# ResBlock(64, 64,),\n",
222
+ "# conv(64, 10, act=False),\n",
223
+ "# nn.Flatten(),\n",
224
+ "# )"
225
  ]
226
  },
227
  {
228
  "cell_type": "code",
229
+ "execution_count": 49,
230
  "metadata": {},
231
  "outputs": [],
232
  "source": [
 
237
  },
238
  {
239
  "cell_type": "code",
240
+ "execution_count": 50,
241
  "metadata": {
242
  "tags": [
243
  "exclude"
 
248
  "name": "stdout",
249
  "output_type": "stream",
250
  "text": [
251
+ "train, epoch:1, loss: 1.8902, accuracy: 0.3183\n",
252
+ "eval, epoch:1, loss: 1.0976, accuracy: 0.6274\n",
253
+ "train, epoch:2, loss: 0.5929, accuracy: 0.8003\n",
254
+ "eval, epoch:2, loss: 0.2895, accuracy: 0.9102\n",
255
+ "train, epoch:3, loss: 0.2396, accuracy: 0.9264\n",
256
+ "eval, epoch:3, loss: 0.1343, accuracy: 0.9597\n",
257
+ "train, epoch:4, loss: 0.1139, accuracy: 0.9651\n",
258
+ "eval, epoch:4, loss: 0.0801, accuracy: 0.9763\n",
259
+ "train, epoch:5, loss: 0.1368, accuracy: 0.9582\n",
260
+ "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
261
  ]
262
  }
263
  ],
 
293
  },
294
  {
295
  "cell_type": "code",
296
+ "execution_count": 51,
297
  "metadata": {
298
  "tags": [
299
  "exclude"
 
304
  "name": "stdout",
305
  "output_type": "stream",
306
  "text": [
307
+ "eval, epoch:1, loss: 0.0882, accuracy: 0.9722\n",
308
+ "eval, epoch:2, loss: 0.0882, accuracy: 0.9722\n",
309
+ "eval, epoch:3, loss: 0.0882, accuracy: 0.9722\n",
310
+ "eval, epoch:4, loss: 0.0882, accuracy: 0.9722\n",
311
+ "eval, epoch:5, loss: 0.0882, accuracy: 0.9722\n"
312
  ]
313
  }
314
  ],
 
331
  },
332
  {
333
  "cell_type": "code",
334
+ "execution_count": 52,
335
  "metadata": {
336
  "tags": [
337
  "exclude"
 
340
  "outputs": [
341
  {
342
  "data": {
343
+ "image/png": "",
344
  "text/plain": [
345
  "<Figure size 720x720 with 5 Axes>"
346
  ]
 
366
  },
367
  {
368
  "cell_type": "code",
369
+ "execution_count": 53,
370
  "metadata": {
371
  "tags": [
372
  "exclude"
 
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 54,
383
  "metadata": {},
384
  "outputs": [],
385
  "source": [
 
390
  },
391
  {
392
  "cell_type": "code",
393
+ "execution_count": 55,
394
  "metadata": {
395
  "tags": [
396
  "exclude"
 
399
  "outputs": [
400
  {
401
  "data": {
402
+ "image/png": "",
403
  "text/plain": [
404
  "<Figure size 720x720 with 5 Axes>"
405
  ]
 
426
  },
427
  {
428
  "cell_type": "code",
429
+ "execution_count": 56,
430
  "metadata": {
431
  "tags": [
432
  "exclude"
 
435
  "outputs": [
436
  {
437
  "data": {
438
+ "image/png": "",
439
  "text/plain": [
440
  "<Figure size 720x720 with 5 Axes>"
441
  ]
 
460
  },
461
  {
462
  "cell_type": "code",
463
+ "execution_count": 57,
464
  "metadata": {},
465
  "outputs": [],
466
  "source": [
 
476
  },
477
  {
478
  "cell_type": "code",
479
+ "execution_count": 58,
480
  "metadata": {
481
  "tags": [
482
  "exclude"
 
487
  "name": "stdout",
488
  "output_type": "stream",
489
  "text": [
490
+ "tensor(3)\n"
491
+ ]
492
+ },
493
+ {
494
+ "name": "stderr",
495
+ "output_type": "stream",
496
+ "text": [
497
+ "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
498
  ]
499
  },
500
  {
501
  "data": {
502
  "text/plain": [
503
+ "[{'digit': 0, 'prob': '0.00%', 'logits': tensor(-5.5980)},\n",
504
+ " {'digit': 1, 'prob': '0.00%', 'logits': tensor(-0.4972)},\n",
505
+ " {'digit': 2, 'prob': '0.02%', 'logits': tensor(1.2516)},\n",
506
+ " {'digit': 3, 'prob': '99.95%', 'logits': tensor(9.9263)},\n",
507
+ " {'digit': 4, 'prob': '0.00%', 'logits': tensor(-5.5094)},\n",
508
+ " {'digit': 5, 'prob': '0.01%', 'logits': tensor(0.2367)},\n",
509
+ " {'digit': 6, 'prob': '0.00%', 'logits': tensor(-9.4633)},\n",
510
+ " {'digit': 7, 'prob': '0.00%', 'logits': tensor(-2.4315)},\n",
511
+ " {'digit': 8, 'prob': '0.02%', 'logits': tensor(1.4733)},\n",
512
+ " {'digit': 9, 'prob': '0.00%', 'logits': tensor(-0.0205)}]"
513
  ]
514
  },
515
+ "execution_count": 58,
516
  "metadata": {},
517
  "output_type": "execute_result"
518
  }
 
536
  },
537
  {
538
  "cell_type": "code",
539
+ "execution_count": 59,
540
  "metadata": {
541
  "tags": [
542
  "exclude"
 
547
  "name": "stdout",
548
  "output_type": "stream",
549
  "text": [
550
+ "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n"
 
551
  ]
552
  }
553
  ],