|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
|
|
class CNNPrenet(torch.nn.Module): |
|
def __init__(self): |
|
super(CNNPrenet, self).__init__() |
|
|
|
|
|
self.conv_layers = nn.Sequential( |
|
nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding=1), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
|
|
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
|
|
nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Dropout(0.1) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
x = self.conv_layers(x) |
|
|
|
|
|
x = x.squeeze(1) |
|
|
|
|
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class CNNDecoderPrenet(nn.Module): |
|
def __init__(self, input_dim=80, hidden_dim=256, output_dim=256, final_dim=512, dropout_rate=0.5): |
|
super(CNNDecoderPrenet, self).__init__() |
|
self.layer1 = nn.Linear(input_dim, hidden_dim) |
|
self.layer2 = nn.Linear(hidden_dim, output_dim) |
|
self.linear_projection = nn.Linear(output_dim, final_dim) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
x = F.relu(self.layer1(x)) |
|
x = self.dropout(x) |
|
x = F.relu(self.layer2(x)) |
|
x = self.dropout(x) |
|
|
|
x = self.linear_projection(x) |
|
x = x.transpose(1, 2) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
class CNNPostNet(torch.nn.Module): |
|
""" |
|
Conv Postnet |
|
Arguments |
|
--------- |
|
n_mel_channels: int |
|
input feature dimension for convolution layers |
|
postnet_embedding_dim: int |
|
output feature dimension for convolution layers |
|
postnet_kernel_size: int |
|
postnet convolution kernal size |
|
postnet_n_convolutions: int |
|
number of convolution layers |
|
postnet_dropout: float |
|
dropout probability fot postnet |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_mel_channels=80, |
|
postnet_embedding_dim=512, |
|
postnet_kernel_size=5, |
|
postnet_n_convolutions=5, |
|
postnet_dropout=0.1, |
|
): |
|
super(CNNPostNet, self).__init__() |
|
|
|
self.conv_pre = nn.Conv1d( |
|
in_channels=n_mel_channels, |
|
out_channels=postnet_embedding_dim, |
|
kernel_size=postnet_kernel_size, |
|
padding="same", |
|
) |
|
|
|
self.convs_intermedite = nn.ModuleList() |
|
for i in range(1, postnet_n_convolutions - 1): |
|
self.convs_intermedite.append( |
|
nn.Conv1d( |
|
in_channels=postnet_embedding_dim, |
|
out_channels=postnet_embedding_dim, |
|
kernel_size=postnet_kernel_size, |
|
padding="same", |
|
), |
|
) |
|
|
|
self.conv_post = nn.Conv1d( |
|
in_channels=postnet_embedding_dim, |
|
out_channels=n_mel_channels, |
|
kernel_size=postnet_kernel_size, |
|
padding="same", |
|
) |
|
|
|
self.tanh = nn.Tanh() |
|
self.ln1 = nn.LayerNorm(postnet_embedding_dim) |
|
self.ln2 = nn.LayerNorm(postnet_embedding_dim) |
|
self.ln3 = nn.LayerNorm(n_mel_channels) |
|
self.dropout1 = nn.Dropout(postnet_dropout) |
|
self.dropout2 = nn.Dropout(postnet_dropout) |
|
self.dropout3 = nn.Dropout(postnet_dropout) |
|
|
|
|
|
def forward(self, x): |
|
"""Computes the forward pass |
|
Arguments |
|
--------- |
|
x: torch.Tensor |
|
a (batch, time_steps, features) input tensor |
|
Returns |
|
------- |
|
output: torch.Tensor (the spectrogram predicted) |
|
""" |
|
x = self.conv_pre(x) |
|
x = self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1) |
|
x = self.tanh(x) |
|
x = self.dropout1(x) |
|
|
|
for i in range(len(self.convs_intermedite)): |
|
x = self.convs_intermedite[i](x) |
|
x = self.ln2(x.permute(0, 2, 1)).permute(0, 2, 1) |
|
x = self.tanh(x) |
|
x = self.dropout2(x) |
|
|
|
x = self.conv_post(x) |
|
x = self.ln3(x.permute(0, 2, 1)).permute(0, 2, 1) |
|
x = self.dropout3(x) |
|
|
|
return x |
|
|
|
|
|
class ScaledPositionalEncoding(nn.Module): |
|
""" |
|
This class implements the absolute sinusoidal positional encoding function |
|
with an adaptive weight parameter alpha. |
|
|
|
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) |
|
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) |
|
|
|
Arguments |
|
--------- |
|
input_size: int |
|
Embedding dimension. |
|
max_len : int, optional |
|
Max length of the input sequences (default 2500). |
|
Example |
|
------- |
|
>>> a = torch.rand((8, 120, 512)) |
|
>>> enc = PositionalEncoding(input_size=a.shape[-1]) |
|
>>> b = enc(a) |
|
>>> b.shape |
|
torch.Size([1, 120, 512]) |
|
""" |
|
|
|
def __init__(self, input_size, max_len=2500): |
|
super().__init__() |
|
if input_size % 2 != 0: |
|
raise ValueError( |
|
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})" |
|
) |
|
self.max_len = max_len |
|
self.alpha = nn.Parameter(torch.ones(1)) |
|
pe = torch.zeros(self.max_len, input_size, requires_grad=False) |
|
positions = torch.arange(0, self.max_len).unsqueeze(1).float() |
|
denominator = torch.exp( |
|
torch.arange(0, input_size, 2).float() |
|
* -(math.log(10000.0) / input_size) |
|
) |
|
|
|
pe[:, 0::2] = torch.sin(positions * denominator) |
|
pe[:, 1::2] = torch.cos(positions * denominator) |
|
pe = pe.unsqueeze(0) |
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, x): |
|
""" |
|
Arguments |
|
--------- |
|
x : tensor |
|
Input feature shape (batch, time, fea) |
|
""" |
|
pe_scaled = self.pe[:, :x.size(1)].clone().detach() * self.alpha |
|
return pe_scaled |
|
|
|
|