zyingt's picture
Upload 685 files
0d80816
raw
history blame
4.13 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn import functional as F
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
class FFTBlock(torch.nn.Module):
"""FFT Block"""
def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(
d_model, d_inner, kernel_size, dropout=dropout
)
def forward(self, enc_input, mask=None, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask
)
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
enc_output = self.pos_ffn(enc_output)
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
return enc_output, enc_slf_attn
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class PostNet(nn.Module):
"""
PostNet: Five 1-d convolution with 512 channels and kernel size 5
"""
def __init__(
self,
n_mel_channels=80,
postnet_embedding_dim=512,
postnet_kernel_size=5,
postnet_n_convolutions=5,
):
super(PostNet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
nn.Sequential(
ConvNorm(
n_mel_channels,
postnet_embedding_dim,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="tanh",
),
nn.BatchNorm1d(postnet_embedding_dim),
)
)
for i in range(1, postnet_n_convolutions - 1):
self.convolutions.append(
nn.Sequential(
ConvNorm(
postnet_embedding_dim,
postnet_embedding_dim,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="tanh",
),
nn.BatchNorm1d(postnet_embedding_dim),
)
)
self.convolutions.append(
nn.Sequential(
ConvNorm(
postnet_embedding_dim,
n_mel_channels,
kernel_size=postnet_kernel_size,
stride=1,
padding=int((postnet_kernel_size - 1) / 2),
dilation=1,
w_init_gain="linear",
),
nn.BatchNorm1d(n_mel_channels),
)
)
def forward(self, x):
x = x.contiguous().transpose(1, 2)
for i in range(len(self.convolutions) - 1):
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
x = x.contiguous().transpose(1, 2)
return x