Spaces:
Sleeping
Sleeping
File size: 4,829 Bytes
af7ac2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from math import sqrt
import torch
from torch import nn
from Encoder import Encoder
from Decoder import Decoder
from Postnet import Postnet
from GST import GST
from utils import to_gpu, get_mask_from_lengths
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
torch.manual_seed(1234)
class tacotron_2(nn.Module):
def __init__(self, tacotron_hyperparams):
super(tacotron_2, self).__init__()
self.mask_padding = tacotron_hyperparams['mask_padding']
self.fp16_run = tacotron_hyperparams['fp16_run']
self.n_mel_channels = tacotron_hyperparams['n_mel_channels']
self.n_frames_per_step = tacotron_hyperparams['number_frames_step']
self.embedding = nn.Embedding(
tacotron_hyperparams['n_symbols'], tacotron_hyperparams['symbols_embedding_length'])
# CHECK THIS OUT!!!
std = sqrt(2.0 / (tacotron_hyperparams['n_symbols'] + tacotron_hyperparams['symbols_embedding_length']))
val = sqrt(3.0) * std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(tacotron_hyperparams)
self.decoder = Decoder(tacotron_hyperparams)
self.postnet = Postnet(tacotron_hyperparams)
self.gst = GST(tacotron_hyperparams)
def parse_batch(self, batch):
# GST I add the new tensor from prosody features to train GST tokens:
text_padded, input_lengths, mel_padded, gate_padded, output_lengths, prosody_padded = batch
text_padded = to_gpu(text_padded).long()
max_len = int(torch.max(input_lengths.data).item()) # With item() you get the pure value (not in a tensor)
input_lengths = to_gpu(input_lengths).long()
mel_padded = to_gpu(mel_padded).float()
gate_padded = to_gpu(gate_padded).float()
output_lengths = to_gpu(output_lengths).long()
prosody_padded = to_gpu(prosody_padded).float()
return (
(text_padded, input_lengths, mel_padded, max_len, output_lengths, prosody_padded),
(mel_padded, gate_padded))
def parse_input(self, inputs):
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
return inputs
def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths)
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
mask = mask.permute(1, 0, 2)
outputs[0].data.masked_fill_(mask, 0.0)
outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
return outputs
def forward(self, inputs):
inputs, input_lengths, targets, max_len, output_lengths, gst_prosody_padded = self.parse_input(inputs)
input_lengths, output_lengths = input_lengths.data, output_lengths.data
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
# GST style embedding plus embedded_inputs before entering the decoder
# bin_locations = gst_prosody_padded[:, 0, :]
# pitch_intensities = gst_prosody_padded[:, 1:, :]
# bin_locations = bin_locations.unsqueeze(2)
gst_style_embedding, gst_scores = self.gst(gst_prosody_padded, output_lengths) # [N, 512]
gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs)
encoder_outputs = encoder_outputs + gst_style_embedding
mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, targets, memory_lengths=input_lengths)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
return self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments, gst_scores],
output_lengths)
def inference(self, inputs, gst_scores): # gst_scores must be a torch tensor
inputs = self.parse_input(inputs)
embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
# GST inference:
gst_style_embedding = self.gst.inference(gst_scores)
gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs)
encoder_outputs = encoder_outputs + gst_style_embedding
mel_outputs, gate_outputs, alignments = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
outputs = self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
return outputs
|