| """Cutom encoder definition for transducer models.""" | |
| import torch | |
| from espnet.nets.pytorch_backend.transducer.blocks import build_blocks | |
| from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L | |
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |
| from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling | |
| class CustomEncoder(torch.nn.Module): | |
| """Custom encoder module for transducer models. | |
| Args: | |
| idim (int): input dim | |
| enc_arch (list): list of encoder blocks (type and parameters) | |
| input_layer (str): input layer type | |
| repeat_block (int): repeat provided block N times if N > 1 | |
| self_attn_type (str): type of self-attention | |
| positional_encoding_type (str): positional encoding type | |
| positionwise_layer_type (str): linear | |
| positionwise_activation_type (str): positionwise activation type | |
| conv_mod_activation_type (str): convolutional module activation type | |
| normalize_before (bool): whether to use layer_norm before the first block | |
| aux_task_layer_list (list): list of layer ids for intermediate output | |
| padding_idx (int): padding_idx for embedding input layer (if specified) | |
| """ | |
| def __init__( | |
| self, | |
| idim, | |
| enc_arch, | |
| input_layer="linear", | |
| repeat_block=0, | |
| self_attn_type="selfattn", | |
| positional_encoding_type="abs_pos", | |
| positionwise_layer_type="linear", | |
| positionwise_activation_type="relu", | |
| conv_mod_activation_type="relu", | |
| normalize_before=True, | |
| aux_task_layer_list=[], | |
| padding_idx=-1, | |
| ): | |
| """Construct an CustomEncoder object.""" | |
| super().__init__() | |
| ( | |
| self.embed, | |
| self.encoders, | |
| self.enc_out, | |
| self.conv_subsampling_factor, | |
| ) = build_blocks( | |
| "encoder", | |
| idim, | |
| input_layer, | |
| enc_arch, | |
| repeat_block=repeat_block, | |
| self_attn_type=self_attn_type, | |
| positional_encoding_type=positional_encoding_type, | |
| positionwise_layer_type=positionwise_layer_type, | |
| positionwise_activation_type=positionwise_activation_type, | |
| conv_mod_activation_type=conv_mod_activation_type, | |
| padding_idx=padding_idx, | |
| ) | |
| self.normalize_before = normalize_before | |
| if self.normalize_before: | |
| self.after_norm = LayerNorm(self.enc_out) | |
| self.n_blocks = len(enc_arch) * repeat_block | |
| self.aux_task_layer_list = aux_task_layer_list | |
| def forward(self, xs, masks): | |
| """Encode input sequence. | |
| Args: | |
| xs (torch.Tensor): input tensor | |
| masks (torch.Tensor): input mask | |
| Returns: | |
| xs (torch.Tensor or tuple): | |
| position embedded output or | |
| (position embedded output, auxiliary outputs) | |
| mask (torch.Tensor): position embedded mask | |
| """ | |
| if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): | |
| xs, masks = self.embed(xs, masks) | |
| else: | |
| xs = self.embed(xs) | |
| if self.aux_task_layer_list: | |
| aux_xs_list = [] | |
| for b in range(self.n_blocks): | |
| xs, masks = self.encoders[b](xs, masks) | |
| if b in self.aux_task_layer_list: | |
| if isinstance(xs, tuple): | |
| aux_xs = xs[0] | |
| else: | |
| aux_xs = xs | |
| if self.normalize_before: | |
| aux_xs_list.append(self.after_norm(aux_xs)) | |
| else: | |
| aux_xs_list.append(aux_xs) | |
| else: | |
| xs, masks = self.encoders(xs, masks) | |
| if isinstance(xs, tuple): | |
| xs = xs[0] | |
| if self.normalize_before: | |
| xs = self.after_norm(xs) | |
| if self.aux_task_layer_list: | |
| return (xs, aux_xs_list), masks | |
| return xs, masks | |