import math import torch from grad.ssim import SSIM from grad.base import BaseModule from grad.encoder import TextEncoder from grad.diffusion import Diffusion from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments SpeakerLoss = torch.nn.CosineEmbeddingLoss() SsimLoss = SSIM() class GradTTS(BaseModule): def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs, n_enc_channels, filter_channels, dec_dim, beta_min, beta_max, pe_scale): super(GradTTS, self).__init__() # common self.n_mels = n_mels self.n_vecs = n_vecs self.n_spks = n_spks self.n_embs = n_embs # encoder self.n_enc_channels = n_enc_channels self.filter_channels = filter_channels # decoder self.dec_dim = dec_dim self.beta_min = beta_min self.beta_max = beta_max self.pe_scale = pe_scale self.pit_emb = torch.nn.Embedding(n_pits, n_embs) self.spk_emb = torch.nn.Linear(n_spks, n_embs) self.encoder = TextEncoder(n_vecs, n_mels, n_embs, n_enc_channels, filter_channels) self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale) def fine_tune(self): for p in self.pit_emb.parameters(): p.requires_grad = False for p in self.spk_emb.parameters(): p.requires_grad = False self.encoder.fine_tune() @torch.no_grad() def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False): """ Generates mel-spectrogram from vec. Returns: 1. encoder outputs 2. decoder outputs Args: lengths (torch.Tensor): lengths of texts in batch. vec (torch.Tensor): batch of speech vec pit (torch.Tensor): batch of speech pit spk (torch.Tensor): batch of speaker n_timesteps (int): number of steps to use for reverse diffusion in decoder. temperature (float, optional): controls variance of terminal distribution. stoc (bool, optional): flag that adds stochastic term to the decoder sampler. Usually, does not provide synthesis improvements. """ lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk]) # Get pitch embedding pit = self.pit_emb(f0_to_coarse(pit)) # Get speaker embedding spk = self.spk_emb(spk) # Transpose vec = torch.transpose(vec, 1, -1) pit = torch.transpose(pit, 1, -1) # Get encoder_outputs `mu_x` mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk) encoder_outputs = mu_x # Sample latent representation from terminal distribution N(mu_y, I) z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature # Generate sample by performing reverse dynamics decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc) encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs) return encoder_outputs, decoder_outputs def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False): """ Computes 2 losses: 1. prior loss: loss between mel-spectrogram and encoder outputs. 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. Args: lengths (torch.Tensor): lengths of texts in batch. vec (torch.Tensor): batch of speech vec pit (torch.Tensor): batch of speech pit spk (torch.Tensor): batch of speaker mel (torch.Tensor): batch of corresponding mel-spectrogram out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. """ lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel]) # Get pitch embedding pit = self.pit_emb(f0_to_coarse(pit)) # Get speaker embedding spk_64 = self.spk_emb(spk) # Transpose vec = torch.transpose(vec, 1, -1) pit = torch.transpose(pit, 1, -1) # Get encoder_outputs `mu_x` mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True) # Compute loss between aligned encoder outputs and mel-spectrogram prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x) prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels) # Mel ssim mel_loss = SsimLoss(mu_x, mel, mask_x) # Compute loss of speaker for GRL spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0)) .to(spk.device).fill_(1.0)) # Compute loss of score-based decoder if skip_diff: diff_loss = prior_loss.clone() diff_loss.fill_(0) else: # Cut a small segment of mel-spectrogram in order to increase batch size if not isinstance(out_size, type(None)): ids = rand_ids_segments(lengths, out_size) mel = slice_segments(mel, ids, out_size) mask_y = slice_segments(mask_x, ids, out_size) mu_y = slice_segments(mu_x, ids, out_size) mu_y = mu_y + torch.randn_like(mu_y) diff_loss, xt = self.decoder.compute_loss( spk_64, mel, mask_y, mu_y) return prior_loss, diff_loss, mel_loss, spk_loss