Diego Carpintero commited on
Commit
31cd6a1
1 Parent(s): aeb5dd0
Files changed (2) hide show
  1. model.py +62 -0
  2. model/digit_classifier.pt +2 -2
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
5
+
6
+
7
+ class Linear(nn.Module):
8
+ def __init__(self, in_features: int, out_features: int):
9
+ super(Linear, self).__init__()
10
+ self.in_features = in_features
11
+ self.out_features = out_features
12
+
13
+ self.weight = nn.Parameter(
14
+ (
15
+ torch.randn((self.in_features, self.out_features), device=device) * 0.1
16
+ ).requires_grad_()
17
+ )
18
+ self.bias = nn.Parameter(
19
+ (torch.randn(self.out_features, device=device) * 0.1).requires_grad_()
20
+ )
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return x @ self.weight + self.bias
24
+
25
+
26
+ class ReLU(nn.Module):
27
+ @staticmethod
28
+ def forward(x: torch.Tensor) -> torch.Tensor:
29
+ return torch.max(x, torch.tensor(0))
30
+
31
+
32
+ class Sequential(nn.Module):
33
+ def __init__(self, *layers):
34
+ super(Sequential, self).__init__()
35
+ self.layers = nn.ModuleList(layers)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ for layer in self.layers:
39
+ x = layer(x)
40
+ return x
41
+
42
+
43
+ class Flatten(nn.Module):
44
+ @staticmethod
45
+ def forward(x: torch.Tensor) -> torch.Tensor:
46
+ return x.view(x.size(0), -1)
47
+
48
+
49
+ class DigitClassifier(nn.Module):
50
+ def __init__(self):
51
+ super(DigitClassifier, self).__init__()
52
+ self.main = Sequential(
53
+ Flatten(),
54
+ Linear(in_features=784, out_features=256),
55
+ ReLU(),
56
+ Linear(in_features=256, out_features=64),
57
+ ReLU(),
58
+ Linear(in_features=64, out_features=10),
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ return self.main(x)
model/digit_classifier.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a3b2bdb0b1c1e16b95b0367795b98b77ed61309281cafb9d44956e599af7cab
3
- size 439575
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0d2908a8b36b225cc6cb6eef0f8ef5fbcb660ef79a13b93c27df22082115a48
3
+ size 875543