gyrojeff commited on
Commit
2d6a578
1 Parent(s): 8b9de80

feat: add different resnet models

Browse files
Files changed (1) hide show
  1. detector/model.py +55 -2
detector/model.py CHANGED
@@ -11,7 +11,7 @@ import pytorch_lightning as ptl
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
- def __init__(self, regression_use_tanh: bool=False):
15
  super().__init__()
16
  self.model = torchvision.models.resnet18(weights=False)
17
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
@@ -27,6 +27,57 @@ class ResNet18Regressor(nn.Module):
27
  return X
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class FontDetectorLoss(nn.Module):
31
  def __init__(self, lambda_font, lambda_direction, lambda_regression):
32
  super().__init__()
@@ -134,7 +185,9 @@ class FontDetector(ptl.LightningModule):
134
  def on_train_epoch_end(self) -> None:
135
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
136
  self.log(
137
- "train_direction_accur", self.direction_accur_train.compute(), sync_dist=True
 
 
138
  )
139
  self.font_accur_train.reset()
140
  self.direction_accur_train.reset()
 
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
+ def __init__(self, regression_use_tanh: bool = False):
15
  super().__init__()
16
  self.model = torchvision.models.resnet18(weights=False)
17
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
 
27
  return X
28
 
29
 
30
+ class ResNet34Regressor(nn.Module):
31
+ def __init__(self, regression_use_tanh: bool = False):
32
+ super().__init__()
33
+ self.model = torchvision.models.resnet34(weights=False)
34
+ self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
35
+ self.regression_use_tanh = regression_use_tanh
36
+
37
+ def forward(self, X):
38
+ X = self.model(X)
39
+ # [0, 1]
40
+ if not self.regression_use_tanh:
41
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
42
+ else:
43
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
44
+ return X
45
+
46
+
47
+ class ResNet50Regressor(nn.Module):
48
+ def __init__(self, regression_use_tanh: bool = False):
49
+ super().__init__()
50
+ self.model = torchvision.models.resnet50(weights=False)
51
+ self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
52
+ self.regression_use_tanh = regression_use_tanh
53
+
54
+ def forward(self, X):
55
+ X = self.model(X)
56
+ # [0, 1]
57
+ if not self.regression_use_tanh:
58
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
59
+ else:
60
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
61
+ return X
62
+
63
+
64
+ class ResNet101Regressor(nn.Module):
65
+ def __init__(self, regression_use_tanh: bool = False):
66
+ super().__init__()
67
+ self.model = torchvision.models.resnet101(weights=False)
68
+ self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
69
+ self.regression_use_tanh = regression_use_tanh
70
+
71
+ def forward(self, X):
72
+ X = self.model(X)
73
+ # [0, 1]
74
+ if not self.regression_use_tanh:
75
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
76
+ else:
77
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
78
+ return X
79
+
80
+
81
  class FontDetectorLoss(nn.Module):
82
  def __init__(self, lambda_font, lambda_direction, lambda_regression):
83
  super().__init__()
 
185
  def on_train_epoch_end(self) -> None:
186
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
187
  self.log(
188
+ "train_direction_accur",
189
+ self.direction_accur_train.compute(),
190
+ sync_dist=True,
191
  )
192
  self.font_accur_train.reset()
193
  self.direction_accur_train.reset()