Spaces:
Sleeping
Sleeping
""" | |
music_datasets.py | |
Desc: Contains the code for the music datasets. | |
""" | |
import torch | |
from torch.utils.data import Dataset | |
import torchaudio | |
import numpy as np | |
import pandas as pd | |
""" | |
MusicMelDataset: | |
Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length | |
Args: | |
audio_files: List of .npy files consisting of mel-specs | |
audio_len: length in seconds (roughly) of audio to be return | |
mask_ratio: Size of mask as a ration of audio_len | |
mask_start: Where the mask starts for learning | |
"midpoint": always mask out the second half of the mel-spec | |
crop_start: Where the starting point for the sample of audio is taken | |
"random": Random valid starting point from audio is taken | |
""" | |
class MusicMelDataset(Dataset): | |
def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"): | |
self.audio_files = audio_files | |
# Convert length to number of frames | |
self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made | |
self.mask_ratio = mask_ratio | |
self.mask_len = int(np.floor(self.audio_len * mask_ratio)) | |
self.mask_start = mask_start | |
self.crop_start = crop_start | |
def __len__(self): | |
return len(self.audio_files) | |
# Get a random crop using audio_length | |
def get_random_crop(self, mel): | |
crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,)) | |
return mel[crop_start:crop_start + self.audio_len, :] | |
def __getitem__(self, idx): | |
mel = torch.Tensor(np.load(self.audio_files[idx])) | |
if self.crop_start == "random": | |
mel = self.get_random_crop(mel) | |
else: | |
raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start") | |
mask = torch.ones_like(mel) | |
if self.mask_start == "midpoint": | |
if self.mask_ratio == 0.5: | |
mask[self.mask_len:, :] = 0 | |
else: | |
mask[self.audio_len // 2 + self.mask_len, :] = 0 | |
else: | |
raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start") | |
mel_mask = mel*mask | |
return mel, mel_mask | |