fix: update to ptl v2.0
Browse files- detector/model.py +2 -2
detector/model.py
CHANGED
@@ -122,7 +122,7 @@ class FontDetector(ptl.LightningModule):
|
|
122 |
),
|
123 |
)
|
124 |
|
125 |
-
def
|
126 |
self.font_accur_train.reset()
|
127 |
self.direction_accur_train.reset()
|
128 |
|
@@ -139,7 +139,7 @@ class FontDetector(ptl.LightningModule):
|
|
139 |
)
|
140 |
return {"loss": loss, "pred": y_hat, "target": y}
|
141 |
|
142 |
-
def
|
143 |
self.log("val_font_accur", self.font_accur_val.compute())
|
144 |
self.log("val_direction_accur", self.direction_accur_val.compute())
|
145 |
self.font_accur_val.reset()
|
|
|
122 |
),
|
123 |
)
|
124 |
|
125 |
+
def on_train_epoch_end(self) -> None:
|
126 |
self.font_accur_train.reset()
|
127 |
self.direction_accur_train.reset()
|
128 |
|
|
|
139 |
)
|
140 |
return {"loss": loss, "pred": y_hat, "target": y}
|
141 |
|
142 |
+
def on_validation_epoch_end(self):
|
143 |
self.log("val_font_accur", self.font_accur_val.compute())
|
144 |
self.log("val_direction_accur", self.direction_accur_val.compute())
|
145 |
self.font_accur_val.reset()
|