|
from torch.utils.data import DataLoader |
|
import librosa |
|
from math import floor |
|
import torch |
|
from torch.nn.functional import pad |
|
from torchaudio.transforms import Resample |
|
from random import randint |
|
|
|
|
|
def get_dataloader(dataset, device, batch_size=16, shuffle=True): |
|
return DataLoader( |
|
dataset.with_format("torch", device=device), |
|
batch_size=batch_size, |
|
collate_fn=prepare_batch, |
|
num_workers=4, |
|
shuffle=shuffle, |
|
drop_last=True, |
|
persistent_workers=True, |
|
) |
|
|
|
def resample(x, sr, newsr): |
|
transform = Resample( |
|
orig_freq=sr, |
|
new_freq=newsr, |
|
resampling_method="sinc_interp_kaiser", |
|
lowpass_filter_width=16, |
|
rolloff=0.85, |
|
beta=8.555504641634386, |
|
) |
|
return transform(x) |
|
|
|
def fixlength(x, L): |
|
x = x[:L] |
|
x = pad(x, (0,L-len(x))) |
|
return x |
|
|
|
def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25): |
|
X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True) |
|
X = torch.abs(X) |
|
X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power)) |
|
for x in X], 0) |
|
X = X.to(torch.bfloat16) |
|
return X |
|
|
|
|
|
def prepare_batch(samples): |
|
newsr = 4000 |
|
n_fft = 2**10 |
|
win_length = 2**10 |
|
hop_length = floor(0.0505*newsr) |
|
labels = [] |
|
signals = [] |
|
for sample in samples: |
|
labels.append(sample['label']) |
|
sr = sample['audio']['sampling_rate'] |
|
x = sample['audio']['array'] |
|
if (sr > newsr and len(x)!=0): |
|
x = resample(x, sr, newsr) |
|
x = fixlength(x, 3*newsr) |
|
signals.append(x) |
|
|
|
signals = torch.stack(signals, 0) |
|
batch = preprocess(signals,newsr, n_fft, win_length, hop_length) |
|
labels = torch.tensor(labels, dtype=float) |
|
return batch, labels |
|
|
|
|
|
|
|
def random_mask(sample): |
|
|
|
B, H, W = sample.shape |
|
for b in range(B): |
|
for _ in range(randint(3,12)): |
|
w = randint(5, 15) |
|
h = randint(10, 100) |
|
x1 = randint(0, W-w) |
|
y1 = randint(0, H-h) |
|
sample[b, y1:y1+h, x1:x1+w] = 0 |
|
return sample |
|
|
|
def timeshift(sample): |
|
padsize = randint(0, 6) |
|
length = sample.size(2) |
|
randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32) |
|
sample = torch.cat((randpad, sample), dim=2) |
|
sample = sample[:,:,:length] |
|
return sample |
|
|
|
def add_noise(sample): |
|
|
|
noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32) |
|
sample = sample + noise |
|
return sample |
|
|
|
def augment(sample): |
|
sample = timeshift(sample) |
|
sample = random_mask(sample) |
|
sample = add_noise(sample) |
|
return sample |