File size: 4,378 Bytes
0d80816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from tqdm import tqdm
import glob
import json
import torchaudio

from utils.util import has_existed
from utils.io import save_audio


def get_splitted_utterances(
    raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
):
    res = []
    raw_song_files = glob.glob(
        os.path.join(raw_wav_dir, "**/pjs*_song.wav"), recursive=True
    )
    trimed_song_files = glob.glob(
        os.path.join(trimed_wav_dir, "**/*.wav"), recursive=True
    )

    if len(raw_song_files) * n_utterance_splits == len(trimed_song_files):
        print("Splitted done...")
        for wav_file in tqdm(trimed_song_files):
            uid = wav_file.split("/")[-1].split(".")[0]
            utt = {"Dataset": "pjs", "Singer": "male1", "Uid": uid, "Path": wav_file}

            waveform, sample_rate = torchaudio.load(wav_file)
            duration = waveform.size(-1) / sample_rate
            utt["Duration"] = duration

            res.append(utt)

    else:
        for wav_file in tqdm(raw_song_files):
            song_id = wav_file.split("/")[-1].split(".")[0]

            waveform, sample_rate = torchaudio.load(wav_file)
            trimed_waveform = torchaudio.functional.vad(waveform, sample_rate)
            trimed_waveform = torchaudio.functional.vad(
                trimed_waveform.flip(dims=[1]), sample_rate
            ).flip(dims=[1])

            audio_len = trimed_waveform.size(-1)
            lapping_len = overlapping * sample_rate

            for i in range(n_utterance_splits):
                start = i * audio_len // 3
                end = start + audio_len // 3 + lapping_len
                splitted_waveform = trimed_waveform[:, start:end]

                utt = {
                    "Dataset": "pjs",
                    "Singer": "male1",
                    "Uid": "{}_{}".format(song_id, i),
                }

                # Duration
                duration = splitted_waveform.size(-1) / sample_rate
                utt["Duration"] = duration

                # Save trimed wav
                splitted_waveform_file = os.path.join(
                    trimed_wav_dir, "{}.wav".format(utt["Uid"])
                )
                save_audio(splitted_waveform_file, splitted_waveform, sample_rate)

                # Path
                utt["Path"] = splitted_waveform_file

                res.append(utt)

    res = sorted(res, key=lambda x: x["Uid"])
    return res


def main(output_path, dataset_path, n_utterance_splits=3, overlapping=1):
    """
    1. Split one raw utterance to three splits (since some samples are too long)
    2. Overlapping of ajacent splits is 1 s
    """
    print("-" * 10)
    print("Preparing training dataset for PJS...")

    save_dir = os.path.join(output_path, "pjs")
    raw_wav_dir = os.path.join(dataset_path, "PJS_corpus_ver1.1")

    # Trim for silence
    trimed_wav_dir = os.path.join(dataset_path, "trim")
    os.makedirs(trimed_wav_dir, exist_ok=True)

    # Total utterances
    utterances = get_splitted_utterances(
        raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
    )
    total_uids = [utt["Uid"] for utt in utterances]

    # Test uids
    n_test_songs = 3
    test_uids = []
    for i in range(1, n_test_songs + 1):
        test_uids += [
            "pjs00{}_song_{}".format(i, split_id)
            for split_id in range(n_utterance_splits)
        ]

    # Train uids
    train_uids = [uid for uid in total_uids if uid not in test_uids]

    for dataset_type in ["train", "test"]:
        output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
        if has_existed(output_file):
            continue

        uids = eval("{}_uids".format(dataset_type))
        res = [utt for utt in utterances if utt["Uid"] in uids]
        for i in range(len(res)):
            res[i]["index"] = i

        time = sum([utt["Duration"] for utt in res])
        print(
            "{}, Total size: {}, Total Duraions = {} s = {:.2f} hour\n".format(
                dataset_type, len(res), time, time / 3600
            )
        )

        # Save
        os.makedirs(save_dir, exist_ok=True)
        with open(output_file, "w") as f:
            json.dump(res, f, indent=4, ensure_ascii=False)