import torch class MNIST(torch.nn.Module): def __init__(self): super(MNIST, self).__init__() self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1), torch.nn.ReLU(), torch.nn.Conv2d(32, 64, 3, 1, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2)) self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024), torch.nn.ReLU(), torch.nn.Dropout(p=0.2), torch.nn.Linear(1024, 10)) def forward(self, x): x = self.conv(x) x = x.view(-1, 14 * 14 * 64) x = self.dense(x) return x