""" BSD 3-Clause License Copyright (c) 2018, NVIDIA Corporation All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ from math import sqrt import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F from training.tacotron2_model.layers import ConvNorm, LinearNorm from training.tacotron2_model.utils import to_gpu, get_mask_from_lengths, get_x class LocationLayer(nn.Module): def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): super(LocationLayer, self).__init__() padding = int((attention_kernel_size - 1) / 2) self.location_conv = ConvNorm( 2, attention_n_filters, kernel_size=attention_kernel_size, padding=padding, bias=False, stride=1, dilation=1 ) self.location_dense = LinearNorm(attention_n_filters, attention_dim, bias=False, w_init_gain="tanh") def forward(self, attention_weights_cat): processed_attention = self.location_conv(attention_weights_cat) processed_attention = processed_attention.transpose(1, 2) processed_attention = self.location_dense(processed_attention) return processed_attention class Attention(nn.Module): def __init__( self, attention_rnn_dim, embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, ): super(Attention, self).__init__() self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh") self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain="tanh") self.v = LinearNorm(attention_dim, 1, bias=False) self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, attention_dim) self.score_mask_value = -float("inf") def get_alignment_energies(self, query, processed_memory, attention_weights_cat): """ PARAMS ------ query: decoder output (batch, n_mel_channels * n_frames_per_step) processed_memory: processed encoder outputs (B, T_in, attention_dim) attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) RETURNS ------- alignment (batch, max_time) """ processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_weights_cat) energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory)) energies = energies.squeeze(-1) return energies def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask): """ PARAMS ------ attention_hidden_state: attention rnn last output memory: encoder outputs processed_memory: processed encoder outputs attention_weights_cat: previous and cummulative attention weights mask: binary mask for padded data """ alignment = self.get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat) if mask is not None: alignment.data.masked_fill_(mask, self.score_mask_value) attention_weights = F.softmax(alignment, dim=1) attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) attention_context = attention_context.squeeze(1) return attention_context, attention_weights class Prenet(nn.Module): def __init__(self, in_dim, sizes): super(Prenet, self).__init__() in_sizes = [in_dim] + sizes[:-1] self.layers = nn.ModuleList( [LinearNorm(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, sizes)] ) def forward(self, x): for linear in self.layers: x = F.dropout(F.relu(linear(x)), p=0.5, training=True) return x class Postnet(nn.Module): """Postnet - Five 1-d convolution with 512 channels and kernel size 5 """ def __init__(self, n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() self.convolutions.append( nn.Sequential( ConvNorm( n_mel_channels, postnet_embedding_dim, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="tanh", ), nn.BatchNorm1d(postnet_embedding_dim), ) ) for i in range(1, postnet_n_convolutions - 1): self.convolutions.append( nn.Sequential( ConvNorm( postnet_embedding_dim, postnet_embedding_dim, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="tanh", ), nn.BatchNorm1d(postnet_embedding_dim), ) ) self.convolutions.append( nn.Sequential( ConvNorm( postnet_embedding_dim, n_mel_channels, kernel_size=postnet_kernel_size, stride=1, padding=int((postnet_kernel_size - 1) / 2), dilation=1, w_init_gain="linear", ), nn.BatchNorm1d(n_mel_channels), ) ) def forward(self, x): for i in range(len(self.convolutions) - 1): x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) x = F.dropout(self.convolutions[-1](x), 0.5, self.training) return x class Encoder(nn.Module): """Encoder module: - Three 1-d convolution banks - Bidirectional LSTM """ def __init__(self, encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim): super(Encoder, self).__init__() convolutions = [] for _ in range(encoder_n_convolutions): conv_layer = nn.Sequential( ConvNorm( encoder_embedding_dim, encoder_embedding_dim, kernel_size=encoder_kernel_size, stride=1, padding=int((encoder_kernel_size - 1) / 2), dilation=1, w_init_gain="relu", ), nn.BatchNorm1d(encoder_embedding_dim), ) convolutions.append(conv_layer) self.convolutions = nn.ModuleList(convolutions) self.lstm = nn.LSTM( encoder_embedding_dim, int(encoder_embedding_dim / 2), 1, batch_first=True, bidirectional=True ) def forward(self, x, input_lengths): for conv in self.convolutions: x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) # pytorch tensor are not reversible, hence the conversion input_lengths = input_lengths.cpu().numpy() x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) self.lstm.flatten_parameters() outputs, _ = self.lstm(x) outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) return outputs def inference(self, x): for conv in self.convolutions: x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) self.lstm.flatten_parameters() outputs, _ = self.lstm(x) return outputs class Decoder(nn.Module): def __init__( self, n_mel_channels, n_frames_per_step, encoder_embedding_dim, attention_dim, attention_rnn_dim, attention_location_n_filters, attention_location_kernel_size, decoder_rnn_dim, prenet_dim, max_decoder_steps, gate_threshold, p_attention_dropout, p_decoder_dropout, ): super(Decoder, self).__init__() self.n_mel_channels = n_mel_channels self.n_frames_per_step = n_frames_per_step self.encoder_embedding_dim = encoder_embedding_dim self.attention_rnn_dim = attention_rnn_dim self.decoder_rnn_dim = decoder_rnn_dim self.prenet_dim = prenet_dim self.max_decoder_steps = max_decoder_steps self.gate_threshold = gate_threshold self.p_attention_dropout = p_attention_dropout self.p_decoder_dropout = p_decoder_dropout self.prenet = Prenet(n_mel_channels * n_frames_per_step, [prenet_dim, prenet_dim]) self.attention_rnn = nn.LSTMCell(prenet_dim + encoder_embedding_dim, attention_rnn_dim) self.attention_layer = Attention( attention_rnn_dim, encoder_embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, ) self.decoder_rnn = nn.LSTMCell(attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, 1) self.linear_projection = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, n_mel_channels * n_frames_per_step) self.gate_layer = LinearNorm(decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid") def get_go_frame(self, memory): """Gets all zeros frames to use as first decoder input PARAMS ------ memory: decoder outputs RETURNS ------- decoder_input: all zeros frames """ B = memory.size(0) decoder_input = Variable(memory.data.new(B, self.n_mel_channels * self.n_frames_per_step).zero_()) return decoder_input def initialize_decoder_states(self, memory, mask): """Initializes attention rnn states, decoder rnn states, attention weights, attention cumulative weights, attention context, stores memory and stores processed memory PARAMS ------ memory: Encoder outputs mask: Mask for padded data if training, expects None for inference """ B = memory.size(0) MAX_TIME = memory.size(1) self.attention_hidden = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) self.attention_cell = Variable(memory.data.new(B, self.attention_rnn_dim).zero_()) self.decoder_hidden = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) self.decoder_cell = Variable(memory.data.new(B, self.decoder_rnn_dim).zero_()) self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) self.attention_context = Variable(memory.data.new(B, self.encoder_embedding_dim).zero_()) self.memory = memory self.processed_memory = self.attention_layer.memory_layer(memory) self.mask = mask def parse_decoder_inputs(self, decoder_inputs): """Prepares decoder inputs, i.e. mel outputs PARAMS ------ decode encoder_kernel_size=5, encoder_n_convolutions=3, encoder_embedding_dim=512,r_inputs: inputs used for teacher-forced training, i.e. mel-specs RETURNS ------- inputs: processed decoder inputs """ # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) decoder_inputs = decoder_inputs.transpose(1, 2) decoder_inputs = decoder_inputs.view( decoder_inputs.size(0), int(decoder_inputs.size(1) / self.n_frames_per_step), -1 ) # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) decoder_inputs = decoder_inputs.transpose(0, 1) return decoder_inputs def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): """Prepares decoder outputs for output PARAMS ------ mel_outputs: gate_outputs: gate output energies alignments: RETURNS ------- mel_outputs: gate_outpust: gate output energies alignments: """ # (T_out, B) -> (B, T_out) alignments = torch.stack(alignments).transpose(0, 1) # (T_out, B) -> (B, T_out) gate_outputs = torch.stack(gate_outputs).transpose(0, 1) gate_outputs = gate_outputs.contiguous() # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() # decouple frames per step mel_outputs = mel_outputs.view(mel_outputs.size(0), -1, self.n_mel_channels) # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) mel_outputs = mel_outputs.transpose(1, 2) return mel_outputs, gate_outputs, alignments def decode(self, decoder_input): """Decoder step using stored states, attention and memory PARAMS ------ decoder_input: previous mel output RETURNS ------- mel_output: gate_output: gate output energies attention_weights: """ cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell) ) self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1 ) self.attention_context, self.attention_weights = self.attention_layer( self.attention_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask ) self.attention_weights_cum += self.attention_weights decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell) ) self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) decoder_output = self.linear_projection(decoder_hidden_attention_context) gate_prediction = self.gate_layer(decoder_hidden_attention_context) return decoder_output, gate_prediction, self.attention_weights def forward(self, memory, decoder_inputs, memory_lengths, device): """Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths, device)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments def inference(self, memory, max_decoder_steps=None): """Decoder inference PARAMS ------ memory: Encoder outputs RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ if not max_decoder_steps: # Use default max decoder steps if not given max_decoder_steps = self.max_decoder_steps decoder_input = self.get_go_frame(memory) self.initialize_decoder_states(memory, mask=None) mel_outputs, gate_outputs, alignments = [], [], [] while True: decoder_input = self.prenet(decoder_input) mel_output, gate_output, alignment = self.decode(decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output] alignments += [alignment] if torch.sigmoid(gate_output.data) > self.gate_threshold: break elif len(mel_outputs) == max_decoder_steps: raise Exception( "Warning! Reached max decoder steps. Either the model is low quality or the given sentence is too short/long" ) decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments class Tacotron2(nn.Module): def __init__( self, mask_padding=True, fp16_run=False, n_mel_channels=80, n_symbols=148, symbols_embedding_dim=512, encoder_kernel_size=5, encoder_n_convolutions=3, encoder_embedding_dim=512, attention_rnn_dim=1024, attention_dim=128, attention_location_n_filters=32, attention_location_kernel_size=31, decoder_rnn_dim=1024, prenet_dim=256, max_decoder_steps=1000, gate_threshold=0.5, p_attention_dropout=0.1, p_decoder_dropout=0.1, postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5, ): super(Tacotron2, self).__init__() self.mask_padding = mask_padding self.fp16_run = fp16_run self.n_mel_channels = n_mel_channels self.n_frames_per_step = 1 self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim) std = sqrt(2.0 / (n_symbols + symbols_embedding_dim)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(encoder_kernel_size, encoder_n_convolutions, encoder_embedding_dim) self.decoder = Decoder( n_mel_channels, self.n_frames_per_step, encoder_embedding_dim, attention_dim, attention_rnn_dim, attention_location_n_filters, attention_location_kernel_size, decoder_rnn_dim, prenet_dim, max_decoder_steps, gate_threshold, p_attention_dropout, p_decoder_dropout, ) self.postnet = Postnet(n_mel_channels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolutions) def parse_batch(self, batch): text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch text_padded = to_gpu(text_padded).long() input_lengths = to_gpu(input_lengths).long() max_len = torch.max(input_lengths.data).item() mel_padded = to_gpu(mel_padded).float() gate_padded = to_gpu(gate_padded).float() output_lengths = to_gpu(output_lengths).long() return ((text_padded, input_lengths, mel_padded, max_len, output_lengths), (mel_padded, gate_padded)) def parse_output(self, outputs, output_lengths, mask_size, alignment_mask_size, device): if self.mask_padding: mask = ~get_mask_from_lengths(output_lengths, device, mask_size) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies if outputs[3].size(2) != alignment_mask_size: outputs[3] = nn.ConstantPad1d((0, alignment_mask_size - outputs[3].size(2)), 0)(outputs[3]) return outputs def forward(self, inputs, mask_size, alignment_mask_size): text_inputs, text_lengths, mels, output_lengths = get_x(inputs) device = text_inputs.device text_lengths, output_lengths = text_lengths.data, output_lengths.data embedded_inputs = self.embedding(text_inputs).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, text_lengths) mel_outputs, gate_outputs, alignments = self.decoder( encoder_outputs, mels, memory_lengths=text_lengths, device=device ) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet return self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths, mask_size, alignment_mask_size, device, ) def inference(self, inputs, max_decoder_steps=None): embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) mel_outputs, gate_outputs, alignments = self.decoder.inference(encoder_outputs, max_decoder_steps) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet return [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]