CircleStar commited on
Commit
0c1cefc
·
verified ·
1 Parent(s): 60e1829

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -19
model.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch.nn as nn
2
-
3
- from config import IMAGE_SIZE
4
 
5
 
6
  class SimpleCNN(nn.Module):
@@ -15,27 +14,17 @@ class SimpleCNN(nn.Module):
15
  ):
16
  super().__init__()
17
 
18
- padding = kernel_size // 2
19
-
20
- self.features = nn.Sequential(
21
- nn.Conv2d(3, conv1_channels, kernel_size=kernel_size, padding=padding),
22
- nn.ReLU(),
23
- nn.MaxPool2d(2),
24
-
25
- nn.Conv2d(conv1_channels, conv2_channels, kernel_size=kernel_size, padding=padding),
26
- nn.ReLU(),
27
- nn.MaxPool2d(2),
28
- )
29
 
30
- flattened_dim = conv2_channels * (IMAGE_SIZE // 4) * (IMAGE_SIZE // 4)
31
-
32
- self.classifier = nn.Sequential(
33
- nn.Flatten(),
34
- nn.Linear(flattened_dim, fc_dim),
35
  nn.ReLU(),
36
  nn.Dropout(dropout),
37
  nn.Linear(fc_dim, num_classes),
38
  )
39
 
40
  def forward(self, x):
41
- return self.classifier(self.features(x))
 
1
  import torch.nn as nn
2
+ from torchvision import models
 
3
 
4
 
5
  class SimpleCNN(nn.Module):
 
14
  ):
15
  super().__init__()
16
 
17
+ weights = models.ResNet18_Weights.DEFAULT
18
+ self.backbone = models.resnet18(weights=weights)
 
 
 
 
 
 
 
 
 
19
 
20
+ in_features = self.backbone.fc.in_features
21
+ self.backbone.fc = nn.Sequential(
22
+ nn.Dropout(dropout),
23
+ nn.Linear(in_features, fc_dim),
 
24
  nn.ReLU(),
25
  nn.Dropout(dropout),
26
  nn.Linear(fc_dim, num_classes),
27
  )
28
 
29
  def forward(self, x):
30
+ return self.backbone(x)