NaturalSpeech2 / models /tts /vits /vits_dataset.py
yuancwang
init
b725c5a
raw
history blame contribute delete
No virus
4.03 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import json
import numpy as np
from text import text_to_sequence
from text.text_token_collation import phoneIDCollation
from models.tts.base.tts_dataset import (
TTSDataset,
TTSCollator,
TTSTestDataset,
TTSTestCollator,
)
class VITSDataset(TTSDataset):
def __init__(self, cfg, dataset, is_valid):
super().__init__(cfg, dataset, is_valid=is_valid)
def __getitem__(self, index):
single_feature = super().__getitem__(index)
return single_feature
def __len__(self):
return super().__len__()
class VITSCollator(TTSCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, batch):
parsed_batch_features = super().__call__(batch)
return parsed_batch_features
class VITSTestDataset(TTSTestDataset):
def __init__(self, args, cfg):
super().__init__(args, cfg)
if cfg.preprocess.use_spkid:
processed_data_dir = os.path.join(
cfg.preprocess.processed_dir, args.dataset
)
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
with open(spk2id_path, "r") as f:
self.spk2id = json.load(f)
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
self.utt2spk = dict()
with open(utt2spk_path, "r") as f:
for line in f.readlines():
utt, spk = line.strip().split("\t")
self.utt2spk[utt] = spk
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
self.utt2seq = {}
for utt_info in self.metadata:
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
if cfg.preprocess.use_text:
text = utt_info["Text"]
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
elif cfg.preprocess.use_phone:
# load phoneme squence from phone file
phone_path = os.path.join(
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
)
with open(phone_path, "r") as fin:
phones = fin.readlines()
assert len(phones) == 1
phones = phones[0].strip()
phones_seq = phones.split(" ")
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
self.utt2seq[utt] = sequence
def __getitem__(self, index):
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
single_feature = dict()
if self.cfg.preprocess.use_spkid:
single_feature["spk_id"] = np.array(
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
)
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
single_feature["phone_len"] = len(self.utt2seq[utt])
return single_feature
def get_metadata(self):
with open(self.metafile_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
return metadata
def __len__(self):
return len(self.metadata)
class VITSTestCollator(TTSTestCollator):
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, batch):
return super().__call__(batch)