mingli commited on
Commit
ae37ff6
1 Parent(s): cba105a

Create mnist.py

Browse files
Files changed (1) hide show
  1. mnist.py +233 -0
mnist.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import lightning as ltn
5
+ import argparse
6
+ import lightning.pytorch as pl
7
+
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
11
+
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("-n", "--n_epochs", type=int, default=200, help="number of epochs of training")
15
+ parser.add_argument("-b", "--batch", type=int, default=256, help="batch size of training")
16
+ parser.add_argument("-m", "--model", type=str, default='mnist0', help="model to execute")
17
+ opt = parser.parse_args()
18
+
19
+ if th.cuda.is_available():
20
+ accelerator = 'gpu'
21
+ elif th.backends.mps.is_available():
22
+ accelerator = 'cpu'
23
+ else:
24
+ accelerator = 'cpu'
25
+
26
+
27
+ class OptAEGV1(nn.Module):
28
+
29
+ def __init__(self, points=11):
30
+ super().__init__()
31
+ self.points = points
32
+ self.iscale = nn.Parameter(th.normal(0, 1, (1, 1, 1, 1)))
33
+ self.oscale = nn.Parameter(th.normal(0, 1, (1, 1, 1, 1)))
34
+ self.theta = th.linspace(-th.pi, th.pi, points)
35
+ self.velocity = th.linspace(0, th.e, points)
36
+ self.weight = nn.Parameter(th.normal(0, 1, (points, points)))
37
+
38
+ @th.compile
39
+ def integral(self, param, index):
40
+ return th.sum(param[index].view(-1, 1) * th.softmax(self.weight, dim=1)[index, :], dim=1)
41
+
42
+ @th.compile
43
+ def interplot(self, param, index):
44
+ lmt = param.size(0) - 1
45
+
46
+ p0 = index.floor().long()
47
+ p1 = p0 + 1
48
+ pos = index - p0
49
+ p0 = p0.clamp(0, lmt)
50
+ p1 = p1.clamp(0, lmt)
51
+
52
+ v0 = self.integral(param, p0)
53
+ v1 = self.integral(param, p1)
54
+
55
+ return (1 - pos) * v0 + pos * v1
56
+
57
+ @th.compile
58
+ def forward(self, data: Tensor) -> Tensor:
59
+ if self.theta.device != data.device:
60
+ self.theta = self.theta.to(data.device)
61
+ self.velocity = self.velocity.to(data.device)
62
+ shape = data.size()
63
+ data = (data - data.mean()) / data.std() * self.iscale
64
+ data = data.flatten(0)
65
+
66
+ theta = self.interplot(self.theta, th.sigmoid(data) * (self.points - 1))
67
+ ds = self.interplot(self.velocity, th.abs(th.tanh(data) * (self.points - 1)))
68
+
69
+ dx = ds * th.cos(theta)
70
+ dy = ds * th.sin(theta)
71
+ data = data * th.exp(dy) + dx
72
+
73
+ data = (data - data.mean()) / data.std() * self.oscale
74
+ return data.view(*shape)
75
+
76
+
77
+ class MNISTModel(ltn.LightningModule):
78
+ def __init__(self):
79
+ super().__init__()
80
+ self.learning_rate = 1e-3
81
+ self.counter = 0
82
+ self.labeled_loss = 0
83
+ self.labeled_correct = 0
84
+
85
+ def configure_optimizers(self):
86
+ optimizer = th.optim.Adam(self.parameters(), lr=self.learning_rate)
87
+ scheduler = th.optim.lr_scheduler.CosineAnnealingLR(optimizer, 37)
88
+ return [optimizer], [scheduler]
89
+
90
+ def training_step(self, train_batch, batch_idx):
91
+ x, y = train_batch
92
+ x = x.view(-1, 1, 28, 28)
93
+ z = self.forward(x)
94
+ loss = F.nll_loss(z, y)
95
+
96
+ self.log('train_loss', loss, prog_bar=True)
97
+ return loss
98
+
99
+ def validation_step(self, val_batch, batch_idx):
100
+ x, y = val_batch
101
+ x = x.view(-1, 1, 28, 28)
102
+
103
+ z = self.forward(x)
104
+ loss = F.nll_loss(z, y)
105
+ self.log('val_loss', loss, prog_bar=True)
106
+
107
+ pred = z.data.max(1, keepdim=True)[1]
108
+ correct = pred.eq(y.data.view_as(pred)).sum() / y.size()[0]
109
+ self.log('correct_rate', correct, prog_bar=True)
110
+
111
+ self.labeled_loss += loss.item() * y.size()[0]
112
+ self.labeled_correct += correct.item() * y.size()[0]
113
+ self.counter += y.size()[0]
114
+
115
+ def test_step(self, test_batch, batch_idx):
116
+ x, y = test_batch
117
+ x = x.view(-1, 1, 28, 28)
118
+ z = self(x)
119
+
120
+ pred = z.data.max(1, keepdim=True)[1]
121
+ correct = pred.eq(y.data.view_as(pred)).sum() / y.size()[0]
122
+ self.log('correct_rate', correct, prog_bar=True)
123
+
124
+ def on_save_checkpoint(self, checkpoint) -> None:
125
+ import glob, os
126
+
127
+ correct = self.labeled_correct / self.counter
128
+ loss = self.labeled_loss / self.counter
129
+ record = '%2.5f-%03d-%1.5f.ckpt' % (correct, checkpoint['epoch'], loss)
130
+ fname = 'best-%s' % record
131
+ with open(fname, 'bw') as f:
132
+ th.save(checkpoint, f)
133
+ for ix, ckpt in enumerate(sorted(glob.glob('best-*.ckpt'), reverse=True)):
134
+ if ix > 5:
135
+ os.unlink(ckpt)
136
+
137
+ self.counter = 0
138
+ self.labeled_loss = 0
139
+ self.labeled_correct = 0
140
+
141
+ print()
142
+
143
+
144
+ class MNIST_OptAEGV1(MNISTModel):
145
+ def __init__(self):
146
+ super().__init__()
147
+ self.pool = nn.MaxPool2d(2)
148
+ self.conv0 = nn.Conv2d(1, 2, kernel_size=7, padding=3, bias=False)
149
+ self.lnon0 = OptAEGV1()
150
+ self.conv1 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
151
+ self.lnon1 = OptAEGV1()
152
+ self.conv2 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
153
+ self.lnon2 = OptAEGV1()
154
+ self.conv3 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
155
+ self.lnon3 = OptAEGV1()
156
+ self.fc = nn.Linear(2 * 3 * 3, 10)
157
+ self.lnon4 = OptAEGV1()
158
+
159
+ def forward(self, x):
160
+ x = self.conv0(x)
161
+ x = self.lnon0(x)
162
+ x = self.pool(x)
163
+ x = self.conv1(x)
164
+ x = self.lnon1(x)
165
+ x = self.pool(x)
166
+ x = self.conv2(x)
167
+ x = self.lnon2(x)
168
+ x = self.pool(x)
169
+ x = th.flatten(x, 1)
170
+ x = self.fc(x)
171
+ x = self.lnon4(x)
172
+ x = F.log_softmax(x, dim=1)
173
+ return x
174
+
175
+
176
+ def test_best():
177
+ import glob
178
+ fname = sorted(glob.glob('best-*.ckpt'), reverse=True)[0]
179
+ with open(fname, 'rb') as f:
180
+ checkpoint = th.load(f)
181
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
182
+ model.eval()
183
+
184
+ with th.no_grad():
185
+ counter, success = 0, 0
186
+ for test_batch in test_loader:
187
+ x, y = test_batch
188
+ x = x.view(-1, 1, 28, 28)
189
+ z = model(x)
190
+ pred = z.data.max(1, keepdim=True)[1]
191
+ correct = pred.eq(y.data.view_as(pred)).sum() / y.size()[0]
192
+ print('.', end='', flush=True)
193
+ if counter % 100 == 0:
194
+ print('')
195
+ success += correct.item()
196
+ counter += 1
197
+ print('')
198
+ print('Accuracy: %2.5f' % (success / counter))
199
+ th.save(model, 'mnist-optaeg-v1.pt')
200
+
201
+
202
+ if __name__ == '__main__':
203
+
204
+ print('loading data...')
205
+ from torch.utils.data import DataLoader
206
+ from torchvision.datasets import MNIST
207
+ from torchvision import transforms
208
+
209
+ mnist_train = MNIST('datasets', train=True, download=True, transform=transforms.Compose([
210
+ transforms.ToTensor(),
211
+ ]))
212
+
213
+ mnist_test = MNIST('datasets', train=False, download=True, transform=transforms.Compose([
214
+ transforms.ToTensor(),
215
+ ]))
216
+
217
+ train_loader = DataLoader(mnist_train, shuffle=True, batch_size=opt.batch, num_workers=8)
218
+ val_loader = DataLoader(mnist_test, batch_size=opt.batch, num_workers=8)
219
+ test_loader = DataLoader(mnist_test, batch_size=opt.batch, num_workers=8)
220
+
221
+ # training
222
+ print('construct trainer...')
223
+ trainer = pl.Trainer(accelerator=accelerator, precision=32, max_epochs=opt.n_epochs,
224
+ callbacks=[EarlyStopping(monitor="correct_rate", mode="max", patience=30)])
225
+
226
+ print('construct model...')
227
+ model = MNIST_OptAEGV1()
228
+
229
+ print('training...')
230
+ trainer.fit(model, train_loader, val_loader)
231
+
232
+ print('testing...')
233
+ test_best()