ffcm commited on
Commit
c7da96c
1 Parent(s): 908930d

adds network class back

Browse files
Files changed (1) hide show
  1. networktorch.py +50 -0
networktorch.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class NeuralNetworkTorch(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ self.stack = nn.Sequential(
9
+ nn.Linear(784, 64),
10
+ nn.Sigmoid(),
11
+
12
+ nn.Linear(64, 10),
13
+ nn.Sigmoid()
14
+ )
15
+
16
+ def forward(self, x):
17
+ return self.stack(x)
18
+
19
+
20
+ class ConvNeuralNetworkTorch(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+ self.conv = nn.Sequential(
25
+ nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
26
+ nn.ReLU(),
27
+
28
+ nn.MaxPool2d(kernel_size=2, stride=2),
29
+
30
+ nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
31
+ nn.ReLU(),
32
+
33
+ # nn.MaxPool2d(kernel_size=2, stride=2),
34
+ )
35
+
36
+ self.fc = nn.Sequential(
37
+ nn.Linear(16 * 14 * 14, 10),
38
+ nn.Softmax(dim=1),
39
+ )
40
+
41
+ def forward(self, x):
42
+ # we do some reshaping here simply to avoid making changes to the caller
43
+ # so it continues to work with the fully conected network above
44
+ x = x.reshape(-1, 1, 28, 28) / 255
45
+
46
+ conv_output = self.conv(x)
47
+ flat = conv_output.reshape(len(x), -1)
48
+ final_output = self.fc(flat)
49
+
50
+ return final_output