Gosula commited on
Commit
6220ea2
1 Parent(s): 6ef77e0

Create model.py (#4)

Browse files

- Create model.py (62a1d5783ec19601f57b8df6df214692b4796c06)

Files changed (1) hide show
  1. model.py +20 -0
model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Cnn(nn.Module):
2
+ def __init__(self, dropout=0.5):
3
+ super(Cnn, self).__init__()
4
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
5
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
6
+ self.conv2_drop = nn.Dropout2d(p=dropout)
7
+ self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height
8
+ self.fc2 = nn.Linear(100, 10)
9
+ self.fc1_drop = nn.Dropout(p=dropout)
10
+
11
+ def forward(self, x):
12
+ x = torch.relu(F.max_pool2d(self.conv1(x), 2))
13
+ x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
14
+
15
+ # flatten over channel, height and width = 1600
16
+ x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
17
+
18
+ x = torch.relu(self.fc1_drop(self.fc1(x)))
19
+ x = torch.softmax(self.fc2(x), dim=-1)
20
+ return x