Nicolas Denier
update readme
0388c00
from torch.utils.data import DataLoader
import librosa
from math import floor
import torch
from torch.nn.functional import pad
from torchaudio.transforms import Resample
from random import randint
def get_dataloader(dataset, device, batch_size=16, shuffle=True):
return DataLoader(
dataset.with_format("torch", device=device),
batch_size=batch_size,
collate_fn=prepare_batch,
num_workers=4,
shuffle=shuffle,
drop_last=True,
persistent_workers=True,
)
def resample(x, sr, newsr):
transform = Resample(
orig_freq=sr,
new_freq=newsr,
resampling_method="sinc_interp_kaiser",
lowpass_filter_width=16,
rolloff=0.85,
beta=8.555504641634386,
)
return transform(x)
def fixlength(x, L):
x = x[:L]
x = pad(x, (0,L-len(x)))
return x
def preprocess(X, newsr, n_fft, win_length, hop_length, gain=0.8, bias=10, power=0.25):
X = torch.stft(X, n_fft, hop_length=hop_length, win_length=win_length, window=torch.hann_window(win_length), onesided=True, return_complex=True)
X = torch.abs(X)
X = torch.stack([torch.from_numpy(librosa.pcen(x.numpy(), sr=newsr, hop_length=hop_length, gain=gain, bias=bias, power=power))
for x in X], 0)
X = X.to(torch.bfloat16)
return X
def prepare_batch(samples):
newsr = 4000
n_fft = 2**10 # power of 2
win_length = 2**10
hop_length = floor(0.0505*newsr)
labels = []
signals = []
for sample in samples:
labels.append(sample['label'])
sr = sample['audio']['sampling_rate']
x = sample['audio']['array']
if (sr > newsr and len(x)!=0):
x = resample(x, sr, newsr)
x = fixlength(x, 3*newsr)
signals.append(x)
signals = torch.stack(signals, 0)
batch = preprocess(signals,newsr, n_fft, win_length, hop_length)
labels = torch.tensor(labels, dtype=float)
return batch, labels
# Data augmentation
def random_mask(sample):
# random rectangular mask
B, H, W = sample.shape
for b in range(B):
for _ in range(randint(3,12)):
w = randint(5, 15)
h = randint(10, 100)
x1 = randint(0, W-w)
y1 = randint(0, H-h)
sample[b, y1:y1+h, x1:x1+w] = 0
return sample
def timeshift(sample):
padsize = randint(0, 6)
length = sample.size(2)
randpad = torch.zeros((sample.size(0), sample.size(1), padsize), dtype=torch.float32)
sample = torch.cat((randpad, sample), dim=2)
sample = sample[:,:,:length]
return sample
def add_noise(sample):
#noise = np.random.normal(0, 0.05*sample.max(), sample.shape)
noise = 0.05*sample.max()*torch.randn(sample.shape, dtype=torch.float32)
sample = sample + noise
return sample
def augment(sample):
sample = timeshift(sample)
sample = random_mask(sample)
sample = add_noise(sample)
return sample