Spaces:
Runtime error
Runtime error
Soumic
commited on
Commit
·
681d043
1
Parent(s):
a458b71
:zap: Made some more modifications
Browse files
app.py
CHANGED
@@ -168,7 +168,7 @@ class TorchMetrics:
|
|
168 |
self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
|
169 |
pass
|
170 |
|
171 |
-
def
|
172 |
b_accuracy = self.binary_accuracy.compute()
|
173 |
b_auc = self.binary_auc.compute()
|
174 |
b_f1_score = self.binary_f1_score.compute()
|
@@ -181,13 +181,14 @@ class TorchMetrics:
|
|
181 |
log(f"{log_prefix}_precision", b_precision)
|
182 |
log(f"{log_prefix}_recall", b_recall)
|
183 |
|
|
|
|
|
|
|
184 |
self.binary_accuracy.reset()
|
185 |
self.binary_auc.reset()
|
186 |
self.binary_f1_score.reset()
|
187 |
self.binary_precision.reset()
|
188 |
self.binary_recall.reset()
|
189 |
-
pass
|
190 |
-
|
191 |
|
192 |
|
193 |
class MQtlBertClassifierLightningModule(LightningModule):
|
@@ -241,12 +242,14 @@ class MQtlBertClassifierLightningModule(LightningModule):
|
|
241 |
self.log("train_loss", loss)
|
242 |
# calculate the scores start
|
243 |
self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
|
|
244 |
# self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
|
245 |
# calculate the scores end
|
246 |
return loss
|
247 |
|
248 |
def on_train_epoch_end(self) -> None:
|
249 |
-
self.train_metrics.
|
|
|
250 |
pass
|
251 |
|
252 |
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
@@ -259,13 +262,15 @@ class MQtlBertClassifierLightningModule(LightningModule):
|
|
259 |
self.log("valid_loss", loss)
|
260 |
# calculate the scores start
|
261 |
self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
|
|
262 |
# self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
|
263 |
|
264 |
# calculate the scores end
|
265 |
return loss
|
266 |
|
267 |
def on_validation_epoch_end(self) -> None:
|
268 |
-
self.validate_metrics.
|
|
|
269 |
return None
|
270 |
|
271 |
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
@@ -276,13 +281,15 @@ class MQtlBertClassifierLightningModule(LightningModule):
|
|
276 |
self.log("test_loss", loss) # do we need this?
|
277 |
# calculate the scores start
|
278 |
self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
|
|
279 |
# self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
|
280 |
|
281 |
# calculate the scores end
|
282 |
return loss
|
283 |
|
284 |
def on_test_epoch_end(self) -> None:
|
285 |
-
self.test_metrics.
|
|
|
286 |
return None
|
287 |
|
288 |
pass
|
|
|
168 |
self.binary_recall.update(input=batch_predicted_labels, target=batch_actual_labels)
|
169 |
pass
|
170 |
|
171 |
+
def compute_metrics_and_log(self, log, log_prefix: str, log_color: str = green):
|
172 |
b_accuracy = self.binary_accuracy.compute()
|
173 |
b_auc = self.binary_auc.compute()
|
174 |
b_f1_score = self.binary_f1_score.compute()
|
|
|
181 |
log(f"{log_prefix}_precision", b_precision)
|
182 |
log(f"{log_prefix}_recall", b_recall)
|
183 |
|
184 |
+
pass
|
185 |
+
|
186 |
+
def reset_on_epoch_end(self):
|
187 |
self.binary_accuracy.reset()
|
188 |
self.binary_auc.reset()
|
189 |
self.binary_f1_score.reset()
|
190 |
self.binary_precision.reset()
|
191 |
self.binary_recall.reset()
|
|
|
|
|
192 |
|
193 |
|
194 |
class MQtlBertClassifierLightningModule(LightningModule):
|
|
|
242 |
self.log("train_loss", loss)
|
243 |
# calculate the scores start
|
244 |
self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
245 |
+
self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
|
246 |
# self.train_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="train")
|
247 |
# calculate the scores end
|
248 |
return loss
|
249 |
|
250 |
def on_train_epoch_end(self) -> None:
|
251 |
+
self.train_metrics.compute_metrics_and_log(log=self.log, log_prefix="train")
|
252 |
+
self.train_metrics.reset_on_epoch_end()
|
253 |
pass
|
254 |
|
255 |
def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
|
262 |
self.log("valid_loss", loss)
|
263 |
# calculate the scores start
|
264 |
self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
265 |
+
self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
|
266 |
# self.validate_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="validate", log_color=blue)
|
267 |
|
268 |
# calculate the scores end
|
269 |
return loss
|
270 |
|
271 |
def on_validation_epoch_end(self) -> None:
|
272 |
+
self.validate_metrics.compute_metrics_and_log(log=self.log, log_prefix="validate", log_color=blue)
|
273 |
+
self.validate_metrics.reset_on_epoch_end()
|
274 |
return None
|
275 |
|
276 |
def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
|
|
281 |
self.log("test_loss", loss) # do we need this?
|
282 |
# calculate the scores start
|
283 |
self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
|
284 |
+
self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
|
285 |
# self.test_metrics.compute_and_log_on_each_step(log=self.log, log_prefix="test", log_color=magenta)
|
286 |
|
287 |
# calculate the scores end
|
288 |
return loss
|
289 |
|
290 |
def on_test_epoch_end(self) -> None:
|
291 |
+
self.test_metrics.compute_metrics_and_log(log=self.log, log_prefix="test", log_color=magenta)
|
292 |
+
self.test_metrics.reset_on_epoch_end()
|
293 |
return None
|
294 |
|
295 |
pass
|