Spaces:
Build error
Build error
| # coding: utf-8 | |
| # adapted from https://github.com/r9y9/tacotron_pytorch | |
| import torch | |
| from torch import nn | |
| from .attentions import init_attn | |
| from .common_layers import Prenet | |
| class BatchNormConv1d(nn.Module): | |
| r"""A wrapper for Conv1d with BatchNorm. It sets the activation | |
| function between Conv and BatchNorm layers. BatchNorm layer | |
| is initialized with the TF default values for momentum and eps. | |
| Args: | |
| in_channels: size of each input sample | |
| out_channels: size of each output samples | |
| kernel_size: kernel size of conv filters | |
| stride: stride of conv filters | |
| padding: padding of conv filters | |
| activation: activation function set b/w Conv1d and BatchNorm | |
| Shapes: | |
| - input: (B, D) | |
| - output: (B, D) | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): | |
| super().__init__() | |
| self.padding = padding | |
| self.padder = nn.ConstantPad1d(padding, 0) | |
| self.conv1d = nn.Conv1d( | |
| in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False | |
| ) | |
| # Following tensorflow's default parameters | |
| self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) | |
| self.activation = activation | |
| # self.init_layers() | |
| def init_layers(self): | |
| if isinstance(self.activation, torch.nn.ReLU): | |
| w_gain = "relu" | |
| elif isinstance(self.activation, torch.nn.Tanh): | |
| w_gain = "tanh" | |
| elif self.activation is None: | |
| w_gain = "linear" | |
| else: | |
| raise RuntimeError("Unknown activation function") | |
| torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) | |
| def forward(self, x): | |
| x = self.padder(x) | |
| x = self.conv1d(x) | |
| x = self.bn(x) | |
| if self.activation is not None: | |
| x = self.activation(x) | |
| return x | |
| class Highway(nn.Module): | |
| r"""Highway layers as explained in https://arxiv.org/abs/1505.00387 | |
| Args: | |
| in_features (int): size of each input sample | |
| out_feature (int): size of each output sample | |
| Shapes: | |
| - input: (B, *, H_in) | |
| - output: (B, *, H_out) | |
| """ | |
| # TODO: Try GLU layer | |
| def __init__(self, in_features, out_feature): | |
| super().__init__() | |
| self.H = nn.Linear(in_features, out_feature) | |
| self.H.bias.data.zero_() | |
| self.T = nn.Linear(in_features, out_feature) | |
| self.T.bias.data.fill_(-1) | |
| self.relu = nn.ReLU() | |
| self.sigmoid = nn.Sigmoid() | |
| # self.init_layers() | |
| def init_layers(self): | |
| torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) | |
| torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) | |
| def forward(self, inputs): | |
| H = self.relu(self.H(inputs)) | |
| T = self.sigmoid(self.T(inputs)) | |
| return H * T + inputs * (1.0 - T) | |
| class CBHG(nn.Module): | |
| """CBHG module: a recurrent neural network composed of: | |
| - 1-d convolution banks | |
| - Highway networks + residual connections | |
| - Bidirectional gated recurrent units | |
| Args: | |
| in_features (int): sample size | |
| K (int): max filter size in conv bank | |
| projections (list): conv channel sizes for conv projections | |
| num_highways (int): number of highways layers | |
| Shapes: | |
| - input: (B, C, T_in) | |
| - output: (B, T_in, C*2) | |
| """ | |
| # pylint: disable=dangerous-default-value | |
| def __init__( | |
| self, | |
| in_features, | |
| K=16, | |
| conv_bank_features=128, | |
| conv_projections=[128, 128], | |
| highway_features=128, | |
| gru_features=128, | |
| num_highways=4, | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.conv_bank_features = conv_bank_features | |
| self.highway_features = highway_features | |
| self.gru_features = gru_features | |
| self.conv_projections = conv_projections | |
| self.relu = nn.ReLU() | |
| # list of conv1d bank with filter size k=1...K | |
| # TODO: try dilational layers instead | |
| self.conv1d_banks = nn.ModuleList( | |
| [ | |
| BatchNormConv1d( | |
| in_features, | |
| conv_bank_features, | |
| kernel_size=k, | |
| stride=1, | |
| padding=[(k - 1) // 2, k // 2], | |
| activation=self.relu, | |
| ) | |
| for k in range(1, K + 1) | |
| ] | |
| ) | |
| # max pooling of conv bank, with padding | |
| # TODO: try average pooling OR larger kernel size | |
| out_features = [K * conv_bank_features] + conv_projections[:-1] | |
| activations = [self.relu] * (len(conv_projections) - 1) | |
| activations += [None] | |
| # setup conv1d projection layers | |
| layer_set = [] | |
| for in_size, out_size, ac in zip(out_features, conv_projections, activations): | |
| layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) | |
| layer_set.append(layer) | |
| self.conv1d_projections = nn.ModuleList(layer_set) | |
| # setup Highway layers | |
| if self.highway_features != conv_projections[-1]: | |
| self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) | |
| self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) | |
| # bi-directional GPU layer | |
| self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) | |
| def forward(self, inputs): | |
| # (B, in_features, T_in) | |
| x = inputs | |
| # (B, hid_features*K, T_in) | |
| # Concat conv1d bank outputs | |
| outs = [] | |
| for conv1d in self.conv1d_banks: | |
| out = conv1d(x) | |
| outs.append(out) | |
| x = torch.cat(outs, dim=1) | |
| assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) | |
| for conv1d in self.conv1d_projections: | |
| x = conv1d(x) | |
| x += inputs | |
| x = x.transpose(1, 2) | |
| if self.highway_features != self.conv_projections[-1]: | |
| x = self.pre_highway(x) | |
| # Residual connection | |
| # TODO: try residual scaling as in Deep Voice 3 | |
| # TODO: try plain residual layers | |
| for highway in self.highways: | |
| x = highway(x) | |
| # (B, T_in, hid_features*2) | |
| # TODO: replace GRU with convolution as in Deep Voice 3 | |
| self.gru.flatten_parameters() | |
| outputs, _ = self.gru(x) | |
| return outputs | |
| class EncoderCBHG(nn.Module): | |
| r"""CBHG module with Encoder specific arguments""" | |
| def __init__(self): | |
| super().__init__() | |
| self.cbhg = CBHG( | |
| 128, | |
| K=16, | |
| conv_bank_features=128, | |
| conv_projections=[128, 128], | |
| highway_features=128, | |
| gru_features=128, | |
| num_highways=4, | |
| ) | |
| def forward(self, x): | |
| return self.cbhg(x) | |
| class Encoder(nn.Module): | |
| r"""Stack Prenet and CBHG module for encoder | |
| Args: | |
| inputs (FloatTensor): embedding features | |
| Shapes: | |
| - inputs: (B, T, D_in) | |
| - outputs: (B, T, 128 * 2) | |
| """ | |
| def __init__(self, in_features): | |
| super().__init__() | |
| self.prenet = Prenet(in_features, out_features=[256, 128]) | |
| self.cbhg = EncoderCBHG() | |
| def forward(self, inputs): | |
| # B x T x prenet_dim | |
| outputs = self.prenet(inputs) | |
| outputs = self.cbhg(outputs.transpose(1, 2)) | |
| return outputs | |
| class PostCBHG(nn.Module): | |
| def __init__(self, mel_dim): | |
| super().__init__() | |
| self.cbhg = CBHG( | |
| mel_dim, | |
| K=8, | |
| conv_bank_features=128, | |
| conv_projections=[256, mel_dim], | |
| highway_features=128, | |
| gru_features=128, | |
| num_highways=4, | |
| ) | |
| def forward(self, x): | |
| return self.cbhg(x) | |
| class Decoder(nn.Module): | |
| """Tacotron decoder. | |
| Args: | |
| in_channels (int): number of input channels. | |
| frame_channels (int): number of feature frame channels. | |
| r (int): number of outputs per time step (reduction rate). | |
| memory_size (int): size of the past window. if <= 0 memory_size = r | |
| attn_type (string): type of attention used in decoder. | |
| attn_windowing (bool): if true, define an attention window centered to maximum | |
| attention response. It provides more robust attention alignment especially | |
| at interence time. | |
| attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. | |
| prenet_type (string): 'original' or 'bn'. | |
| prenet_dropout (float): prenet dropout rate. | |
| forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 | |
| trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 | |
| forward_attn_mask (bool): if true, mask attention values smaller than a threshold. | |
| location_attn (bool): if true, use location sensitive attention. | |
| attn_K (int): number of attention heads for GravesAttention. | |
| separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. | |
| d_vector_dim (int): size of speaker embedding vector, for multi-speaker training. | |
| max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500. | |
| """ | |
| # Pylint gets confused by PyTorch conventions here | |
| # pylint: disable=attribute-defined-outside-init | |
| def __init__( | |
| self, | |
| in_channels, | |
| frame_channels, | |
| r, | |
| memory_size, | |
| attn_type, | |
| attn_windowing, | |
| attn_norm, | |
| prenet_type, | |
| prenet_dropout, | |
| forward_attn, | |
| trans_agent, | |
| forward_attn_mask, | |
| location_attn, | |
| attn_K, | |
| separate_stopnet, | |
| max_decoder_steps, | |
| ): | |
| super().__init__() | |
| self.r_init = r | |
| self.r = r | |
| self.in_channels = in_channels | |
| self.max_decoder_steps = max_decoder_steps | |
| self.use_memory_queue = memory_size > 0 | |
| self.memory_size = memory_size if memory_size > 0 else r | |
| self.frame_channels = frame_channels | |
| self.separate_stopnet = separate_stopnet | |
| self.query_dim = 256 | |
| # memory -> |Prenet| -> processed_memory | |
| prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels | |
| self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) | |
| # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State | |
| # attention_rnn generates queries for the attention mechanism | |
| self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) | |
| self.attention = init_attn( | |
| attn_type=attn_type, | |
| query_dim=self.query_dim, | |
| embedding_dim=in_channels, | |
| attention_dim=128, | |
| location_attention=location_attn, | |
| attention_location_n_filters=32, | |
| attention_location_kernel_size=31, | |
| windowing=attn_windowing, | |
| norm=attn_norm, | |
| forward_attn=forward_attn, | |
| trans_agent=trans_agent, | |
| forward_attn_mask=forward_attn_mask, | |
| attn_K=attn_K, | |
| ) | |
| # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input | |
| self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) | |
| # decoder_RNN_input -> |RNN| -> RNN_state | |
| self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) | |
| # RNN_state -> |Linear| -> mel_spec | |
| self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) | |
| # learn init values instead of zero init. | |
| self.stopnet = StopNet(256 + frame_channels * self.r_init) | |
| def set_r(self, new_r): | |
| self.r = new_r | |
| def _reshape_memory(self, memory): | |
| """ | |
| Reshape the spectrograms for given 'r' | |
| """ | |
| # Grouping multiple frames if necessary | |
| if memory.size(-1) == self.frame_channels: | |
| memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) | |
| # Time first (T_decoder, B, frame_channels) | |
| memory = memory.transpose(0, 1) | |
| return memory | |
| def _init_states(self, inputs): | |
| """ | |
| Initialization of decoder states | |
| """ | |
| B = inputs.size(0) | |
| # go frame as zeros matrix | |
| if self.use_memory_queue: | |
| self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size) | |
| else: | |
| self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels) | |
| # decoder states | |
| self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) | |
| self.decoder_rnn_hiddens = [ | |
| torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) | |
| ] | |
| self.context_vec = inputs.data.new(B, self.in_channels).zero_() | |
| # cache attention inputs | |
| self.processed_inputs = self.attention.preprocess_inputs(inputs) | |
| def _parse_outputs(self, outputs, attentions, stop_tokens): | |
| # Back to batch first | |
| attentions = torch.stack(attentions).transpose(0, 1) | |
| stop_tokens = torch.stack(stop_tokens).transpose(0, 1) | |
| outputs = torch.stack(outputs).transpose(0, 1).contiguous() | |
| outputs = outputs.view(outputs.size(0), -1, self.frame_channels) | |
| outputs = outputs.transpose(1, 2) | |
| return outputs, attentions, stop_tokens | |
| def decode(self, inputs, mask=None): | |
| # Prenet | |
| processed_memory = self.prenet(self.memory_input) | |
| # Attention RNN | |
| self.attention_rnn_hidden = self.attention_rnn( | |
| torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden | |
| ) | |
| self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) | |
| # Concat RNN output and attention context vector | |
| decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) | |
| # Pass through the decoder RNNs | |
| for idx, decoder_rnn in enumerate(self.decoder_rnns): | |
| self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) | |
| # Residual connection | |
| decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input | |
| decoder_output = decoder_input | |
| # predict mel vectors from decoder vectors | |
| output = self.proj_to_mel(decoder_output) | |
| # output = torch.sigmoid(output) | |
| # predict stop token | |
| stopnet_input = torch.cat([decoder_output, output], -1) | |
| if self.separate_stopnet: | |
| stop_token = self.stopnet(stopnet_input.detach()) | |
| else: | |
| stop_token = self.stopnet(stopnet_input) | |
| output = output[:, : self.r * self.frame_channels] | |
| return output, stop_token, self.attention.attention_weights | |
| def _update_memory_input(self, new_memory): | |
| if self.use_memory_queue: | |
| if self.memory_size > self.r: | |
| # memory queue size is larger than number of frames per decoder iter | |
| self.memory_input = torch.cat( | |
| [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], | |
| dim=-1, | |
| ) | |
| else: | |
| # memory queue size smaller than number of frames per decoder iter | |
| self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] | |
| else: | |
| # use only the last frame prediction | |
| # assert new_memory.shape[-1] == self.r * self.frame_channels | |
| self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] | |
| def forward(self, inputs, memory, mask): | |
| """ | |
| Args: | |
| inputs: Encoder outputs. | |
| memory: Decoder memory (autoregression. If None (at eval-time), | |
| decoder outputs are used as decoder inputs. If None, it uses the last | |
| output as the input. | |
| mask: Attention mask for sequence padding. | |
| Shapes: | |
| - inputs: (B, T, D_out_enc) | |
| - memory: (B, T_mel, D_mel) | |
| """ | |
| # Run greedy decoding if memory is None | |
| memory = self._reshape_memory(memory) | |
| outputs = [] | |
| attentions = [] | |
| stop_tokens = [] | |
| t = 0 | |
| self._init_states(inputs) | |
| self.attention.init_states(inputs) | |
| while len(outputs) < memory.size(0): | |
| if t > 0: | |
| new_memory = memory[t - 1] | |
| self._update_memory_input(new_memory) | |
| output, stop_token, attention = self.decode(inputs, mask) | |
| outputs += [output] | |
| attentions += [attention] | |
| stop_tokens += [stop_token.squeeze(1)] | |
| t += 1 | |
| return self._parse_outputs(outputs, attentions, stop_tokens) | |
| def inference(self, inputs): | |
| """ | |
| Args: | |
| inputs: encoder outputs. | |
| Shapes: | |
| - inputs: batch x time x encoder_out_dim | |
| """ | |
| outputs = [] | |
| attentions = [] | |
| stop_tokens = [] | |
| t = 0 | |
| self._init_states(inputs) | |
| self.attention.init_states(inputs) | |
| while True: | |
| if t > 0: | |
| new_memory = outputs[-1] | |
| self._update_memory_input(new_memory) | |
| output, stop_token, attention = self.decode(inputs, None) | |
| stop_token = torch.sigmoid(stop_token.data) | |
| outputs += [output] | |
| attentions += [attention] | |
| stop_tokens += [stop_token] | |
| t += 1 | |
| if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): | |
| break | |
| if t > self.max_decoder_steps: | |
| print(" | > Decoder stopped with 'max_decoder_steps") | |
| break | |
| return self._parse_outputs(outputs, attentions, stop_tokens) | |
| class StopNet(nn.Module): | |
| r"""Stopnet signalling decoder to stop inference. | |
| Args: | |
| in_features (int): feature dimension of input. | |
| """ | |
| def __init__(self, in_features): | |
| super().__init__() | |
| self.dropout = nn.Dropout(0.1) | |
| self.linear = nn.Linear(in_features, 1) | |
| torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) | |
| def forward(self, inputs): | |
| outputs = self.dropout(inputs) | |
| outputs = self.linear(outputs) | |
| return outputs | |