Spaces:
Sleeping
Sleeping
from math import sqrt | |
import torch | |
from torch import nn | |
from Encoder import Encoder | |
from Decoder import Decoder | |
from Postnet import Postnet | |
from GST import GST | |
from utils import to_gpu, get_mask_from_lengths | |
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 | |
class tacotron_2(nn.Module): | |
def __init__(self, tacotron_hyperparams): | |
super(tacotron_2, self).__init__() | |
self.mask_padding = tacotron_hyperparams['mask_padding'] | |
self.fp16_run = tacotron_hyperparams['fp16_run'] | |
self.n_mel_channels = tacotron_hyperparams['n_mel_channels'] | |
self.n_frames_per_step = tacotron_hyperparams['number_frames_step'] | |
self.embedding = nn.Embedding( | |
tacotron_hyperparams['n_symbols'], tacotron_hyperparams['symbols_embedding_length']) | |
# CHECK THIS OUT!!! | |
std = sqrt(2.0 / (tacotron_hyperparams['n_symbols'] + tacotron_hyperparams['symbols_embedding_length'])) | |
val = sqrt(3.0) * std | |
self.embedding.weight.data.uniform_(-val, val) | |
self.encoder = Encoder(tacotron_hyperparams) | |
self.decoder = Decoder(tacotron_hyperparams) | |
self.postnet = Postnet(tacotron_hyperparams) | |
self.gst = GST(tacotron_hyperparams) | |
def parse_batch(self, batch): | |
# GST I add the new tensor from prosody features to train GST tokens: | |
text_padded, input_lengths, mel_padded, gate_padded, output_lengths, prosody_padded = batch | |
text_padded = to_gpu(text_padded).long() | |
max_len = int(torch.max(input_lengths.data).item()) # With item() you get the pure value (not in a tensor) | |
input_lengths = to_gpu(input_lengths).long() | |
mel_padded = to_gpu(mel_padded).float() | |
gate_padded = to_gpu(gate_padded).float() | |
output_lengths = to_gpu(output_lengths).long() | |
prosody_padded = to_gpu(prosody_padded).float() | |
return ( | |
(text_padded, input_lengths, mel_padded, max_len, output_lengths, prosody_padded), | |
(mel_padded, gate_padded)) | |
def parse_input(self, inputs): | |
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs | |
return inputs | |
def parse_output(self, outputs, output_lengths=None): | |
if self.mask_padding and output_lengths is not None: | |
mask = ~get_mask_from_lengths(output_lengths) | |
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) | |
mask = mask.permute(1, 0, 2) | |
outputs[0].data.masked_fill_(mask, 0.0) | |
outputs[1].data.masked_fill_(mask, 0.0) | |
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies | |
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs | |
return outputs | |
def forward(self, inputs): | |
inputs, input_lengths, targets, max_len, output_lengths, gst_prosody_padded = self.parse_input(inputs) | |
input_lengths, output_lengths = input_lengths.data, output_lengths.data | |
embedded_inputs = self.embedding(inputs).transpose(1, 2) | |
encoder_outputs = self.encoder(embedded_inputs, input_lengths) | |
# GST style embedding plus embedded_inputs before entering the decoder | |
# bin_locations = gst_prosody_padded[:, 0, :] | |
# pitch_intensities = gst_prosody_padded[:, 1:, :] | |
# bin_locations = bin_locations.unsqueeze(2) | |
gst_style_embedding, gst_scores = self.gst(gst_prosody_padded, output_lengths) # [N, 512] | |
gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) | |
encoder_outputs = encoder_outputs + gst_style_embedding | |
mel_outputs, gate_outputs, alignments = self.decoder( | |
encoder_outputs, targets, memory_lengths=input_lengths) | |
mel_outputs_postnet = self.postnet(mel_outputs) | |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
return self.parse_output( | |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments, gst_scores], | |
output_lengths) | |
def inference(self, inputs, gst_scores): # gst_scores must be a torch tensor | |
inputs = self.parse_input(inputs) | |
embedded_inputs = self.embedding(inputs).transpose(1, 2) | |
encoder_outputs = self.encoder.inference(embedded_inputs) | |
# GST inference: | |
gst_style_embedding = self.gst.inference(gst_scores) | |
gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) | |
encoder_outputs = encoder_outputs + gst_style_embedding | |
mel_outputs, gate_outputs, alignments = self.decoder.inference( | |
encoder_outputs) | |
mel_outputs_postnet = self.postnet(mel_outputs) | |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
outputs = self.parse_output( | |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) | |
return outputs | |