|
import torch |
|
import torch.utils.data |
|
import torch.nn.functional as F |
|
import torchaudio |
|
|
|
|
|
from pathlib import Path |
|
import numpy as np |
|
import random |
|
|
|
|
|
def files_to_list(filename): |
|
""" |
|
Takes a text file of filenames and makes a list of filenames |
|
""" |
|
with open(filename, encoding="utf-8") as f: |
|
files = f.readlines() |
|
|
|
files = [f.rstrip() for f in files] |
|
return files |
|
|
|
|
|
class AudioDataset(torch.utils.data.Dataset): |
|
""" |
|
This is the main class that calculates the spectrogram and returns the |
|
spectrogram, audio pair. |
|
""" |
|
|
|
def __init__(self, training_files, segment_length, sampling_rate, augment=True): |
|
self.sampling_rate = sampling_rate |
|
self.segment_length = segment_length |
|
self.audio_files = files_to_list(training_files) |
|
self.audio_files = [Path(training_files).parent / x for x in self.audio_files] |
|
random.seed(1234) |
|
random.shuffle(self.audio_files) |
|
self.augment = augment |
|
|
|
def __getitem__(self, index): |
|
|
|
filename = self.audio_files[index] |
|
try: |
|
audio, sampling_rate = self.load_wav_to_torch(filename) |
|
except RuntimeError: |
|
|
|
print(f'Found corrupted file: {filename}, use empty data instead') |
|
audio = torch.tensor([]) |
|
|
|
if audio.size(0) >= self.segment_length: |
|
max_audio_start = audio.size(0) - self.segment_length |
|
audio_start = random.randint(0, max_audio_start) |
|
audio = audio[audio_start : audio_start + self.segment_length] |
|
else: |
|
audio = F.pad( |
|
audio, (0, self.segment_length - audio.size(0)), "constant" |
|
).data |
|
|
|
|
|
return audio.unsqueeze(0) |
|
|
|
def __len__(self): |
|
return len(self.audio_files) |
|
|
|
def load_wav_to_torch(self, full_path): |
|
""" |
|
Loads audio into torch array |
|
""" |
|
data, sampling_rate = torchaudio.load(str(full_path)) |
|
data = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.sampling_rate)(data) |
|
sampling_rate = self.sampling_rate |
|
|
|
if len(data.shape) > 1: |
|
|
|
data = data[random.randint(0, data.shape[0]-1)] |
|
|
|
if self.augment: |
|
amplitude = np.random.uniform(low=0.3, high=1.0) |
|
data = data * amplitude |
|
|
|
return data.float(), sampling_rate |
|
|