File size: 338 Bytes
b1e1a76
 
52e32c0
 
 
 
 
 
b1e1a76
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
# from icefall.utils import make_pad_mask

from .symbol_table import SymbolTable

# make_pad_mask = make_pad_mask
SymbolTable = SymbolTable


class Transpose(nn.Identity):
    """(N, T, D) -> (N, D, T)"""

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input.transpose(1, 2)