|
import torch |
|
from torch import nn |
|
from torch.distributions.multivariate_normal import MultivariateNormal as MVN |
|
from torch.nn import functional as F |
|
|
|
|
|
class CapacitronVAE(nn.Module): |
|
"""Effective Use of Variational Embedding Capacity for prosody transfer. |
|
|
|
See https://arxiv.org/abs/1906.03402""" |
|
|
|
def __init__( |
|
self, |
|
num_mel, |
|
capacitron_VAE_embedding_dim, |
|
encoder_output_dim=256, |
|
reference_encoder_out_dim=128, |
|
speaker_embedding_dim=None, |
|
text_summary_embedding_dim=None, |
|
): |
|
super().__init__() |
|
|
|
self.prior_distribution = MVN( |
|
torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim) |
|
) |
|
self.approximate_posterior_distribution = None |
|
|
|
self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim) |
|
|
|
|
|
self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True) |
|
mlp_input_dimension = reference_encoder_out_dim |
|
|
|
if text_summary_embedding_dim is not None: |
|
self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim) |
|
mlp_input_dimension += text_summary_embedding_dim |
|
if speaker_embedding_dim is not None: |
|
|
|
mlp_input_dimension += speaker_embedding_dim |
|
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim) |
|
|
|
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None): |
|
|
|
if reference_mel_info is not None: |
|
reference_mels = reference_mel_info[0] |
|
mel_lengths = reference_mel_info[1] |
|
enc_out = self.encoder(reference_mels, mel_lengths) |
|
|
|
|
|
if text_info is not None: |
|
text_inputs = text_info[0] |
|
input_lengths = text_info[1] |
|
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device) |
|
enc_out = torch.cat([enc_out, text_summary_out], dim=-1) |
|
if speaker_embedding is not None: |
|
speaker_embedding = torch.squeeze(speaker_embedding) |
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) |
|
|
|
|
|
|
|
mu, sigma = self.post_encoder_mlp(enc_out) |
|
|
|
mu = mu.cpu() |
|
sigma = sigma.cpu() |
|
|
|
|
|
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma)) |
|
VAE_embedding = self.approximate_posterior_distribution.rsample() |
|
|
|
else: |
|
|
|
VAE_embedding = self.prior_distribution.sample().unsqueeze(0) |
|
|
|
|
|
return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta |
|
|
|
|
|
class ReferenceEncoder(nn.Module): |
|
"""NN module creating a fixed size prosody embedding from a spectrogram. |
|
|
|
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] |
|
outputs: [batch_size, embedding_dim] |
|
""" |
|
|
|
def __init__(self, num_mel, out_dim): |
|
super().__init__() |
|
self.num_mel = num_mel |
|
filters = [1] + [32, 32, 64, 64, 128, 128] |
|
num_layers = len(filters) - 1 |
|
convs = [ |
|
nn.Conv2d( |
|
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2) |
|
) |
|
for i in range(num_layers) |
|
] |
|
self.convs = nn.ModuleList(convs) |
|
self.training = False |
|
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) |
|
|
|
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) |
|
self.recurrence = nn.LSTM( |
|
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False |
|
) |
|
|
|
def forward(self, inputs, input_lengths): |
|
batch_size = inputs.size(0) |
|
x = inputs.view(batch_size, 1, -1, self.num_mel) |
|
valid_lengths = input_lengths.float() |
|
for conv, bn in zip(self.convs, self.bns): |
|
x = conv(x) |
|
x = bn(x) |
|
x = F.relu(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_lengths = (valid_lengths / 2).float() |
|
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 |
|
post_conv_max_width = x.size(2) |
|
|
|
mask = torch.arange(post_conv_max_width).to(inputs.device).expand( |
|
len(valid_lengths), post_conv_max_width |
|
) < valid_lengths.unsqueeze(1) |
|
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) |
|
x = x * mask |
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
post_conv_width = x.size(1) |
|
x = x.contiguous().view(batch_size, post_conv_width, -1) |
|
|
|
|
|
|
|
|
|
post_conv_input_lengths = valid_lengths |
|
packed_seqs = nn.utils.rnn.pack_padded_sequence( |
|
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False |
|
) |
|
self.recurrence.flatten_parameters() |
|
_, (ht, _) = self.recurrence(packed_seqs) |
|
last_output = ht[-1] |
|
|
|
return last_output.to(inputs.device) |
|
|
|
@staticmethod |
|
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): |
|
"""Height of spec after n convolutions with fixed kernel/stride/pad.""" |
|
for _ in range(n_convs): |
|
height = (height - kernel_size + 2 * pad) // stride + 1 |
|
return height |
|
|
|
|
|
class TextSummary(nn.Module): |
|
def __init__(self, embedding_dim, encoder_output_dim): |
|
super().__init__() |
|
self.lstm = nn.LSTM( |
|
encoder_output_dim, |
|
embedding_dim, |
|
batch_first=True, |
|
bidirectional=False, |
|
) |
|
|
|
def forward(self, inputs, input_lengths): |
|
|
|
packed_seqs = nn.utils.rnn.pack_padded_sequence( |
|
inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False |
|
) |
|
self.lstm.flatten_parameters() |
|
_, (ht, _) = self.lstm(packed_seqs) |
|
last_output = ht[-1] |
|
return last_output |
|
|
|
|
|
class PostEncoderMLP(nn.Module): |
|
def __init__(self, input_size, hidden_size): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
modules = [ |
|
nn.Linear(input_size, hidden_size), |
|
nn.Tanh(), |
|
nn.Linear(hidden_size, hidden_size * 2), |
|
] |
|
self.net = nn.Sequential(*modules) |
|
self.softplus = nn.Softplus() |
|
|
|
def forward(self, _input): |
|
mlp_output = self.net(_input) |
|
|
|
mu = mlp_output[:, : self.hidden_size] |
|
|
|
sigma = self.softplus(mlp_output[:, self.hidden_size :]) |
|
return mu, sigma |
|
|