feat: add different resnet models
Browse files- detector/model.py +55 -2
detector/model.py
CHANGED
@@ -11,7 +11,7 @@ import pytorch_lightning as ptl
|
|
11 |
|
12 |
|
13 |
class ResNet18Regressor(nn.Module):
|
14 |
-
def __init__(self, regression_use_tanh: bool=False):
|
15 |
super().__init__()
|
16 |
self.model = torchvision.models.resnet18(weights=False)
|
17 |
self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
|
@@ -27,6 +27,57 @@ class ResNet18Regressor(nn.Module):
|
|
27 |
return X
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class FontDetectorLoss(nn.Module):
|
31 |
def __init__(self, lambda_font, lambda_direction, lambda_regression):
|
32 |
super().__init__()
|
@@ -134,7 +185,9 @@ class FontDetector(ptl.LightningModule):
|
|
134 |
def on_train_epoch_end(self) -> None:
|
135 |
self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
|
136 |
self.log(
|
137 |
-
"train_direction_accur",
|
|
|
|
|
138 |
)
|
139 |
self.font_accur_train.reset()
|
140 |
self.direction_accur_train.reset()
|
|
|
11 |
|
12 |
|
13 |
class ResNet18Regressor(nn.Module):
|
14 |
+
def __init__(self, regression_use_tanh: bool = False):
|
15 |
super().__init__()
|
16 |
self.model = torchvision.models.resnet18(weights=False)
|
17 |
self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
|
|
|
27 |
return X
|
28 |
|
29 |
|
30 |
+
class ResNet34Regressor(nn.Module):
|
31 |
+
def __init__(self, regression_use_tanh: bool = False):
|
32 |
+
super().__init__()
|
33 |
+
self.model = torchvision.models.resnet34(weights=False)
|
34 |
+
self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
|
35 |
+
self.regression_use_tanh = regression_use_tanh
|
36 |
+
|
37 |
+
def forward(self, X):
|
38 |
+
X = self.model(X)
|
39 |
+
# [0, 1]
|
40 |
+
if not self.regression_use_tanh:
|
41 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
|
42 |
+
else:
|
43 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
|
44 |
+
return X
|
45 |
+
|
46 |
+
|
47 |
+
class ResNet50Regressor(nn.Module):
|
48 |
+
def __init__(self, regression_use_tanh: bool = False):
|
49 |
+
super().__init__()
|
50 |
+
self.model = torchvision.models.resnet50(weights=False)
|
51 |
+
self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
|
52 |
+
self.regression_use_tanh = regression_use_tanh
|
53 |
+
|
54 |
+
def forward(self, X):
|
55 |
+
X = self.model(X)
|
56 |
+
# [0, 1]
|
57 |
+
if not self.regression_use_tanh:
|
58 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
|
59 |
+
else:
|
60 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
|
61 |
+
return X
|
62 |
+
|
63 |
+
|
64 |
+
class ResNet101Regressor(nn.Module):
|
65 |
+
def __init__(self, regression_use_tanh: bool = False):
|
66 |
+
super().__init__()
|
67 |
+
self.model = torchvision.models.resnet101(weights=False)
|
68 |
+
self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
|
69 |
+
self.regression_use_tanh = regression_use_tanh
|
70 |
+
|
71 |
+
def forward(self, X):
|
72 |
+
X = self.model(X)
|
73 |
+
# [0, 1]
|
74 |
+
if not self.regression_use_tanh:
|
75 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
|
76 |
+
else:
|
77 |
+
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
|
78 |
+
return X
|
79 |
+
|
80 |
+
|
81 |
class FontDetectorLoss(nn.Module):
|
82 |
def __init__(self, lambda_font, lambda_direction, lambda_regression):
|
83 |
super().__init__()
|
|
|
185 |
def on_train_epoch_end(self) -> None:
|
186 |
self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
|
187 |
self.log(
|
188 |
+
"train_direction_accur",
|
189 |
+
self.direction_accur_train.compute(),
|
190 |
+
sync_dist=True,
|
191 |
)
|
192 |
self.font_accur_train.reset()
|
193 |
self.direction_accur_train.reset()
|