""" music_datasets.py Desc: Contains the code for the music datasets. """ import torch from torch.utils.data import Dataset import torchaudio import numpy as np import pandas as pd """ MusicMelDataset: Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length Args: audio_files: List of .npy files consisting of mel-specs audio_len: length in seconds (roughly) of audio to be return mask_ratio: Size of mask as a ration of audio_len mask_start: Where the mask starts for learning "midpoint": always mask out the second half of the mel-spec crop_start: Where the starting point for the sample of audio is taken "random": Random valid starting point from audio is taken """ class MusicMelDataset(Dataset): def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"): self.audio_files = audio_files # Convert length to number of frames self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made self.mask_ratio = mask_ratio self.mask_len = int(np.floor(self.audio_len * mask_ratio)) self.mask_start = mask_start self.crop_start = crop_start def __len__(self): return len(self.audio_files) # Get a random crop using audio_length def get_random_crop(self, mel): crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,)) return mel[crop_start:crop_start + self.audio_len, :] def __getitem__(self, idx): mel = torch.Tensor(np.load(self.audio_files[idx])) if self.crop_start == "random": mel = self.get_random_crop(mel) else: raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start") mask = torch.ones_like(mel) if self.mask_start == "midpoint": if self.mask_ratio == 0.5: mask[self.mask_len:, :] = 0 else: mask[self.audio_len // 2 + self.mask_len, :] = 0 else: raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start") mel_mask = mel*mask return mel, mel_mask