styletts2 / Modules /slmadv.py
ak36's picture
second_stage_v1
bf65828
import torch
import numpy as np
import torch.nn.functional as F
class SLMAdversarialLoss(torch.nn.Module):
def __init__(
self,
model,
wl,
sampler,
min_len,
max_len,
batch_percentage=0.5,
skip_update=10,
sig=1.5,
):
super().__init__()
self.model = model
self.wl = wl
self.sampler = sampler
self.min_len = min_len
self.max_len = max_len
self.batch_percentage = batch_percentage
self.sig = sig
self.skip_update = skip_update
# ------------------------------------------------------------------ #
def forward(
self,
iters,
y_rec_gt,
y_rec_gt_pred,
waves,
mel_input_length,
ref_text,
ref_lengths,
use_ind,
s_trg,
ref_s=None,
):
# ---- full-width mask (matches ref_text.size(1)) ----------------
seq_len = ref_text.size(1)
text_mask = (
torch.arange(seq_len, device=ref_text.device)
.unsqueeze(0)
>= ref_lengths.unsqueeze(1)
) # shape [B, seq_len]
bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# ----- style / prosody sampling ---------------------------------
if use_ind and np.random.rand() < 0.5:
s_preds = s_trg
else:
num_steps = np.random.randint(3, 5)
noise = torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device)
sampler_kwargs = dict(
noise=noise,
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps,
)
if ref_s is not None:
sampler_kwargs["features"] = ref_s
s_preds = self.sampler(**sampler_kwargs).squeeze(1)
s_dur, s = s_preds[:, 128:], s_preds[:, :128]
# random alignment placeholder must match the *padded* token width
seq_len = ref_text.size(1)
rand_align = torch.randn(ref_text.size(0), seq_len, 2, device=ref_text.device)
d, _ = self.model.predictor(
d_en, s_dur, ref_lengths,
rand_align,
text_mask,
)
# ----- differentiable duration modelling -----------------------
attn_preds, output_lengths = [], []
for _s2s_pred, _len in zip(d, ref_lengths):
_s2s_pred_org = _s2s_pred[: _len]
_s2s_pred_sig = torch.sigmoid(_s2s_pred_org)
_dur_pred = _s2s_pred_sig.sum(dim=-1)
l = int(torch.round(_s2s_pred_sig.sum()).item())
t = torch.arange(l, device=ref_text.device).unsqueeze(0).expand(_len, l)
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
h = torch.exp(-0.5 * (t - (l - loc.unsqueeze(-1))) ** 2 / (self.sig**2))
out = F.conv1d(
_s2s_pred_org.unsqueeze(0),
h.unsqueeze(1),
padding=h.size(-1) - 1,
groups=int(_len),
)[..., :l]
attn_preds.append(F.softmax(out.squeeze(), dim=0))
output_lengths.append(l)
max_len = max(output_lengths)
# ----- build full-width alignment matrix -----------------------
with torch.no_grad():
t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
seq_len = ref_text.size(1)
s2s_attn = torch.zeros(
len(ref_lengths), seq_len, max_len, device=ref_text.device
)
for bib, (attn, L) in enumerate(zip(attn_preds, output_lengths)):
s2s_attn[bib, : ref_lengths[bib], :L] = attn
asr_pred = t_en @ s2s_attn
_, p_pred = self.model.predictor(
d_en, s_dur, ref_lengths, s2s_attn, text_mask
)
# ----- clip extraction -----------------------------------------
mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
mel_len = min(mel_len, self.max_len // 2)
en, p_en, sp, wav = [], [], [], []
for bib, L_pred in enumerate(output_lengths):
L_gt = int(mel_input_length[bib].item() / 2)
if L_gt <= mel_len or L_pred <= mel_len:
continue
sp.append(s_preds[bib])
start = np.random.randint(0, L_pred - mel_len)
en.append(asr_pred[bib, :, start : start + mel_len])
p_en.append(p_pred[bib, :, start : start + mel_len])
start_gt = np.random.randint(0, L_gt - mel_len)
y = waves[bib][(start_gt * 2) * 300 : ((start_gt + mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(ref_text.device))
if len(wav) >= self.batch_percentage * len(waves):
break
if len(sp) <= 1:
return None
sp = torch.stack(sp)
wav = torch.stack(wav).float()
en = torch.stack(en)
p_en = torch.stack(p_en)
F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
# -------------- adversarial losses -----------------------------
if (iters + 1) % self.skip_update == 0:
d_loss = self.wl.discriminator(wav.squeeze(), y_pred.detach().squeeze()).mean()
else:
d_loss = 0
gen_loss = self.wl.generator(y_pred.squeeze()).mean()
return d_loss, gen_loss, y_pred.detach().cpu().numpy()
# ------------------------------------------------------------------ #
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Classic length mask: 1 → PAD, 0 → real token."""
max_len = lengths.max()
mask = (
torch.arange(max_len, device=lengths.device)
.unsqueeze(0)
.expand(lengths.size(0), -1)
)
return mask >= lengths.unsqueeze(1)