revise the model to enhance the training stability
Browse files
mnist.py
CHANGED
@@ -11,7 +11,7 @@ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
|
11 |
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
14 |
-
parser.add_argument("-n", "--n_epochs", type=int, default=
|
15 |
parser.add_argument("-b", "--batch", type=int, default=256, help="batch size of training")
|
16 |
parser.add_argument("-m", "--model", type=str, default='mnist0', help="model to execute")
|
17 |
opt = parser.parse_args()
|
@@ -147,14 +147,15 @@ class MNIST_OptAEGV1(MNISTModel):
|
|
147 |
self.pool = nn.MaxPool2d(2)
|
148 |
self.conv0 = nn.Conv2d(1, 2, kernel_size=7, padding=3, bias=False)
|
149 |
self.lnon0 = OptAEGV1()
|
150 |
-
self.conv1 = nn.Conv2d(2, 2, kernel_size=
|
151 |
self.lnon1 = OptAEGV1()
|
152 |
-
self.conv2 = nn.Conv2d(2, 2, kernel_size=
|
153 |
self.lnon2 = OptAEGV1()
|
154 |
-
self.conv3 = nn.Conv2d(2, 2, kernel_size=
|
155 |
self.lnon3 = OptAEGV1()
|
156 |
-
self.
|
157 |
self.lnon4 = OptAEGV1()
|
|
|
158 |
|
159 |
def forward(self, x):
|
160 |
x = self.conv0(x)
|
@@ -167,8 +168,9 @@ class MNIST_OptAEGV1(MNISTModel):
|
|
167 |
x = self.lnon2(x)
|
168 |
x = self.pool(x)
|
169 |
x = th.flatten(x, 1)
|
170 |
-
x = self.
|
171 |
x = self.lnon4(x)
|
|
|
172 |
x = F.log_softmax(x, dim=1)
|
173 |
return x
|
174 |
|
|
|
11 |
|
12 |
|
13 |
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("-n", "--n_epochs", type=int, default=1000, help="number of epochs of training")
|
15 |
parser.add_argument("-b", "--batch", type=int, default=256, help="batch size of training")
|
16 |
parser.add_argument("-m", "--model", type=str, default='mnist0', help="model to execute")
|
17 |
opt = parser.parse_args()
|
|
|
147 |
self.pool = nn.MaxPool2d(2)
|
148 |
self.conv0 = nn.Conv2d(1, 2, kernel_size=7, padding=3, bias=False)
|
149 |
self.lnon0 = OptAEGV1()
|
150 |
+
self.conv1 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
|
151 |
self.lnon1 = OptAEGV1()
|
152 |
+
self.conv2 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
|
153 |
self.lnon2 = OptAEGV1()
|
154 |
+
self.conv3 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
|
155 |
self.lnon3 = OptAEGV1()
|
156 |
+
self.fc1 = nn.Linear(2 * 3 * 3, 10)
|
157 |
self.lnon4 = OptAEGV1()
|
158 |
+
self.fc2 = nn.Linear(10, 10, bias=False)
|
159 |
|
160 |
def forward(self, x):
|
161 |
x = self.conv0(x)
|
|
|
168 |
x = self.lnon2(x)
|
169 |
x = self.pool(x)
|
170 |
x = th.flatten(x, 1)
|
171 |
+
x = self.fc1(x)
|
172 |
x = self.lnon4(x)
|
173 |
+
x = self.fc2(x)
|
174 |
x = F.log_softmax(x, dim=1)
|
175 |
return x
|
176 |
|