import torch import numpy as np import torch.nn.functional as F class SLMAdversarialLoss(torch.nn.Module): def __init__( self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5, ): super(SLMAdversarialLoss, self).__init__() self.model = model self.wl = wl self.sampler = sampler self.min_len = min_len self.max_len = max_len self.batch_percentage = batch_percentage self.sig = sig self.skip_update = skip_update def forward( self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None, ): text_mask = length_to_mask(ref_lengths).to(ref_text.device) bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) if use_ind and np.random.rand() < 0.5: s_preds = s_trg else: num_steps = np.random.randint(3, 5) if ref_s is not None: s_preds = self.sampler( noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), embedding=bert_dur, embedding_scale=1, features=ref_s, # reference from the same speaker as the embedding embedding_mask_proba=0.1, num_steps=num_steps, ).squeeze(1) else: s_preds = self.sampler( noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), embedding=bert_dur, embedding_scale=1, embedding_mask_proba=0.1, num_steps=num_steps, ).squeeze(1) s_dur = s_preds[:, 128:] s = s_preds[:, :128] d, _ = self.model.predictor( d_en, s_dur, ref_lengths, torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device), text_mask, ) bib = 0 output_lengths = [] attn_preds = [] # differentiable duration modeling for _s2s_pred, _text_length in zip(d, ref_lengths): _s2s_pred_org = _s2s_pred[:_text_length, :] _s2s_pred = torch.sigmoid(_s2s_pred_org) _dur_pred = _s2s_pred.sum(axis=-1) l = int(torch.round(_s2s_pred.sum()).item()) t = torch.arange(0, l).expand(l) t = ( torch.arange(0, l) .unsqueeze(0) .expand((len(_s2s_pred), l)) .to(ref_text.device) ) loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 h = torch.exp( -0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2 ) out = torch.nn.functional.conv1d( _s2s_pred_org.unsqueeze(0), h.unsqueeze(1), padding=h.shape[-1] - 1, groups=int(_text_length), )[..., :l] attn_preds.append(F.softmax(out.squeeze(), dim=0)) output_lengths.append(l) max_len = max(output_lengths) with torch.no_grad(): t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to( ref_text.device ) for bib in range(len(output_lengths)): s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib] asr_pred = t_en @ s2s_attn _, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask) mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) mel_len = min(mel_len, self.max_len // 2) # get clips en = [] p_en = [] sp = [] F0_fakes = [] N_fakes = [] wav = [] for bib in range(len(output_lengths)): mel_length_pred = output_lengths[bib] mel_length_gt = int(mel_input_length[bib].item() / 2) if mel_length_gt <= mel_len or mel_length_pred <= mel_len: continue sp.append(s_preds[bib]) random_start = np.random.randint(0, mel_length_pred - mel_len) en.append(asr_pred[bib, :, random_start : random_start + mel_len]) p_en.append(p_pred[bib, :, random_start : random_start + mel_len]) # get ground truth clips random_start = np.random.randint(0, mel_length_gt - mel_len) y = waves[bib][ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 ] wav.append(torch.from_numpy(y).to(ref_text.device)) if len(wav) >= self.batch_percentage * len( waves ): # prevent OOM due to longer lengths break if len(sp) <= 1: return None sp = torch.stack(sp) wav = torch.stack(wav).float() en = torch.stack(en) p_en = torch.stack(p_en) F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) # discriminator loss if (iters + 1) % self.skip_update == 0: if np.random.randint(0, 2) == 0: wav = y_rec_gt_pred use_rec = True else: use_rec = False crop_size = min(wav.size(-1), y_pred.size(-1)) if ( use_rec ): # use reconstructed (shorter lengths), do length invariant regularization if wav.size(-1) > y_pred.size(-1): real_GP = wav[:, :, :crop_size] out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) out_org = self.wl.discriminator_forward(wav.detach().squeeze()) loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) if np.random.randint(0, 2) == 0: d_loss = self.wl.discriminator( real_GP.detach().squeeze(), y_pred.detach().squeeze() ).mean() else: d_loss = self.wl.discriminator( wav.detach().squeeze(), y_pred.detach().squeeze() ).mean() else: real_GP = y_pred[:, :, :crop_size] out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) out_org = self.wl.discriminator_forward(y_pred.detach().squeeze()) loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) if np.random.randint(0, 2) == 0: d_loss = self.wl.discriminator( wav.detach().squeeze(), real_GP.detach().squeeze() ).mean() else: d_loss = self.wl.discriminator( wav.detach().squeeze(), y_pred.detach().squeeze() ).mean() # regularization (ignore length variation) d_loss += loss_reg out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze()) out_rec = self.wl.discriminator_forward( y_rec_gt_pred.detach().squeeze() ) # regularization (ignore reconstruction artifacts) d_loss += F.l1_loss(out_gt, out_rec) else: d_loss = self.wl.discriminator( wav.detach().squeeze(), y_pred.detach().squeeze() ).mean() else: d_loss = 0 # generator loss gen_loss = self.wl.generator(y_pred.squeeze()) gen_loss = gen_loss.mean() return d_loss, gen_loss, y_pred.detach().cpu().numpy() def length_to_mask(lengths): mask = ( torch.arange(lengths.max()) .unsqueeze(0) .expand(lengths.shape[0], -1) .type_as(lengths) ) mask = torch.gt(mask + 1, lengths.unsqueeze(1)) return mask