Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from tasks.tts.ps_adv import PortaSpeechAdvTask, FastSpeechTask | |
from text_to_speech.utils.commons.hparams import hparams | |
from text_to_speech.utils.nn.seq_utils import group_hidden_by_segs | |
class PortaSpeechAdvMLMTask(PortaSpeechAdvTask): | |
def build_scheduler(self, optimizer): | |
return [ | |
FastSpeechTask.build_scheduler(self, optimizer[0]), # Generator Scheduler | |
torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], # Discriminator Scheduler | |
**hparams["discriminator_scheduler_params"]), | |
] | |
def on_before_optimization(self, opt_idx): | |
if opt_idx in [0, 2]: | |
nn.utils.clip_grad_norm_(self.dp_params, hparams['clip_grad_norm']) | |
if self.use_bert: | |
nn.utils.clip_grad_norm_(self.bert_params, hparams['clip_grad_norm']) | |
nn.utils.clip_grad_norm_(self.gen_params_except_bert_and_dp, hparams['clip_grad_norm']) | |
else: | |
nn.utils.clip_grad_norm_(self.gen_params_except_dp, hparams['clip_grad_norm']) | |
else: | |
nn.utils.clip_grad_norm_(self.disc_params, hparams["clip_grad_norm"]) | |
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): | |
if self.scheduler is not None: | |
self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) | |
self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) | |
def _training_step(self, sample, batch_idx, optimizer_idx): | |
loss_output = {} | |
loss_weights = {} | |
disc_start = self.global_step >= hparams["disc_start_steps"] and hparams['lambda_mel_adv'] > 0 | |
if optimizer_idx == 0: | |
####################### | |
# Generator # | |
####################### | |
loss_output, model_out = self.run_model(sample, infer=False) | |
self.model_out_gt = self.model_out = \ | |
{k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} | |
if disc_start: | |
mel_p = model_out['mel_out'] | |
if hasattr(self.model, 'out2mel'): | |
mel_p = self.model.out2mel(mel_p) | |
o_ = self.mel_disc(mel_p) | |
p_, pc_ = o_['y'], o_['y_c'] | |
if p_ is not None: | |
loss_output['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) | |
loss_weights['a'] = hparams['lambda_mel_adv'] | |
if pc_ is not None: | |
loss_output['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) | |
loss_weights['ac'] = hparams['lambda_mel_adv'] | |
else: | |
return None | |
loss_output2, model_out2 = self.run_contrastive_learning(sample) | |
loss_output.update(loss_output2) | |
model_out.update(model_out2) | |
elif optimizer_idx == 1: | |
####################### | |
# Discriminator # | |
####################### | |
if disc_start and self.global_step % hparams['disc_interval'] == 0: | |
model_out = self.model_out_gt | |
mel_g = sample['mels'] | |
mel_p = model_out['mel_out'] | |
o = self.mel_disc(mel_g) | |
p, pc = o['y'], o['y_c'] | |
o_ = self.mel_disc(mel_p) | |
p_, pc_ = o_['y'], o_['y_c'] | |
if p_ is not None: | |
loss_output["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) | |
loss_output["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) | |
if pc_ is not None: | |
loss_output["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) | |
loss_output["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) | |
total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) | |
loss_output['batch_size'] = sample['txt_tokens'].size()[0] | |
return total_loss, loss_output | |
def run_contrastive_learning(self, sample): | |
losses = {} | |
outputs = {} | |
bert = self.model.encoder.bert.bert | |
bert_for_mlm = self.model.encoder.bert | |
pooler = self.model.encoder.pooler | |
sim = self.model.encoder.sim | |
tokenizer = self.model.encoder.tokenizer | |
ph_encoder = self.model.encoder | |
if hparams['lambda_cl'] > 0: | |
if hparams.get("cl_version", "v1") == "v1": | |
cl_feats = sample['cl_feats'] | |
bs, _, t = cl_feats['cl_input_ids'].shape | |
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) | |
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) | |
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
pooler_output = pooler(cl_attention_mask, cl_output) | |
pooler_output = pooler_output.reshape([bs, 2, -1]) | |
z1, z2 = pooler_output[:,0], pooler_output[:,1] | |
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) | |
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) | |
ce_fn = nn.CrossEntropyLoss() | |
cl_loss = ce_fn(cos_sim, labels) | |
losses['cl_v'] = cl_loss.detach() | |
losses['cl'] = cl_loss * hparams['lambda_cl'] | |
elif hparams['cl_version'] == "v2": | |
# use the output of ph encoder as sentence embedding | |
cl_feats = sample['cl_feats'] | |
bs, _, t = cl_feats['cl_input_ids'].shape | |
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) | |
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) | |
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) | |
txt_tokens = sample['txt_tokens'] | |
bert_feats = sample['bert_feats'] | |
src_nonpadding = (txt_tokens > 0).float()[:, :, None] | |
ph_encoder_out1 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding | |
ph_encoder_out2 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding | |
# word_encoding1 = group_hidden_by_segs(ph_encoder_out1, sample['ph2word'], sample['ph2word'].max().item()) | |
# word_encoding2 = group_hidden_by_segs(ph_encoder_out2, sample['ph2word'], sample['ph2word'].max().item()) | |
z1 = ((ph_encoder_out1 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) | |
z2 = ((ph_encoder_out2 * src_nonpadding).sum(1) / src_nonpadding.sum(1)) | |
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) | |
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) | |
ce_fn = nn.CrossEntropyLoss() | |
cl_loss = ce_fn(cos_sim, labels) | |
losses['cl_v'] = cl_loss.detach() | |
losses['cl'] = cl_loss * hparams['lambda_cl'] | |
elif hparams['cl_version'] == "v3": | |
# use the word-level contrastive learning | |
cl_feats = sample['cl_feats'] | |
bs, _, t = cl_feats['cl_input_ids'].shape | |
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) | |
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) | |
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
cl_output = cl_output.last_hidden_state.reshape([-1, 768]) # [bs*2,t_w,768] ==> [bs*2*t_w, 768] | |
cl_word_out = cl_output[cl_attention_mask.reshape([-1]).bool()] # [num_word*2, 768] | |
cl_word_out = cl_word_out.view([-1, 2, 768]) | |
z1_total, z2_total = cl_word_out[:,0], cl_word_out[:,1] # [num_word, 768] | |
ce_fn = nn.CrossEntropyLoss() | |
start_idx = 0 | |
lengths = cl_attention_mask.sum(-1) | |
cl_loss_accu = 0 | |
for i in range(bs): | |
length = lengths[i] | |
z1 = z1_total[start_idx:start_idx + length] | |
z2 = z2_total[start_idx:start_idx + length] | |
start_idx += length | |
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) | |
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) | |
cl_loss_accu += ce_fn(cos_sim, labels) * length | |
cl_loss = cl_loss_accu / lengths.sum() | |
losses['cl_v'] = cl_loss.detach() | |
losses['cl'] = cl_loss * hparams['lambda_cl'] | |
elif hparams['cl_version'] == "v4": | |
# with Wiki dataset | |
cl_feats = sample['cl_feats'] | |
bs, _, t = cl_feats['cl_input_ids'].shape | |
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t]) | |
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t]) | |
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t]) | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
pooler_output = pooler(cl_attention_mask, cl_output) | |
pooler_output = pooler_output.reshape([bs, 2, -1]) | |
z1, z2 = pooler_output[:,0], pooler_output[:,1] | |
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) | |
labels = torch.arange(cos_sim.size(0)).long().to(z1.device) | |
ce_fn = nn.CrossEntropyLoss() | |
cl_loss = ce_fn(cos_sim, labels) | |
losses['cl_v'] = cl_loss.detach() | |
losses['cl'] = cl_loss * hparams['lambda_cl'] | |
elif hparams['cl_version'] == "v5": | |
# with NLI dataset | |
cl_feats = sample['cl_feats'] | |
cl_input_ids = cl_feats['sent0']['cl_input_ids'] | |
cl_attention_mask = cl_feats['sent0']['cl_attention_mask'] | |
cl_token_type_ids = cl_feats['sent0']['cl_token_type_ids'] | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
z1 = pooler_output_sent0 = pooler(cl_attention_mask, cl_output) | |
cl_input_ids = cl_feats['sent1']['cl_input_ids'] | |
cl_attention_mask = cl_feats['sent1']['cl_attention_mask'] | |
cl_token_type_ids = cl_feats['sent1']['cl_token_type_ids'] | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
z2 = pooler_output_sent1 = pooler(cl_attention_mask, cl_output) | |
cl_input_ids = cl_feats['hard_neg']['cl_input_ids'] | |
cl_attention_mask = cl_feats['hard_neg']['cl_attention_mask'] | |
cl_token_type_ids = cl_feats['hard_neg']['cl_token_type_ids'] | |
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,) | |
z3 = pooler_output_neg = pooler(cl_attention_mask, cl_output) | |
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0)) | |
z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0)) | |
cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) # [n_sent, n_sent * 2] | |
labels = torch.arange(cos_sim.size(0)).long().to(cos_sim.device) # [n_sent, ] | |
ce_fn = nn.CrossEntropyLoss() | |
cl_loss = ce_fn(cos_sim, labels) | |
losses['cl_v'] = cl_loss.detach() | |
losses['cl'] = cl_loss * hparams['lambda_cl'] | |
else: | |
raise NotImplementedError() | |
if hparams['lambda_mlm'] > 0: | |
cl_feats = sample['cl_feats'] | |
mlm_input_ids = cl_feats['mlm_input_ids'] | |
bs, t = mlm_input_ids.shape | |
mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1))) | |
mlm_labels = cl_feats['mlm_labels'] | |
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) | |
mlm_attention_mask = cl_feats['mlm_attention_mask'] | |
prediction_scores = bert_for_mlm(mlm_input_ids, mlm_attention_mask).logits | |
ce_fn = nn.CrossEntropyLoss(reduction="none") | |
mlm_loss = ce_fn(prediction_scores.view(-1, tokenizer.vocab_size), mlm_labels.view(-1)) | |
mlm_loss = mlm_loss[mlm_labels.view(-1)>=0].mean() | |
losses['mlm'] = mlm_loss * hparams['lambda_mlm'] | |
losses['mlm_v'] = mlm_loss.detach() | |
return losses, outputs | |