[#1] logging the training loss
Browse files- idiomify/models.py +3 -0
idiomify/models.py
CHANGED
@@ -44,6 +44,9 @@ class Alpha(pl.LightningModule): # noqa
|
|
44 |
"loss": loss
|
45 |
}
|
46 |
|
|
|
|
|
|
|
47 |
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
48 |
pred_ids = self.bart.generate(
|
49 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
|
|
44 |
"loss": loss
|
45 |
}
|
46 |
|
47 |
+
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
48 |
+
self.log("Train/Loss", outputs['loss'])
|
49 |
+
|
50 |
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
51 |
pred_ids = self.bart.generate(
|
52 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|