|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from torch.nn import functional as F |
|
|
|
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward |
|
|
|
|
|
class FFTBlock(torch.nn.Module): |
|
"""FFT Block""" |
|
|
|
def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): |
|
super(FFTBlock, self).__init__() |
|
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) |
|
self.pos_ffn = PositionwiseFeedForward( |
|
d_model, d_inner, kernel_size, dropout=dropout |
|
) |
|
|
|
def forward(self, enc_input, mask=None, slf_attn_mask=None): |
|
enc_output, enc_slf_attn = self.slf_attn( |
|
enc_input, enc_input, enc_input, mask=slf_attn_mask |
|
) |
|
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) |
|
|
|
enc_output = self.pos_ffn(enc_output) |
|
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) |
|
|
|
return enc_output, enc_slf_attn |
|
|
|
|
|
class ConvNorm(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=None, |
|
dilation=1, |
|
bias=True, |
|
w_init_gain="linear", |
|
): |
|
super(ConvNorm, self).__init__() |
|
|
|
if padding is None: |
|
assert kernel_size % 2 == 1 |
|
padding = int(dilation * (kernel_size - 1) / 2) |
|
|
|
self.conv = torch.nn.Conv1d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, signal): |
|
conv_signal = self.conv(signal) |
|
|
|
return conv_signal |
|
|
|
|
|
class PostNet(nn.Module): |
|
""" |
|
PostNet: Five 1-d convolution with 512 channels and kernel size 5 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_mel_channels=80, |
|
postnet_embedding_dim=512, |
|
postnet_kernel_size=5, |
|
postnet_n_convolutions=5, |
|
): |
|
|
|
super(PostNet, self).__init__() |
|
self.convolutions = nn.ModuleList() |
|
|
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm( |
|
n_mel_channels, |
|
postnet_embedding_dim, |
|
kernel_size=postnet_kernel_size, |
|
stride=1, |
|
padding=int((postnet_kernel_size - 1) / 2), |
|
dilation=1, |
|
w_init_gain="tanh", |
|
), |
|
nn.BatchNorm1d(postnet_embedding_dim), |
|
) |
|
) |
|
|
|
for i in range(1, postnet_n_convolutions - 1): |
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm( |
|
postnet_embedding_dim, |
|
postnet_embedding_dim, |
|
kernel_size=postnet_kernel_size, |
|
stride=1, |
|
padding=int((postnet_kernel_size - 1) / 2), |
|
dilation=1, |
|
w_init_gain="tanh", |
|
), |
|
nn.BatchNorm1d(postnet_embedding_dim), |
|
) |
|
) |
|
|
|
self.convolutions.append( |
|
nn.Sequential( |
|
ConvNorm( |
|
postnet_embedding_dim, |
|
n_mel_channels, |
|
kernel_size=postnet_kernel_size, |
|
stride=1, |
|
padding=int((postnet_kernel_size - 1) / 2), |
|
dilation=1, |
|
w_init_gain="linear", |
|
), |
|
nn.BatchNorm1d(n_mel_channels), |
|
) |
|
) |
|
|
|
def forward(self, x): |
|
x = x.contiguous().transpose(1, 2) |
|
|
|
for i in range(len(self.convolutions) - 1): |
|
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) |
|
x = F.dropout(self.convolutions[-1](x), 0.5, self.training) |
|
|
|
x = x.contiguous().transpose(1, 2) |
|
return x |
|
|