Spaces:
Running
Running
import torch | |
import utils | |
from utils.hparams import hparams | |
from .diff.net import DiffNet | |
from .diff.shallow_diffusion_tts import GaussianDiffusion | |
from .task import DiffFsTask | |
from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder | |
from utils.pitch_utils import denorm_f0 | |
from tasks.tts.fs2_utils import FastSpeechDataset | |
DIFF_DECODERS = { | |
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), | |
} | |
class DiffSpeechTask(DiffFsTask): | |
def __init__(self): | |
super(DiffSpeechTask, self).__init__() | |
self.dataset_cls = FastSpeechDataset | |
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() | |
def build_tts_model(self): | |
mel_bins = hparams['audio_num_mel_bins'] | |
self.model = GaussianDiffusion( | |
phone_encoder=self.phone_encoder, | |
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), | |
timesteps=hparams['timesteps'], | |
K_step=hparams['K_step'], | |
loss_type=hparams['diff_loss_type'], | |
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], | |
) | |
if hparams['fs2_ckpt'] != '': | |
utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) | |
# self.model.fs2.decoder = None | |
for k, v in self.model.fs2.named_parameters(): | |
if not 'predictor' in k: | |
v.requires_grad = False | |
def build_optimizer(self, model): | |
self.optimizer = optimizer = torch.optim.AdamW( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=hparams['lr'], | |
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), | |
weight_decay=hparams['weight_decay']) | |
return optimizer | |
def run_model(self, model, sample, return_output=False, infer=False): | |
txt_tokens = sample['txt_tokens'] # [B, T_t] | |
target = sample['mels'] # [B, T_s, 80] | |
# mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s] | |
mel2ph = sample['mel2ph'] | |
f0 = sample['f0'] | |
uv = sample['uv'] | |
energy = sample['energy'] | |
# fs2_mel = sample['fs2_mels'] | |
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') | |
if hparams['pitch_type'] == 'cwt': | |
cwt_spec = sample[f'cwt_spec'] | |
f0_mean = sample['f0_mean'] | |
f0_std = sample['f0_std'] | |
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) | |
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, | |
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) | |
losses = {} | |
if 'diff_loss' in output: | |
losses['mel'] = output['diff_loss'] | |
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) | |
if hparams['use_pitch_embed']: | |
self.add_pitch_loss(output, sample, losses) | |
if hparams['use_energy_embed']: | |
self.add_energy_loss(output['energy_pred'], energy, losses) | |
if not return_output: | |
return losses | |
else: | |
return losses, output | |
def validation_step(self, sample, batch_idx): | |
outputs = {} | |
txt_tokens = sample['txt_tokens'] # [B, T_t] | |
energy = sample['energy'] | |
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') | |
mel2ph = sample['mel2ph'] | |
f0 = sample['f0'] | |
uv = sample['uv'] | |
outputs['losses'] = {} | |
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) | |
outputs['total_loss'] = sum(outputs['losses'].values()) | |
outputs['nsamples'] = sample['nsamples'] | |
outputs = utils.tensors_to_scalars(outputs) | |
if batch_idx < hparams['num_valid_plots']: | |
# model_out = self.model( | |
# txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, infer=True) | |
# self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}') | |
model_out = self.model( | |
txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) | |
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) | |
self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm')) | |
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) | |
return outputs | |
############ | |
# validation plots | |
############ | |
def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): | |
gt_wav = gt_wav[0].cpu().numpy() | |
wav_out = wav_out[0].cpu().numpy() | |
gt_f0 = gt_f0[0].cpu().numpy() | |
f0 = f0[0].cpu().numpy() | |
if is_mel: | |
gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) | |
wav_out = self.vocoder.spec2wav(wav_out, f0=f0) | |
self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) | |
self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) | |