ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
"""
music_datasets.py
Desc: Contains the code for the music datasets.
"""
import torch
from torch.utils.data import Dataset
import torchaudio
import numpy as np
import pandas as pd
"""
MusicMelDataset:
Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length
Args:
audio_files: List of .npy files consisting of mel-specs
audio_len: length in seconds (roughly) of audio to be return
mask_ratio: Size of mask as a ration of audio_len
mask_start: Where the mask starts for learning
"midpoint": always mask out the second half of the mel-spec
crop_start: Where the starting point for the sample of audio is taken
"random": Random valid starting point from audio is taken
"""
class MusicMelDataset(Dataset):
def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"):
self.audio_files = audio_files
# Convert length to number of frames
self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made
self.mask_ratio = mask_ratio
self.mask_len = int(np.floor(self.audio_len * mask_ratio))
self.mask_start = mask_start
self.crop_start = crop_start
def __len__(self):
return len(self.audio_files)
# Get a random crop using audio_length
def get_random_crop(self, mel):
crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,))
return mel[crop_start:crop_start + self.audio_len, :]
def __getitem__(self, idx):
mel = torch.Tensor(np.load(self.audio_files[idx]))
if self.crop_start == "random":
mel = self.get_random_crop(mel)
else:
raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start")
mask = torch.ones_like(mel)
if self.mask_start == "midpoint":
if self.mask_ratio == 0.5:
mask[self.mask_len:, :] = 0
else:
mask[self.audio_len // 2 + self.mask_len, :] = 0
else:
raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start")
mel_mask = mel*mask
return mel, mel_mask