carlfeynman commited on
Commit
5993d2f
β€’
1 Parent(s): 53075d2

cnn classifier added, accuracy is 98%

Browse files
Files changed (5) hide show
  1. cnn_classifier.pkl +0 -0
  2. linear_classifier.pkl +0 -0
  3. mlp_classifier.pkl +0 -0
  4. mnist.ipynb +115 -34
  5. mnist.py +55 -17
cnn_classifier.pkl ADDED
Binary file (286 kB). View file
 
linear_classifier.pkl DELETED
Binary file (173 kB)
 
mlp_classifier.pkl ADDED
Binary file (173 kB). View file
 
mnist.ipynb CHANGED
@@ -31,7 +31,7 @@
31
  "output_type": "stream",
32
  "text": [
33
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
34
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 65.21it/s]\n"
35
  ]
36
  }
37
  ],
@@ -117,72 +117,153 @@
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": 6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
124
  "def cnn_classifier():\n",
125
  " ks,stride = 3,2\n",
126
  " return nn.Sequential(\n",
127
- " nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n",
128
- " nn.ReLU(),\n",
129
- " nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
130
  " nn.ReLU(),\n",
131
  " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
 
132
  " nn.ReLU(),\n",
133
  " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
 
 
 
 
134
  " nn.ReLU(),\n",
135
- " nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
 
136
  " nn.ReLU(),\n",
137
- " nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
138
  " nn.Flatten(),\n",
139
  " )"
140
  ]
141
  },
142
  {
143
  "cell_type": "code",
144
- "execution_count": 7,
145
  "metadata": {},
146
  "outputs": [],
147
  "source": [
148
- "# model definition\n",
149
- "def linear_classifier():\n",
150
- " return nn.Sequential(\n",
151
- " Reshape((-1, 784)),\n",
152
- " nn.Linear(784, 50),\n",
153
- " nn.ReLU(),\n",
154
- " nn.Linear(50, 50),\n",
155
- " nn.ReLU(),\n",
156
- " nn.Linear(50, 10)\n",
157
- " )"
158
  ]
159
  },
