Spaces:
Runtime error
Runtime error
File size: 3,234 Bytes
b9522b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as transforms
import torch
import torch.nn as nn
class AudioTransform():
def __init__(self,device, sample_rate, n_mfcc, window_size, hop_length, desired_frames):
self.device = device
self.sample_rate = sample_rate
self.window_size = window_size
self.hop_length = hop_length
self.desired_frames = desired_frames
self.n_mfcc = n_mfcc
self.hop_length = hop_length
self.desired_frames = desired_frames
self.n_fft = int(sample_rate * window_size)
self.mfcc_transform = transforms.MFCC(
sample_rate=self.sample_rate,
n_mfcc=self.n_mfcc,
melkwargs={
"n_fft": self.n_fft,
"hop_length": self.hop_length,
"mel_scale": "htk",
},
)
def getMFCC(self, audioPaths):
waveform, sample_rate = torchaudio.load(audioPaths, normalize=True)
# Desired number of frames in Mel spectrogram
num_samples = self.desired_frames * self.hop_length
# Trim or pad the waveform
if waveform.shape[1] > num_samples:
waveform = waveform[:, :num_samples] # Trim
elif waveform.shape[1] < num_samples:
pad_size = num_samples - waveform.shape[1]
waveform = nn.functional.pad(waveform, (0, pad_size)) # Pad
mfcc_specgram = self.mfcc_transform(waveform)
# If still not the correct shape, trim off extra samples
if mfcc_specgram.shape[2] > self.desired_frames:
mfcc_specgram = mfcc_specgram[:, :, :self.desired_frames]
# Assuming mfcc is already computed
mfcc_mean = mfcc_specgram.mean(dim=2, keepdim=True)
mfcc_std = mfcc_specgram.std(dim=2, keepdim=True)
# Avoid division by zero
mfcc_std = mfcc_std.clamp(min=1e-5)
# Normalize the MFCC
normalized_mfcc = (mfcc_specgram - mfcc_mean) / mfcc_std
return normalized_mfcc.to(self.device)
class AudioDataset(Dataset):
def __init__(self, audioPaths, labels, device, sample_rate, n_mfcc, window_size, hop_length, desired_frames):
self.device = device
self.audioPaths = audioPaths
self.labels = labels
self.sample_rate = sample_rate
self.window_size = window_size
self.hop_length = hop_length
self.desired_frames = desired_frames
self.n_mfcc = n_mfcc
self.n_fft = int(sample_rate * window_size)
self.audio_transform = AudioTransform(device, sample_rate, n_mfcc, window_size, hop_length, desired_frames)
def __getitem__(self, index):
mfcc = self.audio_transform.getMFCC(self.audioPaths[index])
return mfcc.to(self.device), self.labels[index]
def __len__(self):
return len(self.audioPaths)
|