SyntaSpeech / tasks /tts /ps_flow.py
yerfor's picture
init
22871e7
import torch
from modules.tts.portaspeech.portaspeech_flow import PortaSpeechFlow
from tasks.tts.fs import FastSpeechTask
from tasks.tts.ps import PortaSpeechTask
from utils.audio.pitch.utils import denorm_f0
from utils.commons.hparams import hparams
class PortaSpeechFlowTask(PortaSpeechTask):
def __init__(self):
super().__init__()
self.training_post_glow = False
def build_tts_model(self):
ph_dict_size = len(self.token_encoder)
word_dict_size = len(self.word_encoder)
self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams)
def _training_step(self, sample, batch_idx, opt_idx):
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
and hparams['use_post_flow']
if hparams['two_stage'] and \
((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)):
return None
loss_output, _ = self.run_model(sample)
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
loss_output['batch_size'] = sample['txt_tokens'].size()[0]
if 'postflow' in loss_output and loss_output['postflow'] is None:
return None
return total_loss, loss_output
def run_model(self, sample, infer=False, *args, **kwargs):
if not infer:
training_post_glow = self.training_post_glow
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
output = self.model(sample['txt_tokens'],
sample['word_tokens'],
ph2word=sample['ph2word'],
mel2word=sample['mel2word'],
mel2ph=sample['mel2ph'],
word_len=sample['word_lengths'].max(),
tgt_mels=sample['mels'],
pitch=sample.get('pitch'),
spk_embed=spk_embed,
spk_id=spk_id,
infer=False,
forward_post_glow=training_post_glow,
two_stage=hparams['two_stage'],
global_step=self.global_step)
losses = {}
self.add_mel_loss(output['mel_out'], sample['mels'], losses)
if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']:
losses['postflow'] = output['postflow']
losses['l1'] = losses['l1'].detach()
losses['ssim'] = losses['ssim'].detach()
if not training_post_glow or not hparams['two_stage'] or not self.training:
losses['kl'] = output['kl']
if self.global_step < hparams['kl_start_steps']:
losses['kl'] = losses['kl'].detach()
else:
losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min'])
losses['kl'] = losses['kl'] * hparams['lambda_kl']
if hparams['dur_level'] == 'word':
self.add_dur_loss(
output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses)
self.get_attn_stats(output['attn'], sample, losses)
else:
super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses)
return losses, output
else:
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \
and hparams['use_post_flow']
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
output = self.model(
sample['txt_tokens'],
sample['word_tokens'],
ph2word=sample['ph2word'],
word_len=sample['word_lengths'].max(),
pitch=sample.get('pitch'),
mel2ph=sample['mel2ph'] if use_gt_dur else None,
mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None,
infer=True,
forward_post_glow=forward_post_glow,
spk_embed=spk_embed,
spk_id=spk_id,
two_stage=hparams['two_stage']
)
return output
def validation_step(self, sample, batch_idx):
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
and hparams['use_post_flow']
return super().validation_step(sample, batch_idx)
def save_valid_result(self, sample, batch_idx, model_out):
super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out)
sr = hparams['audio_sample_rate']
f0_gt = None
if sample.get('f0') is not None:
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
if self.global_step > 0:
# save FVAE result
if hparams['use_post_flow']:
wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt)
self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0],
f'mel_fvae_{batch_idx}', f0s=f0_gt)
def build_optimizer(self, model):
if hparams['two_stage'] and hparams['use_post_flow']:
self.optimizer = torch.optim.AdamW(
[p for name, p in self.model.named_parameters() if 'post_flow' not in name],
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
self.post_flow_optimizer = torch.optim.AdamW(
self.model.post_flow.parameters(),
lr=hparams['post_flow_lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
return [self.optimizer, self.post_flow_optimizer]
else:
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
return [self.optimizer]
def build_scheduler(self, optimizer):
return FastSpeechTask.build_scheduler(self, optimizer[0])