NaturalSpeech2 / models /tts /base /tts_dataset.py
yuancwang
init
b725c5a
raw
history blame
No virus
14.1 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import torchaudio
import numpy as np
import torch
from utils.data_utils import *
from torch.nn.utils.rnn import pad_sequence
from text import text_to_sequence
from text.text_token_collation import phoneIDCollation
from processors.acoustic_extractor import cal_normalized_mel
from models.base.base_dataset import (
BaseDataset,
BaseCollator,
BaseTestDataset,
BaseTestCollator,
)
from processors.content_extractor import (
ContentvecExtractor,
WenetExtractor,
WhisperExtractor,
)
class TTSDataset(BaseDataset):
def __init__(self, cfg, dataset, is_valid=False):
"""
Args:
cfg: config
dataset: dataset name
is_valid: whether to use train or valid dataset
"""
assert isinstance(dataset, str)
self.cfg = cfg
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
self.metafile_path = os.path.join(processed_data_dir, meta_file)
self.metadata = self.get_metadata()
"""
load spk2id and utt2spk from json file
spk2id: {spk1: 0, spk2: 1, ...}
utt2spk: {dataset_uid: spk1, ...}
"""
if cfg.preprocess.use_spkid:
dataset = self.metadata[0]["Dataset"]
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
with open(spk2id_path, "r") as f:
self.spk2id = json.load(f)
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
self.utt2spk = dict()
with open(utt2spk_path, "r") as f:
for line in f.readlines():
utt, spk = line.strip().split("\t")
self.utt2spk[utt] = spk
if cfg.preprocess.use_uv:
self.utt2uv_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2uv_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.uv_dir,
uid + ".npy",
)
if cfg.preprocess.use_frame_pitch:
self.utt2frame_pitch_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2frame_pitch_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.pitch_dir,
uid + ".npy",
)
if cfg.preprocess.use_frame_energy:
self.utt2frame_energy_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2frame_energy_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.energy_dir,
uid + ".npy",
)
if cfg.preprocess.use_mel:
self.utt2mel_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2mel_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.mel_dir,
uid + ".npy",
)
if cfg.preprocess.use_linear:
self.utt2linear_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2linear_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.linear_dir,
uid + ".npy",
)
if cfg.preprocess.use_audio:
self.utt2audio_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
if cfg.preprocess.extract_audio:
self.utt2audio_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.audio_dir,
uid + ".wav",
)
else:
self.utt2audio_path[utt] = utt_info["Path"]
# self.utt2audio_path[utt] = os.path.join(
# cfg.preprocess.processed_dir,
# dataset,
# cfg.preprocess.audio_dir,
# uid + ".numpy",
# )
elif cfg.preprocess.use_label:
self.utt2label_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2label_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.label_dir,
uid + ".npy",
)
elif cfg.preprocess.use_one_hot:
self.utt2one_hot_path = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
self.utt2one_hot_path[utt] = os.path.join(
cfg.preprocess.processed_dir,
dataset,
cfg.preprocess.one_hot_dir,
uid + ".npy",
)
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
self.utt2seq = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
if cfg.preprocess.use_text:
text = utt_info["Text"]
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
elif cfg.preprocess.use_phone:
# load phoneme squence from phone file
phone_path = os.path.join(
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
)
with open(phone_path, "r") as fin:
phones = fin.readlines()
assert len(phones) == 1
phones = phones[0].strip()
phones_seq = phones.split(" ")
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
self.utt2seq[utt] = sequence
def __getitem__(self, index):
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
single_feature = dict()
if self.cfg.preprocess.use_spkid:
single_feature["spk_id"] = np.array(
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
)
if self.cfg.preprocess.use_mel:
mel = np.load(self.utt2mel_path[utt])
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
if self.cfg.preprocess.use_min_max_norm_mel:
# do mel norm
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = mel.shape[1]
single_feature["mel"] = mel.T # [T, n_mels]
if self.cfg.preprocess.use_linear:
linear = np.load(self.utt2linear_path[utt])
if "target_len" not in single_feature.keys():
single_feature["target_len"] = linear.shape[1]
single_feature["linear"] = linear.T # [T, n_linear]
if self.cfg.preprocess.use_frame_pitch:
frame_pitch_path = self.utt2frame_pitch_path[utt]
frame_pitch = np.load(frame_pitch_path)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_pitch)
aligned_frame_pitch = align_length(
frame_pitch, single_feature["target_len"]
)
single_feature["frame_pitch"] = aligned_frame_pitch
if self.cfg.preprocess.use_uv:
frame_uv_path = self.utt2uv_path[utt]
frame_uv = np.load(frame_uv_path)
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
aligned_frame_uv = [
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
]
aligned_frame_uv = np.array(aligned_frame_uv)
single_feature["frame_uv"] = aligned_frame_uv
if self.cfg.preprocess.use_frame_energy:
frame_energy_path = self.utt2frame_energy_path[utt]
frame_energy = np.load(frame_energy_path)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_energy)
aligned_frame_energy = align_length(
frame_energy, single_feature["target_len"]
)
single_feature["frame_energy"] = aligned_frame_energy
if self.cfg.preprocess.use_audio:
audio, sr = torchaudio.load(self.utt2audio_path[utt])
audio = audio.cpu().numpy().squeeze()
single_feature["audio"] = audio
single_feature["audio_len"] = audio.shape[0]
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
single_feature["phone_len"] = len(self.utt2seq[utt])
return single_feature
def __len__(self):
return super().__len__()
def get_metadata(self):
return super().get_metadata()
class TTSCollator(BaseCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, batch):
parsed_batch_features = super().__call__(batch)
return parsed_batch_features
class TTSTestDataset(BaseTestDataset):
def __init__(self, args, cfg):
self.cfg = cfg
# inference from test list file
if args.test_list_file is not None:
# construst metadata
self.metadata = []
with open(args.test_list_file, "r") as fin:
for idx, line in enumerate(fin.readlines()):
utt_info = {}
utt_info["Dataset"] = "test"
utt_info["Text"] = line.strip()
utt_info["Uid"] = str(idx)
self.metadata.append(utt_info)
else:
assert args.testing_set
self.metafile_path = os.path.join(
cfg.preprocess.processed_dir,
args.dataset,
"{}.json".format(args.testing_set),
)
self.metadata = self.get_metadata()
def __getitem__(self, index):
single_feature = {}
return single_feature
def __len__(self):
return len(self.metadata)
class TTSTestCollator(BaseTestCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, batch):
packed_batch_features = dict()
# mel: [b, T, n_mels]
# frame_pitch, frame_energy: [1, T]
# target_len: [1]
# spk_id: [b, 1]
# mask: [b, T, 1]
for key in batch[0].keys():
if key == "target_len":
packed_batch_features["target_len"] = torch.LongTensor(
[b["target_len"] for b in batch]
)
masks = [
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
]
packed_batch_features["mask"] = pad_sequence(
masks, batch_first=True, padding_value=0
)
elif key == "phone_len":
packed_batch_features["phone_len"] = torch.LongTensor(
[b["phone_len"] for b in batch]
)
masks = [
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
]
packed_batch_features["phn_mask"] = pad_sequence(
masks, batch_first=True, padding_value=0
)
elif key == "audio_len":
packed_batch_features["audio_len"] = torch.LongTensor(
[b["audio_len"] for b in batch]
)
masks = [
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
]
else:
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
return packed_batch_features