File size: 7,769 Bytes
51da11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Adapted from:
# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py
import os
import csv
import torch
import fnmatch
import numpy as np
import random
from enum import Enum
import pyloudnorm as pyln


class DSPMode(Enum):
    NONE = "none"
    TRAIN_INFER = "train_infer"
    INFER = "infer"

    def __str__(self):
        return self.value


def loudness_normalize(x, sample_rate, target_loudness=-24.0):
    x = x.view(1, -1)
    stereo_audio = x.repeat(2, 1).permute(1, 0).numpy()
    meter = pyln.Meter(sample_rate)
    loudness = meter.integrated_loudness(stereo_audio)
    norm_x = pyln.normalize.loudness(
        stereo_audio,
        loudness,
        target_loudness,
    )
    x = torch.tensor(norm_x).permute(1, 0)
    x = x[0, :].view(1, -1)

    return x


def get_random_file_id(keys):
    # generate a random index into the keys of the input files
    rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
    # find the key (file_id) correponding to the random index
    rand_input_file_id = list(keys)[rand_input_idx]

    return rand_input_file_id


def get_random_patch(audio_file, length, check_silence=True):
    silent = True
    while silent:
        start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
        stop_idx = start_idx + length
        patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
        if (patch ** 2).mean() > 1e-4 or not check_silence:
            silent = False

    return start_idx, stop_idx


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def getFilesPath(directory, extension):

    n_path = []
    for path, subdirs, files in os.walk(directory):
        for name in files:
            if fnmatch.fnmatch(name, extension):
                n_path.append(os.path.join(path, name))
    n_path.sort()

    return n_path


def count_parameters(model, trainable_only=True):

    if trainable_only:
        if len(list(model.parameters())) > 0:
            params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        else:
            params = 0
    else:
        if len(list(model.parameters())) > 0:
            params = sum(p.numel() for p in model.parameters())
        else:
            params = 0

    return params


def system_summary(system):
    print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M")
    print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M")

    if hasattr(system, "adv_loss_fn"):
        for idx, disc in enumerate(system.adv_loss_fn.discriminators):
            print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M")


def center_crop(x, length: int):
    if x.shape[-1] != length:
        start = (x.shape[-1] - length) // 2
        stop = start + length
        x = x[..., start:stop]
    return x


def causal_crop(x, length: int):
    if x.shape[-1] != length:
        stop = x.shape[-1] - 1
        start = stop - length
        x = x[..., start:stop]
    return x


def denormalize(norm_val, max_val, min_val):
    return (norm_val * (max_val - min_val)) + min_val


def normalize(denorm_val, max_val, min_val):
    return (denorm_val - min_val) / (max_val - min_val)


def get_random_patch(audio_file, length, energy_treshold=1e-4):
    """Produce sample indicies for a random patch of size `length`.

    This function will check the energy of the selected patch to
    ensure that it is not complete silence. If silence is found,
    it will continue searching for a non-silent patch.

    Args:
        audio_file (AudioFile): Audio file object.
        length (int): Number of samples in random patch.

    Returns:
        start_idx (int): Starting sample index
        stop_idx (int): Stop sample index
    """

    silent = True
    while silent:
        start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
        stop_idx = start_idx + length
        patch = audio_file.audio[:, start_idx:stop_idx]
        if (patch ** 2).mean() > energy_treshold:
            silent = False

    return start_idx, stop_idx


def split_dataset(file_list, subset, train_frac):
    """Given a list of files, split into train/val/test sets.

    Args:
        file_list (list): List of audio files.
        subset (str): One of "train", "val", or "test".
        train_frac (float): Fraction of the dataset to use for training.

    Returns:
        file_list (list): List of audio files corresponding to subset.
    """
    assert train_frac > 0.1 and train_frac < 1.0

    total_num_examples = len(file_list)

    train_num_examples = int(total_num_examples * train_frac)
    val_num_examples = int(total_num_examples * (1 - train_frac) / 2)
    test_num_examples = total_num_examples - (train_num_examples + val_num_examples)

    if train_num_examples < 0:
        raise ValueError(
            f"No examples in training set. Try increasing train_frac: {train_frac}."
        )
    elif val_num_examples < 0:
        raise ValueError(
            f"No examples in validation set. Try decreasing train_frac: {train_frac}."
        )
    elif test_num_examples < 0:
        raise ValueError(
            f"No examples in test set. Try decreasing train_frac: {train_frac}."
        )

    if subset == "train":
        start_idx = 0
        stop_idx = train_num_examples
    elif subset == "val":
        start_idx = train_num_examples
        stop_idx = start_idx + val_num_examples
    elif subset == "test":
        start_idx = train_num_examples + val_num_examples
        stop_idx = start_idx + test_num_examples + 1
    else:
        raise ValueError("Invalid subset: {subset}.")

    return file_list[start_idx:stop_idx]


def rademacher(size):
    """Generates random samples from a Rademacher distribution +-1

    Args:
        size (int):

    """
    m = torch.distributions.binomial.Binomial(1, 0.5)
    x = m.sample(size)
    x[x == 0] = -1
    return x


def get_subset(csv_file):
    subset_files = []
    with open(csv_file) as fp:
        reader = csv.DictReader(fp)
        for row in reader:
            subset_files.append(row["filepath"])

    return list(set(subset_files))


def conform_length(x: torch.Tensor, length: int):
    """Crop or pad input on last dim to match `length`."""
    if x.shape[-1] < length:
        padsize = length - x.shape[-1]
        x = torch.nn.functional.pad(x, (0, padsize))
    elif x.shape[-1] > length:
        x = x[..., :length]

    return x


def linear_fade(
    x: torch.Tensor,
    fade_ms: float = 50.0,
    sample_rate: float = 22050,
):
    """Apply fade in and fade out to last dim."""
    fade_samples = int(fade_ms * 1e-3 * 22050)

    fade_in = torch.linspace(0.0, 1.0, steps=fade_samples)
    fade_out = torch.linspace(1.0, 0.0, steps=fade_samples)

    # fade in
    x[..., :fade_samples] *= fade_in

    # fade out
    x[..., -fade_samples:] *= fade_out

    return x


# def get_random_patch(x, sample_rate, length_samples):
#     length = length_samples
#     silent = True
#     while silent:
#         start_idx = np.random.randint(0, x.shape[-1] - length - 1)
#         stop_idx = start_idx + length
#         x_crop = x[0:1, start_idx:stop_idx]

#         # check for silence
#         frames = length // sample_rate
#         silent_frames = []
#         for n in range(frames):
#             start_idx = n * sample_rate
#             stop_idx = start_idx + sample_rate
#             x_frame = x_crop[0:1, start_idx:stop_idx]
#             if (x_frame ** 2).mean() > 3e-4:
#                 silent_frames.append(False)
#             else:
#                 silent_frames.append(True)
#         silent = True if any(silent_frames) else False

#     x_crop /= x_crop.abs().max()

#     return x_crop