# coding: utf-8 from __future__ import with_statement, print_function, absolute_import import torch from torch.autograd import Variable from torch import nn from .attention import BahdanauAttention, AttentionWrapper from .attention import get_mask_from_lengths class Prenet(nn.Module): def __init__(self, in_dim, sizes=[256, 128]): super(Prenet, self).__init__() in_sizes = [in_dim] + sizes[:-1] self.layers = nn.ModuleList( [nn.Linear(in_size, out_size) for (in_size, out_size) in zip(in_sizes, sizes)]) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) def forward(self, inputs): for linear in self.layers: inputs = self.dropout(self.relu(linear(inputs))) return inputs class BatchNormConv1d(nn.Module): def __init__(self, in_dim, out_dim, kernel_size, stride, padding, activation=None): super(BatchNormConv1d, self).__init__() self.conv1d = nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm1d(out_dim) self.activation = activation def forward(self, x): x = self.conv1d(x) if self.activation is not None: x = self.activation(x) return self.bn(x) class Highway(nn.Module): def __init__(self, in_size, out_size): super(Highway, self).__init__() self.H = nn.Linear(in_size, out_size) self.H.bias.data.zero_() self.T = nn.Linear(in_size, out_size) self.T.bias.data.fill_(-1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, inputs): H = self.relu(self.H(inputs)) T = self.sigmoid(self.T(inputs)) return H * T + inputs * (1.0 - T) class CBHG(nn.Module): """CBHG module: a recurrent neural network composed of: - 1-d convolution banks - Highway networks + residual connections - Bidirectional gated recurrent units """ def __init__(self, in_dim, K=16, projections=[128, 128]): super(CBHG, self).__init__() self.in_dim = in_dim self.relu = nn.ReLU() self.conv1d_banks = nn.ModuleList( [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, padding=k // 2, activation=self.relu) for k in range(1, K + 1)]) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) in_sizes = [K * in_dim] + projections[:-1] activations = [self.relu] * (len(projections) - 1) + [None] self.conv1d_projections = nn.ModuleList( [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=1, activation=ac) for (in_size, out_size, ac) in zip( in_sizes, projections, activations)]) self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) self.highways = nn.ModuleList( [Highway(in_dim, in_dim) for _ in range(4)]) self.gru = nn.GRU( in_dim, in_dim, 1, batch_first=True, bidirectional=True) def forward(self, inputs, input_lengths=None): # (B, T_in, in_dim) x = inputs # Needed to perform conv1d on time-axis # (B, in_dim, T_in) if x.size(-1) == self.in_dim: x = x.transpose(1, 2) T = x.size(-1) # (B, in_dim*K, T_in) # Concat conv1d bank outputs x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) assert x.size(1) == self.in_dim * len(self.conv1d_banks) x = self.max_pool1d(x)[:, :, :T] for conv1d in self.conv1d_projections: x = conv1d(x) # (B, T_in, in_dim) # Back to the original shape x = x.transpose(1, 2) if x.size(-1) != self.in_dim: x = self.pre_highway(x) # Residual connection x += inputs for highway in self.highways: x = highway(x) if input_lengths is not None: x = nn.utils.rnn.pack_padded_sequence( x, input_lengths, batch_first=True) # (B, T_in, in_dim*2) outputs, _ = self.gru(x) if input_lengths is not None: outputs, _ = nn.utils.rnn.pad_packed_sequence( outputs, batch_first=True) return outputs class Encoder(nn.Module): def __init__(self, in_dim): super(Encoder, self).__init__() self.prenet = Prenet(in_dim, sizes=[256, 128]) self.cbhg = CBHG(128, K=16, projections=[128, 128]) def forward(self, inputs, input_lengths=None): inputs = self.prenet(inputs) return self.cbhg(inputs, input_lengths) class Decoder(nn.Module): def __init__(self, in_dim, r): super(Decoder, self).__init__() self.in_dim = in_dim self.r = r self.prenet = Prenet(in_dim * r, sizes=[256, 128]) # (prenet_out + attention context) -> output self.attention_rnn = AttentionWrapper( nn.GRUCell(256 + 128, 256), BahdanauAttention(256) ) self.memory_layer = nn.Linear(256, 256, bias=False) self.project_to_decoder_in = nn.Linear(512, 256) self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) self.proj_to_mel = nn.Linear(256, in_dim * r) self.max_decoder_steps = 200 def forward(self, encoder_outputs, inputs=None, memory_lengths=None): """ Decoder forward step. If decoder inputs are not given (e.g., at testing time), as noted in Tacotron paper, greedy decoding is adapted. Args: encoder_outputs: Encoder outputs. (B, T_encoder, dim) inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time), decoder outputs are used as decoder inputs. memory_lengths: Encoder output (memory) lengths. If not None, used for attention masking. """ B = encoder_outputs.size(0) processed_memory = self.memory_layer(encoder_outputs) if memory_lengths is not None: mask = get_mask_from_lengths(processed_memory, memory_lengths) else: mask = None # Run greedy decoding if inputs is None greedy = inputs is None if inputs is not None: # Grouping multiple frames if necessary if inputs.size(-1) == self.in_dim: inputs = inputs.view(B, inputs.size(1) // self.r, -1) assert inputs.size(-1) == self.in_dim * self.r T_decoder = inputs.size(1) # go frames initial_input = Variable( encoder_outputs.data.new(B, self.in_dim * self.r).zero_()) # Init decoder states attention_rnn_hidden = Variable( encoder_outputs.data.new(B, 256).zero_()) decoder_rnn_hiddens = [Variable( encoder_outputs.data.new(B, 256).zero_()) for _ in range(len(self.decoder_rnns))] current_attention = Variable( encoder_outputs.data.new(B, 256).zero_()) # Time first (T_decoder, B, in_dim) if inputs is not None: inputs = inputs.transpose(0, 1) outputs = [] alignments = [] t = 0 current_input = initial_input while True: if t > 0: current_input = outputs[-1] if greedy else inputs[t - 1] # Prenet current_input = self.prenet(current_input) # Attention RNN attention_rnn_hidden, current_attention, alignment = self.attention_rnn( current_input, current_attention, attention_rnn_hidden, encoder_outputs, processed_memory=processed_memory, mask=mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((attention_rnn_hidden, current_attention), -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( decoder_input, decoder_rnn_hiddens[idx]) # Residual connectinon decoder_input = decoder_rnn_hiddens[idx] + decoder_input output = decoder_input output = self.proj_to_mel(output) outputs += [output] alignments += [alignment] t += 1 if greedy: if t > 1 and is_end_of_frames(output): break elif t > self.max_decoder_steps: print("Warning! doesn't seems to be converged") break else: if t >= T_decoder: break assert greedy or len(outputs) == T_decoder # Back to batch first alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() return outputs, alignments def is_end_of_frames(output, eps=0.2): return (output.data <= eps).all() class Tacotron(nn.Module): def __init__(self, n_vocab, embedding_dim=256, mel_dim=80, linear_dim=1025, r=5, padding_idx=None, use_memory_mask=False): super(Tacotron, self).__init__() self.mel_dim = mel_dim self.linear_dim = linear_dim self.use_memory_mask = use_memory_mask self.embedding = nn.Embedding(n_vocab, embedding_dim, padding_idx=padding_idx) # Trying smaller std self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(embedding_dim) self.decoder = Decoder(mel_dim, r) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.last_linear = nn.Linear(mel_dim * 2, linear_dim) def forward(self, inputs, targets=None, input_lengths=None): B = inputs.size(0) inputs = self.embedding(inputs) # (B, T', in_dim) encoder_outputs = self.encoder(inputs, input_lengths) if self.use_memory_mask: memory_lengths = input_lengths else: memory_lengths = None # (B, T', mel_dim*r) mel_outputs, alignments = self.decoder( encoder_outputs, targets, memory_lengths=memory_lengths) # Post net processing below # Reshape # (B, T, mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments