DiffSpeech / tasks /tts /fs2_orig.py
RayeRen's picture
init
d1b91e7
raw
history blame
No virus
5.38 kB
import torch
import torch.nn.functional as F
from modules.tts.fs2_orig import FastSpeech2Orig
from tasks.tts.dataset_utils import FastSpeechDataset
from tasks.tts.fs import FastSpeechTask
from utils.commons.dataset_utils import collate_1d, collate_2d
from utils.commons.hparams import hparams
class FastSpeech2OrigDataset(FastSpeechDataset):
def __init__(self, prefix, shuffle=False, items=None, data_dir=None):
super().__init__(prefix, shuffle, items, data_dir)
self.pitch_type = hparams.get('pitch_type')
def __getitem__(self, index):
sample = super().__getitem__(index)
item = self._get_item(index)
hparams = self.hparams
mel = sample['mel']
T = mel.shape[0]
sample['energy'] = (mel.exp() ** 2).sum(-1).sqrt()
if hparams['use_pitch_embed'] and self.pitch_type == 'cwt':
cwt_spec = torch.Tensor(item['cwt_spec'])[:T]
f0_mean = item.get('f0_mean', item.get('cwt_mean'))
f0_std = item.get('f0_std', item.get('cwt_std'))
sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
return sample
def collater(self, samples):
if len(samples) == 0:
return {}
batch = super().collater(samples)
if hparams['use_pitch_embed']:
energy = collate_1d([s['energy'] for s in samples], 0.0)
else:
energy = None
batch.update({'energy': energy})
if self.pitch_type == 'cwt':
cwt_spec = collate_2d([s['cwt_spec'] for s in samples])
f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
f0_std = torch.Tensor([s['f0_std'] for s in samples])
batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
return batch
class FastSpeech2OrigTask(FastSpeechTask):
def __init__(self):
super(FastSpeech2OrigTask, self).__init__()
self.dataset_cls = FastSpeech2OrigDataset
def build_tts_model(self):
dict_size = len(self.token_encoder)
self.model = FastSpeech2Orig(dict_size, hparams)
def run_model(self, sample, infer=False, *args, **kwargs):
txt_tokens = sample['txt_tokens'] # [B, T_t]
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
if not infer:
target = sample['mels'] # [B, T_s, 80]
mel2ph = sample['mel2ph'] # [B, T_s]
f0 = sample.get('f0')
uv = sample.get('uv')
energy = sample.get('energy')
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
f0=f0, uv=uv, energy=energy, infer=False)
losses = {}
self.add_mel_loss(output['mel_out'], target, losses)
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
if hparams['use_pitch_embed']:
self.add_pitch_loss(output, sample, losses)
if hparams['use_energy_embed']:
self.add_energy_loss(output, sample, losses)
return losses, output
else:
mel2ph, uv, f0, energy = None, None, None, None
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
use_gt_f0 = kwargs.get('infer_use_gt_f0', hparams['use_gt_f0'])
use_gt_energy = kwargs.get('infer_use_gt_energy', hparams['use_gt_energy'])
if use_gt_dur:
mel2ph = sample['mel2ph']
if use_gt_f0:
f0 = sample['f0']
uv = sample['uv']
if use_gt_energy:
energy = sample['energy']
output = self.model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, spk_id=spk_id,
f0=f0, uv=uv, energy=energy, infer=True)
return output
def add_pitch_loss(self, output, sample, losses):
if hparams['pitch_type'] == 'cwt':
cwt_spec = sample[f'cwt_spec']
f0_mean = sample['f0_mean']
uv = sample['uv']
mel2ph = sample['mel2ph']
f0_std = sample['f0_std']
cwt_pred = output['cwt'][:, :, :10]
f0_mean_pred = output['f0_mean']
f0_std_pred = output['f0_std']
nonpadding = (mel2ph != 0).float()
losses['C'] = F.l1_loss(cwt_pred, cwt_spec) * hparams['lambda_f0']
if hparams['use_uv']:
assert output['cwt'].shape[-1] == 11
uv_pred = output['cwt'][:, :, -1]
losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none')
* nonpadding).sum() / nonpadding.sum() * hparams['lambda_uv']
losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0']
losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0']
else:
super(FastSpeech2OrigTask, self).add_pitch_loss(output, sample, losses)
def add_energy_loss(self, output, sample, losses):
energy_pred, energy = output['energy_pred'], sample['energy']
nonpadding = (energy != 0).float()
loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
loss = loss * hparams['lambda_energy']
losses['e'] = loss