eubinecto commited on
Commit
143c53f
1 Parent(s): 2bd8a1e

[#2] refactoring; running transpose subsequently

Browse files
Files changed (1) hide show
  1. idiomify/models.py +1 -2
idiomify/models.py CHANGED
@@ -38,8 +38,7 @@ class Seq2Seq(pl.LightningModule): # noqa
38
 
39
  def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
40
  srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
41
- logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
42
- logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
43
  loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
44
  .sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
45
  return {
 
38
 
39
  def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
40
  srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
41
+ logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
 
42
  loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
43
  .sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
44
  return {