CircleStar commited on
Commit
ba2a7fa
·
verified ·
1 Parent(s): 958eb86

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +20 -4
model.py CHANGED
@@ -6,9 +6,9 @@ class ResNet18Classifier(nn.Module):
6
  def __init__(
7
  self,
8
  num_classes: int,
9
- dropout: float = 0.5,
10
  fc_dim: int = 256,
11
- freeze_backbone: bool = True,
12
  ):
13
  super().__init__()
14
 
@@ -17,9 +17,24 @@ class ResNet18Classifier(nn.Module):
17
 
18
  in_features = self.backbone.fc.in_features
19
 
20
- if freeze_backbone:
 
 
 
 
 
 
 
 
 
 
 
 
21
  for param in self.backbone.parameters():
22
- param.requires_grad = False
 
 
 
23
 
24
  self.backbone.fc = nn.Sequential(
25
  nn.Dropout(dropout),
@@ -29,6 +44,7 @@ class ResNet18Classifier(nn.Module):
29
  nn.Linear(fc_dim, num_classes),
30
  )
31
 
 
32
  for param in self.backbone.fc.parameters():
33
  param.requires_grad = True
34
 
 
6
  def __init__(
7
  self,
8
  num_classes: int,
9
+ dropout: float = 0.4,
10
  fc_dim: int = 256,
11
+ fine_tune_mode: str = "layer4",
12
  ):
13
  super().__init__()
14
 
 
17
 
18
  in_features = self.backbone.fc.in_features
19
 
20
+ # Freeze everything first
21
+ for param in self.backbone.parameters():
22
+ param.requires_grad = False
23
+
24
+ # Fine-tuning strategy
25
+ if fine_tune_mode == "frozen":
26
+ pass
27
+
28
+ elif fine_tune_mode == "layer4":
29
+ for param in self.backbone.layer4.parameters():
30
+ param.requires_grad = True
31
+
32
+ elif fine_tune_mode == "full":
33
  for param in self.backbone.parameters():
34
+ param.requires_grad = True
35
+
36
+ else:
37
+ raise ValueError(f"Unsupported fine_tune_mode: {fine_tune_mode}")
38
 
39
  self.backbone.fc = nn.Sequential(
40
  nn.Dropout(dropout),
 
44
  nn.Linear(fc_dim, num_classes),
45
  )
46
 
47
+ # Always train classifier head
48
  for param in self.backbone.fc.parameters():
49
  param.requires_grad = True
50