|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
class CTC(torch.nn.Module): |
|
"""CTC module""" |
|
|
|
def __init__( |
|
self, |
|
odim: int, |
|
encoder_output_size: int, |
|
dropout_rate: float = 0.0, |
|
reduce: bool = True, |
|
): |
|
"""Construct CTC module |
|
Args: |
|
odim: dimension of outputs |
|
encoder_output_size: number of encoder projection units |
|
dropout_rate: dropout rate (0.0 ~ 1.0) |
|
reduce: reduce the CTC loss into a scalar |
|
""" |
|
super().__init__() |
|
eprojs = encoder_output_size |
|
self.dropout_rate = dropout_rate |
|
self.ctc_lo = torch.nn.Linear(eprojs, odim) |
|
|
|
reduction_type = "sum" if reduce else "none" |
|
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) |
|
|
|
def forward( |
|
self, |
|
hs_pad: torch.Tensor, |
|
hlens: torch.Tensor, |
|
ys_pad: torch.Tensor, |
|
ys_lens: torch.Tensor, |
|
) -> torch.Tensor: |
|
"""Calculate CTC loss. |
|
|
|
Args: |
|
hs_pad: batch of padded hidden state sequences (B, Tmax, D) |
|
hlens: batch of lengths of hidden state sequences (B) |
|
ys_pad: batch of padded character id sequence tensor (B, Lmax) |
|
ys_lens: batch of lengths of character sequence (B) |
|
""" |
|
|
|
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) |
|
|
|
ys_hat = ys_hat.transpose(0, 1) |
|
ys_hat = ys_hat.log_softmax(2) |
|
loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) |
|
|
|
loss = loss / ys_hat.size(1) |
|
return loss |
|
|
|
def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: |
|
"""log_softmax of frame activations |
|
|
|
Args: |
|
Tensor hs_pad: 3d tensor (B, Tmax, eprojs) |
|
Returns: |
|
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) |
|
""" |
|
return F.log_softmax(self.ctc_lo(hs_pad), dim=2) |
|
|
|
def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: |
|
"""argmax of frame activations |
|
|
|
Args: |
|
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) |
|
Returns: |
|
torch.Tensor: argmax applied 2d tensor (B, Tmax) |
|
""" |
|
return torch.argmax(self.ctc_lo(hs_pad), dim=2) |
|
|