Spaces:
Running
Running
import math | |
import torch | |
from grad.ssim import SSIM | |
from grad.base import BaseModule | |
from grad.encoder import TextEncoder | |
from grad.diffusion import Diffusion | |
from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments | |
SpeakerLoss = torch.nn.CosineEmbeddingLoss() | |
SsimLoss = SSIM() | |
class GradTTS(BaseModule): | |
def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs, | |
n_enc_channels, filter_channels, | |
dec_dim, beta_min, beta_max, pe_scale): | |
super(GradTTS, self).__init__() | |
# common | |
self.n_mels = n_mels | |
self.n_vecs = n_vecs | |
self.n_spks = n_spks | |
self.n_embs = n_embs | |
# encoder | |
self.n_enc_channels = n_enc_channels | |
self.filter_channels = filter_channels | |
# decoder | |
self.dec_dim = dec_dim | |
self.beta_min = beta_min | |
self.beta_max = beta_max | |
self.pe_scale = pe_scale | |
self.pit_emb = torch.nn.Embedding(n_pits, n_embs) | |
self.spk_emb = torch.nn.Linear(n_spks, n_embs) | |
self.encoder = TextEncoder(n_vecs, | |
n_mels, | |
n_embs, | |
n_enc_channels, | |
filter_channels) | |
self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale) | |
def fine_tune(self): | |
for p in self.pit_emb.parameters(): | |
p.requires_grad = False | |
for p in self.spk_emb.parameters(): | |
p.requires_grad = False | |
self.encoder.fine_tune() | |
def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False): | |
""" | |
Generates mel-spectrogram from vec. Returns: | |
1. encoder outputs | |
2. decoder outputs | |
Args: | |
lengths (torch.Tensor): lengths of texts in batch. | |
vec (torch.Tensor): batch of speech vec | |
pit (torch.Tensor): batch of speech pit | |
spk (torch.Tensor): batch of speaker | |
n_timesteps (int): number of steps to use for reverse diffusion in decoder. | |
temperature (float, optional): controls variance of terminal distribution. | |
stoc (bool, optional): flag that adds stochastic term to the decoder sampler. | |
Usually, does not provide synthesis improvements. | |
""" | |
lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk]) | |
# Get pitch embedding | |
pit = self.pit_emb(f0_to_coarse(pit)) | |
# Get speaker embedding | |
spk = self.spk_emb(spk) | |
# Transpose | |
vec = torch.transpose(vec, 1, -1) | |
pit = torch.transpose(pit, 1, -1) | |
# Get encoder_outputs `mu_x` | |
mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk) | |
encoder_outputs = mu_x | |
# Sample latent representation from terminal distribution N(mu_y, I) | |
z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature | |
# Generate sample by performing reverse dynamics | |
decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc) | |
encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs) | |
return encoder_outputs, decoder_outputs | |
def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False): | |
""" | |
Computes 2 losses: | |
1. prior loss: loss between mel-spectrogram and encoder outputs. | |
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. | |
Args: | |
lengths (torch.Tensor): lengths of texts in batch. | |
vec (torch.Tensor): batch of speech vec | |
pit (torch.Tensor): batch of speech pit | |
spk (torch.Tensor): batch of speaker | |
mel (torch.Tensor): batch of corresponding mel-spectrogram | |
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. | |
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. | |
""" | |
lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel]) | |
# Get pitch embedding | |
pit = self.pit_emb(f0_to_coarse(pit)) | |
# Get speaker embedding | |
spk_64 = self.spk_emb(spk) | |
# Transpose | |
vec = torch.transpose(vec, 1, -1) | |
pit = torch.transpose(pit, 1, -1) | |
# Get encoder_outputs `mu_x` | |
mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True) | |
# Compute loss between aligned encoder outputs and mel-spectrogram | |
prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x) | |
prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels) | |
# Mel ssim | |
mel_loss = SsimLoss(mu_x, mel, mask_x) | |
# Compute loss of speaker for GRL | |
spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0)) | |
.to(spk.device).fill_(1.0)) | |
# Compute loss of score-based decoder | |
if skip_diff: | |
diff_loss = prior_loss.clone() | |
diff_loss.fill_(0) | |
else: | |
# Cut a small segment of mel-spectrogram in order to increase batch size | |
if not isinstance(out_size, type(None)): | |
ids = rand_ids_segments(lengths, out_size) | |
mel = slice_segments(mel, ids, out_size) | |
mask_y = slice_segments(mask_x, ids, out_size) | |
mu_y = slice_segments(mu_x, ids, out_size) | |
mu_y = mu_y + torch.randn_like(mu_y) | |
diff_loss, xt = self.decoder.compute_loss( | |
spk_64, mel, mask_y, mu_y) | |
return prior_loss, diff_loss, mel_loss, spk_loss | |