tsm-net / tsmnet /dataset.py
ernestchu's picture
update
b6ef12a
raw
history blame contribute delete
No virus
2.51 kB
import torch
import torch.utils.data
import torch.nn.functional as F
import torchaudio
from pathlib import Path
import numpy as np
import random
def files_to_list(filename):
"""
Takes a text file of filenames and makes a list of filenames
"""
with open(filename, encoding="utf-8") as f:
files = f.readlines()
files = [f.rstrip() for f in files]
return files
class AudioDataset(torch.utils.data.Dataset):
"""
This is the main class that calculates the spectrogram and returns the
spectrogram, audio pair.
"""
def __init__(self, training_files, segment_length, sampling_rate, augment=True):
self.sampling_rate = sampling_rate
self.segment_length = segment_length
self.audio_files = files_to_list(training_files)
self.audio_files = [Path(training_files).parent / x for x in self.audio_files]
random.seed(1234)
random.shuffle(self.audio_files)
self.augment = augment
def __getitem__(self, index):
# Read audio
filename = self.audio_files[index]
try:
audio, sampling_rate = self.load_wav_to_torch(filename)
except RuntimeError:
# there's lots of corrupted files in FMA
print(f'Found corrupted file: {filename}, use empty data instead')
audio = torch.tensor([])
# Take segment
if audio.size(0) >= self.segment_length:
max_audio_start = audio.size(0) - self.segment_length
audio_start = random.randint(0, max_audio_start)
audio = audio[audio_start : audio_start + self.segment_length]
else:
audio = F.pad(
audio, (0, self.segment_length - audio.size(0)), "constant"
).data
# audio = audio / 32768.0
return audio.unsqueeze(0)
def __len__(self):
return len(self.audio_files)
def load_wav_to_torch(self, full_path):
"""
Loads audio into torch array
"""
data, sampling_rate = torchaudio.load(str(full_path))
data = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=self.sampling_rate)(data)
sampling_rate = self.sampling_rate
if len(data.shape) > 1:
# convert to mono
data = data[random.randint(0, data.shape[0]-1)]
if self.augment:
amplitude = np.random.uniform(low=0.3, high=1.0)
data = data * amplitude
return data.float(), sampling_rate