import torch import torch.utils.data import torch.nn.functional as F import torchaudio from pathlib import Path import numpy as np import random def files_to_list(filename): """ Takes a text file of filenames and makes a list of filenames """ with open(filename, encoding="utf-8") as f: files = f.readlines() files = [f.rstrip() for f in files] return files class AudioDataset(torch.utils.data.Dataset): """ This is the main class that calculates the spectrogram and returns the spectrogram, audio pair. """ def __init__(self, training_files, segment_length, sampling_rate, augment=True): self.sampling_rate = sampling_rate self.segment_length = segment_length self.audio_files = files_to_list(training_files) self.audio_files = [Path(training_files).parent / x for x in self.audio_files] random.seed(1234) random.shuffle(self.audio_files) self.augment = augment def __getitem__(self, index): # Read audio filename = self.audio_files[index] try: audio, sampling_rate = self.load_wav_to_torch(filename) except RuntimeError: # there's lots of corrupted files in FMA print(f'Found corrupted file: {filename}, use empty data instead') audio = torch.tensor([]) # Take segment if audio.size(0) >= self.segment_length: max_audio_start = audio.size(0) - self.segment_length audio_start = random.randint(0, max_audio_start) audio = audio[audio_start : audio_start + self.segment_length] else: audio = F.pad( audio, (0, self.segment_length - audio.size(0)), "constant" ).data # audio = audio / 32768.0 return audio.unsqueeze(0) def __len__(self): return len(self.audio_files) def load_wav_to_torch(self, full_path): """ Loads audio into torch array """ data, sampling_rate = torchaudio.load(str(full_path)) data = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.sampling_rate)(data) sampling_rate = self.sampling_rate if len(data.shape) > 1: # convert to mono data = data[random.randint(0, data.shape[0]-1)] if self.augment: amplitude = np.random.uniform(low=0.3, high=1.0) data = data * amplitude return data.float(), sampling_rate