gyrojeff commited on
Commit
d1e10d9
1 Parent(s): afbe904

fix: update accuracy for classification only

Browse files
Files changed (1) hide show
  1. detector/model.py +48 -36
detector/model.py CHANGED
@@ -150,27 +150,29 @@ class FontDetector(ptl.LightningModule):
150
  self.font_accur_train = torchmetrics.Accuracy(
151
  task="multiclass", num_classes=config.FONT_COUNT
152
  )
153
- self.direction_accur_train = torchmetrics.Accuracy(
154
- task="multiclass", num_classes=2
155
- )
156
  self.font_accur_val = torchmetrics.Accuracy(
157
  task="multiclass", num_classes=config.FONT_COUNT
158
  )
159
- self.direction_accur_val = torchmetrics.Accuracy(
160
- task="multiclass", num_classes=2
161
- )
162
  self.font_accur_test = torchmetrics.Accuracy(
163
  task="multiclass", num_classes=config.FONT_COUNT
164
  )
165
- self.direction_accur_test = torchmetrics.Accuracy(
166
- task="multiclass", num_classes=2
167
- )
 
 
 
 
 
 
 
168
  self.lr = lr
169
  self.betas = betas
170
  self.num_warmup_iters = num_warmup_iters
171
  self.num_iters = num_iters
172
  self.num_epochs = num_epochs
173
  self.load_epoch = 0
 
174
 
175
  def forward(self, x):
176
  return self.model(x)
@@ -188,24 +190,26 @@ class FontDetector(ptl.LightningModule):
188
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
189
  sync_dist=True,
190
  )
191
- self.log(
192
- "train_direction_accur",
193
- self.direction_accur_train(
194
- y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
195
- ),
196
- sync_dist=True,
197
- )
 
198
  return {"loss": loss}
199
 
200
  def on_train_epoch_end(self) -> None:
201
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
202
- self.log(
203
- "train_direction_accur",
204
- self.direction_accur_train.compute(),
205
- sync_dist=True,
206
- )
207
  self.font_accur_train.reset()
208
- self.direction_accur_train.reset()
 
 
 
 
 
 
209
 
210
  def validation_step(
211
  self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
@@ -215,18 +219,22 @@ class FontDetector(ptl.LightningModule):
215
  loss = self.loss(y_hat, y)
216
  self.log("val_loss", loss, prog_bar=True, sync_dist=True)
217
  self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
218
- self.direction_accur_val.update(
219
- y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
220
- )
 
221
  return {"loss": loss}
222
 
223
  def on_validation_epoch_end(self):
224
  self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
225
- self.log(
226
- "val_direction_accur", self.direction_accur_val.compute(), sync_dist=True
227
- )
228
  self.font_accur_val.reset()
229
- self.direction_accur_val.reset()
 
 
 
 
 
 
230
 
231
  def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
232
  X, y = batch
@@ -234,18 +242,22 @@ class FontDetector(ptl.LightningModule):
234
  loss = self.loss(y_hat, y)
235
  self.log("test_loss", loss, prog_bar=True, sync_dist=True)
236
  self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
237
- self.direction_accur_test.update(
238
- y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
239
- )
 
240
  return {"loss": loss}
241
 
242
  def on_test_epoch_end(self) -> None:
243
  self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
244
- self.log(
245
- "test_direction_accur", self.direction_accur_test.compute(), sync_dist=True
246
- )
247
  self.font_accur_test.reset()
248
- self.direction_accur_test.reset()
 
 
 
 
 
 
249
 
250
  def configure_optimizers(self):
251
  optimizer = torch.optim.Adam(
 
150
  self.font_accur_train = torchmetrics.Accuracy(
151
  task="multiclass", num_classes=config.FONT_COUNT
152
  )
 
 
 
153
  self.font_accur_val = torchmetrics.Accuracy(
154
  task="multiclass", num_classes=config.FONT_COUNT
155
  )
 
 
 
156
  self.font_accur_test = torchmetrics.Accuracy(
157
  task="multiclass", num_classes=config.FONT_COUNT
158
  )
159
+ if not font_classification_only:
160
+ self.direction_accur_train = torchmetrics.Accuracy(
161
+ task="multiclass", num_classes=2
162
+ )
163
+ self.direction_accur_val = torchmetrics.Accuracy(
164
+ task="multiclass", num_classes=2
165
+ )
166
+ self.direction_accur_test = torchmetrics.Accuracy(
167
+ task="multiclass", num_classes=2
168
+ )
169
  self.lr = lr
170
  self.betas = betas
171
  self.num_warmup_iters = num_warmup_iters
172
  self.num_iters = num_iters
173
  self.num_epochs = num_epochs
174
  self.load_epoch = 0
175
+ self.font_classification_only = font_classification_only
176
 
177
  def forward(self, x):
178
  return self.model(x)
 
190
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
191
  sync_dist=True,
192
  )
193
+ if not self.font_classification_only:
194
+ self.log(
195
+ "train_direction_accur",
196
+ self.direction_accur_train(
197
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
198
+ ),
199
+ sync_dist=True,
200
+ )
201
  return {"loss": loss}
202
 
203
  def on_train_epoch_end(self) -> None:
204
  self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
 
 
 
 
 
205
  self.font_accur_train.reset()
206
+ if not self.font_classification_only:
207
+ self.log(
208
+ "train_direction_accur",
209
+ self.direction_accur_train.compute(),
210
+ sync_dist=True,
211
+ )
212
+ self.direction_accur_train.reset()
213
 
214
  def validation_step(
215
  self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
 
219
  loss = self.loss(y_hat, y)
220
  self.log("val_loss", loss, prog_bar=True, sync_dist=True)
221
  self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
222
+ if not self.font_classification_only:
223
+ self.direction_accur_val.update(
224
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
225
+ )
226
  return {"loss": loss}
227
 
228
  def on_validation_epoch_end(self):
229
  self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
 
 
 
230
  self.font_accur_val.reset()
231
+ if not self.font_classification_only:
232
+ self.log(
233
+ "val_direction_accur",
234
+ self.direction_accur_val.compute(),
235
+ sync_dist=True,
236
+ )
237
+ self.direction_accur_val.reset()
238
 
239
  def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
240
  X, y = batch
 
242
  loss = self.loss(y_hat, y)
243
  self.log("test_loss", loss, prog_bar=True, sync_dist=True)
244
  self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
245
+ if not self.font_classification_only:
246
+ self.direction_accur_test.update(
247
+ y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
248
+ )
249
  return {"loss": loss}
250
 
251
  def on_test_epoch_end(self) -> None:
252
  self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
 
 
 
253
  self.font_accur_test.reset()
254
+ if not self.font_classification_only:
255
+ self.log(
256
+ "test_direction_accur",
257
+ self.direction_accur_test.compute(),
258
+ sync_dist=True,
259
+ )
260
+ self.direction_accur_test.reset()
261
 
262
  def configure_optimizers(self):
263
  optimizer = torch.optim.Adam(