# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.nn import functional as F from .SubLayers import MultiHeadAttention, PositionwiseFeedForward class FFTBlock(torch.nn.Module): """FFT Block""" def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): super(FFTBlock, self).__init__() self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.pos_ffn = PositionwiseFeedForward( d_model, d_inner, kernel_size, dropout=dropout ) def forward(self, enc_input, mask=None, slf_attn_mask=None): enc_output, enc_slf_attn = self.slf_attn( enc_input, enc_input, enc_input, mask=slf_attn_mask ) enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) enc_output = self.pos_ffn(enc_output) enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) return enc_output, enc_slf_attn class ConvNorm(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain="linear", ): super(ConvNorm, self).__init__() if padding is None: assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) self.conv = torch.nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ) def forward(self, signal): conv_signal = self.conv(signal) return conv_signal class PostNet(nn.Module): """ PostNet: Five 1-d convolution with 512 channels and kernel size 5 """ def __init__( self, n_mel_channels=80, postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5, ): super(PostNet, self).__init__() self.convolutions = nn.ModuleList() self.convolutions.append( nn.Sequential( ConvNorm( n_mel_channels, postnet_embedding_dim, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="tanh", ), nn.BatchNorm1d(postnet_embedding_dim), ) ) for i in range(1, postnet_n_convolutions - 1): self.convolutions.append( nn.Sequential( ConvNorm( postnet_embedding_dim, postnet_embedding_dim, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="tanh", ), nn.BatchNorm1d(postnet_embedding_dim), ) ) self.convolutions.append( nn.Sequential( ConvNorm( postnet_embedding_dim, n_mel_channels, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="linear", ), nn.BatchNorm1d(n_mel_channels), ) ) def forward(self, x): x = x.contiguous().transpose(1, 2) for i in range(len(self.convolutions) - 1): x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) x = F.dropout(self.convolutions[-1](x), 0.5, self.training) x = x.contiguous().transpose(1, 2) return x