grad-svc / grad /model.py
maxmax20160403's picture
Upload 39 files
3aa4060
raw
history blame contribute delete
No virus
5.83 kB
import math
import torch
from grad.ssim import SSIM
from grad.base import BaseModule
from grad.encoder import TextEncoder
from grad.diffusion import Diffusion
from grad.utils import f0_to_coarse, rand_ids_segments, slice_segments
SpeakerLoss = torch.nn.CosineEmbeddingLoss()
SsimLoss = SSIM()
class GradTTS(BaseModule):
def __init__(self, n_mels, n_vecs, n_pits, n_spks, n_embs,
n_enc_channels, filter_channels,
dec_dim, beta_min, beta_max, pe_scale):
super(GradTTS, self).__init__()
# common
self.n_mels = n_mels
self.n_vecs = n_vecs
self.n_spks = n_spks
self.n_embs = n_embs
# encoder
self.n_enc_channels = n_enc_channels
self.filter_channels = filter_channels
# decoder
self.dec_dim = dec_dim
self.beta_min = beta_min
self.beta_max = beta_max
self.pe_scale = pe_scale
self.pit_emb = torch.nn.Embedding(n_pits, n_embs)
self.spk_emb = torch.nn.Linear(n_spks, n_embs)
self.encoder = TextEncoder(n_vecs,
n_mels,
n_embs,
n_enc_channels,
filter_channels)
self.decoder = Diffusion(n_mels, dec_dim, n_embs, beta_min, beta_max, pe_scale)
def fine_tune(self):
for p in self.pit_emb.parameters():
p.requires_grad = False
for p in self.spk_emb.parameters():
p.requires_grad = False
self.encoder.fine_tune()
@torch.no_grad()
def forward(self, lengths, vec, pit, spk, n_timesteps, temperature=1.0, stoc=False):
"""
Generates mel-spectrogram from vec. Returns:
1. encoder outputs
2. decoder outputs
Args:
lengths (torch.Tensor): lengths of texts in batch.
vec (torch.Tensor): batch of speech vec
pit (torch.Tensor): batch of speech pit
spk (torch.Tensor): batch of speaker
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
stoc (bool, optional): flag that adds stochastic term to the decoder sampler.
Usually, does not provide synthesis improvements.
"""
lengths, vec, pit, spk = self.relocate_input([lengths, vec, pit, spk])
# Get pitch embedding
pit = self.pit_emb(f0_to_coarse(pit))
# Get speaker embedding
spk = self.spk_emb(spk)
# Transpose
vec = torch.transpose(vec, 1, -1)
pit = torch.transpose(pit, 1, -1)
# Get encoder_outputs `mu_x`
mu_x, mask_x, _ = self.encoder(lengths, vec, pit, spk)
encoder_outputs = mu_x
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature
# Generate sample by performing reverse dynamics
decoder_outputs = self.decoder(spk, z, mask_x, mu_x, n_timesteps, stoc)
encoder_outputs = encoder_outputs + torch.randn_like(encoder_outputs)
return encoder_outputs, decoder_outputs
def compute_loss(self, lengths, vec, pit, spk, mel, out_size, skip_diff=False):
"""
Computes 2 losses:
1. prior loss: loss between mel-spectrogram and encoder outputs.
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
Args:
lengths (torch.Tensor): lengths of texts in batch.
vec (torch.Tensor): batch of speech vec
pit (torch.Tensor): batch of speech pit
spk (torch.Tensor): batch of speaker
mel (torch.Tensor): batch of corresponding mel-spectrogram
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
"""
lengths, vec, pit, spk, mel = self.relocate_input([lengths, vec, pit, spk, mel])
# Get pitch embedding
pit = self.pit_emb(f0_to_coarse(pit))
# Get speaker embedding
spk_64 = self.spk_emb(spk)
# Transpose
vec = torch.transpose(vec, 1, -1)
pit = torch.transpose(pit, 1, -1)
# Get encoder_outputs `mu_x`
mu_x, mask_x, spk_preds = self.encoder(lengths, vec, pit, spk_64, training=True)
# Compute loss between aligned encoder outputs and mel-spectrogram
prior_loss = torch.sum(0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * mask_x)
prior_loss = prior_loss / (torch.sum(mask_x) * self.n_mels)
# Mel ssim
mel_loss = SsimLoss(mu_x, mel, mask_x)
# Compute loss of speaker for GRL
spk_loss = SpeakerLoss(spk, spk_preds, torch.Tensor(spk_preds.size(0))
.to(spk.device).fill_(1.0))
# Compute loss of score-based decoder
if skip_diff:
diff_loss = prior_loss.clone()
diff_loss.fill_(0)
else:
# Cut a small segment of mel-spectrogram in order to increase batch size
if not isinstance(out_size, type(None)):
ids = rand_ids_segments(lengths, out_size)
mel = slice_segments(mel, ids, out_size)
mask_y = slice_segments(mask_x, ids, out_size)
mu_y = slice_segments(mu_x, ids, out_size)
mu_y = mu_y + torch.randn_like(mu_y)
diff_loss, xt = self.decoder.compute_loss(
spk_64, mel, mask_y, mu_y)
return prior_loss, diff_loss, mel_loss, spk_loss