Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a | |
| # copy of this software and associated documentation files (the "Software"), | |
| # to deal in the Software without restriction, including without limitation | |
| # the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| # and/or sell copies of the Software, and to permit persons to whom the | |
| # Software is furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in | |
| # all copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
| # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
| # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
| # DEALINGS IN THE SOFTWARE. | |
| # 1x1InvertibleConv and WN based on implementation from WaveGlow https://github.com/NVIDIA/waveglow/blob/master/glow.py | |
| # Original license: | |
| # ***************************************************************************** | |
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # * Redistributions of source code must retain the above copyright | |
| # notice, this list of conditions and the following disclaimer. | |
| # * Redistributions in binary form must reproduce the above copyright | |
| # notice, this list of conditions and the following disclaimer in the | |
| # documentation and/or other materials provided with the distribution. | |
| # * Neither the name of the NVIDIA CORPORATION nor the | |
| # names of its contributors may be used to endorse or promote products | |
| # derived from this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
| # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
| # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY | |
| # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | |
| # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | |
| # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | |
| # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | |
| # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| # ***************************************************************************** | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.cuda import amp | |
| from torch.cuda.amp import autocast as autocast | |
| import numpy as np | |
| import ast | |
| from splines import ( | |
| piecewise_linear_transform, | |
| piecewise_linear_inverse_transform, | |
| unbounded_piecewise_quadratic_transform, | |
| ) | |
| from partialconv1d import PartialConv1d as pconv1d | |
| from typing import Tuple | |
| def update_params(config, params): | |
| for param in params: | |
| print(param) | |
| k, v = param.split("=") | |
| try: | |
| v = ast.literal_eval(v) | |
| except: | |
| pass | |
| k_split = k.split(".") | |
| if len(k_split) > 1: | |
| parent_k = k_split[0] | |
| cur_param = [".".join(k_split[1:]) + "=" + str(v)] | |
| update_params(config[parent_k], cur_param) | |
| elif k in config and len(k_split) == 1: | |
| print(f"overriding {k} with {v}") | |
| config[k] = v | |
| else: | |
| print("{}, {} params not updated".format(k, v)) | |
| def get_mask_from_lengths(lengths): | |
| """Constructs binary mask from a 1D torch tensor of input lengths | |
| Args: | |
| lengths (torch.tensor): 1D tensor | |
| Returns: | |
| mask (torch.tensor): num_sequences x max_length x 1 binary tensor | |
| """ | |
| max_len = torch.max(lengths).item() | |
| if torch.cuda.is_available(): | |
| ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) | |
| else: | |
| ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)) | |
| mask = (ids < lengths.unsqueeze(1)).bool() | |
| return mask | |
| class ExponentialClass(torch.nn.Module): | |
| def __init__(self): | |
| super(ExponentialClass, self).__init__() | |
| def forward(self, x): | |
| return torch.exp(x) | |
| class LinearNorm(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): | |
| super(LinearNorm, self).__init__() | |
| self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) | |
| ) | |
| def forward(self, x): | |
| return self.linear_layer(x) | |
| class ConvNorm(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=None, | |
| dilation=1, | |
| bias=True, | |
| w_init_gain="linear", | |
| use_partial_padding=False, | |
| use_weight_norm=False, | |
| ): | |
| super(ConvNorm, self).__init__() | |
| if padding is None: | |
| assert kernel_size % 2 == 1 | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.kernel_size = kernel_size | |
| self.dilation = dilation | |
| self.use_partial_padding = use_partial_padding | |
| self.use_weight_norm = use_weight_norm | |
| conv_fn = torch.nn.Conv1d | |
| if self.use_partial_padding: | |
| conv_fn = pconv1d | |
| self.conv = conv_fn( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain) | |
| ) | |
| if self.use_weight_norm: | |
| self.conv = nn.utils.weight_norm(self.conv) | |
| def forward(self, signal, mask=None): | |
| if self.use_partial_padding: | |
| conv_signal = self.conv(signal, mask) | |
| else: | |
| conv_signal = self.conv(signal) | |
| if mask is not None: | |
| # always re-zero output if mask is | |
| # available to match zero-padding | |
| conv_signal = conv_signal * mask | |
| return conv_signal | |
| class DenseLayer(nn.Module): | |
| def __init__(self, in_dim=1024, sizes=[1024, 1024]): | |
| super(DenseLayer, self).__init__() | |
| in_sizes = [in_dim] + sizes[:-1] | |
| self.layers = nn.ModuleList( | |
| [ | |
| LinearNorm(in_size, out_size, bias=True) | |
| for (in_size, out_size) in zip(in_sizes, sizes) | |
| ] | |
| ) | |
| def forward(self, x): | |
| for linear in self.layers: | |
| x = torch.tanh(linear(x)) | |
| return x | |
| class LengthRegulator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, dur): | |
| output = [] | |
| for x_i, dur_i in zip(x, dur): | |
| expanded = self.expand(x_i, dur_i) | |
| output.append(expanded) | |
| output = self.pad(output) | |
| return output | |
| def expand(self, x, dur): | |
| output = [] | |
| for i, frame in enumerate(x): | |
| expanded_len = int(dur[i] + 0.5) | |
| expanded = frame.expand(expanded_len, -1) | |
| output.append(expanded) | |
| output = torch.cat(output, 0) | |
| return output | |
| def pad(self, x): | |
| output = [] | |
| max_len = max([x[i].size(0) for i in range(len(x))]) | |
| for i, seq in enumerate(x): | |
| padded = F.pad(seq, [0, 0, 0, max_len - seq.size(0)], "constant", 0.0) | |
| output.append(padded) | |
| output = torch.stack(output) | |
| return output | |
| class ConvLSTMLinear(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| n_layers=2, | |
| n_channels=256, | |
| kernel_size=3, | |
| p_dropout=0.1, | |
| lstm_type="bilstm", | |
| use_linear=True, | |
| ): | |
| super(ConvLSTMLinear, self).__init__() | |
| self.out_dim = out_dim | |
| self.lstm_type = lstm_type | |
| self.use_linear = use_linear | |
| self.dropout = nn.Dropout(p=p_dropout) | |
| convolutions = [] | |
| for i in range(n_layers): | |
| conv_layer = ConvNorm( | |
| in_dim if i == 0 else n_channels, | |
| n_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=int((kernel_size - 1) / 2), | |
| dilation=1, | |
| w_init_gain="relu", | |
| ) | |
| conv_layer = torch.nn.utils.weight_norm(conv_layer.conv, name="weight") | |
| convolutions.append(conv_layer) | |
| self.convolutions = nn.ModuleList(convolutions) | |
| if not self.use_linear: | |
| n_channels = out_dim | |
| if self.lstm_type != "": | |
| use_bilstm = False | |
| lstm_channels = n_channels | |
| if self.lstm_type == "bilstm": | |
| use_bilstm = True | |
| lstm_channels = int(n_channels // 2) | |
| self.bilstm = nn.LSTM( | |
| n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm | |
| ) | |
| lstm_norm_fn_pntr = nn.utils.spectral_norm | |
| self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") | |
| if self.lstm_type == "bilstm": | |
| self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") | |
| if self.use_linear: | |
| self.dense = nn.Linear(n_channels, out_dim) | |
| def run_padded_sequence(self, context, lens): | |
| context_embedded = [] | |
| for b_ind in range(context.size()[0]): # TODO: speed up | |
| curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() | |
| for conv in self.convolutions: | |
| curr_context = self.dropout(F.relu(conv(curr_context))) | |
| context_embedded.append(curr_context[0].transpose(0, 1)) | |
| context = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) | |
| return context | |
| def run_unsorted_inputs(self, fn, context, lens): | |
| lens_sorted, ids_sorted = torch.sort(lens, descending=True) | |
| unsort_ids = [0] * lens.size(0) | |
| for i in range(len(ids_sorted)): | |
| unsort_ids[ids_sorted[i]] = i | |
| lens_sorted = lens_sorted.long().cpu() | |
| context = context[ids_sorted] | |
| context = nn.utils.rnn.pack_padded_sequence( | |
| context, lens_sorted, batch_first=True | |
| ) | |
| context = fn(context)[0] | |
| context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] | |
| # map back to original indices | |
| context = context[unsort_ids] | |
| return context | |
| def forward(self, context, lens): | |
| if context.size()[0] > 1: | |
| context = self.run_padded_sequence(context, lens) | |
| # to B, D, T | |
| context = context.transpose(1, 2) | |
| else: | |
| for conv in self.convolutions: | |
| context = self.dropout(F.relu(conv(context))) | |
| if self.lstm_type != "": | |
| context = context.transpose(1, 2) | |
| self.bilstm.flatten_parameters() | |
| if lens is not None: | |
| context = self.run_unsorted_inputs(self.bilstm, context, lens) | |
| else: | |
| context = self.bilstm(context)[0] | |
| context = context.transpose(1, 2) | |
| x_hat = context | |
| if self.use_linear: | |
| x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) | |
| return x_hat | |
| def infer(self, z, txt_enc, spk_emb): | |
| x_hat = self.forward(txt_enc, spk_emb)["x_hat"] | |
| x_hat = self.feature_processing.denormalize(x_hat) | |
| return x_hat | |
| class Encoder(nn.Module): | |
| """Encoder module: | |
| - Three 1-d convolution banks | |
| - Bidirectional LSTM | |
| """ | |
| def __init__( | |
| self, | |
| encoder_n_convolutions=3, | |
| encoder_embedding_dim=512, | |
| encoder_kernel_size=5, | |
| norm_fn=nn.BatchNorm1d, | |
| lstm_norm_fn=None, | |
| ): | |
| super(Encoder, self).__init__() | |
| convolutions = [] | |
| for _ in range(encoder_n_convolutions): | |
| conv_layer = nn.Sequential( | |
| ConvNorm( | |
| encoder_embedding_dim, | |
| encoder_embedding_dim, | |
| kernel_size=encoder_kernel_size, | |
| stride=1, | |
| padding=int((encoder_kernel_size - 1) / 2), | |
| dilation=1, | |
| w_init_gain="relu", | |
| use_partial_padding=True, | |
| ), | |
| norm_fn(encoder_embedding_dim, affine=True), | |
| ) | |
| convolutions.append(conv_layer) | |
| self.convolutions = nn.ModuleList(convolutions) | |
| self.lstm = nn.LSTM( | |
| encoder_embedding_dim, | |
| int(encoder_embedding_dim / 2), | |
| 1, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) | |
| if lstm_norm_fn is not None: | |
| if "spectral" in lstm_norm_fn: | |
| print("Applying spectral norm to text encoder LSTM") | |
| lstm_norm_fn_pntr = torch.nn.utils.spectral_norm | |
| elif "weight" in lstm_norm_fn: | |
| print("Applying weight norm to text encoder LSTM") | |
| lstm_norm_fn_pntr = torch.nn.utils.weight_norm | |
| self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0") | |
| self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse") | |
| def forward(self, x, in_lens): | |
| """ | |
| Args: | |
| x (torch.tensor): N x C x L padded input of text embeddings | |
| in_lens (torch.tensor): 1D tensor of sequence lengths | |
| """ | |
| if x.size()[0] > 1: | |
| x_embedded = [] | |
| for b_ind in range(x.size()[0]): # TODO: improve speed | |
| curr_x = x[b_ind : b_ind + 1, :, : in_lens[b_ind]].clone() | |
| for conv in self.convolutions: | |
| curr_x = F.dropout(F.relu(conv(curr_x)), 0.5, self.training) | |
| x_embedded.append(curr_x[0].transpose(0, 1)) | |
| x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True) | |
| else: | |
| for conv in self.convolutions: | |
| x = F.dropout(F.relu(conv(x)), 0.5, self.training) | |
| x = x.transpose(1, 2) | |
| # recent amp change -- change in_lens to int | |
| in_lens = in_lens.int().cpu() | |
| x = nn.utils.rnn.pack_padded_sequence(x, in_lens, batch_first=True) | |
| self.lstm.flatten_parameters() | |
| outputs, _ = self.lstm(x) | |
| outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) | |
| return outputs | |
| def infer(self, x): | |
| for conv in self.convolutions: | |
| x = F.dropout(F.relu(conv(x)), 0.5, self.training) | |
| x = x.transpose(1, 2) | |
| self.lstm.flatten_parameters() | |
| outputs, _ = self.lstm(x) | |
| return outputs | |
| class Invertible1x1ConvLUS(torch.nn.Module): | |
| def __init__(self, c, cache_inverse=False): | |
| super(Invertible1x1ConvLUS, self).__init__() | |
| # Sample a random orthonormal matrix to initialize weights | |
| W = torch.qr(torch.FloatTensor(c, c).normal_())[0] | |
| # Ensure determinant is 1.0 not -1.0 | |
| if torch.det(W) < 0: | |
| W[:, 0] = -1 * W[:, 0] | |
| p, lower, upper = torch.lu_unpack(*torch.lu(W)) | |
| self.register_buffer("p", p) | |
| # diagonals of lower will always be 1s anyway | |
| lower = torch.tril(lower, -1) | |
| lower_diag = torch.diag(torch.eye(c, c)) | |
| self.register_buffer("lower_diag", lower_diag) | |
| self.lower = nn.Parameter(lower) | |
| self.upper_diag = nn.Parameter(torch.diag(upper)) | |
| self.upper = nn.Parameter(torch.triu(upper, 1)) | |
| self.cache_inverse = cache_inverse | |
| def forward(self, z, inverse=False): | |
| U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) | |
| L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) | |
| W = torch.mm(self.p, torch.mm(L, U)) | |
| if inverse: | |
| if not hasattr(self, "W_inverse"): | |
| # inverse computation | |
| W_inverse = W.float().inverse() | |
| if z.type() == "torch.cuda.HalfTensor": | |
| W_inverse = W_inverse.half() | |
| self.W_inverse = W_inverse[..., None] | |
| z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) | |
| if not self.cache_inverse: | |
| delattr(self, "W_inverse") | |
| return z | |
| else: | |
| W = W[..., None] | |
| z = F.conv1d(z, W, bias=None, stride=1, padding=0) | |
| log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag))) | |
| return z, log_det_W | |
| class Invertible1x1Conv(torch.nn.Module): | |
| """ | |
| The layer outputs both the convolution, and the log determinant | |
| of its weight matrix. If inverse=True it does convolution with | |
| inverse | |
| """ | |
| def __init__(self, c, cache_inverse=False): | |
| super(Invertible1x1Conv, self).__init__() | |
| self.conv = torch.nn.Conv1d( | |
| c, c, kernel_size=1, stride=1, padding=0, bias=False | |
| ) | |
| # Sample a random orthonormal matrix to initialize weights | |
| W = torch.qr(torch.FloatTensor(c, c).normal_())[0] | |
| # Ensure determinant is 1.0 not -1.0 | |
| if torch.det(W) < 0: | |
| W[:, 0] = -1 * W[:, 0] | |
| W = W.view(c, c, 1) | |
| self.conv.weight.data = W | |
| self.cache_inverse = cache_inverse | |
| def forward(self, z, inverse=False): | |
| # DO NOT apply n_of_groups, as it doesn't account for padded sequences | |
| W = self.conv.weight.squeeze() | |
| if inverse: | |
| if not hasattr(self, "W_inverse"): | |
| # Inverse computation | |
| W_inverse = W.float().inverse() | |
| if z.type() == "torch.cuda.HalfTensor": | |
| W_inverse = W_inverse.half() | |
| self.W_inverse = W_inverse[..., None] | |
| z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) | |
| if not self.cache_inverse: | |
| delattr(self, "W_inverse") | |
| return z | |
| else: | |
| # Forward computation | |
| log_det_W = torch.logdet(W).clone() | |
| z = self.conv(z) | |
| return z, log_det_W | |
| class SimpleConvNet(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_mel_channels, | |
| n_context_dim, | |
| final_out_channels, | |
| n_layers=2, | |
| kernel_size=5, | |
| with_dilation=True, | |
| max_channels=1024, | |
| zero_init=True, | |
| use_partial_padding=True, | |
| ): | |
| super(SimpleConvNet, self).__init__() | |
| self.layers = torch.nn.ModuleList() | |
| self.n_layers = n_layers | |
| in_channels = n_mel_channels + n_context_dim | |
| out_channels = -1 | |
| self.use_partial_padding = use_partial_padding | |
| for i in range(n_layers): | |
| dilation = 2**i if with_dilation else 1 | |
| padding = int((kernel_size * dilation - dilation) / 2) | |
| out_channels = min(max_channels, in_channels * 2) | |
| self.layers.append( | |
| ConvNorm( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=True, | |
| w_init_gain="relu", | |
| use_partial_padding=use_partial_padding, | |
| ) | |
| ) | |
| in_channels = out_channels | |
| self.last_layer = torch.nn.Conv1d( | |
| out_channels, final_out_channels, kernel_size=1 | |
| ) | |
| if zero_init: | |
| self.last_layer.weight.data *= 0 | |
| self.last_layer.bias.data *= 0 | |
| def forward(self, z_w_context, seq_lens: torch.Tensor = None): | |
| # seq_lens: tensor array of sequence sequence lengths | |
| # output should be b x n_mel_channels x z_w_context.shape(2) | |
| mask = None | |
| if seq_lens is not None: | |
| mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float() | |
| for i in range(self.n_layers): | |
| z_w_context = self.layers[i](z_w_context, mask) | |
| z_w_context = torch.relu(z_w_context) | |
| z_w_context = self.last_layer(z_w_context) | |
| return z_w_context | |
| class WN(torch.nn.Module): | |
| """ | |
| Adapted from WN() module in WaveGlow with modififcations to variable names | |
| """ | |
| def __init__( | |
| self, | |
| n_in_channels, | |
| n_context_dim, | |
| n_layers, | |
| n_channels, | |
| kernel_size=5, | |
| affine_activation="softplus", | |
| use_partial_padding=True, | |
| ): | |
| super(WN, self).__init__() | |
| assert kernel_size % 2 == 1 | |
| assert n_channels % 2 == 0 | |
| self.n_layers = n_layers | |
| self.n_channels = n_channels | |
| self.in_layers = torch.nn.ModuleList() | |
| self.res_skip_layers = torch.nn.ModuleList() | |
| start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1) | |
| start = torch.nn.utils.weight_norm(start, name="weight") | |
| self.start = start | |
| self.softplus = torch.nn.Softplus() | |
| self.affine_activation = affine_activation | |
| self.use_partial_padding = use_partial_padding | |
| # Initializing last layer to 0 makes the affine coupling layers | |
| # do nothing at first. This helps with training stability | |
| end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) | |
| end.weight.data.zero_() | |
| end.bias.data.zero_() | |
| self.end = end | |
| for i in range(n_layers): | |
| dilation = 2**i | |
| padding = int((kernel_size * dilation - dilation) / 2) | |
| in_layer = ConvNorm( | |
| n_channels, | |
| n_channels, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| padding=padding, | |
| use_partial_padding=use_partial_padding, | |
| use_weight_norm=True, | |
| ) | |
| # in_layer = nn.Conv1d(n_channels, n_channels, kernel_size, | |
| # dilation=dilation, padding=padding) | |
| # in_layer = nn.utils.weight_norm(in_layer) | |
| self.in_layers.append(in_layer) | |
| res_skip_layer = nn.Conv1d(n_channels, n_channels, 1) | |
| res_skip_layer = nn.utils.weight_norm(res_skip_layer) | |
| self.res_skip_layers.append(res_skip_layer) | |
| def forward( | |
| self, | |
| forward_input: Tuple[torch.Tensor, torch.Tensor], | |
| seq_lens: torch.Tensor = None, | |
| ): | |
| z, context = forward_input | |
| z = torch.cat((z, context), 1) # append context to z as well | |
| z = self.start(z) | |
| output = torch.zeros_like(z) | |
| mask = None | |
| if seq_lens is not None: | |
| mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float() | |
| non_linearity = torch.relu | |
| if self.affine_activation == "softplus": | |
| non_linearity = self.softplus | |
| for i in range(self.n_layers): | |
| z = non_linearity(self.in_layers[i](z, mask)) | |
| res_skip_acts = non_linearity(self.res_skip_layers[i](z)) | |
| output = output + res_skip_acts | |
| output = self.end(output) # [B, dim, seq_len] | |
| return output | |
| # Affine Coupling Layers | |
| class SplineTransformationLayerAR(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_in_channels, | |
| n_context_dim, | |
| n_layers, | |
| affine_model="simple_conv", | |
| kernel_size=1, | |
| scaling_fn="exp", | |
| affine_activation="softplus", | |
| n_channels=1024, | |
| n_bins=8, | |
| left=-6, | |
| right=6, | |
| bottom=-6, | |
| top=6, | |
| use_quadratic=False, | |
| ): | |
| super(SplineTransformationLayerAR, self).__init__() | |
| self.n_in_channels = n_in_channels # input dimensions | |
| self.left = left | |
| self.right = right | |
| self.bottom = bottom | |
| self.top = top | |
| self.n_bins = n_bins | |
| self.spline_fn = piecewise_linear_transform | |
| self.inv_spline_fn = piecewise_linear_inverse_transform | |
| self.use_quadratic = use_quadratic | |
| if self.use_quadratic: | |
| self.spline_fn = unbounded_piecewise_quadratic_transform | |
| self.inv_spline_fn = unbounded_piecewise_quadratic_transform | |
| self.n_bins = 2 * self.n_bins + 1 | |
| final_out_channels = self.n_in_channels * self.n_bins | |
| # autoregressive flow, kernel size 1 and no dilation | |
| self.param_predictor = SimpleConvNet( | |
| n_context_dim, | |
| 0, | |
| final_out_channels, | |
| n_layers, | |
| with_dilation=False, | |
| kernel_size=1, | |
| zero_init=True, | |
| use_partial_padding=False, | |
| ) | |
| # output is unnormalized bin weights | |
| def normalize(self, z, inverse): | |
| # normalize to [0, 1] | |
| if inverse: | |
| z = (z - self.bottom) / (self.top - self.bottom) | |
| else: | |
| z = (z - self.left) / (self.right - self.left) | |
| return z | |
| def denormalize(self, z, inverse): | |
| if inverse: | |
| z = z * (self.right - self.left) + self.left | |
| else: | |
| z = z * (self.top - self.bottom) + self.bottom | |
| return z | |
| def forward(self, z, context, inverse=False): | |
| b_s, c_s, t_s = z.size(0), z.size(1), z.size(2) | |
| z = self.normalize(z, inverse) | |
| if z.min() < 0.0 or z.max() > 1.0: | |
| print("spline z scaled beyond [0, 1]", z.min(), z.max()) | |
| z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1) | |
| affine_params = self.param_predictor(context) | |
| q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1) | |
| with amp.autocast(enabled=False): | |
| if self.use_quadratic: | |
| w = q_tilde[:, :, : self.n_bins // 2] | |
| v = q_tilde[:, :, self.n_bins // 2 :] | |
| z_tformed, log_s = self.spline_fn( | |
| z_reshaped.float(), w.float(), v.float(), inverse=inverse | |
| ) | |
| else: | |
| z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float()) | |
| z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1) | |
| z = self.denormalize(z, inverse) | |
| if inverse: | |
| return z | |
| log_s = log_s.reshape(b_s, t_s, -1) | |
| log_s = log_s.permute(0, 2, 1) | |
| log_s = log_s + c_s * ( | |
| np.log(self.top - self.bottom) - np.log(self.right - self.left) | |
| ) | |
| return z, log_s | |
| class SplineTransformationLayer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_mel_channels, | |
| n_context_dim, | |
| n_layers, | |
| with_dilation=True, | |
| kernel_size=5, | |
| scaling_fn="exp", | |
| affine_activation="softplus", | |
| n_channels=1024, | |
| n_bins=8, | |
| left=-4, | |
| right=4, | |
| bottom=-4, | |
| top=4, | |
| use_quadratic=False, | |
| ): | |
| super(SplineTransformationLayer, self).__init__() | |
| self.n_mel_channels = n_mel_channels # input dimensions | |
| self.half_mel_channels = int(n_mel_channels / 2) # half, because we split | |
| self.left = left | |
| self.right = right | |
| self.bottom = bottom | |
| self.top = top | |
| self.n_bins = n_bins | |
| self.spline_fn = piecewise_linear_transform | |
| self.inv_spline_fn = piecewise_linear_inverse_transform | |
| self.use_quadratic = use_quadratic | |
| if self.use_quadratic: | |
| self.spline_fn = unbounded_piecewise_quadratic_transform | |
| self.inv_spline_fn = unbounded_piecewise_quadratic_transform | |
| self.n_bins = 2 * self.n_bins + 1 | |
| final_out_channels = self.half_mel_channels * self.n_bins | |
| self.param_predictor = SimpleConvNet( | |
| self.half_mel_channels, | |
| n_context_dim, | |
| final_out_channels, | |
| n_layers, | |
| with_dilation=with_dilation, | |
| kernel_size=kernel_size, | |
| zero_init=False, | |
| ) | |
| # output is unnormalized bin weights | |
| def forward(self, z, context, inverse=False, seq_lens=None): | |
| b_s, c_s, t_s = z.size(0), z.size(1), z.size(2) | |
| # condition on z_0, transform z_1 | |
| n_half = self.half_mel_channels | |
| z_0, z_1 = z[:, :n_half], z[:, n_half:] | |
| # normalize to [0,1] | |
| if inverse: | |
| z_1 = (z_1 - self.bottom) / (self.top - self.bottom) | |
| else: | |
| z_1 = (z_1 - self.left) / (self.right - self.left) | |
| z_w_context = torch.cat((z_0, context), 1) | |
| affine_params = self.param_predictor(z_w_context, seq_lens) | |
| z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1) | |
| q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins) | |
| with autocast(enabled=False): | |
| if self.use_quadratic: | |
| w = q_tilde[:, :, : self.n_bins // 2] | |
| v = q_tilde[:, :, self.n_bins // 2 :] | |
| z_1_tformed, log_s = self.spline_fn( | |
| z_1_reshaped.float(), w.float(), v.float(), inverse=inverse | |
| ) | |
| if not inverse: | |
| log_s = torch.sum(log_s, 1) | |
| else: | |
| if inverse: | |
| z_1_tformed, _dc = self.inv_spline_fn( | |
| z_1_reshaped.float(), q_tilde.float(), False | |
| ) | |
| else: | |
| z_1_tformed, log_s = self.spline_fn( | |
| z_1_reshaped.float(), q_tilde.float() | |
| ) | |
| z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1) | |
| # undo [0, 1] normalization | |
| if inverse: | |
| z_1 = z_1 * (self.right - self.left) + self.left | |
| z = torch.cat((z_0, z_1), dim=1) | |
| return z | |
| else: # training | |
| z_1 = z_1 * (self.top - self.bottom) + self.bottom | |
| z = torch.cat((z_0, z_1), dim=1) | |
| log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * ( | |
| np.log(self.top - self.bottom) - np.log(self.right - self.left) | |
| ) | |
| return z, log_s | |
| class AffineTransformationLayer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| n_mel_channels, | |
| n_context_dim, | |
| n_layers, | |
| affine_model="simple_conv", | |
| with_dilation=True, | |
| kernel_size=5, | |
| scaling_fn="exp", | |
| affine_activation="softplus", | |
| n_channels=1024, | |
| use_partial_padding=False, | |
| ): | |
| super(AffineTransformationLayer, self).__init__() | |
| if affine_model not in ("wavenet", "simple_conv"): | |
| raise Exception("{} affine model not supported".format(affine_model)) | |
| if isinstance(scaling_fn, list): | |
| if not all( | |
| [x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn] | |
| ): | |
| raise Exception("{} scaling fn not supported".format(scaling_fn)) | |
| else: | |
| if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"): | |
| raise Exception("{} scaling fn not supported".format(scaling_fn)) | |
| self.affine_model = affine_model | |
| self.scaling_fn = scaling_fn | |
| if affine_model == "wavenet": | |
| self.affine_param_predictor = WN( | |
| int(n_mel_channels / 2), | |
| n_context_dim, | |
| n_layers=n_layers, | |
| n_channels=n_channels, | |
| affine_activation=affine_activation, | |
| use_partial_padding=use_partial_padding, | |
| ) | |
| elif affine_model == "simple_conv": | |
| self.affine_param_predictor = SimpleConvNet( | |
| int(n_mel_channels / 2), | |
| n_context_dim, | |
| n_mel_channels, | |
| n_layers, | |
| with_dilation=with_dilation, | |
| kernel_size=kernel_size, | |
| use_partial_padding=use_partial_padding, | |
| ) | |
| self.n_mel_channels = n_mel_channels | |
| def get_scaling_and_logs(self, scale_unconstrained): | |
| if self.scaling_fn == "translate": | |
| s = torch.exp(scale_unconstrained * 0) | |
| log_s = scale_unconstrained * 0 | |
| elif self.scaling_fn == "exp": | |
| s = torch.exp(scale_unconstrained) | |
| log_s = scale_unconstrained # log(exp | |
| elif self.scaling_fn == "tanh": | |
| s = torch.tanh(scale_unconstrained) + 1 + 1e-6 | |
| log_s = torch.log(s) | |
| elif self.scaling_fn == "sigmoid": | |
| s = torch.sigmoid(scale_unconstrained + 10) + 1e-6 | |
| log_s = torch.log(s) | |
| elif isinstance(self.scaling_fn, list): | |
| s_list, log_s_list = [], [] | |
| for i in range(scale_unconstrained.shape[1]): | |
| scaling_i = self.scaling_fn[i] | |
| if scaling_i == "translate": | |
| s_i = torch.exp(scale_unconstrained[:i] * 0) | |
| log_s_i = scale_unconstrained[:, i] * 0 | |
| elif scaling_i == "exp": | |
| s_i = torch.exp(scale_unconstrained[:, i]) | |
| log_s_i = scale_unconstrained[:, i] | |
| elif scaling_i == "tanh": | |
| s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6 | |
| log_s_i = torch.log(s_i) | |
| elif scaling_i == "sigmoid": | |
| s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6 | |
| log_s_i = torch.log(s_i) | |
| s_list.append(s_i[:, None]) | |
| log_s_list.append(log_s_i[:, None]) | |
| s = torch.cat(s_list, dim=1) | |
| log_s = torch.cat(log_s_list, dim=1) | |
| return s, log_s | |
| def forward(self, z, context, inverse=False, seq_lens=None): | |
| n_half = int(self.n_mel_channels / 2) | |
| z_0, z_1 = z[:, :n_half], z[:, n_half:] | |
| if self.affine_model == "wavenet": | |
| affine_params = self.affine_param_predictor( | |
| (z_0, context), seq_lens=seq_lens | |
| ) | |
| elif self.affine_model == "simple_conv": | |
| z_w_context = torch.cat((z_0, context), 1) | |
| affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens) | |
| scale_unconstrained = affine_params[:, :n_half, :] | |
| b = affine_params[:, n_half:, :] | |
| s, log_s = self.get_scaling_and_logs(scale_unconstrained) | |
| if inverse: | |
| z_1 = (z_1 - b) / s | |
| z = torch.cat((z_0, z_1), dim=1) | |
| return z | |
| else: | |
| z_1 = s * z_1 + b | |
| z = torch.cat((z_0, z_1), dim=1) | |
| return z, log_s | |
| class ConvAttention(torch.nn.Module): | |
| def __init__( | |
| self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0 | |
| ): | |
| super(ConvAttention, self).__init__() | |
| self.temperature = temperature | |
| self.softmax = torch.nn.Softmax(dim=3) | |
| self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
| self.key_proj = nn.Sequential( | |
| ConvNorm( | |
| n_text_channels, | |
| n_text_channels * 2, | |
| kernel_size=3, | |
| bias=True, | |
| w_init_gain="relu", | |
| ), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True), | |
| ) | |
| self.query_proj = nn.Sequential( | |
| ConvNorm( | |
| n_mel_channels, | |
| n_mel_channels * 2, | |
| kernel_size=3, | |
| bias=True, | |
| w_init_gain="relu", | |
| ), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), | |
| ) | |
| def run_padded_sequence( | |
| self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model | |
| ): | |
| """Sorts input data by previded ordering (and un-ordering) and runs the | |
| packed data through the recurrent model | |
| Args: | |
| sorted_idx (torch.tensor): 1D sorting index | |
| unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx) | |
| lens: lengths of input data (sorted in descending order) | |
| padded_data (torch.tensor): input sequences (padded) | |
| recurrent_model (nn.Module): recurrent model to run data through | |
| Returns: | |
| hidden_vectors (torch.tensor): outputs of the RNN, in the original, | |
| unsorted, ordering | |
| """ | |
| # sort the data by decreasing length using provided index | |
| # we assume batch index is in dim=1 | |
| padded_data = padded_data[:, sorted_idx] | |
| padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens) | |
| hidden_vectors = recurrent_model(padded_data)[0] | |
| hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors) | |
| # unsort the results at dim=1 and return | |
| hidden_vectors = hidden_vectors[:, unsort_idx] | |
| return hidden_vectors | |
| def forward( | |
| self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None | |
| ): | |
| """Attention mechanism for radtts. Unlike in Flowtron, we have no | |
| restrictions such as causality etc, since we only need this during | |
| training. | |
| Args: | |
| queries (torch.tensor): B x C x T1 tensor (likely mel data) | |
| keys (torch.tensor): B x C2 x T2 tensor (text data) | |
| query_lens: lengths for sorting the queries in descending order | |
| mask (torch.tensor): uint8 binary mask for variable length entries | |
| (should be in the T2 domain) | |
| Output: | |
| attn (torch.tensor): B x 1 x T1 x T2 attention mask. | |
| Final dim T2 should sum to 1 | |
| """ | |
| temp = 0.0005 | |
| keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 | |
| # Beware can only do this since query_dim = attn_dim = n_mel_channels | |
| queries_enc = self.query_proj(queries) | |
| # Gaussian Isotopic Attention | |
| # B x n_attn_dims x T1 x T2 | |
| attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 | |
| # compute log-likelihood from gaussian | |
| eps = 1e-8 | |
| attn = -temp * attn.sum(1, keepdim=True) | |
| if attn_prior is not None: | |
| attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps) | |
| attn_logprob = attn.clone() | |
| if mask is not None: | |
| attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf")) | |
| attn = self.softmax(attn) # softmax along T2 | |
| return attn, attn_logprob | |