fix: update accuracy for classification only
Browse files- detector/model.py +48 -36
detector/model.py
CHANGED
@@ -150,27 +150,29 @@ class FontDetector(ptl.LightningModule):
|
|
150 |
self.font_accur_train = torchmetrics.Accuracy(
|
151 |
task="multiclass", num_classes=config.FONT_COUNT
|
152 |
)
|
153 |
-
self.direction_accur_train = torchmetrics.Accuracy(
|
154 |
-
task="multiclass", num_classes=2
|
155 |
-
)
|
156 |
self.font_accur_val = torchmetrics.Accuracy(
|
157 |
task="multiclass", num_classes=config.FONT_COUNT
|
158 |
)
|
159 |
-
self.direction_accur_val = torchmetrics.Accuracy(
|
160 |
-
task="multiclass", num_classes=2
|
161 |
-
)
|
162 |
self.font_accur_test = torchmetrics.Accuracy(
|
163 |
task="multiclass", num_classes=config.FONT_COUNT
|
164 |
)
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
self.lr = lr
|
169 |
self.betas = betas
|
170 |
self.num_warmup_iters = num_warmup_iters
|
171 |
self.num_iters = num_iters
|
172 |
self.num_epochs = num_epochs
|
173 |
self.load_epoch = 0
|
|
|
174 |
|
175 |
def forward(self, x):
|
176 |
return self.model(x)
|
@@ -188,24 +190,26 @@ class FontDetector(ptl.LightningModule):
|
|
188 |
self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
|
189 |
sync_dist=True,
|
190 |
)
|
191 |
-
self.
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
198 |
return {"loss": loss}
|
199 |
|
200 |
def on_train_epoch_end(self) -> None:
|
201 |
self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
|
202 |
-
self.log(
|
203 |
-
"train_direction_accur",
|
204 |
-
self.direction_accur_train.compute(),
|
205 |
-
sync_dist=True,
|
206 |
-
)
|
207 |
self.font_accur_train.reset()
|
208 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
def validation_step(
|
211 |
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
@@ -215,18 +219,22 @@ class FontDetector(ptl.LightningModule):
|
|
215 |
loss = self.loss(y_hat, y)
|
216 |
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
217 |
self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
|
218 |
-
self.
|
219 |
-
|
220 |
-
|
|
|
221 |
return {"loss": loss}
|
222 |
|
223 |
def on_validation_epoch_end(self):
|
224 |
self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
|
225 |
-
self.log(
|
226 |
-
"val_direction_accur", self.direction_accur_val.compute(), sync_dist=True
|
227 |
-
)
|
228 |
self.font_accur_val.reset()
|
229 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
232 |
X, y = batch
|
@@ -234,18 +242,22 @@ class FontDetector(ptl.LightningModule):
|
|
234 |
loss = self.loss(y_hat, y)
|
235 |
self.log("test_loss", loss, prog_bar=True, sync_dist=True)
|
236 |
self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
|
237 |
-
self.
|
238 |
-
|
239 |
-
|
|
|
240 |
return {"loss": loss}
|
241 |
|
242 |
def on_test_epoch_end(self) -> None:
|
243 |
self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
|
244 |
-
self.log(
|
245 |
-
"test_direction_accur", self.direction_accur_test.compute(), sync_dist=True
|
246 |
-
)
|
247 |
self.font_accur_test.reset()
|
248 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
def configure_optimizers(self):
|
251 |
optimizer = torch.optim.Adam(
|
|
|
150 |
self.font_accur_train = torchmetrics.Accuracy(
|
151 |
task="multiclass", num_classes=config.FONT_COUNT
|
152 |
)
|
|
|
|
|
|
|
153 |
self.font_accur_val = torchmetrics.Accuracy(
|
154 |
task="multiclass", num_classes=config.FONT_COUNT
|
155 |
)
|
|
|
|
|
|
|
156 |
self.font_accur_test = torchmetrics.Accuracy(
|
157 |
task="multiclass", num_classes=config.FONT_COUNT
|
158 |
)
|
159 |
+
if not font_classification_only:
|
160 |
+
self.direction_accur_train = torchmetrics.Accuracy(
|
161 |
+
task="multiclass", num_classes=2
|
162 |
+
)
|
163 |
+
self.direction_accur_val = torchmetrics.Accuracy(
|
164 |
+
task="multiclass", num_classes=2
|
165 |
+
)
|
166 |
+
self.direction_accur_test = torchmetrics.Accuracy(
|
167 |
+
task="multiclass", num_classes=2
|
168 |
+
)
|
169 |
self.lr = lr
|
170 |
self.betas = betas
|
171 |
self.num_warmup_iters = num_warmup_iters
|
172 |
self.num_iters = num_iters
|
173 |
self.num_epochs = num_epochs
|
174 |
self.load_epoch = 0
|
175 |
+
self.font_classification_only = font_classification_only
|
176 |
|
177 |
def forward(self, x):
|
178 |
return self.model(x)
|
|
|
190 |
self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
|
191 |
sync_dist=True,
|
192 |
)
|
193 |
+
if not self.font_classification_only:
|
194 |
+
self.log(
|
195 |
+
"train_direction_accur",
|
196 |
+
self.direction_accur_train(
|
197 |
+
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
|
198 |
+
),
|
199 |
+
sync_dist=True,
|
200 |
+
)
|
201 |
return {"loss": loss}
|
202 |
|
203 |
def on_train_epoch_end(self) -> None:
|
204 |
self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
|
|
|
|
|
|
|
|
|
|
|
205 |
self.font_accur_train.reset()
|
206 |
+
if not self.font_classification_only:
|
207 |
+
self.log(
|
208 |
+
"train_direction_accur",
|
209 |
+
self.direction_accur_train.compute(),
|
210 |
+
sync_dist=True,
|
211 |
+
)
|
212 |
+
self.direction_accur_train.reset()
|
213 |
|
214 |
def validation_step(
|
215 |
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
|
|
219 |
loss = self.loss(y_hat, y)
|
220 |
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
|
221 |
self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
|
222 |
+
if not self.font_classification_only:
|
223 |
+
self.direction_accur_val.update(
|
224 |
+
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
|
225 |
+
)
|
226 |
return {"loss": loss}
|
227 |
|
228 |
def on_validation_epoch_end(self):
|
229 |
self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
|
|
|
|
|
|
|
230 |
self.font_accur_val.reset()
|
231 |
+
if not self.font_classification_only:
|
232 |
+
self.log(
|
233 |
+
"val_direction_accur",
|
234 |
+
self.direction_accur_val.compute(),
|
235 |
+
sync_dist=True,
|
236 |
+
)
|
237 |
+
self.direction_accur_val.reset()
|
238 |
|
239 |
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
240 |
X, y = batch
|
|
|
242 |
loss = self.loss(y_hat, y)
|
243 |
self.log("test_loss", loss, prog_bar=True, sync_dist=True)
|
244 |
self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
|
245 |
+
if not self.font_classification_only:
|
246 |
+
self.direction_accur_test.update(
|
247 |
+
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
|
248 |
+
)
|
249 |
return {"loss": loss}
|
250 |
|
251 |
def on_test_epoch_end(self) -> None:
|
252 |
self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
|
|
|
|
|
|
|
253 |
self.font_accur_test.reset()
|
254 |
+
if not self.font_classification_only:
|
255 |
+
self.log(
|
256 |
+
"test_direction_accur",
|
257 |
+
self.direction_accur_test.compute(),
|
258 |
+
sync_dist=True,
|
259 |
+
)
|
260 |
+
self.direction_accur_test.reset()
|
261 |
|
262 |
def configure_optimizers(self):
|
263 |
optimizer = torch.optim.Adam(
|