Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
class SLMAdversarialLoss(torch.nn.Module): | |
def __init__( | |
self, | |
model, | |
wl, | |
sampler, | |
min_len, | |
max_len, | |
batch_percentage=0.5, | |
skip_update=10, | |
sig=1.5, | |
): | |
super(SLMAdversarialLoss, self).__init__() | |
self.model = model | |
self.wl = wl | |
self.sampler = sampler | |
self.min_len = min_len | |
self.max_len = max_len | |
self.batch_percentage = batch_percentage | |
self.sig = sig | |
self.skip_update = skip_update | |
def forward( | |
self, | |
iters, | |
y_rec_gt, | |
y_rec_gt_pred, | |
waves, | |
mel_input_length, | |
ref_text, | |
ref_lengths, | |
use_ind, | |
s_trg, | |
ref_s=None, | |
): | |
text_mask = length_to_mask(ref_lengths).to(ref_text.device) | |
bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) | |
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) | |
if use_ind and np.random.rand() < 0.5: | |
s_preds = s_trg | |
else: | |
num_steps = np.random.randint(3, 5) | |
if ref_s is not None: | |
s_preds = self.sampler( | |
noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
embedding=bert_dur, | |
embedding_scale=1, | |
features=ref_s, # reference from the same speaker as the embedding | |
embedding_mask_proba=0.1, | |
num_steps=num_steps, | |
).squeeze(1) | |
else: | |
s_preds = self.sampler( | |
noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device), | |
embedding=bert_dur, | |
embedding_scale=1, | |
embedding_mask_proba=0.1, | |
num_steps=num_steps, | |
).squeeze(1) | |
s_dur = s_preds[:, 128:] | |
s = s_preds[:, :128] | |
d, _ = self.model.predictor( | |
d_en, | |
s_dur, | |
ref_lengths, | |
torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device), | |
text_mask, | |
) | |
bib = 0 | |
output_lengths = [] | |
attn_preds = [] | |
# differentiable duration modeling | |
for _s2s_pred, _text_length in zip(d, ref_lengths): | |
_s2s_pred_org = _s2s_pred[:_text_length, :] | |
_s2s_pred = torch.sigmoid(_s2s_pred_org) | |
_dur_pred = _s2s_pred.sum(axis=-1) | |
l = int(torch.round(_s2s_pred.sum()).item()) | |
t = torch.arange(0, l).expand(l) | |
t = ( | |
torch.arange(0, l) | |
.unsqueeze(0) | |
.expand((len(_s2s_pred), l)) | |
.to(ref_text.device) | |
) | |
loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2 | |
h = torch.exp( | |
-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2 | |
) | |
out = torch.nn.functional.conv1d( | |
_s2s_pred_org.unsqueeze(0), | |
h.unsqueeze(1), | |
padding=h.shape[-1] - 1, | |
groups=int(_text_length), | |
)[..., :l] | |
attn_preds.append(F.softmax(out.squeeze(), dim=0)) | |
output_lengths.append(l) | |
max_len = max(output_lengths) | |
with torch.no_grad(): | |
t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask) | |
s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to( | |
ref_text.device | |
) | |
for bib in range(len(output_lengths)): | |
s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib] | |
asr_pred = t_en @ s2s_attn | |
_, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask) | |
mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2) | |
mel_len = min(mel_len, self.max_len // 2) | |
# get clips | |
en = [] | |
p_en = [] | |
sp = [] | |
F0_fakes = [] | |
N_fakes = [] | |
wav = [] | |
for bib in range(len(output_lengths)): | |
mel_length_pred = output_lengths[bib] | |
mel_length_gt = int(mel_input_length[bib].item() / 2) | |
if mel_length_gt <= mel_len or mel_length_pred <= mel_len: | |
continue | |
sp.append(s_preds[bib]) | |
random_start = np.random.randint(0, mel_length_pred - mel_len) | |
en.append(asr_pred[bib, :, random_start : random_start + mel_len]) | |
p_en.append(p_pred[bib, :, random_start : random_start + mel_len]) | |
# get ground truth clips | |
random_start = np.random.randint(0, mel_length_gt - mel_len) | |
y = waves[bib][ | |
(random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300 | |
] | |
wav.append(torch.from_numpy(y).to(ref_text.device)) | |
if len(wav) >= self.batch_percentage * len( | |
waves | |
): # prevent OOM due to longer lengths | |
break | |
if len(sp) <= 1: | |
return None | |
sp = torch.stack(sp) | |
wav = torch.stack(wav).float() | |
en = torch.stack(en) | |
p_en = torch.stack(p_en) | |
F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:]) | |
y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128]) | |
# discriminator loss | |
if (iters + 1) % self.skip_update == 0: | |
if np.random.randint(0, 2) == 0: | |
wav = y_rec_gt_pred | |
use_rec = True | |
else: | |
use_rec = False | |
crop_size = min(wav.size(-1), y_pred.size(-1)) | |
if ( | |
use_rec | |
): # use reconstructed (shorter lengths), do length invariant regularization | |
if wav.size(-1) > y_pred.size(-1): | |
real_GP = wav[:, :, :crop_size] | |
out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) | |
out_org = self.wl.discriminator_forward(wav.detach().squeeze()) | |
loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) | |
if np.random.randint(0, 2) == 0: | |
d_loss = self.wl.discriminator( | |
real_GP.detach().squeeze(), y_pred.detach().squeeze() | |
).mean() | |
else: | |
d_loss = self.wl.discriminator( | |
wav.detach().squeeze(), y_pred.detach().squeeze() | |
).mean() | |
else: | |
real_GP = y_pred[:, :, :crop_size] | |
out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze()) | |
out_org = self.wl.discriminator_forward(y_pred.detach().squeeze()) | |
loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)]) | |
if np.random.randint(0, 2) == 0: | |
d_loss = self.wl.discriminator( | |
wav.detach().squeeze(), real_GP.detach().squeeze() | |
).mean() | |
else: | |
d_loss = self.wl.discriminator( | |
wav.detach().squeeze(), y_pred.detach().squeeze() | |
).mean() | |
# regularization (ignore length variation) | |
d_loss += loss_reg | |
out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze()) | |
out_rec = self.wl.discriminator_forward( | |
y_rec_gt_pred.detach().squeeze() | |
) | |
# regularization (ignore reconstruction artifacts) | |
d_loss += F.l1_loss(out_gt, out_rec) | |
else: | |
d_loss = self.wl.discriminator( | |
wav.detach().squeeze(), y_pred.detach().squeeze() | |
).mean() | |
else: | |
d_loss = 0 | |
# generator loss | |
gen_loss = self.wl.generator(y_pred.squeeze()) | |
gen_loss = gen_loss.mean() | |
return d_loss, gen_loss, y_pred.detach().cpu().numpy() | |
def length_to_mask(lengths): | |
mask = ( | |
torch.arange(lengths.max()) | |
.unsqueeze(0) | |
.expand(lengths.shape[0], -1) | |
.type_as(lengths) | |
) | |
mask = torch.gt(mask + 1, lengths.unsqueeze(1)) | |
return mask | |