import torch import numpy as np import torch.nn.functional as F class SLMAdversarialLoss(torch.nn.Module): def __init__( self, model, wl, sampler, min_len, max_len, batch_percentage=0.5, skip_update=10, sig=1.5, ): super().__init__() self.model = model self.wl = wl self.sampler = sampler self.min_len = min_len self.max_len = max_len self.batch_percentage = batch_percentage self.sig = sig self.skip_update = skip_update # ------------------------------------------------------------------ # def forward( self, iters, y_rec_gt, y_rec_gt_pred, waves, mel_input_length, ref_text, ref_lengths, use_ind, s_trg, ref_s=None, ): # ---- full-width mask (matches ref_text.size(1)) ---------------- seq_len = ref_text.size(1) text_mask = ( torch.arange(seq_len, device=ref_text.device) .unsqueeze(0) >= ref_lengths.unsqueeze(1) ) # shape [B, seq_len] bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) # ----- style / prosody sampling --------------------------------- if use_ind and np.random.rand() < 0.5: s_preds = s_trg else: num_steps = np.random.randint(3, 5) noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device) sampler_kwargs = dict( noise=noise, embedding=bert_dur, embedding_scale=1, embedding_mask_proba=0.1, num_steps=num_steps, ) if ref_s is not None: sampler_kwargs["features"] = ref_s s_preds = self.sampler(**sampler_kwargs).squeeze(1) s_dur, s = s_preds[:, 128:], s_preds[:, :128] # random alignment placeholder must match the *padded* token width seq_len = ref_text.size(1) rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device) d, _ = self.model.predictor( d_en, s_dur, ref_lengths, rand_align, text_mask, ) # ----- differentiable duration modelling ----------------------- attn_preds, output_lengths = [], [] for _s2s_pred, _len in zip(d, ref_lengths): _s2s_pred_org = _s2s_pred[: _len] _s2s_pred_sig = torch.sigmoid(_s2s_pred_org) _dur_pred = _s2s_pred_sig.sum(dim=-1) l = int(torch.round(_s2s_pred_sig.sum()).item()) t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l) loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2)) out = F.conv1d( _s2s_pred_org.unsqueeze(0), h.unsqueeze(1), padding=h.size(-1) - 1, groups=int(_len), )[..., :l] attn_preds.append(F.softmax(out.squeeze(), dim=0)) output_lengths.append(l) max_len = max(output_lengths) # ----- build full-width alignment matrix ----------------------- with torch.no_grad(): t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) seq_len = ref_text.size(1) s2s_attn = torch.zeros( len(ref_lengths), seq_len, max_len, device=ref_text.device ) for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)): s2s_attn[bib, : ref_lengths[bib], :L] = attn asr_pred = t_en @ s2s_attn _, p_pred = self.model.predictor( d_en, s_dur, ref_lengths, s2s_attn, text_mask ) # ----- clip extraction ----------------------------------------- mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) mel_len = min(mel_len, self.max_len // 2) en, p_en, sp, wav = [], [], [], [] for bib, L_pred in enumerate(output_lengths): L_gt = int(mel_input_length[bib].item() / 2) if L_gt <= mel_len or L_pred <= mel_len: continue sp.append(s_preds[bib]) start = np.random.randint(0, L_pred - mel_len) en.append(asr_pred[bib, :, start : start + mel_len]) p_en.append(p_pred[bib, :, start : start + mel_len]) start_gt = np.random.randint(0, L_gt - mel_len) y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300] wav.append(torch.from_numpy(y).to(ref_text.device)) if len(wav) >= self.batch_percentage * len(waves): break if len(sp) <= 1: return None sp = torch.stack(sp) wav = torch.stack(wav).float() en = torch.stack(en) p_en = torch.stack(p_en) F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) # -------------- adversarial losses ----------------------------- if (iters + 1) % self.skip_update == 0: d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean() else: d_loss = 0 gen_loss = self.wl.generator(y_pred.squeeze()).mean() return d_loss, gen_loss, y_pred.detach().cpu().numpy() # ------------------------------------------------------------------ # def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: """Classic length mask: 1 → PAD, 0 → real token.""" max_len = lengths.max() mask = ( torch.arange(max_len, device=lengths.device) .unsqueeze(0) .expand(lengths.size(0), -1) ) return mask >= lengths.unsqueeze(1)