File size: 13,526 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import os
import sys
import csv
import glob
import torch
import random
from tqdm import tqdm
from typing import List, Any

from deepafx_st.data.audio import AudioFile
import deepafx_st.utils as utils
import deepafx_st.data.augmentations as augmentations


class AudioDataset(torch.utils.data.Dataset):
    """Audio dataset which returns an input and target file.

    Args:
        audio_dir (str): Path to the top level of the audio dataset.
        input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"]
        subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
        length (int, optional): Number of samples to load for each example. Default: 65536
        train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
        val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
        buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
            Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
        buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
        half (bool, optional): Sotre audio samples as float 16. Default: False
        num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
        random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False
        random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False
        augmentations (dict, optional): List of augmentation types to apply to inputs. Default: []
        freq_corrupt (bool, optional): Apply bad EQ filters. Default: False
        drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False
        ext (str, optional): Expected audio file extension. Default: "wav"
    """

    def __init__(
        self,
        audio_dir,
        input_dirs: List[str] = ["cleanraw"],
        subset: str = "train",
        length: int = 65536,
        train_frac: float = 0.8,
        val_per: float = 0.1,
        buffer_size_gb: float = 1.0,
        buffer_reload_rate: float = 1000,
        half: bool = False,
        num_examples_per_epoch: int = 10000,
        random_scale_input: bool = False,
        random_scale_target: bool = False,
        augmentations: dict = {},
        freq_corrupt: bool = False,
        drc_corrupt: bool = False,
        ext: str = "wav",
    ):
        super().__init__()
        self.audio_dir = audio_dir
        self.dataset_name = os.path.basename(audio_dir)
        self.input_dirs = input_dirs
        self.subset = subset
        self.length = length
        self.train_frac = train_frac
        self.val_per = val_per
        self.buffer_size_gb = buffer_size_gb
        self.buffer_reload_rate = buffer_reload_rate
        self.half = half
        self.num_examples_per_epoch = num_examples_per_epoch
        self.random_scale_input = random_scale_input
        self.random_scale_target = random_scale_target
        self.augmentations = augmentations
        self.freq_corrupt = freq_corrupt
        self.drc_corrupt = drc_corrupt
        self.ext = ext

        self.input_filepaths = []
        for input_dir in input_dirs:
            search_path = os.path.join(audio_dir, input_dir, f"*.{ext}")
            self.input_filepaths += glob.glob(search_path)
        self.input_filepaths = sorted(self.input_filepaths)

        # create dataset split based on subset
        self.input_filepaths = utils.split_dataset(
            self.input_filepaths,
            subset,
            train_frac,
        )

        # get details about input audio files
        input_files = {}
        input_dur_frames = 0
        for input_filepath in tqdm(self.input_filepaths, ncols=80):
            file_id = os.path.basename(input_filepath)
            audio_file = AudioFile(
                input_filepath,
                preload=False,
                half=half,
            )
            if audio_file.num_frames < (self.length * 2):
                continue
            input_files[file_id] = audio_file
            input_dur_frames += input_files[file_id].num_frames

        if len(list(input_files.items())) < 1:
            raise RuntimeError(f"No files found in {search_path}.")

        input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600
        print(
            f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours."
        )

        self.sample_rate = input_files[file_id].sample_rate

        # save a csv file with details about the train and test split
        splits_dir = os.path.join("configs", "splits")
        if not os.path.isdir(splits_dir):
            os.makedirs(splits_dir)
        csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv")

        with open(csv_filepath, "w") as fp:
            dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"])
            dw.writeheader()
            for input_filepath in self.input_filepaths:
                dw.writerow(
                    {
                        "file_id": self.get_file_id(input_filepath),
                        "filepath": input_filepath,
                        "type": "input",
                        "subset": self.subset,
                    }
                )

        # some setup for iteratble loading of the dataset into RAM
        self.items_since_load = self.buffer_reload_rate

    def __len__(self):
        return self.num_examples_per_epoch

    def load_audio_buffer(self):
        self.input_files_loaded = {}  # clear audio buffer
        self.items_since_load = 0  # reset iteration counter
        nbytes_loaded = 0  # counter for data in RAM

        # different subset in each
        random.shuffle(self.input_filepaths)

        # load files into RAM
        for input_filepath in self.input_filepaths:
            file_id = os.path.basename(input_filepath)
            audio_file = AudioFile(
                input_filepath,
                preload=True,
                half=self.half,
            )

            if audio_file.num_frames < (self.length * 2):
                continue

            self.input_files_loaded[file_id] = audio_file

            nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
            nbytes_loaded += nbytes

            # check the size of loaded data
            if nbytes_loaded > self.buffer_size_gb * 1e9:
                break

    def generate_pair(self):
        # ------------------------ Input audio ----------------------
        rand_input_file_id = None
        input_file = None
        start_idx = None
        stop_idx = None
        while True:
            rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys())

            # use this random key to retrieve an input file
            input_file = self.input_files_loaded[rand_input_file_id]

            # load the audio data if needed
            if not input_file.loaded:
                raise RuntimeError("Audio not loaded.")

            # get a random patch of size `self.length` x 2
            start_idx, stop_idx = self.get_random_patch(
                input_file, int(self.length * 2)
            )
            if start_idx >= 0:
                break

        input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
        input_audio = input_audio.view(1, -1)

        if self.half:
            input_audio = input_audio.float()

        # peak normalize to -12 dBFS
        input_audio /= input_audio.abs().max()
        input_audio *= 10 ** (-12.0 / 20)  # with min 3 dBFS headroom

        if len(list(self.augmentations.items())) > 0:
            if torch.rand(1).sum() < 0.5:
                input_audio_aug = augmentations.apply(
                    [input_audio],
                    self.sample_rate,
                    self.augmentations,
                )[0]
            else:
                input_audio_aug = input_audio.clone()
        else:
            input_audio_aug = input_audio.clone()

        input_audio_corrupt = input_audio_aug.clone()
        # apply frequency and dynamic range corrpution (expander)
        if self.freq_corrupt and torch.rand(1).sum() < 0.75:
            input_audio_corrupt = augmentations.frequency_corruption(
                [input_audio_corrupt], self.sample_rate
            )[0]

        # peak normalize again before passing through dynamic range expander
        input_audio_corrupt /= input_audio_corrupt.abs().max()
        input_audio_corrupt *= 10 ** (-12.0 / 20)  # with min 3 dBFS headroom

        if self.drc_corrupt and torch.rand(1).sum() < 0.10:
            input_audio_corrupt = augmentations.dynamic_range_corruption(
                [input_audio_corrupt], self.sample_rate
            )[0]

        # ------------------------ Target audio ----------------------
        # use the same augmented audio clip, add different random EQ and compressor

        target_audio_corrupt = input_audio_aug.clone()
        # apply frequency and dynamic range corrpution (expander)
        if self.freq_corrupt and torch.rand(1).sum() < 0.75:
            target_audio_corrupt = augmentations.frequency_corruption(
                [target_audio_corrupt], self.sample_rate
            )[0]

        # peak normalize again before passing through dynamic range compressor
        input_audio_corrupt /= input_audio_corrupt.abs().max()
        input_audio_corrupt *= 10 ** (-12.0 / 20)  # with min 3 dBFS headroom

        if self.drc_corrupt and torch.rand(1).sum() < 0.75:
            target_audio_corrupt = augmentations.dynamic_range_compression(
                [target_audio_corrupt], self.sample_rate
            )[0]

        return input_audio_corrupt, target_audio_corrupt

    def __getitem__(self, _):
        """ """

        # increment counter
        self.items_since_load += 1

        # load next chunk into buffer if needed
        if self.items_since_load > self.buffer_reload_rate:
            self.load_audio_buffer()

        # generate pairs for style training
        input_audio, target_audio = self.generate_pair()

        # ------------------------ Conform length of files -------------------
        input_audio = utils.conform_length(input_audio, int(self.length * 2))
        target_audio = utils.conform_length(target_audio, int(self.length * 2))

        # ------------------------ Apply fade in and fade out -------------------
        input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate)
        target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate)

        # ------------------------ Final normalizeation ----------------------
        # always peak normalize final input to -12 dBFS
        input_audio /= input_audio.abs().max()
        input_audio *= 10 ** (-12.0 / 20.0)

        # always peak normalize the target to -12 dBFS
        target_audio /= target_audio.abs().max()
        target_audio *= 10 ** (-12.0 / 20.0)

        return input_audio, target_audio

    @staticmethod
    def get_random_file_id(keys):
        # generate a random index into the keys of the input files
        rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
        # find the key (file_id) correponding to the random index
        rand_input_file_id = list(keys)[rand_input_idx]

        return rand_input_file_id

    @staticmethod
    def get_random_patch(audio_file, length, check_silence=True):
        silent = True
        count = 0
        while silent:
            count += 1
            start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0]
            # int(torch.rand(1) * (audio_file.num_frames - length))
            stop_idx = start_idx + length
            patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()

            length = patch.shape[-1]
            first_patch = patch[..., : length // 2]
            second_patch = patch[..., length // 2 :]

            if (
                (first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5
            ) or not check_silence:
                silent = False

            if count > 100:
                print("get_random_patch count", count)
                return -1, -1
                # break

        return start_idx, stop_idx

    def get_file_id(self, filepath):
        """Given a filepath extract the DAPS file id.

        Args:
            filepath (str): Path to an audio files in the DAPS dataset.

        Returns:
            file_id (str): DAPS file id of the form <participant_id>_<script_id>
            file_set (str): The DAPS set to which the file belongs.
        """
        file_id = os.path.basename(filepath).split("_")[:2]
        file_id = "_".join(file_id)
        return file_id

    def get_file_set(self, filepath):
        """Given a filepath extract the DAPS file set name.

        Args:
            filepath (str): Path to an audio files in the DAPS dataset.

        Returns:
            file_set (str): The DAPS set to which the file belongs.
        """
        file_set = os.path.basename(filepath).split("_")[2:]
        file_set = "_".join(file_set)
        file_set = file_set.replace(f".{self.ext}", "")
        return file_set