Pendrokar's picture
relocate folders
ed18ebf
raw
history blame contribute delete
No virus
9.23 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class Invertible1x1ConvLUS(torch.nn.Module):
def __init__(self, c):
super(Invertible1x1ConvLUS, self).__init__()
# Sample a random orthonormal matrix to initialize weights
W, _ = torch.linalg.qr(torch.randn(c, c))
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1*W[:, 0]
p, lower, upper = torch.lu_unpack(*torch.lu(W))
self.register_buffer('p', p)
# diagonals of lower will always be 1s anyway
lower = torch.tril(lower, -1)
lower_diag = torch.diag(torch.eye(c, c))
self.register_buffer('lower_diag', lower_diag)
self.lower = nn.Parameter(lower)
self.upper_diag = nn.Parameter(torch.diag(upper))
self.upper = nn.Parameter(torch.triu(upper, 1))
def forward(self, z, reverse=False):
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
W = torch.mm(self.p, torch.mm(L, U))
if reverse:
if not hasattr(self, 'W_inverse'):
# Reverse computation
W_inverse = W.float().inverse()
if z.type() == 'torch.cuda.HalfTensor':
W_inverse = W_inverse.half()
self.W_inverse = W_inverse[..., None]
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
return z
else:
W = W[..., None]
z = F.conv1d(z, W, bias=None, stride=1, padding=0)
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
return z, log_det_W
class ConvAttention(torch.nn.Module):
def __init__(self, n_mel_channels=80, n_speaker_dim=128,
n_text_channels=512, n_att_channels=80, temperature=1.0,
n_mel_convs=2, align_query_enc_type='3xconv',
use_query_proj=True):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.att_scaling_factor = np.sqrt(n_att_channels)
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.query_proj = Invertible1x1ConvLUS(n_mel_channels)
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
self.align_query_enc_type = align_query_enc_type
self.use_query_proj = bool(use_query_proj)
self.key_proj = nn.Sequential(
ConvNorm(n_text_channels,
n_text_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu'),
torch.nn.ReLU(),
ConvNorm(n_text_channels * 2,
n_att_channels,
kernel_size=1,
bias=True))
self.align_query_enc_type = align_query_enc_type
if align_query_enc_type == "inv_conv":
self.query_proj = Invertible1x1ConvLUS(n_mel_channels)
elif align_query_enc_type == "3xconv":
self.query_proj = nn.Sequential(
ConvNorm(n_mel_channels,
n_mel_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu'),
torch.nn.ReLU(),
ConvNorm(n_mel_channels * 2,
n_mel_channels,
kernel_size=1,
bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels,
n_att_channels,
kernel_size=1,
bias=True))
else:
raise ValueError("Unknown query encoder type specified")
def run_padded_sequence(self, sorted_idx, unsort_idx, lens, padded_data,
recurrent_model):
"""Sorts input data by previded ordering (and un-ordering) and runs the
packed data through the recurrent model
Args:
sorted_idx (torch.tensor): 1D sorting index
unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx)
lens: lengths of input data (sorted in descending order)
padded_data (torch.tensor): input sequences (padded)
recurrent_model (nn.Module): recurrent model to run data through
Returns:
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
unsorted, ordering
"""
# sort the data by decreasing length using provided index
# we assume batch index is in dim=1
padded_data = padded_data[:, sorted_idx]
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens)
hidden_vectors = recurrent_model(padded_data)[0]
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
# unsort the results at dim=1 and return
hidden_vectors = hidden_vectors[:, unsort_idx]
return hidden_vectors
def encode_query(self, query, query_lens):
query = query.permute(2, 0, 1) # seq_len, batch, feature dim
lens, ids = torch.sort(query_lens, descending=True)
original_ids = [0] * lens.size(0)
for i in range(len(ids)):
original_ids[ids[i]] = i
query_encoded = self.run_padded_sequence(ids, original_ids, lens,
query, self.query_lstm)
query_encoded = query_encoded.permute(1, 2, 0)
return query_encoded
def forward(self, queries, keys, query_lens, mask=None, key_lens=None,
keys_encoded=None, attn_prior=None):
"""Attention mechanism for flowtron parallel
Unlike in Flowtron, we have no restrictions such as causality etc,
since we only need this during training.
Args:
queries (torch.tensor): B x C x T1 tensor
(probably going to be mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
query_lens: lengths for sorting the queries in descending order
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
if self.use_query_proj:
if self.align_query_enc_type == "inv_conv":
queries_enc, log_det_W = self.query_proj(queries)
elif self.align_query_enc_type == "3xconv":
queries_enc = self.query_proj(queries)
log_det_W = 0.0
else:
queries_enc, log_det_W = self.query_proj(queries)
else:
queries_enc, log_det_W = queries, 0.0
# different ways of computing attn,
# one is isotopic gaussians (per phoneme)
# Simplistic Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]+1e-8)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2),
-float("inf"))
attn = self.softmax(attn) # Softmax along T2
return attn, attn_logprob