Soumic commited on
Commit
681d043
·
1 Parent(s): a458b71

:zap: Made some more modifications

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -168,7 +168,7 @@ class TorchMetrics:
168
  self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
169
  pass
170
 
171
- def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
172
  b_accuracy = self.binary_accuracy.compute()
173
  b_auc = self.binary_auc.compute()
174
  b_f1_score = self.binary_f1_score.compute()
@@ -181,13 +181,14 @@ class TorchMetrics:
181
  log(f"{log_prefix}_precision", b_precision)
182
  log(f"{log_prefix}_recall", b_recall)
183
 
 
 
 
184
  self.binary_accuracy.reset()
185
  self.binary_auc.reset()
186
  self.binary_f1_score.reset()
187
  self.binary_precision.reset()
188
  self.binary_recall.reset()
189
- pass
190
-
191
 
192
 
193
  class MQtlBertClassifierLightningModule(LightningModule):
@@ -241,12 +242,14 @@ class MQtlBertClassifierLightningModule(LightningModule):
241
  self.log("train_loss", loss)
242
  # calculate the scores start
243
  self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
 
244
  # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
245
  # calculate the scores end
246
  return loss
247
 
248
  def on_train_epoch_end(self) -> None:
249
- self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
 
250
  pass
251
 
252
  def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -259,13 +262,15 @@ class MQtlBertClassifierLightningModule(LightningModule):
259
  self.log("valid_loss", loss)
260
  # calculate the scores start
261
  self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
 
262
  # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
263
 
264
  # calculate the scores end
265
  return loss
266
 
267
  def on_validation_epoch_end(self) -> None:
268
- self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
 
269
  return None
270
 
271
  def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -276,13 +281,15 @@ class MQtlBertClassifierLightningModule(LightningModule):
276
  self.log("test_loss", loss) # do we need this?
277
  # calculate the scores start
278
  self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
 
279
  # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
280
 
281
  # calculate the scores end
282
  return loss
283
 
284
  def on_test_epoch_end(self) -> None:
285
- self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
 
286
  return None
287
 
288
  pass
 
168
  self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
169
  pass
170
 
171
+ def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green):
172
  b_accuracy = self.binary_accuracy.compute()
173
  b_auc = self.binary_auc.compute()
174
  b_f1_score = self.binary_f1_score.compute()
 
181
  log(f"{log_prefix}_precision", b_precision)
182
  log(f"{log_prefix}_recall", b_recall)
183
 
184
+ pass
185
+
186
+ def reset_on_epoch_end(self):
187
  self.binary_accuracy.reset()
188
  self.binary_auc.reset()
189
  self.binary_f1_score.reset()
190
  self.binary_precision.reset()
191
  self.binary_recall.reset()
 
 
192
 
193
 
194
  class MQtlBertClassifierLightningModule(LightningModule):
 
242
  self.log("train_loss", loss)
243
  # calculate the scores start
244
  self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
245
+ self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
246
  # self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
247
  # calculate the scores end
248
  return loss
249
 
250
  def on_train_epoch_end(self) -> None:
251
+ self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
252
+ self.train_metrics.reset_on_epoch_end()
253
  pass
254
 
255
  def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
 
262
  self.log("valid_loss", loss)
263
  # calculate the scores start
264
  self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
265
+ self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
266
  # self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
267
 
268
  # calculate the scores end
269
  return loss
270
 
271
  def on_validation_epoch_end(self) -> None:
272
+ self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
273
+ self.validate_metrics.reset_on_epoch_end()
274
  return None
275
 
276
  def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
 
281
  self.log("test_loss", loss) # do we need this?
282
  # calculate the scores start
283
  self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
284
+ self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
285
  # self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
286
 
287
  # calculate the scores end
288
  return loss
289
 
290
  def on_test_epoch_end(self) -> None:
291
+ self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
292
+ self.test_metrics.reset_on_epoch_end()
293
  return None
294
 
295
  pass