lmzjms's picture
Upload 591 files
9206300
raw
history blame
12.9 kB
import torch
from torch import nn
from tasks.tts.ps_adv import PortaSpeechAdvTask, FastSpeechTask
from text_to_speech.utils.commons.hparams import hparams
from text_to_speech.utils.nn.seq_utils import group_hidden_by_segs
class PortaSpeechAdvMLMTask(PortaSpeechAdvTask):
def build_scheduler(self, optimizer):
return [
FastSpeechTask.build_scheduler(self, optimizer[0]), # Generator Scheduler
torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], # Discriminator Scheduler
**hparams["discriminator_scheduler_params"]),
]
def on_before_optimization(self, opt_idx):
if opt_idx in [0, 2]:
nn.utils.clip_grad_norm_(self.dp_params, hparams['clip_grad_norm'])
if self.use_bert:
nn.utils.clip_grad_norm_(self.bert_params, hparams['clip_grad_norm'])
nn.utils.clip_grad_norm_(self.gen_params_except_bert_and_dp, hparams['clip_grad_norm'])
else:
nn.utils.clip_grad_norm_(self.gen_params_except_dp, hparams['clip_grad_norm'])
else:
nn.utils.clip_grad_norm_(self.disc_params, hparams["clip_grad_norm"])
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
if self.scheduler is not None:
self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches'])
self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches'])
def _training_step(self, sample, batch_idx, optimizer_idx):
loss_output = {}
loss_weights = {}
disc_start = self.global_step >= hparams["disc_start_steps"] and hparams['lambda_mel_adv'] > 0
if optimizer_idx == 0:
#######################
# Generator #
#######################
loss_output, model_out = self.run_model(sample, infer=False)
self.model_out_gt = self.model_out = \
{k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)}
if disc_start:
mel_p = model_out['mel_out']
if hasattr(self.model, 'out2mel'):
mel_p = self.model.out2mel(mel_p)
o_ = self.mel_disc(mel_p)
p_, pc_ = o_['y'], o_['y_c']
if p_ is not None:
loss_output['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size()))
loss_weights['a'] = hparams['lambda_mel_adv']
if pc_ is not None:
loss_output['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size()))
loss_weights['ac'] = hparams['lambda_mel_adv']
else:
return None
loss_output2, model_out2 = self.run_contrastive_learning(sample)
loss_output.update(loss_output2)
model_out.update(model_out2)
elif optimizer_idx == 1:
#######################
# Discriminator #
#######################
if disc_start and self.global_step % hparams['disc_interval'] == 0:
model_out = self.model_out_gt
mel_g = sample['mels']
mel_p = model_out['mel_out']
o = self.mel_disc(mel_g)
p, pc = o['y'], o['y_c']
o_ = self.mel_disc(mel_p)
p_, pc_ = o_['y'], o_['y_c']
if p_ is not None:
loss_output["r"] = self.mse_loss_fn(p, p.new_ones(p.size()))
loss_output["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size()))
if pc_ is not None:
loss_output["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size()))
loss_output["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size()))
total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad])
loss_output['batch_size'] = sample['txt_tokens'].size()[0]
return total_loss, loss_output
def run_contrastive_learning(self, sample):
losses = {}
outputs = {}
bert = self.model.encoder.bert.bert
bert_for_mlm = self.model.encoder.bert
pooler = self.model.encoder.pooler
sim = self.model.encoder.sim
tokenizer = self.model.encoder.tokenizer
ph_encoder = self.model.encoder
if hparams['lambda_cl'] > 0:
if hparams.get("cl_version", "v1") == "v1":
cl_feats = sample['cl_feats']
bs, _, t = cl_feats['cl_input_ids'].shape
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t])
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t])
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t])
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
pooler_output = pooler(cl_attention_mask, cl_output)
pooler_output = pooler_output.reshape([bs, 2, -1])
z1, z2 = pooler_output[:,0], pooler_output[:,1]
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0))
labels = torch.arange(cos_sim.size(0)).long().to(z1.device)
ce_fn = nn.CrossEntropyLoss()
cl_loss = ce_fn(cos_sim, labels)
losses['cl_v'] = cl_loss.detach()
losses['cl'] = cl_loss * hparams['lambda_cl']
elif hparams['cl_version'] == "v2":
# use the output of ph encoder as sentence embedding
cl_feats = sample['cl_feats']
bs, _, t = cl_feats['cl_input_ids'].shape
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t])
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t])
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t])
txt_tokens = sample['txt_tokens']
bert_feats = sample['bert_feats']
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
ph_encoder_out1 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding
ph_encoder_out2 = ph_encoder(txt_tokens, bert_feats=bert_feats, ph2word=sample['ph2word']) * src_nonpadding
# word_encoding1 = group_hidden_by_segs(ph_encoder_out1, sample['ph2word'], sample['ph2word'].max().item())
# word_encoding2 = group_hidden_by_segs(ph_encoder_out2, sample['ph2word'], sample['ph2word'].max().item())
z1 = ((ph_encoder_out1 * src_nonpadding).sum(1) / src_nonpadding.sum(1))
z2 = ((ph_encoder_out2 * src_nonpadding).sum(1) / src_nonpadding.sum(1))
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0))
labels = torch.arange(cos_sim.size(0)).long().to(z1.device)
ce_fn = nn.CrossEntropyLoss()
cl_loss = ce_fn(cos_sim, labels)
losses['cl_v'] = cl_loss.detach()
losses['cl'] = cl_loss * hparams['lambda_cl']
elif hparams['cl_version'] == "v3":
# use the word-level contrastive learning
cl_feats = sample['cl_feats']
bs, _, t = cl_feats['cl_input_ids'].shape
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t])
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t])
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t])
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
cl_output = cl_output.last_hidden_state.reshape([-1, 768]) # [bs*2,t_w,768] ==> [bs*2*t_w, 768]
cl_word_out = cl_output[cl_attention_mask.reshape([-1]).bool()] # [num_word*2, 768]
cl_word_out = cl_word_out.view([-1, 2, 768])
z1_total, z2_total = cl_word_out[:,0], cl_word_out[:,1] # [num_word, 768]
ce_fn = nn.CrossEntropyLoss()
start_idx = 0
lengths = cl_attention_mask.sum(-1)
cl_loss_accu = 0
for i in range(bs):
length = lengths[i]
z1 = z1_total[start_idx:start_idx + length]
z2 = z2_total[start_idx:start_idx + length]
start_idx += length
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0))
labels = torch.arange(cos_sim.size(0)).long().to(z1.device)
cl_loss_accu += ce_fn(cos_sim, labels) * length
cl_loss = cl_loss_accu / lengths.sum()
losses['cl_v'] = cl_loss.detach()
losses['cl'] = cl_loss * hparams['lambda_cl']
elif hparams['cl_version'] == "v4":
# with Wiki dataset
cl_feats = sample['cl_feats']
bs, _, t = cl_feats['cl_input_ids'].shape
cl_input_ids = cl_feats['cl_input_ids'].reshape([bs*2, t])
cl_attention_mask = cl_feats['cl_attention_mask'].reshape([bs*2, t])
cl_token_type_ids = cl_feats['cl_token_type_ids'].reshape([bs*2, t])
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
pooler_output = pooler(cl_attention_mask, cl_output)
pooler_output = pooler_output.reshape([bs, 2, -1])
z1, z2 = pooler_output[:,0], pooler_output[:,1]
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0))
labels = torch.arange(cos_sim.size(0)).long().to(z1.device)
ce_fn = nn.CrossEntropyLoss()
cl_loss = ce_fn(cos_sim, labels)
losses['cl_v'] = cl_loss.detach()
losses['cl'] = cl_loss * hparams['lambda_cl']
elif hparams['cl_version'] == "v5":
# with NLI dataset
cl_feats = sample['cl_feats']
cl_input_ids = cl_feats['sent0']['cl_input_ids']
cl_attention_mask = cl_feats['sent0']['cl_attention_mask']
cl_token_type_ids = cl_feats['sent0']['cl_token_type_ids']
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
z1 = pooler_output_sent0 = pooler(cl_attention_mask, cl_output)
cl_input_ids = cl_feats['sent1']['cl_input_ids']
cl_attention_mask = cl_feats['sent1']['cl_attention_mask']
cl_token_type_ids = cl_feats['sent1']['cl_token_type_ids']
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
z2 = pooler_output_sent1 = pooler(cl_attention_mask, cl_output)
cl_input_ids = cl_feats['hard_neg']['cl_input_ids']
cl_attention_mask = cl_feats['hard_neg']['cl_attention_mask']
cl_token_type_ids = cl_feats['hard_neg']['cl_token_type_ids']
cl_output = bert(cl_input_ids, attention_mask=cl_attention_mask,token_type_ids=cl_token_type_ids,)
z3 = pooler_output_neg = pooler(cl_attention_mask, cl_output)
cos_sim = sim(z1.unsqueeze(1), z2.unsqueeze(0))
z1_z3_cos = sim(z1.unsqueeze(1), z3.unsqueeze(0))
cos_sim = torch.cat([cos_sim, z1_z3_cos], 1) # [n_sent, n_sent * 2]
labels = torch.arange(cos_sim.size(0)).long().to(cos_sim.device) # [n_sent, ]
ce_fn = nn.CrossEntropyLoss()
cl_loss = ce_fn(cos_sim, labels)
losses['cl_v'] = cl_loss.detach()
losses['cl'] = cl_loss * hparams['lambda_cl']
else:
raise NotImplementedError()
if hparams['lambda_mlm'] > 0:
cl_feats = sample['cl_feats']
mlm_input_ids = cl_feats['mlm_input_ids']
bs, t = mlm_input_ids.shape
mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))
mlm_labels = cl_feats['mlm_labels']
mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
mlm_attention_mask = cl_feats['mlm_attention_mask']
prediction_scores = bert_for_mlm(mlm_input_ids, mlm_attention_mask).logits
ce_fn = nn.CrossEntropyLoss(reduction="none")
mlm_loss = ce_fn(prediction_scores.view(-1, tokenizer.vocab_size), mlm_labels.view(-1))
mlm_loss = mlm_loss[mlm_labels.view(-1)>=0].mean()
losses['mlm'] = mlm_loss * hparams['lambda_mlm']
losses['mlm_v'] = mlm_loss.detach()
return losses, outputs