import torch | |
class MNIST(torch.nn.Module): | |
def __init__(self): | |
super(MNIST, self).__init__() | |
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(32, 64, 3, 1, 1), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(2, 2)) | |
self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024), | |
torch.nn.ReLU(), | |
torch.nn.Dropout(p=0.2), | |
torch.nn.Linear(1024, 10)) | |
def forward(self, x): | |
x = self.conv(x) | |
x = x.view(-1, 14 * 14 * 64) | |
x = self.dense(x) | |
return x |