mingli commited on
Commit
dcb1e53
1 Parent(s): 202bf36

revise the model to enhance the training stability

Browse files
Files changed (1) hide show
  1. mnist.py +8 -6
mnist.py CHANGED
@@ -11,7 +11,7 @@ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
11
 
12
 
13
  parser = argparse.ArgumentParser()
14
- parser.add_argument("-n", "--n_epochs", type=int, default=200, help="number of epochs of training")
15
  parser.add_argument("-b", "--batch", type=int, default=256, help="batch size of training")
16
  parser.add_argument("-m", "--model", type=str, default='mnist0', help="model to execute")
17
  opt = parser.parse_args()
@@ -147,14 +147,15 @@ class MNIST_OptAEGV1(MNISTModel):
147
  self.pool = nn.MaxPool2d(2)
148
  self.conv0 = nn.Conv2d(1, 2, kernel_size=7, padding=3, bias=False)
149
  self.lnon0 = OptAEGV1()
150
- self.conv1 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
151
  self.lnon1 = OptAEGV1()
152
- self.conv2 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
153
  self.lnon2 = OptAEGV1()
154
- self.conv3 = nn.Conv2d(2, 2, kernel_size=7, padding=3, bias=False)
155
  self.lnon3 = OptAEGV1()
156
- self.fc = nn.Linear(2 * 3 * 3, 10)
157
  self.lnon4 = OptAEGV1()
 
158
 
159
  def forward(self, x):
160
  x = self.conv0(x)
@@ -167,8 +168,9 @@ class MNIST_OptAEGV1(MNISTModel):
167
  x = self.lnon2(x)
168
  x = self.pool(x)
169
  x = th.flatten(x, 1)
170
- x = self.fc(x)
171
  x = self.lnon4(x)
 
172
  x = F.log_softmax(x, dim=1)
173
  return x
174
 
 
11
 
12
 
13
  parser = argparse.ArgumentParser()
14
+ parser.add_argument("-n", "--n_epochs", type=int, default=1000, help="number of epochs of training")
15
  parser.add_argument("-b", "--batch", type=int, default=256, help="batch size of training")
16
  parser.add_argument("-m", "--model", type=str, default='mnist0', help="model to execute")
17
  opt = parser.parse_args()
 
147
  self.pool = nn.MaxPool2d(2)
148
  self.conv0 = nn.Conv2d(1, 2, kernel_size=7, padding=3, bias=False)
149
  self.lnon0 = OptAEGV1()
150
+ self.conv1 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
151
  self.lnon1 = OptAEGV1()
152
+ self.conv2 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
153
  self.lnon2 = OptAEGV1()
154
+ self.conv3 = nn.Conv2d(2, 2, kernel_size=5, padding=2)
155
  self.lnon3 = OptAEGV1()
156
+ self.fc1 = nn.Linear(2 * 3 * 3, 10)
157
  self.lnon4 = OptAEGV1()
158
+ self.fc2 = nn.Linear(10, 10, bias=False)
159
 
160
  def forward(self, x):
161
  x = self.conv0(x)
 
168
  x = self.lnon2(x)
169
  x = self.pool(x)
170
  x = th.flatten(x, 1)
171
+ x = self.fc1(x)
172
  x = self.lnon4(x)
173
+ x = self.fc2(x)
174
  x = F.log_softmax(x, dim=1)
175
  return x
176