Spaces:
Running
Running
File size: 6,153 Bytes
26925fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import matplotlib
matplotlib.use('Agg')
import torch
import numpy as np
import os
from tasks.base_task import BaseDataset
from tasks.tts.fs2 import FastSpeech2Task
from modules.fastspeech.pe import PitchExtractor
import utils
from utils.indexed_datasets import IndexedDataset
from utils.hparams import hparams
from utils.plot import f0_to_figure
from utils.pitch_utils import norm_interp_f0, denorm_f0
class PeDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.data_dir = hparams['binary_data_dir']
self.prefix = prefix
self.hparams = hparams
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.indexed_ds = None
# pitch stats
f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
if os.path.exists(f0_stats_fn):
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
hparams['f0_mean'] = float(hparams['f0_mean'])
hparams['f0_std'] = float(hparams['f0_std'])
else:
hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
if prefix == 'test':
if hparams['num_test_samples'] > 0:
self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
self.sizes = [self.sizes[i] for i in self.avail_idxs]
def _get_item(self, index):
if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
index = self.avail_idxs[index]
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
return self.indexed_ds[index]
def __getitem__(self, index):
hparams = self.hparams
item = self._get_item(index)
max_frames = hparams['max_frames']
spec = torch.Tensor(item['mel'])[:max_frames]
# mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
# print(item.keys(), item['mel'].shape, spec.shape)
sample = {
"id": index,
"item_name": item['item_name'],
"text": item['txt'],
"mel": spec,
"pitch": pitch,
"f0": f0,
"uv": uv,
# "mel2ph": mel2ph,
# "mel_nonpadding": spec.abs().sum(-1) > 0,
}
return sample
def collater(self, samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s['id'] for s in samples])
item_names = [s['item_name'] for s in samples]
text = [s['text'] for s in samples]
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
pitch = utils.collate_1d([s['pitch'] for s in samples])
uv = utils.collate_1d([s['uv'] for s in samples])
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
# mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
# if samples[0]['mel2ph'] is not None else None
# mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0)
batch = {
'id': id,
'item_name': item_names,
'nsamples': len(samples),
'text': text,
'mels': mels,
'mel_lengths': mel_lengths,
'pitch': pitch,
# 'mel2ph': mel2ph,
# 'mel_nonpaddings': mel_nonpaddings,
'f0': f0,
'uv': uv,
}
return batch
class PitchExtractionTask(FastSpeech2Task):
def __init__(self):
super().__init__()
self.dataset_cls = PeDataset
def build_tts_model(self):
self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
# def build_scheduler(self, optimizer):
# return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
def _training_step(self, sample, batch_idx, _):
loss_output = self.run_model(self.model, sample)
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
loss_output['batch_size'] = sample['mels'].size()[0]
return total_loss, loss_output
def validation_step(self, sample, batch_idx):
outputs = {}
outputs['losses'] = {}
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
outputs['total_loss'] = sum(outputs['losses'].values())
outputs['nsamples'] = sample['nsamples']
outputs = utils.tensors_to_scalars(outputs)
if batch_idx < hparams['num_valid_plots']:
self.plot_pitch(batch_idx, model_out, sample)
return outputs
def run_model(self, model, sample, return_output=False, infer=False):
f0 = sample['f0']
uv = sample['uv']
output = model(sample['mels'])
losses = {}
self.add_pitch_loss(output, sample, losses)
if not return_output:
return losses
else:
return losses, output
def plot_pitch(self, batch_idx, model_out, sample):
gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
self.logger.experiment.add_figure(
f'f0_{batch_idx}',
f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
self.global_step)
def add_pitch_loss(self, output, sample, losses):
# mel2ph = sample['mel2ph'] # [B, T_s]
mel = sample['mels']
f0 = sample['f0']
uv = sample['uv']
# nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
# else (sample['txt_tokens'] != 0).float()
nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings']
# print(nonpadding[0][-8:], nonpadding.shape)
self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) |