pflowtts_ukr_demo / pflow /data /text_mel_datamodule.py
Serhiy Stetskovych
New multispeaker model
3d2700d
raw
history blame contribute delete
No virus
8.97 kB
import random
from typing import Any, Dict, Optional
import torch
import torchaudio as ta
from lightning import LightningDataModule
from torch.utils.data.dataloader import DataLoader
from pflow.text import text_to_sequence
from pflow.utils.audio import mel_spectrogram
from pflow.utils.model import fix_len_compatibility, normalize
from pflow.utils.utils import intersperse
def parse_filelist(filelist_path, split_char="|"):
with open(filelist_path, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split_char) for line in f]
return filepaths_and_text
class TextMelDataModule(LightningDataModule):
def __init__( # pylint: disable=unused-argument
self,
name,
train_filelist_path,
valid_filelist_path,
batch_size,
num_workers,
pin_memory,
cleaners,
add_blank,
n_spks,
n_fft,
n_feats,
sample_rate,
hop_length,
win_length,
f_min,
f_max,
data_statistics,
seed,
min_sample_size,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
# load and split datasets only if not loaded already
self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.train_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.min_sample_size,
)
self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init
self.hparams.valid_filelist_path,
self.hparams.n_spks,
self.hparams.cleaners,
self.hparams.add_blank,
self.hparams.n_fft,
self.hparams.n_feats,
self.hparams.sample_rate,
self.hparams.hop_length,
self.hparams.win_length,
self.hparams.f_min,
self.hparams.f_max,
self.hparams.data_statistics,
self.hparams.seed,
self.hparams.min_sample_size,
)
def train_dataloader(self):
return DataLoader(
dataset=self.trainset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def val_dataloader(self):
return DataLoader(
dataset=self.validset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
collate_fn=TextMelBatchCollate(self.hparams.n_spks),
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass # pylint: disable=unnecessary-pass
def state_dict(self): # pylint: disable=no-self-use
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass # pylint: disable=unnecessary-pass
class TextMelDataset(torch.utils.data.Dataset):
def __init__(
self,
filelist_path,
n_spks,
cleaners,
add_blank=True,
n_fft=1024,
n_mels=80,
sample_rate=22050,
hop_length=256,
win_length=1024,
f_min=0.0,
f_max=8000,
data_parameters=None,
seed=None,
min_sample_size=4,
):
self.filepaths_and_text = parse_filelist(filelist_path)
self.n_spks = n_spks
self.cleaners = cleaners
self.add_blank = add_blank
self.n_fft = n_fft
self.n_mels = n_mels
self.sample_rate = sample_rate
self.hop_length = hop_length
self.win_length = win_length
self.f_min = f_min
self.f_max = f_max
self.min_sample_size = min_sample_size
if data_parameters is not None:
self.data_parameters = data_parameters
else:
self.data_parameters = {"mel_mean": 0, "mel_std": 1}
random.seed(seed)
random.shuffle(self.filepaths_and_text)
def get_datapoint(self, filepath_and_text):
if self.n_spks > 1:
filepath, spk, text = (
filepath_and_text[0],
int(filepath_and_text[1]),
filepath_and_text[2],
)
else:
filepath, text = filepath_and_text[0], filepath_and_text[1]
spk = None
text = self.get_text(text, add_blank=self.add_blank)
mel, audio = self.get_mel(filepath)
# TODO: make dictionary to get different spec for same speaker
# right now naively repeating target mel for testing purposes
return {"x": text, "y": mel, "spk": spk, "wav":audio}
def get_mel(self, filepath):
audio, sr = ta.load(filepath)
assert sr == self.sample_rate
mel = mel_spectrogram(
audio,
self.n_fft,
self.n_mels,
self.sample_rate,
self.hop_length,
self.win_length,
self.f_min,
self.f_max,
center=False,
).squeeze()
mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"])
return mel, audio
def get_text(self, text, add_blank=True):
text_norm = text_to_sequence(text, self.cleaners)
if self.add_blank:
text_norm = intersperse(text_norm, 0)
text_norm = torch.IntTensor(text_norm)
return text_norm
def __getitem__(self, index):
datapoint = self.get_datapoint(self.filepaths_and_text[index])
if datapoint["wav"].shape[1] <= self.min_sample_size * self.sample_rate:
'''
skip datapoint if too short (<4s , prompt is 3s)
TODO To not waste data, we can concatenate wavs less than 3s and use them
TODO as a hyperparameter; multispeaker dataset can use another wav of same speaker
'''
return self.__getitem__(random.randint(0, len(self.filepaths_and_text)-1))
return datapoint
def __len__(self):
return len(self.filepaths_and_text)
class TextMelBatchCollate:
def __init__(self, n_spks):
self.n_spks = n_spks
def __call__(self, batch):
B = len(batch)
y_max_length = max([item["y"].shape[-1] for item in batch])
y_max_length = fix_len_compatibility(y_max_length)
wav_max_length = y_max_length * 256
x_max_length = max([item["x"].shape[-1] for item in batch])
n_feats = batch[0]["y"].shape[-2]
y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
x = torch.zeros((B, x_max_length), dtype=torch.long)
wav = torch.zeros((B, 1, wav_max_length), dtype=torch.float32)
y_lengths, x_lengths = [], []
wav_lengths = []
spks = []
for i, item in enumerate(batch):
y_, x_ = item["y"], item["x"]
wav_ = item["wav"][:,:wav_max_length] if item["wav"].shape[-1] > wav_max_length else item["wav"]
y_lengths.append(y_.shape[-1])
x_lengths.append(x_.shape[-1])
wav_lengths.append(wav_.shape[-1])
y[i, :, : y_.shape[-1]] = y_
x[i, : x_.shape[-1]] = x_
wav[i, :, : wav_.shape[-1]] = wav_
spks.append(item["spk"])
y_lengths = torch.tensor(y_lengths, dtype=torch.long)
x_lengths = torch.tensor(x_lengths, dtype=torch.long)
wav_lengths = torch.tensor(wav_lengths, dtype=torch.long)
spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None
return {
"x": x,
"x_lengths": x_lengths,
"y": y,
"y_lengths": y_lengths,
"spks": spks,
"wav":wav,
"wav_lengths":wav_lengths,
"prompt_spec": y,
"prompt_lengths": y_lengths,
}