mymnist / Model.py
jiew's picture
Upload 5 files
d852849
import torch
class MNIST(torch.nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2, 2))
self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.2),
torch.nn.Linear(1024, 10))
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 14 * 14 * 64)
x = self.dense(x)
return x