""" Taken from ESPNet """ import torch class PostNet(torch.nn.Module): """ From Tacotron2 Postnet module for Spectrogram prediction network. This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet refines the predicted Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True): """ Initialize postnet module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. n_layers (int, optional): The number of layers. n_filts (int, optional): The number of filter size. n_units (int, optional): The number of filter channels. use_batch_norm (bool, optional): Whether to use batch normalization.. dropout_rate (float, optional): Dropout rate.. """ super(PostNet, self).__init__() self.postnet = torch.nn.ModuleList() for layer in range(n_layers - 1): ichans = odim if layer == 0 else n_chans ochans = odim if layer == n_layers - 1 else n_chans if use_batch_norm: self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), )] else: self.postnet += [ torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate), )] ichans = n_chans if n_layers != 1 else odim if use_batch_norm: self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.GroupNorm(num_groups=20, num_channels=odim), torch.nn.Dropout(dropout_rate), )] else: self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Dropout(dropout_rate), )] def forward(self, xs): """ Calculate forward propagation. Args: xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). Returns: Tensor: Batch of padded output tensor. (B, odim, Tmax). """ for i in range(len(self.postnet)): xs = self.postnet[i](xs) return xs