carlfeynman commited on
Commit
8e35bc7
β€’
1 Parent(s): 056ab4f

resblock added

Browse files
Files changed (3) hide show
  1. mlp_classifier.pkl +0 -0
  2. mnist_classifier.ipynb +117 -71
  3. mnist_classifier.py +61 -23
mlp_classifier.pkl CHANGED
Binary files a/mlp_classifier.pkl and b/mlp_classifier.pkl differ
 
mnist_classifier.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -23,7 +23,7 @@
23
  },
24
  {
25
  "cell_type": "code",
26
- "execution_count": 2,
27
  "metadata": {},
28
  "outputs": [
29
  {
@@ -31,7 +31,7 @@
31
  "output_type": "stream",
32
  "text": [
33
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
34
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 69.76it/s]\n"
35
  ]
36
  }
37
  ],
@@ -43,7 +43,7 @@
43
  },
44
  {
45
  "cell_type": "code",
46
- "execution_count": 3,
47
  "metadata": {},
48
  "outputs": [
49
  {
@@ -70,7 +70,7 @@
70
  },
71
  {
72
  "cell_type": "code",
73
- "execution_count": 4,
74
  "metadata": {},
75
  "outputs": [
76
  {
@@ -79,7 +79,7 @@
79
  "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
80
  ]
81
  },
82
- "execution_count": 4,
83
  "metadata": {},
84
  "output_type": "execute_result"
85
  }
@@ -89,7 +89,7 @@
89
  "class DataLoaders:\n",
90
  " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
91
  " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
92
- " self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
93
  "\n",
94
  "def collate_fn(b):\n",
95
  " collate = default_collate(b)\n",
@@ -102,7 +102,7 @@
102
  },
103
  {
104
  "cell_type": "code",
105
- "execution_count": 5,
106
  "metadata": {},
107
  "outputs": [],
108
  "source": [
@@ -117,7 +117,7 @@
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": 43,
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
@@ -135,23 +135,23 @@
135
  },
136
  {
137
  "cell_type": "code",
138
- "execution_count": 44,
139
  "metadata": {},
140
  "outputs": [
141
  {
142
  "name": "stdout",
143
  "output_type": "stream",
144
  "text": [
145
- "train, epoch:1, loss: 0.2640, accuracy: 0.7885\n",
146
- "eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n",
147
- "train, epoch:2, loss: 0.2368, accuracy: 0.9182\n",
148
- "eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n",
149
- "train, epoch:3, loss: 0.1951, accuracy: 0.9402\n",
150
- "eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n",
151
- "train, epoch:4, loss: 0.1511, accuracy: 0.9513\n",
152
- "eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n",
153
- "train, epoch:5, loss: 0.1182, accuracy: 0.9567\n",
154
- "eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n"
155
  ]
156
  }
157
  ],
@@ -184,7 +184,7 @@
184
  },
185
  {
186
  "cell_type": "code",
187
- "execution_count": 46,
188
  "metadata": {
189
  "tags": [
190
  "exclude"
@@ -192,42 +192,97 @@
192
  },
193
  "outputs": [],
194
  "source": [
195
- "with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
196
- " pickle.dump(model, model_file)"
197
  ]
198
  },
199
  {
200
  "cell_type": "code",
201
- "execution_count": 35,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  "metadata": {},
203
  "outputs": [],
204
  "source": [
205
  "def cnn_classifier():\n",
206
- " ks,stride = 3,2\n",
207
  " return nn.Sequential(\n",
208
- " nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
209
- " nn.BatchNorm2d(8),\n",
210
- " nn.ReLU(),\n",
211
- " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
212
- " nn.BatchNorm2d(16),\n",
213
- " nn.ReLU(),\n",
214
- " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
215
- " nn.BatchNorm2d(32),\n",
216
- " nn.ReLU(),\n",
217
- " nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
218
- " nn.BatchNorm2d(64),\n",
219
- " nn.ReLU(),\n",
220
- " nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
221
- " nn.BatchNorm2d(64),\n",
222
- " nn.ReLU(),\n",
223
- " nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
224
  " nn.Flatten(),\n",
225
- " )"
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  ]
227
  },
228
  {
229
  "cell_type": "code",
230
- "execution_count": 36,
231
  "metadata": {},
232
  "outputs": [],
233
  "source": [
@@ -238,23 +293,23 @@
238
  },
239
  {
240
  "cell_type": "code",
241
- "execution_count": 37,
242
  "metadata": {},
243
  "outputs": [
244
  {
245
  "name": "stdout",
246
  "output_type": "stream",
247
  "text": [
248
- "train, epoch:1, loss: 0.1096, accuracy: 0.9145\n",
249
- "eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n",
250
- "train, epoch:2, loss: 0.0487, accuracy: 0.9808\n",
251
- "eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n",
252
- "train, epoch:3, loss: 0.0536, accuracy: 0.9840\n",
253
- "eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n",
254
- "train, epoch:4, loss: 0.0358, accuracy: 0.9842\n",
255
- "eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n",
256
- "train, epoch:5, loss: 0.0514, accuracy: 0.9852\n",
257
- "eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n"
258
  ]
259
  }
260
  ],
@@ -266,7 +321,6 @@
266
  "epochs = 5\n",
267
  "opt = optim.AdamW(model.parameters(), lr=lr)\n",
268
  "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
269
- "\n",
270
  "for epoch in range(epochs):\n",
271
  " for train in (True, False):\n",
272
  " accuracy = 0\n",
@@ -283,13 +337,12 @@
283
  " if train:\n",
284
  " sched.step()\n",
285
  " accuracy /= len(dl)\n",
286
- " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
287
- " "
288
  ]
289
  },
290
  {
291
  "cell_type": "code",
292
- "execution_count": 41,
293
  "metadata": {
294
  "tags": [
295
  "exclude"
@@ -297,8 +350,8 @@
297
  },
298
  "outputs": [],
299
  "source": [
300
- "with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
301
- " pickle.dump(model, model_file)"
302
  ]
303
  },
304
  {
@@ -314,7 +367,7 @@
314
  },
315
  {
316
  "cell_type": "code",
317
- "execution_count": 1,
318
  "metadata": {
319
  "tags": [
320
  "exclude"
@@ -325,22 +378,15 @@
325
  "name": "stdout",
326
  "output_type": "stream",
327
  "text": [
328
- "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n"
 
329
  ]
330
  }
331
  ],
332
  "source": [
333
- "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n",
334
- "\n"
335
  ]
336
  },
337
- {
338
- "cell_type": "code",
339
- "execution_count": null,
340
- "metadata": {},
341
- "outputs": [],
342
- "source": []
343
- },
344
  {
345
  "cell_type": "code",
346
  "execution_count": null,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 60,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
23
  },
24
  {
25
  "cell_type": "code",
26
+ "execution_count": 61,
27
  "metadata": {},
28
  "outputs": [
29
  {
 
31
  "output_type": "stream",
32
  "text": [
33
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
34
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 71.77it/s]\n"
35
  ]
36
  }
37
  ],
 
43
  },
44
  {
45
  "cell_type": "code",
46
+ "execution_count": 62,
47
  "metadata": {},
48
  "outputs": [
49
  {
 
70
  },
71
  {
72
  "cell_type": "code",
73
+ "execution_count": 87,
74
  "metadata": {},
75
  "outputs": [
76
  {
 
79
  "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
80
  ]
81
  },
82
+ "execution_count": 87,
83
  "metadata": {},
84
  "output_type": "execute_result"
85
  }
 
89
  "class DataLoaders:\n",
90
  " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
91
  " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
92
+ " self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
93
  "\n",
94
  "def collate_fn(b):\n",
95
  " collate = default_collate(b)\n",
 
102
  },
103
  {
104
  "cell_type": "code",
105
+ "execution_count": 77,
106
  "metadata": {},
107
  "outputs": [],
108
  "source": [
 
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": 78,
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
 
135
  },
136
  {
137
  "cell_type": "code",
138
+ "execution_count": 79,
139
  "metadata": {},
140
  "outputs": [
141
  {
142
  "name": "stdout",
143
  "output_type": "stream",
144
  "text": [
145
+ "train, epoch:1, loss: 0.3142, accuracy: 0.7951\n",
146
+ "eval, epoch:1, loss: 0.2298, accuracy: 0.9048\n",
147
+ "train, epoch:2, loss: 0.2198, accuracy: 0.9204\n",
148
+ "eval, epoch:2, loss: 0.1663, accuracy: 0.9350\n",
149
+ "train, epoch:3, loss: 0.1776, accuracy: 0.9420\n",
150
+ "eval, epoch:3, loss: 0.1267, accuracy: 0.9493\n",
151
+ "train, epoch:4, loss: 0.1328, accuracy: 0.9568\n",
152
+ "eval, epoch:4, loss: 0.0959, accuracy: 0.9598\n",
153
+ "train, epoch:5, loss: 0.1038, accuracy: 0.9637\n",
154
+ "eval, epoch:5, loss: 0.0913, accuracy: 0.9643\n"
155
  ]
156
  }
157
  ],
 
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": 81,
188
  "metadata": {
189
  "tags": [
190
  "exclude"
 
192
  },
193
  "outputs": [],
194
  "source": [
195
+ "# with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
196
+ "# pickle.dump(model, model_file)"
197
  ]
198
  },
199
  {
200
  "cell_type": "code",
201
+ "execution_count": 82,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n",
206
+ "# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n",
207
+ "# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n",
208
+ "\n",
209
+ "# class ResBlock(nn.Module):\n",
210
+ "# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n",
211
+ "# super().__init__()\n",
212
+ "# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n",
213
+ "# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n",
214
+ "# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
215
+ "# self.act = act()\n",
216
+ "\n",
217
+ "# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 83,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
227
+ " layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n",
228
+ " if norm:\n",
229
+ " layers.append(norm)\n",
230
+ " if act:\n",
231
+ " layers.append(act())\n",
232
+ " return nn.Sequential(*layers)\n",
233
+ "\n",
234
+ "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
235
+ " return nn.Sequential(\n",
236
+ " conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n",
237
+ " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
238
+ " )\n",
239
+ "\n",
240
+ "class ResBlock(nn.Module):\n",
241
+ " def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n",
242
+ " super().__init__()\n",
243
+ " self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n",
244
+ " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n",
245
+ " self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
246
+ " self.act = act()\n",
247
+ " \n",
248
+ " def forward(self, x):\n",
249
+ " return self.act(self.convs(x) + self.idconv(self.pool(x)))"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 92,
255
  "metadata": {},
256
  "outputs": [],
257
  "source": [
258
  "def cnn_classifier():\n",
 
259
  " return nn.Sequential(\n",
260
+ " ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n",
261
+ " ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n",
262
+ " ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
263
+ " ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
264
+ " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
265
+ " conv(64, 10, act=False),\n",
 
 
 
 
 
 
 
 
 
 
266
  " nn.Flatten(),\n",
267
+ " )\n",
268
+ "\n",
269
+ "\n",
270
+ "# def cnn_classifier():\n",
271
+ "# return nn.Sequential(\n",
272
+ "# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),\n",
273
+ "# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
274
+ "# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
275
+ "# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),\n",
276
+ "# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),\n",
277
+ "# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),\n",
278
+ "# conv(256, 10, act=False),\n",
279
+ "# nn.Flatten(),\n",
280
+ "# )"
281
  ]
282
  },
283
  {
284
  "cell_type": "code",
285
+ "execution_count": 93,
286
  "metadata": {},
287
  "outputs": [],
288
  "source": [
 
293
  },
294
  {
295
  "cell_type": "code",
296
+ "execution_count": 94,
297
  "metadata": {},
298
  "outputs": [
299
  {
300
  "name": "stdout",
301
  "output_type": "stream",
302
  "text": [
303
+ "train, epoch:1, loss: 0.0827, accuracy: 0.9102\n",
304
+ "eval, epoch:1, loss: 0.0448, accuracy: 0.9817\n",
305
+ "train, epoch:2, loss: 0.0382, accuracy: 0.9835\n",
306
+ "eval, epoch:2, loss: 0.0353, accuracy: 0.9863\n",
307
+ "train, epoch:3, loss: 0.0499, accuracy: 0.9856\n",
308
+ "eval, epoch:3, loss: 0.0300, accuracy: 0.9867\n",
309
+ "train, epoch:4, loss: 0.0361, accuracy: 0.9869\n",
310
+ "eval, epoch:4, loss: 0.0203, accuracy: 0.9877\n",
311
+ "train, epoch:5, loss: 0.0427, accuracy: 0.9846\n",
312
+ "eval, epoch:5, loss: 0.0250, accuracy: 0.9866\n"
313
  ]
314
  }
315
  ],
 
321
  "epochs = 5\n",
322
  "opt = optim.AdamW(model.parameters(), lr=lr)\n",
323
  "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
 
324
  "for epoch in range(epochs):\n",
325
  " for train in (True, False):\n",
326
  " accuracy = 0\n",
 
337
  " if train:\n",
338
  " sched.step()\n",
339
  " accuracy /= len(dl)\n",
340
+ " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")"
 
341
  ]
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": 95,
346
  "metadata": {
347
  "tags": [
348
  "exclude"
 
350
  },
351
  "outputs": [],
352
  "source": [
353
+ "# with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
354
+ "# pickle.dump(model, model_file)"
355
  ]
356
  },
357
  {
 
367
  },
368
  {
369
  "cell_type": "code",
370
+ "execution_count": 96,
371
  "metadata": {
372
  "tags": [
373
  "exclude"
 
378
  "name": "stdout",
379
  "output_type": "stream",
380
  "text": [
381
+ "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
382
+ "[NbConvertApp] Writing 5934 bytes to mnist_classifier.py\n"
383
  ]
384
  }
385
  ],
386
  "source": [
387
+ "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
 
388
  ]
389
  },
 
 
 
 
 
 
 
390
  {
391
  "cell_type": "code",
392
  "execution_count": null,
mnist_classifier.py CHANGED
@@ -33,7 +33,7 @@ bs = 1024
33
  class DataLoaders:
34
  def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
35
  self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
36
- self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)
37
 
38
  def collate_fn(b):
39
  collate = default_collate(b)
@@ -91,29 +91,72 @@ for epoch in range(epochs):
91
  print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def cnn_classifier():
95
- ks,stride = 3,2
96
  return nn.Sequential(
97
- nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),
98
- nn.BatchNorm2d(8),
99
- nn.ReLU(),
100
- nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
101
- nn.BatchNorm2d(16),
102
- nn.ReLU(),
103
- nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
104
- nn.BatchNorm2d(32),
105
- nn.ReLU(),
106
- nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),
107
- nn.BatchNorm2d(64),
108
- nn.ReLU(),
109
- nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
110
- nn.BatchNorm2d(64),
111
- nn.ReLU(),
112
- nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),
113
  nn.Flatten(),
114
  )
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def kaiming_init(m):
118
  if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
119
  nn.init.kaiming_normal_(m.weight)
@@ -126,7 +169,6 @@ max_lr = 0.3
126
  epochs = 5
127
  opt = optim.AdamW(model.parameters(), lr=lr)
128
  sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
129
-
130
  for epoch in range(epochs):
131
  for train in (True, False):
132
  accuracy = 0
@@ -144,10 +186,6 @@ for epoch in range(epochs):
144
  sched.step()
145
  accuracy /= len(dl)
146
  print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
147
-
148
-
149
-
150
-
151
 
152
 
153
 
 
33
  class DataLoaders:
34
  def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
35
  self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
36
+ self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)
37
 
38
  def collate_fn(b):
39
  collate = default_collate(b)
 
91
  print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
92
 
93
 
94
+ # def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):
95
+ # return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),
96
+ # conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))
97
+
98
+ # class ResBlock(nn.Module):
99
+ # def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):
100
+ # super().__init__()
101
+ # self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)
102
+ # self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)
103
+ # self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
104
+ # self.act = act()
105
+
106
+ # def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))
107
+
108
+
109
+ def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):
110
+ layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]
111
+ if norm:
112
+ layers.append(norm)
113
+ if act:
114
+ layers.append(act())
115
+ return nn.Sequential(*layers)
116
+
117
+ def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):
118
+ return nn.Sequential(
119
+ conv(ni, nf, ks=ks, s=1, norm=norm, act=act),
120
+ conv(nf, nf, ks=ks, s=s, norm=norm, act=act),
121
+ )
122
+
123
+ class ResBlock(nn.Module):
124
+ def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):
125
+ super().__init__()
126
+ self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)
127
+ self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)
128
+ self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)
129
+ self.act = act()
130
+
131
+ def forward(self, x):
132
+ return self.act(self.convs(x) + self.idconv(self.pool(x)))
133
+
134
+
135
  def cnn_classifier():
 
