Gosula's picture
Create model.py
62a1d57
raw
history blame
No virus
802 Bytes
class Cnn(nn.Module):
def __init__(self, dropout=0.5):
super(Cnn, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.conv2_drop = nn.Dropout2d(p=dropout)
self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height
self.fc2 = nn.Linear(100, 10)
self.fc1_drop = nn.Dropout(p=dropout)
def forward(self, x):
x = torch.relu(F.max_pool2d(self.conv1(x), 2))
x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
# flatten over channel, height and width = 1600
x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
x = torch.relu(self.fc1_drop(self.fc1(x)))
x = torch.softmax(self.fc2(x), dim=-1)
return x