ccolas's picture
Upload 174 files
93c029f
raw
history blame
12.9 kB
from torch.utils.data import Dataset
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from src.music2cocktailrep.analysis.explore import get_alignment_dataset
# Add your custom dataset class here
class CocktailDataset(Dataset):
def __init__(self, split, cocktail_reps):
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape
labels = np.zeros([self.n_cocktails])
if split == 'train':
self.cocktail_reps = cocktail_reps[:int(0.9 * self.n_cocktails), :].copy()
self.labels = labels[:int(0.9 * self.n_cocktails)].copy()
elif split == 'test':
self.cocktail_reps = cocktail_reps[int(0.9 * self.n_cocktails):, :].copy()
self.labels = labels[int(0.9 * self.n_cocktails):].copy()
elif split == 'all':
self.cocktail_reps = cocktail_reps.copy()
self.labels = labels.copy()
else:
raise ValueError
# self.n_cocktails = self.cocktail_reps.shape[0]
# indexes = np.arange(self.n_cocktails)
# np.random.shuffle(indexes)
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps).to(device)
# oversample cocktails with eggs and bubbles
ind_egg = np.argwhere(self.cocktail_reps[:, -1] > 0).flatten()
ind_bubbles = np.argwhere(self.cocktail_reps[:, -3] > 0).flatten()
n_copies = 4
egg_copies = torch.tile(self.cocktail_reps[ind_egg, :], dims=(n_copies * 3, 1))
bubbles_copies = torch.tile(self.cocktail_reps[ind_bubbles, :], dims=(n_copies, 1))
self.cocktail_reps = torch.cat([self.cocktail_reps, egg_copies, bubbles_copies], dim=0)
self.n_cocktails = self.cocktail_reps.shape[0]
indexes = np.arange(self.n_cocktails)
np.random.shuffle(indexes)
self.cocktail_reps = self.cocktail_reps[indexes]
self.labels = torch.LongTensor(np.zeros([self.n_cocktails])).to(device)
self.contains_egg = self.cocktail_reps[:, -1] > 0
self.contains_bubbles = self.cocktail_reps[:, -3] > 0
def __len__(self):
return self.cocktail_reps.shape[0]
def __getitem__(self, idx):
return self.cocktail_reps[idx], self.labels[idx], self.contains_egg[idx], self.contains_bubbles[idx]
class CocktailLabeledDataset(Dataset):
def __init__(self, split, cocktail_reps):
dataset = get_alignment_dataset()
labels = sorted(dataset['cocktail'].keys())
self.n_labels = len(labels)
# n_cocktails = np.sum([len(dataset['cocktail'][k]) for k in labels])
all_cocktails = []
for k in labels:
all_cocktails += dataset['cocktail'][k]
# assert n_cocktails == len(set(all_cocktails))
all_cocktails = np.array(all_cocktails)
cocktail_reps = cocktail_reps[all_cocktails]
cocktail_labels = []
for i in all_cocktails:
for i_k, k in enumerate(labels):
if i in dataset['cocktail'][k]:
cocktail_labels.append(i_k)
break
cocktail_labels = np.array(cocktail_labels)
assert len(cocktail_labels) == len(cocktail_reps)
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape
indexes_train = []
indexes_test = []
for k in labels:
indexes_k = np.argwhere(cocktail_labels == labels.index(k)).flatten()
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))])
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):])
indexes_train = np.array(indexes_train)
indexes_test = np.array(indexes_test)
assert len(set(indexes_train) & set(indexes_test)) == 0
if split == 'train':
self.cocktail_reps = cocktail_reps[indexes_train].copy()
self.labels = cocktail_labels[indexes_train].copy()
elif split == 'test':
self.cocktail_reps = cocktail_reps[indexes_test].copy()
self.labels = cocktail_labels[indexes_test].copy()
elif split == 'all':
self.cocktail_reps = cocktail_reps.copy()
self.labels = cocktail_labels.copy()
else:
raise ValueError
self.n_cocktails = self.cocktail_reps.shape[0]
indexes = np.arange(self.n_cocktails)
np.random.shuffle(indexes)
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device)
self.labels = torch.LongTensor(self.labels[indexes]).to(device)
def __len__(self):
return self.cocktail_reps.shape[0]
def __getitem__(self, idx):
return self.cocktail_reps[idx], self.labels[idx]
class MusicDataset(Dataset):
def __init__(self, split, music_reps, music_rep_paths):
self.n_music, self.dim_music = music_reps.shape
labels = np.zeros([self.n_music])
if split == 'train':
self.music_reps = music_reps[:int(0.9 * self.n_music), :].copy()
self.labels = labels[:int(0.9 * self.n_music)].copy()
elif split == 'test':
self.music_reps = music_reps[int(0.9 * self.n_music):, :].copy()
self.labels = labels[int(0.9 * self.n_music):].copy()
elif split == 'all':
self.music_reps = music_reps.copy()
self.labels = labels.copy()
else:
raise ValueError
self.n_music = self.music_reps.shape[0]
indexes = np.arange(self.n_music)
np.random.shuffle(indexes)
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device)
self.labels = torch.LongTensor(self.labels[indexes]).to(device)
def __len__(self):
return self.music_reps.shape[0]
def __getitem__(self, idx):
return self.music_reps[idx], self.labels[idx]
class RegressedGroundingDataset(Dataset):
def __init__(self, split, music_reps, music_rep_paths, cocktail_reps):
dataset = get_alignment_dataset()
labels = sorted(dataset['cocktail'].keys())
self.n_labels = len(labels)
n_music = np.sum([len(dataset['music'][k]) for k in labels])
all_music_filenames = []
for k in labels:
all_music_filenames += dataset['music'][k]
assert n_music == len(set(all_music_filenames))
all_music_filenames = np.array(all_music_filenames)
all_cocktails = []
for k in labels:
all_cocktails += dataset['cocktail'][k]
# assert n_cocktails == len(set(all_cocktails))
all_cocktails = np.array(all_cocktails)
indexes = []
for music_filename in all_music_filenames:
rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt')
found = False
for i, rep_path in enumerate(music_rep_paths):
if rep_name == rep_path[-len(rep_name):]:
indexes.append(i)
found = True
break
assert found
# assert len(indexes) == len(all_music_filenames)
music_reps = music_reps[np.array(indexes)]
music_labels = []
for music_filename in all_music_filenames:
for i_k, k in enumerate(labels):
if music_filename in dataset['music'][k]:
music_labels.append(i_k)
break
assert len(music_labels) == len(music_reps)
music_labels = np.array(music_labels)
self.n_music, self.dim_music = music_reps.shape
self.classes = labels
cocktail_reps = cocktail_reps[all_cocktails]
cocktail_labels = []
for i in all_cocktails:
for i_k, k in enumerate(labels):
if i in dataset['cocktail'][k]:
cocktail_labels.append(i_k)
break
cocktail_labels = np.array(cocktail_labels)
assert len(cocktail_labels) == len(cocktail_reps)
self.n_cocktails, self.dim_cocktail = cocktail_reps.shape
cocktail_reps_matching_music_reps = []
for l in music_labels:
ind_cocktails = np.where(cocktail_labels==l)[0]
cocktail_reps_matching_music_reps.append(cocktail_reps[np.random.choice(ind_cocktails)])
cocktail_reps_matching_music_reps = np.array(cocktail_reps_matching_music_reps)
indexes_train = []
indexes_test = []
for k in labels:
indexes_k = np.argwhere(music_labels == labels.index(k)).flatten()
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))])
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):])
indexes_train = np.array(indexes_train)
indexes_test = np.array(indexes_test)
assert len(set(indexes_train) & set(indexes_test)) == 0
if split == 'train':
self.music_reps = music_reps[indexes_train].copy()
self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_train].copy()
# self.labels = music_labels[indexes_train].copy()
elif split == 'test':
self.music_reps = music_reps[indexes_test].copy()
self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_test].copy()
# self.labels = music_labels[indexes_test].copy()
elif split == 'all':
self.music_reps = music_reps.copy()
self.cocktail_reps = cocktail_reps_matching_music_reps.copy()
# self.labels = music_labels.copy()
else:
raise ValueError
self.n_music = self.music_reps.shape[0]
indexes = np.arange(self.n_music)
np.random.shuffle(indexes)
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device)
self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device)
# self.labels = torch.LongTensor(self.labels[indexes]).to(device)
def __len__(self):
return self.music_reps.shape[0]
def __getitem__(self, idx):
return self.music_reps[idx], self.cocktail_reps[idx]
class MusicLabeledDataset(Dataset):
def __init__(self, split, music_reps, music_rep_paths):
dataset = get_alignment_dataset()
labels = sorted(dataset['cocktail'].keys())
self.n_labels = len(labels)
n_music = np.sum([len(dataset['music'][k]) for k in labels])
all_music_filenames = []
for k in labels:
all_music_filenames += dataset['music'][k]
assert n_music == len(set(all_music_filenames))
all_music_filenames = np.array(all_music_filenames)
indexes = []
for music_filename in all_music_filenames:
rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt')
found = False
for i, rep_path in enumerate(music_rep_paths):
if rep_name == rep_path[-len(rep_name):]:
indexes.append(i)
found = True
break
assert found
# assert len(indexes) == len(all_music_filenames)
music_reps = music_reps[np.array(indexes)]
music_labels = []
for music_filename in all_music_filenames:
for i_k, k in enumerate(labels):
if music_filename in dataset['music'][k]:
music_labels.append(i_k)
break
assert len(music_labels) == len(music_reps)
music_labels = np.array(music_labels)
self.n_music, self.dim_music = music_reps.shape
self.classes = labels
indexes_train = []
indexes_test = []
for k in labels:
indexes_k = np.argwhere(music_labels == labels.index(k)).flatten()
indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))])
indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):])
indexes_train = np.array(indexes_train)
indexes_test = np.array(indexes_test)
assert len(set(indexes_train) & set(indexes_test)) == 0
if split == 'train':
self.music_reps = music_reps[indexes_train].copy()
self.labels = music_labels[indexes_train].copy()
elif split == 'test':
self.music_reps = music_reps[indexes_test].copy()
self.labels = music_labels[indexes_test].copy()
elif split == 'all':
self.music_reps = music_reps.copy()
self.labels = music_labels.copy()
else:
raise ValueError
self.n_music = self.music_reps.shape[0]
indexes = np.arange(self.n_music)
np.random.shuffle(indexes)
self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device)
self.labels = torch.LongTensor(self.labels[indexes]).to(device)
def __len__(self):
return self.music_reps.shape[0]
def __getitem__(self, idx):
return self.music_reps[idx], self.labels[idx]