# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import json import pickle import glob from collections import defaultdict from tqdm import tqdm from preprocessors import get_golden_samples_indexes TRAIN_MAX_NUM_EVERY_PERSON = 250 TEST_MAX_NUM_EVERY_PERSON = 25 def select_sample_idxs(): # =========== Train =========== with open(os.path.join(vctk_dir, "train.json"), "r") as f: raw_train = json.load(f) train_idxs = [] train_nums = defaultdict(int) for utt in tqdm(raw_train): idx = utt["index"] singer = utt["Singer"] if train_nums[singer] < TRAIN_MAX_NUM_EVERY_PERSON: train_idxs.append(idx) train_nums[singer] += 1 # =========== Test =========== with open(os.path.join(vctk_dir, "test.json"), "r") as f: raw_test = json.load(f) # golden test test_idxs = get_golden_samples_indexes( dataset_name="vctk", split="test", dataset_dir=vctk_dir ) test_nums = defaultdict(int) for idx in test_idxs: singer = raw_test[idx]["Singer"] test_nums[singer] += 1 for utt in tqdm(raw_test): idx = utt["index"] singer = utt["Singer"] if test_nums[singer] < TEST_MAX_NUM_EVERY_PERSON: test_idxs.append(idx) test_nums[singer] += 1 train_idxs.sort() test_idxs.sort() return train_idxs, test_idxs, raw_train, raw_test if __name__ == "__main__": root_path = "" vctk_dir = os.path.join(root_path, "vctk") sample_dir = os.path.join(root_path, "vctksample") os.makedirs(sample_dir, exist_ok=True) train_idxs, test_idxs, raw_train, raw_test = select_sample_idxs() print("#Train = {}, #Test = {}".format(len(train_idxs), len(test_idxs))) for split, chosen_idxs, utterances in zip( ["train", "test"], [train_idxs, test_idxs], [raw_train, raw_test] ): print( "#{} = {}, #chosen idx = {}\n".format( split, len(utterances), len(chosen_idxs) ) ) # Select features feat_files = glob.glob( "**/{}.pkl".format(split), root_dir=vctk_dir, recursive=True ) for file in tqdm(feat_files): raw_file = os.path.join(vctk_dir, file) new_file = os.path.join(sample_dir, file) new_dir = "/".join(new_file.split("/")[:-1]) os.makedirs(new_dir, exist_ok=True) if "mel_min" in file or "mel_max" in file: os.system("cp {} {}".format(raw_file, new_file)) continue with open(raw_file, "rb") as f: raw_feats = pickle.load(f) print("file: {}, #raw_feats = {}".format(file, len(raw_feats))) new_feats = [raw_feats[idx] for idx in chosen_idxs] with open(new_file, "wb") as f: pickle.dump(new_feats, f) # Utterance re-index news_utts = [utterances[idx] for idx in chosen_idxs] for i, utt in enumerate(news_utts): utt["Dataset"] = "vctksample" utt["index"] = i with open(os.path.join(sample_dir, "{}.json".format(split)), "w") as f: json.dump(news_utts, f, indent=4)