eubinecto commited on
Commit
d3d3e90
1 Parent(s): d8d4c8d

[#1] logging the training loss

Browse files
Files changed (1) hide show
  1. idiomify/models.py +3 -0
idiomify/models.py CHANGED
@@ -44,6 +44,9 @@ class Alpha(pl.LightningModule): # noqa
44
  "loss": loss
45
  }
46
 
 
 
 
47
  def predict(self, srcs: torch.Tensor) -> torch.Tensor:
48
  pred_ids = self.bart.generate(
49
  inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
 
44
  "loss": loss
45
  }
46
 
47
+ def on_train_batch_end(self, outputs: dict, *args, **kwargs):
48
+ self.log("Train/Loss", outputs['loss'])
49
+
50
  def predict(self, srcs: torch.Tensor) -> torch.Tensor:
51
  pred_ids = self.bart.generate(
52
  inputs=srcs[:, 0], # (N, 2, L) -> (N, L)