| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from scaling import Balancer |
|
|
|
|
| class Decoder(nn.Module): |
| """This class modifies the stateless decoder from the following paper: |
| |
| RNN-transducer with stateless prediction network |
| https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 |
| |
| It removes the recurrent connection from the decoder, i.e., the prediction |
| network. Different from the above paper, it adds an extra Conv1d |
| right after the embedding layer. |
| |
| TODO: Implement https://arxiv.org/pdf/2109.07513.pdf |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| decoder_dim: int, |
| blank_id: int, |
| context_size: int, |
| ): |
| """ |
| Args: |
| vocab_size: |
| Number of tokens of the modeling unit including blank. |
| decoder_dim: |
| Dimension of the input embedding, and of the decoder output. |
| blank_id: |
| The ID of the blank symbol. |
| context_size: |
| Number of previous words to use to predict the next word. |
| 1 means bigram; 2 means trigram. n means (n+1)-gram. |
| """ |
| super().__init__() |
|
|
| self.embedding = nn.Embedding( |
| num_embeddings=vocab_size, |
| embedding_dim=decoder_dim, |
| ) |
| |
| |
| self.balancer = Balancer( |
| decoder_dim, |
| channel_dim=-1, |
| min_positive=0.0, |
| max_positive=1.0, |
| min_abs=0.5, |
| max_abs=1.0, |
| prob=0.05, |
| ) |
|
|
| self.blank_id = blank_id |
|
|
| assert context_size >= 1, context_size |
| self.context_size = context_size |
| self.vocab_size = vocab_size |
|
|
| if context_size > 1: |
| self.conv = nn.Conv1d( |
| in_channels=decoder_dim, |
| out_channels=decoder_dim, |
| kernel_size=context_size, |
| padding=0, |
| groups=decoder_dim // 4, |
| bias=False, |
| ) |
| self.balancer2 = Balancer( |
| decoder_dim, |
| channel_dim=-1, |
| min_positive=0.0, |
| max_positive=1.0, |
| min_abs=0.5, |
| max_abs=1.0, |
| prob=0.05, |
| ) |
| else: |
| |
| |
| self.conv = nn.Identity() |
| self.balancer2 = nn.Identity() |
|
|
| def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: |
| """ |
| Args: |
| y: |
| A 2-D tensor of shape (N, U). |
| need_pad: |
| True to left pad the input. Should be True during training. |
| False to not pad the input. Should be False during inference. |
| Returns: |
| Return a tensor of shape (N, U, decoder_dim). |
| """ |
| y = y.to(torch.int64) |
| |
| |
| embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) |
|
|
| embedding_out = self.balancer(embedding_out) |
|
|
| if self.context_size > 1: |
| embedding_out = embedding_out.permute(0, 2, 1) |
| if need_pad is True: |
| embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) |
| else: |
| |
| |
| assert embedding_out.size(-1) == self.context_size |
| embedding_out = self.conv(embedding_out) |
| embedding_out = embedding_out.permute(0, 2, 1) |
| embedding_out = F.relu(embedding_out) |
| embedding_out = self.balancer2(embedding_out) |
|
|
| return embedding_out |
|
|