160
  {
161
  "cell_type": "code",
162
- "execution_count": 8,
163
  "metadata": {},
164
  "outputs": [
165
  {
166
  "name": "stdout",
167
  "output_type": "stream",
168
  "text": [
169
- "train, epoch:1, loss: 0.2638, accuracy: 0.8032\n",
170
- "eval, epoch:1, loss: 0.2929, accuracy: 0.9011\n",
171
- "train, epoch:2, loss: 0.2497, accuracy: 0.9180\n",
172
- "eval, epoch:2, loss: 0.2317, accuracy: 0.9312\n",
173
- "train, epoch:3, loss: 0.1817, accuracy: 0.9391\n",
174
- "eval, epoch:3, loss: 0.1751, accuracy: 0.9496\n",
175
- "train, epoch:4, loss: 0.1589, accuracy: 0.9518\n",
176
- "eval, epoch:4, loss: 0.1630, accuracy: 0.9638\n",
177
- "train, epoch:5, loss: 0.1498, accuracy: 0.9603\n",
178
- "eval, epoch:5, loss: 0.1425, accuracy: 0.9655\n"
179
  ]
180
  }
181
  ],
182
  "source": [
183
- "model = linear_classifier()\n",
 
184
  "lr = 0.1\n",
185
- "max_lr = 0.1\n",
186
  "epochs = 5\n",
187
  "opt = optim.AdamW(model.parameters(), lr=lr)\n",
188
  "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
@@ -209,7 +290,7 @@
209
  },
210
  {
211
  "cell_type": "code",
212
- "execution_count": 31,
213
  "metadata": {
214
  "tags": [
215
  "exclude"
@@ -217,8 +298,8 @@
217
  },
218
  "outputs": [],
219
  "source": [
220
- "# with open('./linear_classifier.pkl', 'wb') as model_file:\n",
221
- "# pickle.dump(model, model_file)"
222
  ]
223
  },
224
  {
 
31
  "output_type": "stream",
32
  "text": [
33
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
34
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 69.76it/s]\n"
35
  ]
36
  }
37
  ],
 
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": 43,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "# model definition\n",
125
+ "def linear_classifier():\n",
126
+ " return nn.Sequential(\n",
127
+ " Reshape((-1, 784)),\n",
128
+ " nn.Linear(784, 50),\n",
129
+ " nn.ReLU(),\n",
130
+ " nn.Linear(50, 50),\n",
131
+ " nn.ReLU(),\n",
132
+ " nn.Linear(50, 10)\n",
133
+ " )"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 44,
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "train, epoch:1, loss: 0.2640, accuracy: 0.7885\n",
146
+ "eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n",
147
+ "train, epoch:2, loss: 0.2368, accuracy: 0.9182\n",
148
+ "eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n",
149
+ "train, epoch:3, loss: 0.1951, accuracy: 0.9402\n",
150
+ "eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n",
151
+ "train, epoch:4, loss: 0.1511, accuracy: 0.9513\n",
152
+ "eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n",
153
+ "train, epoch:5, loss: 0.1182, accuracy: 0.9567\n",
154
+ "eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n"
155
+ ]
156
+ }
157
+ ],
158
+ "source": [
159
+ "model = linear_classifier()\n",
160
+ "lr = 0.1\n",
161
+ "max_lr = 0.1\n",
162
+ "epochs = 5\n",
163
+ "opt = optim.AdamW(model.parameters(), lr=lr)\n",
164
+ "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
165
+ "\n",
166
+ "for epoch in range(epochs):\n",
167
+ " for train in (True, False):\n",
168
+ " accuracy = 0\n",
169
+ " dl = dls.train if train else dls.valid\n",
170
+ " for xb,yb in dl:\n",
171
+ " preds = model(xb)\n",
172
+ " loss = F.cross_entropy(preds, yb)\n",
173
+ " if train:\n",
174
+ " loss.backward()\n",
175
+ " opt.step()\n",
176
+ " opt.zero_grad()\n",
177
+ " with torch.no_grad():\n",
178
+ " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
179
+ " if train:\n",
180
+ " sched.step()\n",
181
+ " accuracy /= len(dl)\n",
182
+ " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
183
+ " "
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 46,
189
+ "metadata": {
190
+ "tags": [
191
+ "exclude"
192
+ ]
193
+ },
194
+ "outputs": [],
195
+ "source": [
196
+ "with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
197
+ " pickle.dump(model, model_file)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 35,
203
  "metadata": {},
204
  "outputs": [],
205
  "source": [
206
  "def cnn_classifier():\n",
207
  " ks,stride = 3,2\n",
208
  " return nn.Sequential(\n",
209
+ " nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
210
+ " nn.BatchNorm2d(8),\n",
 
211
  " nn.ReLU(),\n",
212
  " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
213
+ " nn.BatchNorm2d(16),\n",
214
  " nn.ReLU(),\n",
215
  " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
216
+ " nn.BatchNorm2d(32),\n",
217
+ " nn.ReLU(),\n",
218
+ " nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
219
+ " nn.BatchNorm2d(64),\n",
220
  " nn.ReLU(),\n",
221
+ " nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
222
+ " nn.BatchNorm2d(64),\n",
223
  " nn.ReLU(),\n",
224
+ " nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
225
  " nn.Flatten(),\n",
226
  " )"
227
  ]
228
  },
229
  {
230
  "cell_type": "code",
231
+ "execution_count": 36,
232
  "metadata": {},
233
  "outputs": [],
234
  "source": [
235
+ "def kaiming_init(m):\n",
236
+ " if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n",
237
+ " nn.init.kaiming_normal_(m.weight)"
 
 
 
 
 
 
 
238
  ]
239
  },
240
  {
241
  "cell_type": "code",
242
+ "execution_count": 37,
243
  "metadata": {},
244
  "outputs": [
245
  {
246
  "name": "stdout",
247
  "output_type": "stream",
248
  "text": [
249
+ "train, epoch:1, loss: 0.1096, accuracy: 0.9145\n",
250
+ "eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n",
251
+ "train, epoch:2, loss: 0.0487, accuracy: 0.9808\n",
252
+ "eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n",
253
+ "train, epoch:3, loss: 0.0536, accuracy: 0.9840\n",
254
+ "eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n",
255
+ "train, epoch:4, loss: 0.0358, accuracy: 0.9842\n",
256
+ "eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n",
257
+ "train, epoch:5, loss: 0.0514, accuracy: 0.9852\n",
258
+ "eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n"
259
  ]
260
  }
261
  ],
262
  "source": [
263
+ "model = cnn_classifier()\n",
264
+ "model.apply(kaiming_init)\n",
265
  "lr = 0.1\n",
266
+ "max_lr = 0.3\n",
267
  "epochs = 5\n",
268
  "opt = optim.AdamW(model.parameters(), lr=lr)\n",
269
  "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
 
290
  },
291
  {
292
  "cell_type": "code",
293
+ "execution_count": 41,
294
  "metadata": {
295
  "tags": [
296
  "exclude"
 
298
  },
299
  "outputs": [],
300
  "source": [
301
+ "with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
302
+ " pickle.dump(model, model_file)"
303
  ]
304
  },
305
  {
mnist.py CHANGED
@@ -53,39 +53,77 @@ class Reshape(nn.Module):
53
  return x.reshape(self.dim)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def cnn_classifier():
57
  ks,stride = 3,2
58
  return nn.Sequential(
59
- nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),
60
- nn.ReLU(),
61
- nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),
62
  nn.ReLU(),
63
  nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
 
64
  nn.ReLU(),
65
  nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
 
66
  nn.ReLU(),
67
- nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),
 
68
  nn.ReLU(),
69
- nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),
 
 
 
70
  nn.Flatten(),
71
  )
