[#2] refactoring; running transpose subsequently
Browse files- idiomify/models.py +1 -2
idiomify/models.py
CHANGED
@@ -38,8 +38,7 @@ class Seq2Seq(pl.LightningModule): # noqa
|
|
38 |
|
39 |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
40 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
41 |
-
logits = self.forward(srcs, tgts_r) # -> (N, L, |V|)
|
42 |
-
logits = logits.transpose(1, 2) # (N, L, |V|) -> (N, |V|, L)
|
43 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
44 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
45 |
return {
|
|
|
38 |
|
39 |
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> dict:
|
40 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
41 |
+
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
|
|
42 |
loss = F.cross_entropy(logits, tgts, ignore_index=self.hparams['pad_token_id'])\
|
43 |
.sum() # (N, L, |V|), (N, L) -> (N,) -> (1,)
|
44 |
return {
|