CircleStar commited on
Commit
01ce719
·
verified ·
1 Parent(s): bb17c8c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +41 -0
model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from config import IMAGE_SIZE
4
+
5
+
6
+ class SimpleCNN(nn.Module):
7
+ def __init__(
8
+ self,
9
+ num_classes: int,
10
+ conv1_channels: int = 16,
11
+ conv2_channels: int = 32,
12
+ kernel_size: int = 3,
13
+ dropout: float = 0.2,
14
+ fc_dim: int = 128,
15
+ ):
16
+ super().__init__()
17
+
18
+ padding = kernel_size // 2
19
+
20
+ self.features = nn.Sequential(
21
+ nn.Conv2d(3, conv1_channels, kernel_size=kernel_size, padding=padding),
22
+ nn.ReLU(),
23
+ nn.MaxPool2d(2),
24
+
25
+ nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
26
+ nn.ReLU(),
27
+ nn.MaxPool2d(2),
28
+ )
29
+
30
+ flattened_dim = conv2_channels * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4)
31
+
32
+ self.classifier = nn.Sequential(
33
+ nn.Flatten(),
34
+ nn.Linear(flattened_dim, fc_dim),
35
+ nn.ReLU(),
36
+ nn.Dropout(dropout),
37
+ nn.Linear(fc_dim, num_classes),
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.classifier(self.features(x))