72
 
73
 
74
- # model definition
75
- def linear_classifier():
76
- return nn.Sequential(
77
- Reshape((-1, 784)),
78
- nn.Linear(784, 50),
79
- nn.ReLU(),
80
- nn.Linear(50, 50),
81
- nn.ReLU(),
82
- nn.Linear(50, 10)
83
- )
84
 
85
 
86
- model = linear_classifier()
 
87
  lr = 0.1
88
- max_lr = 0.1
89
  epochs = 5
90
  opt = optim.AdamW(model.parameters(), lr=lr)
91
  sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
 
53
  return x.reshape(self.dim)
54
 
55
 
56
+ # model definition
57
+ def linear_classifier():
58
+ return nn.Sequential(
59
+ Reshape((-1, 784)),
60
+ nn.Linear(784, 50),
61
+ nn.ReLU(),
62
+ nn.Linear(50, 50),
63
+ nn.ReLU(),
64
+ nn.Linear(50, 10)
65
+ )
66
+
67
+
68
+ model = linear_classifier()
69
+ lr = 0.1
70
+ max_lr = 0.1
71
+ epochs = 5
72
+ opt = optim.AdamW(model.parameters(), lr=lr)
73
+ sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
74
+
75
+ for epoch in range(epochs):
76
+ for train in (True, False):
77
+ accuracy = 0
78
+ dl = dls.train if train else dls.valid
79
+ for xb,yb in dl:
80
+ preds = model(xb)
81
+ loss = F.cross_entropy(preds, yb)
82
+ if train:
83
+ loss.backward()
84
+ opt.step()
85
+ opt.zero_grad()
86
+ with torch.no_grad():
87
+ accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
88
+ if train:
89
+ sched.step()
90
+ accuracy /= len(dl)
91
+ print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
92
+
93
+
94
+
95
  def cnn_classifier():
96
  ks,stride = 3,2
97
  return nn.Sequential(
98
+ nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),
99
+ nn.BatchNorm2d(8),
 
100
  nn.ReLU(),
101
  nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
102
+ nn.BatchNorm2d(16),
103
  nn.ReLU(),
104
  nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
105
+ nn.BatchNorm2d(32),
106
  nn.ReLU(),
107
+ nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),
108
+ nn.BatchNorm2d(64),
109
  nn.ReLU(),
110
+ nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
111
+ nn.BatchNorm2d(64),
112
+ nn.ReLU(),
113
+ nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),
114
  nn.Flatten(),
115
  )
116
 
117
 
118
+ def kaiming_init(m):
119
+ if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
120
+ nn.init.kaiming_normal_(m.weight)
 
 
 
 
 
 
 
121
 
122
 
123
+ model = cnn_classifier()
124
+ model.apply(kaiming_init)
125
  lr = 0.1
126
+ max_lr = 0.3
127
  epochs = 5
128
  opt = optim.AdamW(model.parameters(), lr=lr)
129
  sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)