# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import random import torch from torch.nn.utils.rnn import pad_sequence from utils.data_utils import * from models.base.base_dataset import ( BaseCollator, BaseDataset, BaseTestDataset, BaseTestCollator, ) from text import text_to_sequence class FS2Dataset(BaseDataset): def __init__(self, cfg, dataset, is_valid=False): BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) self.batch_size = cfg.train.batch_size cfg = cfg.preprocess # utt2duration self.utt2duration_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2duration_path[utt] = os.path.join( cfg.processed_dir, dataset, cfg.duration_dir, uid + ".npy", ) self.utt2dur = self.read_duration() if cfg.use_frame_energy: self.frame_utt2energy, self.energy_statistic = load_energy( self.metadata, cfg.processed_dir, cfg.energy_dir, use_log_scale=cfg.use_log_scale_energy, utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None, return_norm=True, ) elif cfg.use_phone_energy: self.phone_utt2energy, self.energy_statistic = load_energy( self.metadata, cfg.processed_dir, cfg.phone_energy_dir, use_log_scale=cfg.use_log_scale_energy, utt2spk=self.utt2spk if cfg.use_spkid else None, return_norm=True, ) if cfg.use_frame_pitch: self.frame_utt2pitch, self.pitch_statistic = load_energy( self.metadata, cfg.processed_dir, cfg.pitch_dir, use_log_scale=cfg.energy_extract_mode, utt2spk=self.utt2spk if cfg.use_spkid else None, return_norm=True, ) elif cfg.use_phone_pitch: self.phone_utt2pitch, self.pitch_statistic = load_energy( self.metadata, cfg.processed_dir, cfg.phone_pitch_dir, use_log_scale=cfg.use_log_scale_pitch, utt2spk=self.utt2spk if cfg.use_spkid else None, return_norm=True, ) # utt2lab self.utt2lab_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2lab_path[utt] = os.path.join( cfg.processed_dir, dataset, cfg.lab_dir, uid + ".txt", ) self.speaker_map = {} if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): with open( os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) ) as f: self.speaker_map = json.load(f) self.metadata = self.check_metadata() def __getitem__(self, index): single_feature = BaseDataset.__getitem__(self, index) utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) duration = self.utt2dur[utt] # text f = open(self.utt2lab_path[utt], "r") phones = f.readlines()[0].strip() f.close() # todo: add cleaner(chenxi) phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"])) text_len = len(phones_ids) if self.cfg.preprocess.use_frame_pitch: pitch = self.frame_utt2pitch[utt] elif self.cfg.preprocess.use_phone_pitch: pitch = self.phone_utt2pitch[utt] if self.cfg.preprocess.use_frame_energy: energy = self.frame_utt2energy[utt] elif self.cfg.preprocess.use_phone_energy: energy = self.phone_utt2energy[utt] # speaker if len(self.speaker_map) > 0: speaker_id = self.speaker_map[utt_info["Singer"]] else: speaker_id = 0 single_feature.update( { "durations": duration, "texts": phones_ids, "spk_id": speaker_id, "text_len": text_len, "pitch": pitch, "energy": energy, "uid": uid, } ) return self.clip_if_too_long(single_feature) def read_duration(self): # read duration utt2dur = {} for index in range(len(self.metadata)): utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists( self.utt2duration_path[utt] ): continue mel = np.load(self.utt2mel_path[utt]).transpose(1, 0) duration = np.load(self.utt2duration_path[utt]) assert mel.shape[0] == sum( duration ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}" utt2dur[utt] = duration return utt2dur def __len__(self): return len(self.metadata) def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): """ ending_ts: to avoid invalid whisper features for over 30s audios 2812 = 30 * 24000 // 256 """ ts = max(feature_seq_len - max_seq_len, 0) ts = min(ts, ending_ts - max_seq_len) start = random.randint(0, ts) end = start + max_seq_len return start, end def clip_if_too_long(self, sample, max_seq_len=1000): """ sample : { 'spk_id': (1,), 'target_len': int 'mel': (seq_len, dim), 'frame_pitch': (seq_len,) 'frame_energy': (seq_len,) 'content_vector_feat': (seq_len, dim) } """ if sample["target_len"] <= max_seq_len: return sample start, end = self.random_select(sample["target_len"], max_seq_len) sample["target_len"] = end - start for k in sample.keys(): if k not in ["spk_id", "target_len"]: sample[k] = sample[k][start:end] return sample def check_metadata(self): new_metadata = [] for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists( self.utt2mel_path[utt] ): continue else: new_metadata.append(utt_info) return new_metadata class FS2Collator(BaseCollator): """Zero-pads model inputs and targets based on number of frames per step""" def __init__(self, cfg): BaseCollator.__init__(self, cfg) self.sort = cfg.train.sort_sample self.batch_size = cfg.train.batch_size self.drop_last = cfg.train.drop_last def __call__(self, batch): # mel: [b, T, n_mels] # frame_pitch, frame_energy: [1, T] # target_len: [1] # spk_id: [b, 1] # mask: [b, T, 1] packed_batch_features = dict() for key in batch[0].keys(): if key == "target_len": packed_batch_features["target_len"] = torch.LongTensor( [b["target_len"] for b in batch] ) masks = [ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "text_len": packed_batch_features["text_len"] = torch.LongTensor( [b["text_len"] for b in batch] ) masks = [ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["text_mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "spk_id": packed_batch_features["spk_id"] = torch.LongTensor( [b["spk_id"] for b in batch] ) elif key == "uid": packed_batch_features[key] = [b["uid"] for b in batch] else: values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) return packed_batch_features class FS2TestDataset(BaseTestDataset): def __init__(self, args, cfg, infer_type=None): datasets = cfg.dataset cfg = cfg.preprocess is_bigdata = False assert len(datasets) >= 1 if len(datasets) > 1: datasets.sort() bigdata_version = "_".join(datasets) processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version) is_bigdata = True else: processed_data_dir = os.path.join(cfg.processed_dir, args.dataset) if args.test_list_file: self.metafile_path = args.test_list_file self.metadata = self.get_metadata() else: assert args.testing_set source_metafile_path = os.path.join( cfg.processed_dir, args.dataset, "{}.json".format(args.testing_set), ) with open(source_metafile_path, "r") as f: self.metadata = json.load(f) self.cfg = cfg self.datasets = datasets self.data_root = processed_data_dir self.is_bigdata = is_bigdata self.source_dataset = args.dataset ######### Load source acoustic features ######### if cfg.use_spkid: spk2id_path = os.path.join(self.data_root, cfg.spk2id) utt2sp_path = os.path.join(self.data_root, cfg.utt2spk) self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets) # utt2lab self.utt2lab_path = {} for utt_info in self.metadata: dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) self.utt2lab_path[utt] = os.path.join( cfg.processed_dir, dataset, cfg.lab_dir, uid + ".txt", ) self.speaker_map = {} if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): with open( os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) ) as f: self.speaker_map = json.load(f) def __getitem__(self, index): single_feature = {} utt_info = self.metadata[index] dataset = utt_info["Dataset"] uid = utt_info["Uid"] utt = "{}_{}".format(dataset, uid) # text f = open(self.utt2lab_path[utt], "r") phones = f.readlines()[0].strip() f.close() phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners)) text_len = len(phones_ids) # speaker if len(self.speaker_map) > 0: speaker_id = self.speaker_map[utt_info["Singer"]] else: speaker_id = 0 single_feature.update( { "texts": phones_ids, "spk_id": speaker_id, "text_len": text_len, } ) return single_feature def __len__(self): return len(self.metadata) def get_metadata(self): with open(self.metafile_path, "r", encoding="utf-8") as f: metadata = json.load(f) return metadata class FS2TestCollator(BaseTestCollator): """Zero-pads model inputs and targets based on number of frames per step""" def __init__(self, cfg): self.cfg = cfg def __call__(self, batch): packed_batch_features = dict() # mel: [b, T, n_mels] # frame_pitch, frame_energy: [1, T] # target_len: [1] # spk_id: [b, 1] # mask: [b, T, 1] for key in batch[0].keys(): if key == "target_len": packed_batch_features["target_len"] = torch.LongTensor( [b["target_len"] for b in batch] ) masks = [ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "text_len": packed_batch_features["text_len"] = torch.LongTensor( [b["text_len"] for b in batch] ) masks = [ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch ] packed_batch_features["text_mask"] = pad_sequence( masks, batch_first=True, padding_value=0 ) elif key == "spk_id": packed_batch_features["spk_id"] = torch.LongTensor( [b["spk_id"] for b in batch] ) else: values = [torch.from_numpy(b[key]) for b in batch] packed_batch_features[key] = pad_sequence( values, batch_first=True, padding_value=0 ) return packed_batch_features