|
""" |
|
Taken from ESPNet |
|
""" |
|
|
|
import torch |
|
|
|
|
|
class PostNet(torch.nn.Module): |
|
""" |
|
From Tacotron2 |
|
|
|
Postnet module for Spectrogram prediction network. |
|
|
|
This is a module of Postnet in Spectrogram prediction network, |
|
which described in `Natural TTS Synthesis by |
|
Conditioning WaveNet on Mel Spectrogram Predictions`_. |
|
The Postnet refines the predicted |
|
Mel-filterbank of the decoder, |
|
which helps to compensate the detail sturcture of spectrogram. |
|
|
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: |
|
https://arxiv.org/abs/1712.05884 |
|
""" |
|
|
|
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True): |
|
""" |
|
Initialize postnet module. |
|
|
|
Args: |
|
idim (int): Dimension of the inputs. |
|
odim (int): Dimension of the outputs. |
|
n_layers (int, optional): The number of layers. |
|
n_filts (int, optional): The number of filter size. |
|
n_units (int, optional): The number of filter channels. |
|
use_batch_norm (bool, optional): Whether to use batch normalization.. |
|
dropout_rate (float, optional): Dropout rate.. |
|
""" |
|
super(PostNet, self).__init__() |
|
self.postnet = torch.nn.ModuleList() |
|
for layer in range(n_layers - 1): |
|
ichans = odim if layer == 0 else n_chans |
|
ochans = odim if layer == n_layers - 1 else n_chans |
|
if use_batch_norm: |
|
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), |
|
torch.nn.GroupNorm(num_groups=32, num_channels=ochans), torch.nn.Tanh(), |
|
torch.nn.Dropout(dropout_rate), )] |
|
|
|
else: |
|
self.postnet += [ |
|
torch.nn.Sequential(torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), torch.nn.Tanh(), |
|
torch.nn.Dropout(dropout_rate), )] |
|
ichans = n_chans if n_layers != 1 else odim |
|
if use_batch_norm: |
|
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), |
|
torch.nn.GroupNorm(num_groups=20, num_channels=odim), |
|
torch.nn.Dropout(dropout_rate), )] |
|
|
|
else: |
|
self.postnet += [torch.nn.Sequential(torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False, ), |
|
torch.nn.Dropout(dropout_rate), )] |
|
|
|
def forward(self, xs): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). |
|
|
|
Returns: |
|
Tensor: Batch of padded output tensor. (B, odim, Tmax). |
|
""" |
|
for i in range(len(self.postnet)): |
|
xs = self.postnet[i](xs) |
|
return xs |
|
|