|
from typing import Sequence
|
|
|
|
import math
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
class VectorQuantizer(nn.Module):
|
|
"""
|
|
Reference:
|
|
[1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
|
|
"""
|
|
def __init__(self,
|
|
num_embeddings: int,
|
|
hidden_dim: int,
|
|
beta: float = 0.25):
|
|
super().__init__()
|
|
self.K = num_embeddings
|
|
self.D = hidden_dim
|
|
self.beta = 0.05
|
|
|
|
self.embedding = nn.Embedding(self.K, self.D)
|
|
self.embedding.weight.data.normal_(0.8, 0.1)
|
|
|
|
def forward(self, latents: torch.Tensor) -> torch.Tensor:
|
|
|
|
latents_shape = latents.shape
|
|
flat_latents = latents.view(-1, self.D)
|
|
|
|
|
|
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
|
|
torch.sum(self.embedding.weight ** 2, dim=1) - \
|
|
2 * torch.matmul(flat_latents, self.embedding.weight.t())
|
|
|
|
|
|
encoding_inds = torch.argmin(dist, dim=1)
|
|
output_inds = encoding_inds.view(latents_shape[0], latents_shape[1])
|
|
encoding_inds = encoding_inds.unsqueeze(1)
|
|
|
|
|
|
device = latents.device
|
|
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
|
|
encoding_one_hot.scatter_(1, encoding_inds, 1)
|
|
|
|
|
|
|
|
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)
|
|
quantized_latents = quantized_latents.view(latents_shape)
|
|
|
|
|
|
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
|
|
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
|
|
|
|
vq_loss = commitment_loss * self.beta + embedding_loss
|
|
|
|
|
|
quantized_latents = latents + (quantized_latents - latents).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
avg_probs = torch.mean(encoding_one_hot, dim=0)
|
|
|
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
|
|
|
return quantized_latents, vq_loss, output_inds, self.embedding, perplexity
|
|
|
|
|
|
class ProsodyEncoder(nn.Module):
|
|
"""VQ-VAE prosody encoder module.
|
|
|
|
Args:
|
|
odim (int): Number of input channels (mel spectrogram channels).
|
|
ref_enc_conv_layers (int, optional):
|
|
The number of conv layers in the reference encoder.
|
|
ref_enc_conv_chans_list: (Sequence[int], optional):
|
|
List of the number of channels of conv layers in the referece encoder.
|
|
ref_enc_conv_kernel_size (int, optional):
|
|
Kernal size of conv layers in the reference encoder.
|
|
ref_enc_conv_stride (int, optional):
|
|
Stride size of conv layers in the reference encoder.
|
|
ref_enc_gru_layers (int, optional):
|
|
The number of GRU layers in the reference encoder.
|
|
ref_enc_gru_units (int, optional):
|
|
The number of GRU units in the reference encoder.
|
|
ref_emb_integration_type: How to integrate reference embedding.
|
|
adim (int, optional): This value is not that important.
|
|
This will not change the capacity in the information-bottleneck.
|
|
num_embeddings (int, optional): The higher this value, the higher the
|
|
capacity in the information bottleneck.
|
|
hidden_dim (int, optional): Number of hidden channels.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
odim: int,
|
|
adim: int = 64,
|
|
num_embeddings: int = 10,
|
|
hidden_dim: int = 3,
|
|
beta: float = 0.25,
|
|
ref_enc_conv_layers: int = 2,
|
|
ref_enc_conv_chans_list: Sequence[int] = (32, 32),
|
|
ref_enc_conv_kernel_size: int = 3,
|
|
ref_enc_conv_stride: int = 1,
|
|
global_enc_gru_layers: int = 1,
|
|
global_enc_gru_units: int = 32,
|
|
global_emb_integration_type: str = "add",
|
|
) -> None:
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
|
|
self.global_emb_integration_type = global_emb_integration_type
|
|
|
|
padding = (ref_enc_conv_kernel_size - 1) // 2
|
|
|
|
self.ref_encoder = RefEncoder(
|
|
ref_enc_conv_layers=ref_enc_conv_layers,
|
|
ref_enc_conv_chans_list=ref_enc_conv_chans_list,
|
|
ref_enc_conv_kernel_size=ref_enc_conv_kernel_size,
|
|
ref_enc_conv_stride=ref_enc_conv_stride,
|
|
ref_enc_conv_padding=padding,
|
|
)
|
|
|
|
|
|
ref_enc_output_units = odim
|
|
for i in range(ref_enc_conv_layers):
|
|
ref_enc_output_units = (
|
|
ref_enc_output_units - ref_enc_conv_kernel_size + 2 * padding
|
|
) // ref_enc_conv_stride + 1
|
|
ref_enc_output_units *= ref_enc_conv_chans_list[-1]
|
|
|
|
self.fg_encoder = FGEncoder(
|
|
ref_enc_output_units + global_enc_gru_units,
|
|
hidden_dim=hidden_dim,
|
|
)
|
|
|
|
self.global_encoder = GlobalEncoder(
|
|
ref_enc_output_units,
|
|
global_enc_gru_layers=global_enc_gru_layers,
|
|
global_enc_gru_units=global_enc_gru_units,
|
|
)
|
|
|
|
|
|
if self.global_emb_integration_type == "add":
|
|
self.global_projection = nn.Linear(global_enc_gru_units, adim)
|
|
else:
|
|
self.global_projection = nn.Linear(
|
|
adim + global_enc_gru_units, adim
|
|
)
|
|
|
|
self.ar_prior = ARPrior(
|
|
adim,
|
|
num_embeddings=num_embeddings,
|
|
hidden_dim=hidden_dim,
|
|
)
|
|
|
|
self.vq_layer = VectorQuantizer(num_embeddings, hidden_dim, beta)
|
|
|
|
|
|
self.qfg_projection = nn.Linear(hidden_dim, adim)
|
|
|
|
def forward(
|
|
self,
|
|
ys: torch.Tensor,
|
|
ds: torch.Tensor,
|
|
hs: torch.Tensor,
|
|
global_embs: torch.Tensor = None,
|
|
train_ar_prior: bool = False,
|
|
ar_prior_inference: bool = False,
|
|
fg_inds: torch.Tensor = None,
|
|
) -> Sequence[torch.Tensor]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
ys (Tensor): Batch of padded target features (B, Lmax, odim).
|
|
ds (LongTensor): Batch of padded durations (B, Tmax).
|
|
hs (Tensor): Batch of phoneme embeddings (B, Tmax, D).
|
|
global_embs (Tensor, optional): Global embeddings (B, D)
|
|
|
|
Returns:
|
|
Tensor: Fine-grained quantized prosody embeddings (B, Tmax, adim).
|
|
Tensor: VQ loss.
|
|
Tensor: Global prosody embeddings (B, ref_enc_gru_units)
|
|
"""
|
|
if ys is not None:
|
|
print('generating global_embs')
|
|
ref_embs = self.ref_encoder(ys)
|
|
global_embs = self.global_encoder(ref_embs)
|
|
|
|
if ar_prior_inference:
|
|
print('Using ar prior')
|
|
hs_integrated = self._integrate_with_global_embs(hs, global_embs)
|
|
qs, top_inds = self.ar_prior.inference(
|
|
hs_integrated, fg_inds, self.vq_layer.embedding
|
|
)
|
|
|
|
qs = self.qfg_projection(qs)
|
|
assert hs.size(2) == qs.size(2)
|
|
|
|
p_embs = self._integrate_with_global_embs(qs, global_embs)
|
|
assert hs.shape == p_embs.shape
|
|
|
|
return p_embs, 0, 0, 0, top_inds
|
|
|
|
|
|
global_embs_expanded = global_embs.unsqueeze(1).expand(-1, ref_embs.size(1), -1)
|
|
|
|
ref_embs_integrated = torch.cat([ref_embs, global_embs_expanded], dim=-1)
|
|
|
|
|
|
fg_embs = self.fg_encoder(ref_embs_integrated, ds, ys.size(1))
|
|
|
|
|
|
qs, vq_loss, inds, codebook, perplexity = self.vq_layer(fg_embs)
|
|
|
|
assert hs.size(1) == qs.size(1)
|
|
|
|
qs = self.qfg_projection(qs)
|
|
assert hs.size(2) == qs.size(2)
|
|
|
|
p_embs = self._integrate_with_global_embs(qs, global_embs)
|
|
assert hs.shape == p_embs.shape
|
|
|
|
ar_prior_loss = 0
|
|
if train_ar_prior:
|
|
|
|
hs_integrated = self._integrate_with_global_embs(hs, global_embs)
|
|
qs, ar_prior_loss = self.ar_prior(hs_integrated, inds, codebook)
|
|
qs = self.qfg_projection(qs)
|
|
assert hs.size(2) == qs.size(2)
|
|
|
|
p_embs = self._integrate_with_global_embs(qs, global_embs)
|
|
assert hs.shape == p_embs.shape
|
|
|
|
return p_embs, vq_loss, ar_prior_loss, perplexity, global_embs
|
|
|
|
def _integrate_with_global_embs(
|
|
self,
|
|
qs: torch.Tensor,
|
|
global_embs: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Integrate ref embedding with spectrogram hidden states.
|
|
|
|
Args:
|
|
qs (Tensor): Batch of quantized FG embeddings (B, Tmax, adim).
|
|
global_embs (Tensor): Batch of global embeddings (B, global_enc_gru_units).
|
|
|
|
Returns:
|
|
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim).
|
|
"""
|
|
if self.global_emb_integration_type == "add":
|
|
|
|
global_embs = self.global_projection(global_embs)
|
|
res = qs + global_embs.unsqueeze(1)
|
|
elif self.global_emb_integration_type == "concat":
|
|
|
|
|
|
global_embs = global_embs.unsqueeze(1).expand(-1, qs.size(1), -1)
|
|
|
|
res = self.prosody_projection(torch.cat([qs, global_embs], dim=-1))
|
|
else:
|
|
raise NotImplementedError("support only add or concat.")
|
|
|
|
return res
|
|
|
|
|
|
class RefEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
ref_enc_conv_layers: int = 2,
|
|
ref_enc_conv_chans_list: Sequence[int] = (32, 32),
|
|
ref_enc_conv_kernel_size: int = 3,
|
|
ref_enc_conv_stride: int = 1,
|
|
ref_enc_conv_padding: int = 1,
|
|
):
|
|
"""Initilize reference encoder module."""
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
|
|
assert ref_enc_conv_kernel_size % 2 == 1, "kernel size must be odd."
|
|
assert (
|
|
len(ref_enc_conv_chans_list) == ref_enc_conv_layers
|
|
), "the number of conv layers and length of channels list must be the same."
|
|
|
|
convs = []
|
|
for i in range(ref_enc_conv_layers):
|
|
conv_in_chans = 1 if i == 0 else ref_enc_conv_chans_list[i - 1]
|
|
conv_out_chans = ref_enc_conv_chans_list[i]
|
|
convs += [
|
|
nn.Conv2d(
|
|
conv_in_chans,
|
|
conv_out_chans,
|
|
kernel_size=ref_enc_conv_kernel_size,
|
|
stride=ref_enc_conv_stride,
|
|
padding=ref_enc_conv_padding,
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
|
|
]
|
|
self.convs = nn.Sequential(*convs)
|
|
|
|
def forward(self, ys: torch.Tensor) -> torch.Tensor:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
ys (Tensor): Batch of padded target features (B, Lmax, odim).
|
|
|
|
Returns:
|
|
Tensor: Batch of spectrogram hiddens (B, L', ref_enc_output_units)
|
|
|
|
"""
|
|
B = ys.size(0)
|
|
ys = ys.unsqueeze(1)
|
|
hs = self.convs(ys)
|
|
hs = hs.transpose(1, 2)
|
|
L = hs.size(1)
|
|
|
|
hs = hs.contiguous().view(B, L, -1)
|
|
|
|
return hs
|
|
|
|
|
|
class GlobalEncoder(nn.Module):
|
|
"""Module that creates a global embedding from a hidden spectrogram sequence.
|
|
|
|
Args:
|
|
"""
|
|
def __init__(
|
|
self,
|
|
ref_enc_output_units: int,
|
|
global_enc_gru_layers: int = 1,
|
|
global_enc_gru_units: int = 32,
|
|
):
|
|
super().__init__()
|
|
self.gru = torch.nn.GRU(ref_enc_output_units, global_enc_gru_units,
|
|
global_enc_gru_layers, batch_first=True)
|
|
|
|
def forward(
|
|
self,
|
|
hs: torch.Tensor,
|
|
):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
hs (Tensor): Batch of spectrogram hiddens (B, L', ref_enc_output_units).
|
|
|
|
Returns:
|
|
Tensor: Reference embedding (B, ref_enc_gru_units).
|
|
"""
|
|
self.gru.flatten_parameters()
|
|
_, global_embs = self.gru(hs)
|
|
global_embs = global_embs[-1]
|
|
|
|
return global_embs
|
|
|
|
|
|
class FGEncoder(nn.Module):
|
|
"""Spectrogram to phoneme alignment module.
|
|
|
|
Args:
|
|
"""
|
|
def __init__(
|
|
self,
|
|
input_units: int,
|
|
hidden_dim: int = 3,
|
|
):
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
self.projection = nn.Sequential(
|
|
nn.Sequential(
|
|
nn.Linear(input_units, input_units // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(p=0.2),
|
|
),
|
|
nn.Sequential(
|
|
nn.Linear(input_units // 2, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(p=0.2),
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hs: torch.Tensor,
|
|
ds: torch.Tensor,
|
|
Lmax: int
|
|
):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
hs (Tensor): Batch of spectrogram hiddens
|
|
(B, L', ref_enc_output_units + global_enc_gru_units).
|
|
ds (LongTensor): Batch of padded durations (B, Tmax).
|
|
|
|
Returns:
|
|
Tensor: aligned spectrogram hiddens (B, Tmax, hidden_dim).
|
|
"""
|
|
|
|
hs = self._align_durations(hs, ds, Lmax)
|
|
hs = self.projection(hs)
|
|
|
|
return hs
|
|
|
|
def _align_durations(self, hs, ds, Lmax):
|
|
"""Transform the spectrogram hiddens according to the ground-truth durations
|
|
so that there's only one hidden per phoneme hidden.
|
|
|
|
Args:
|
|
# (B, L', ref_enc_output_units + global_enc_gru_units)
|
|
hs (Tensor): Batch of spectrogram hidden state sequences .
|
|
ds (LongTensor): Batch of padded durations (B, Tmax)
|
|
|
|
Returns:
|
|
# (B, Tmax, ref_enc_output_units + global_enc_gru_units)
|
|
Tensor: Batch of averaged spectrogram hidden state sequences.
|
|
"""
|
|
B = hs.size(0)
|
|
L = hs.size(1)
|
|
D = hs.size(2)
|
|
|
|
Tmax = ds.size(1)
|
|
|
|
device = hs.device
|
|
hs_res = torch.zeros(
|
|
[B, Tmax, D],
|
|
device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
for b_i in range(B):
|
|
durations = ds[b_i]
|
|
multiplier = L / Lmax
|
|
i = 0
|
|
for d_i in range(Tmax):
|
|
|
|
d = max(math.floor(durations[d_i].item() * multiplier), 1)
|
|
if durations[d_i].item() > 0:
|
|
hs_slice = hs[b_i, i:i + d, :]
|
|
hs_res[b_i, d_i, :] = torch.mean(hs_slice, 0)
|
|
i += d
|
|
hs_res.requires_grad_(hs.requires_grad)
|
|
return hs_res
|
|
|
|
|
|
class ARPrior(nn.Module):
|
|
|
|
"""Autoregressive prior.
|
|
|
|
This module is inspired by the AR prior described in `Generating diverse and
|
|
natural text-to-speech samples using a quantized fine-grained VAE and
|
|
auto-regressive prosody prior`. This prior is fit in the continuous latent space.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
adim: int,
|
|
num_embeddings: int = 10,
|
|
hidden_dim: int = 3,
|
|
):
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
|
|
self.adim = adim
|
|
self.hidden_dim = hidden_dim
|
|
self.num_embeddings = num_embeddings
|
|
|
|
self.qs_projection = nn.Linear(hidden_dim, adim)
|
|
|
|
self.lstm = nn.LSTMCell(
|
|
self.adim,
|
|
self.num_embeddings,
|
|
)
|
|
|
|
self.criterion = nn.NLLLoss()
|
|
|
|
def inds_to_embs(self, inds, codebook, device):
|
|
"""Returns the quantized embeddings from the codebook,
|
|
corresponding to the indices.
|
|
|
|
Args:
|
|
inds (Tensor): Batch of indices (B, Tmax, 1).
|
|
codebook (Embedding): (num_embeddings, D).
|
|
|
|
Returns:
|
|
Tensor: Quantized embeddings (B, Tmax, D).
|
|
"""
|
|
flat_inds = torch.flatten(inds).unsqueeze(1)
|
|
|
|
|
|
encoding_one_hot = torch.zeros(
|
|
flat_inds.size(0),
|
|
self.num_embeddings,
|
|
device=device
|
|
)
|
|
encoding_one_hot.scatter_(1, flat_inds, 1)
|
|
|
|
|
|
|
|
quantized_embs = torch.matmul(encoding_one_hot, codebook.weight)
|
|
|
|
quantized_embs = quantized_embs.view(
|
|
inds.size(0), inds.size(1), self.hidden_dim
|
|
)
|
|
|
|
return quantized_embs
|
|
|
|
def top_embeddings(self, emb_scores: torch.Tensor, codebook):
|
|
"""Returns the top quantized embeddings from the codebook using the scores.
|
|
|
|
Args:
|
|
emb_scores (Tensor): Batch of embedding scores (B, Tmax, num_embeddings).
|
|
codebook (Embedding): (num_embeddings, D).
|
|
|
|
Returns:
|
|
Tensor: Top quantized embeddings (B, Tmax, D).
|
|
Tensor: Top 3 inds (B, Tmax, 3).
|
|
"""
|
|
_, top_inds = emb_scores.topk(1, dim=-1)
|
|
quantized_embs = self.inds_to_embs(
|
|
top_inds,
|
|
codebook,
|
|
emb_scores.device,
|
|
)
|
|
_, top3_inds = emb_scores.topk(3, dim=-1)
|
|
return quantized_embs, top3_inds
|
|
|
|
def _forward(self, hs_ref_embs, codebook, fg_inds=None):
|
|
inds = []
|
|
scores = []
|
|
embs = []
|
|
|
|
if fg_inds is not None:
|
|
init_embs = self.inds_to_embs(fg_inds, codebook, hs_ref_embs.device)
|
|
embs = [init_emb.unsqueeze(1) for init_emb in init_embs.transpose(1, 0)]
|
|
|
|
start = fg_inds.size(1) if fg_inds is not None else 0
|
|
hidden = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size)
|
|
cell = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size)
|
|
|
|
for i in range(start, hs_ref_embs.size(1)):
|
|
|
|
input = hs_ref_embs[:, i]
|
|
if i != 0:
|
|
|
|
qs = self.qs_projection(embs[-1])
|
|
|
|
input = hs_ref_embs[:, i] + qs.squeeze()
|
|
hidden, cell = self.lstm(input, (hidden, cell))
|
|
out = hidden.unsqueeze(1)
|
|
|
|
emb_scores = F.log_softmax(out, dim=2)
|
|
quantized_embs, top_inds = self.top_embeddings(emb_scores, codebook)
|
|
|
|
embs.append(quantized_embs)
|
|
scores.append(emb_scores)
|
|
inds.append(top_inds)
|
|
|
|
out_embs = torch.cat(embs, dim=1)
|
|
assert(out_embs.size(0) == hs_ref_embs.size(0))
|
|
assert(out_embs.size(1) == hs_ref_embs.size(1))
|
|
out_emb_scores = torch.cat(scores, dim=1) if start < hs_ref_embs.size(1) else scores
|
|
out_inds = torch.cat(inds, dim=1) if start < hs_ref_embs.size(1) else fg_inds
|
|
|
|
return out_embs, out_emb_scores, out_inds
|
|
|
|
def forward(self, hs_ref_embs, inds, codebook):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
hs_p_embs (Tensor): Batch of phoneme embeddings
|
|
with integrated global prosody embeddings (B, Tmax, D).
|
|
inds (Tensor): Batch of ground-truth codebook indices
|
|
(B, Tmax).
|
|
|
|
Returns:
|
|
Tensor: Batch of predicted quantized latents (B, Tmax, D).
|
|
Tensor: Cross entropy loss value.
|
|
|
|
"""
|
|
quantized_embs, emb_scores, _ = self._forward(hs_ref_embs, codebook)
|
|
emb_scores = emb_scores.permute(0, 2, 1).contiguous()
|
|
loss = self.criterion(emb_scores, inds)
|
|
return quantized_embs, loss
|
|
|
|
def inference(self, hs_ref_embs, fg_inds, codebook):
|
|
"""Inference duration.
|
|
|
|
Args:
|
|
hs_p_embs (Tensor): Batch of phoneme embeddings
|
|
with integrated global prosody embeddings (B, Tmax, D).
|
|
|
|
Returns:
|
|
Tensor: Batch of predicted quantized latents (B, Tmax, D).
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quantized_embs, _, top_inds = self._forward(hs_ref_embs, codebook, fg_inds)
|
|
return quantized_embs, top_inds
|
|
|