# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os from tqdm import tqdm import glob import json import torchaudio from utils.util import has_existed from utils.io import save_audio def get_splitted_utterances( raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping ): res = [] raw_song_files = glob.glob( os.path.join(raw_wav_dir, "**/pjs*_song.wav"), recursive=True ) trimed_song_files = glob.glob( os.path.join(trimed_wav_dir, "**/*.wav"), recursive=True ) if len(raw_song_files) * n_utterance_splits == len(trimed_song_files): print("Splitted done...") for wav_file in tqdm(trimed_song_files): uid = wav_file.split("/")[-1].split(".")[0] utt = {"Dataset": "pjs", "Singer": "male1", "Uid": uid, "Path": wav_file} waveform, sample_rate = torchaudio.load(wav_file) duration = waveform.size(-1) / sample_rate utt["Duration"] = duration res.append(utt) else: for wav_file in tqdm(raw_song_files): song_id = wav_file.split("/")[-1].split(".")[0] waveform, sample_rate = torchaudio.load(wav_file) trimed_waveform = torchaudio.functional.vad(waveform, sample_rate) trimed_waveform = torchaudio.functional.vad( trimed_waveform.flip(dims=[1]), sample_rate ).flip(dims=[1]) audio_len = trimed_waveform.size(-1) lapping_len = overlapping * sample_rate for i in range(n_utterance_splits): start = i * audio_len // 3 end = start + audio_len // 3 + lapping_len splitted_waveform = trimed_waveform[:, start:end] utt = { "Dataset": "pjs", "Singer": "male1", "Uid": "{}_{}".format(song_id, i), } # Duration duration = splitted_waveform.size(-1) / sample_rate utt["Duration"] = duration # Save trimed wav splitted_waveform_file = os.path.join( trimed_wav_dir, "{}.wav".format(utt["Uid"]) ) save_audio(splitted_waveform_file, splitted_waveform, sample_rate) # Path utt["Path"] = splitted_waveform_file res.append(utt) res = sorted(res, key=lambda x: x["Uid"]) return res def main(output_path, dataset_path, n_utterance_splits=3, overlapping=1): """ 1. Split one raw utterance to three splits (since some samples are too long) 2. Overlapping of ajacent splits is 1 s """ print("-" * 10) print("Preparing training dataset for PJS...") save_dir = os.path.join(output_path, "pjs") raw_wav_dir = os.path.join(dataset_path, "PJS_corpus_ver1.1") # Trim for silence trimed_wav_dir = os.path.join(dataset_path, "trim") os.makedirs(trimed_wav_dir, exist_ok=True) # Total utterances utterances = get_splitted_utterances( raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping ) total_uids = [utt["Uid"] for utt in utterances] # Test uids n_test_songs = 3 test_uids = [] for i in range(1, n_test_songs + 1): test_uids += [ "pjs00{}_song_{}".format(i, split_id) for split_id in range(n_utterance_splits) ] # Train uids train_uids = [uid for uid in total_uids if uid not in test_uids] for dataset_type in ["train", "test"]: output_file = os.path.join(save_dir, "{}.json".format(dataset_type)) if has_existed(output_file): continue uids = eval("{}_uids".format(dataset_type)) res = [utt for utt in utterances if utt["Uid"] in uids] for i in range(len(res)): res[i]["index"] = i time = sum([utt["Duration"] for utt in res]) print( "{}, Total size: {}, Total Duraions = {} s = {:.2f} hour\n".format( dataset_type, len(res), time, time / 3600 ) ) # Save os.makedirs(save_dir, exist_ok=True) with open(output_file, "w") as f: json.dump(res, f, indent=4, ensure_ascii=False)