mpc001's picture
Upload 125 files
09481f3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import torch
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
:param nn.Module pos_enc_class: positional encoding layer
"""
def __init__(self, idim, odim, dropout_rate, pos_enc_class):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc_class,
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
or Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
# if RelPositionalEncoding, x: Tuple[torch.Tensor, torch.Tensor]
# else x: torch.Tensor
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]