136
  return nn.Sequential(
137
+ ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
138
+ ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
139
+ ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
140
+ ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
141
+ ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
142
+ conv(64, 10, act=False),
 
 
 
 
 
 
 
 
 
 
143
  nn.Flatten(),
144
  )
145
 
146
 
147
+ # def cnn_classifier():
148
+ # return nn.Sequential(
149
+ # ResBlock(1, 16, norm=nn.BatchNorm2d(16)),
150
+ # ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
151
+ # ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
152
+ # ResBlock(64, 128, norm=nn.BatchNorm2d(128)),
153
+ # ResBlock(128, 256, norm=nn.BatchNorm2d(256)),
154
+ # ResBlock(256, 256, norm=nn.BatchNorm2d(256)),
155
+ # conv(256, 10, act=False),
156
+ # nn.Flatten(),
157
+ # )
158
+
159
+
160
  def kaiming_init(m):
161
  if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
162
  nn.init.kaiming_normal_(m.weight)
 
169
  epochs = 5
170
  opt = optim.AdamW(model.parameters(), lr=lr)
171
  sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
 
172
  for epoch in range(epochs):
173
  for train in (True, False):
174
  accuracy = 0
 
186
  sched.step()
187
  accuracy /= len(dl)
188
  print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
 
 
 
 
189
 
190
 
191