diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/__init__.py b/src/cocktails/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/__pycache__/__init__.cpython-39.pyc b/src/cocktails/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..382a21c5f338e749ead439a77583fe407fb842b0 Binary files /dev/null and b/src/cocktails/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/cocktails/__pycache__/config.cpython-39.pyc b/src/cocktails/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..553d60186e05f5acc35c844c8546e65b4b6c94d4 Binary files /dev/null and b/src/cocktails/__pycache__/config.cpython-39.pyc differ diff --git a/src/cocktails/config.py b/src/cocktails/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bce5b65a666caf9972ea64933a4a74eb4e2532c0 --- /dev/null +++ b/src/cocktails/config.py @@ -0,0 +1,21 @@ +import os + +REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/' + +# QUADRUPLETS_PATH = REPO_PATH + 'checkpoints/cocktail_representation/quadruplets.pickle' +INGREDIENTS_LIST_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_list.csv' +# ING_MATCH_SCORE_Q_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_q.txt' +# ING_MATCH_SCORE_COUNT_PATH = REPO_PATH + 'checkpoints/cocktail_representation/ingredient_match_score_count.txt' +# COCKTAIL_DATA_FOLDER_PATH = REPO_PATH + 'checkpoints/cocktail_representation/' +COCKTAILS_CSV_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.csv' +# COCKTAILS_PKL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_data.pkl' +# COCKTAILS_URL_DATA = REPO_PATH + 'checkpoints/cocktail_representation/cocktails_names_urls.pkl' +EXPERIMENT_PATH = REPO_PATH + 'experiments/cocktails/representation_learning/' +# ANALYSIS_PATH = REPO_PATH + 'experiments/cocktails/representation_analysis/' +# REPRESENTATIONS_PATH = REPO_PATH + 'experiments/cocktails/learned_representations/' + +FULL_COCKTAIL_REP_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/cocktail_handcoded_reps_minmax_norm-1_1_dim13_customkeys.txt" +RECIPE2FEATURES_PATH = REPO_PATH + "/checkpoints/cocktail_representation/" # get this by running run_without_vae +COCKTAIL_REP_CHKPT_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/" +# FULL_COCKTAIL_REP_PATH = REPO_PATH + "experiments/cocktails/representation_analysis/affective_mapping/clustered_representations/all_cocktail_reps_norm-1_1_custom_keys_dim13.txt' +COCKTAIL_NN_PATH = REPO_PATH + "/checkpoints/cocktail_representation/handcoded_reps/nn_model.pickle" \ No newline at end of file diff --git a/src/cocktails/pipeline/__init__.py b/src/cocktails/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/pipeline/__pycache__/__init__.cpython-39.pyc b/src/cocktails/pipeline/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0706553feb0afa16ff8a8b843482a81c627aa770 Binary files /dev/null and b/src/cocktails/pipeline/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/cocktails/pipeline/__pycache__/cocktail2affect.cpython-39.pyc b/src/cocktails/pipeline/__pycache__/cocktail2affect.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96502ab55cfde3752380321c9b979c1b13409646 Binary files /dev/null and b/src/cocktails/pipeline/__pycache__/cocktail2affect.cpython-39.pyc differ diff --git a/src/cocktails/pipeline/__pycache__/cocktailrep2recipe.cpython-39.pyc b/src/cocktails/pipeline/__pycache__/cocktailrep2recipe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7757eedba8a847aebfed4347d8b802601d00d2d0 Binary files /dev/null and b/src/cocktails/pipeline/__pycache__/cocktailrep2recipe.cpython-39.pyc differ diff --git a/src/cocktails/pipeline/__pycache__/get_affect2affective_cluster.cpython-39.pyc b/src/cocktails/pipeline/__pycache__/get_affect2affective_cluster.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d96a3fe6619774e4a97ea771541921837c2dc2 Binary files /dev/null and b/src/cocktails/pipeline/__pycache__/get_affect2affective_cluster.cpython-39.pyc differ diff --git a/src/cocktails/pipeline/__pycache__/get_cocktail2affective_cluster.cpython-39.pyc b/src/cocktails/pipeline/__pycache__/get_cocktail2affective_cluster.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7844a780d4124da8cfcc4814224ad2f4e0a85570 Binary files /dev/null and b/src/cocktails/pipeline/__pycache__/get_cocktail2affective_cluster.cpython-39.pyc differ diff --git a/src/cocktails/pipeline/cocktail2affect.py b/src/cocktails/pipeline/cocktail2affect.py new file mode 100644 index 0000000000000000000000000000000000000000..ad272dbdd48ea1b3cbd8d4a972dc092e16d251e1 --- /dev/null +++ b/src/cocktails/pipeline/cocktail2affect.py @@ -0,0 +1,372 @@ +import pandas as pd +import numpy as np +import os +from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys +from src.cocktails.utilities.other_scrubbing_utilities import print_recipe +from src.cocktails.config import COCKTAILS_CSV_DATA +from src.music.config import CHECKPOINTS_PATH, EXPERIMENT_PATH +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans +from sklearn.mixture import GaussianMixture +from sklearn.neighbors import NearestNeighbors +import pickle +import random + +experiment_path = EXPERIMENT_PATH + '/cocktails/representation_analysis/affective_mapping/' +min_max_path = CHECKPOINTS_PATH + "/cocktail_representation/minmax/" +cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle" +affective_space_dimensions = ((-1, 1), (-1, 1), (-1, 1)) # valence, arousal, dominance +n_splits = (3, 3, 2) # number of bins per dimension +# dimensions_weights = [1, 1, 0.5] +dimensions_weights = [1, 1, 1] +total_n_clusters = np.prod(n_splits) # total number of bins +affective_boundaries = [np.arange(asd[0], asd[1]+1e-6, (asd[1] - asd[0]) / n_split) for asd, n_split in zip(affective_space_dimensions, n_splits)] +for af in affective_boundaries: + af[-1] += 1e-6 +all_keys = get_bunch_of_rep_keys()['custom'] +original_affective_keys = get_bunch_of_rep_keys()['affective'] +affective_keys = [a.split(' ')[1] for a in original_affective_keys] +random.seed(0) +cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(total_n_clusters)] + +clustering_method = 'k_means' # 'k_means', 'handcoded', 'agglo', 'spectral' +if clustering_method != 'handcoded': + total_n_clusters = 10 +min_arousal = np.loadtxt(min_max_path + 'min_arousal.txt') +max_arousal = np.loadtxt(min_max_path + 'max_arousal.txt') +min_val = np.loadtxt(min_max_path + 'min_valence.txt') +max_val = np.loadtxt(min_max_path + 'max_valence.txt') +min_dom = np.loadtxt(min_max_path + 'min_dominance.txt') +max_dom = np.loadtxt(min_max_path + 'max_dominance.txt') + +def get_cocktail_reps(path, save=False): + cocktail_data = pd.read_csv(path) + cocktail_reps = np.array([cocktail_data[k] for k in original_affective_keys]).transpose() + n_data, dim_rep = cocktail_reps.shape + # print(f'{n_data} data points of {dim_rep} dimensions: {affective_keys}') + cocktail_reps = normalize_cocktail_reps_affective(cocktail_reps, save=save) + if save: + np.savetxt(experiment_path + f'cocktail_reps_for_affective_mapping_-1_1_norm_sigmoid_rescaling_{dim_rep}_keys.txt', cocktail_reps) + return cocktail_reps + +def sigmoid(x, shift, beta): + return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 + +def normalize_cocktail_reps_affective(cocktail_reps, save=False): + if save: + min_cr = cocktail_reps.min(axis=0) + max_cr = cocktail_reps.max(axis=0) + np.savetxt(min_max_path + 'min_cocktail_reps_affective.txt', min_cr) + np.savetxt(min_max_path + 'max_cocktail_reps_affective.txt', max_cr) + else: + min_cr = np.loadtxt(min_max_path + 'min_cocktail_reps_affective.txt') + max_cr = np.loadtxt(min_max_path + 'max_cocktail_reps_affective.txt') + cocktail_reps = ((cocktail_reps - min_cr) / (max_cr - min_cr) - 0.5) * 2 + cocktail_reps[:, 0] = sigmoid(cocktail_reps[:, 0], shift=0.05, beta=4) + cocktail_reps[:, 1] = sigmoid(cocktail_reps[:, 1], shift=0.3, beta=5) + cocktail_reps[:, 2] = sigmoid(cocktail_reps[:, 2], shift=0.15, beta=3) + cocktail_reps[:, 3] = sigmoid(cocktail_reps[:, 3], shift=0.9, beta=20) + cocktail_reps[:, 4] = sigmoid(cocktail_reps[:, 4], shift=0, beta=4) + cocktail_reps[:, 5] = sigmoid(cocktail_reps[:, 5], shift=0.2, beta=3) + cocktail_reps[:, 6] = sigmoid(cocktail_reps[:, 6], shift=0.5, beta=5) + cocktail_reps[:, 7] = sigmoid(cocktail_reps[:, 7], shift=0.2, beta=6) + return cocktail_reps + +def plot(cocktail_reps): + dim_rep = cocktail_reps.shape[1] + for i in range(dim_rep): + for j in range(i+1, dim_rep): + plt.figure() + plt.scatter(cocktail_reps[:, i], cocktail_reps[:, j], s=150, alpha=0.5) + plt.xlabel(affective_keys[i]) + plt.ylabel(affective_keys[j]) + plt.savefig(experiment_path + f'scatters/{affective_keys[i]}_vs_{affective_keys[j]}.png', dpi=300) + plt.close('all') + plt.figure() + plt.hist(cocktail_reps[:, i]) + plt.xlabel(affective_keys[i]) + plt.savefig(experiment_path + f'hists/{affective_keys[i]}.png', dpi=300) + plt.close('all') + +def get_clusters(affective_coordinates, save=False): + if clustering_method in ['k_means', 'gmm',]: + if clustering_method == 'k_means': model = KMeans(n_clusters=total_n_clusters) + elif clustering_method == 'gmm': model = GaussianMixture(n_components=total_n_clusters, covariance_type="full") + model.fit(affective_coordinates * np.array(dimensions_weights)) + + def find_cluster(aff_coord): + if aff_coord.ndim == 1: + aff_coord = aff_coord.reshape(1, -1) + return model.predict(aff_coord * np.array(dimensions_weights)) + cluster_centers = model.cluster_centers_ if clustering_method == 'k_means' else [] + if save: + to_save = dict(cluster_model=model, + cluster_centers=cluster_centers, + nb_clusters=len(cluster_centers), + dimensions_weights=dimensions_weights) + with open(cluster_model_path, 'wb') as f: + pickle.dump(to_save, f) + stop= 1 + + elif clustering_method == 'handcoded': + def find_cluster(aff_coord): + if aff_coord.ndim == 1: + aff_coord = aff_coord.reshape(1, -1) + cluster_coordinates = [] + for i in range(aff_coord.shape[0]): + cluster_coordinates.append([np.argwhere(affective_boundaries[j] <= aff_coord[i, j]).flatten()[-1] for j in range(3)]) + cluster_coordinates = np.array(cluster_coordinates) + cluster_ids = cluster_coordinates[:, 0] * np.prod(n_splits[1:]) + cluster_coordinates[:, 1] * n_splits[-1] + cluster_coordinates[:, 2] + return cluster_ids + # find cluster centers + cluster_centers = [] + for i in range(n_splits[0]): + asd = affective_space_dimensions[0] + x_coordinate = np.arange(asd[0] + 1 / n_splits[0], asd[1], (asd[1] - asd[0]) / n_splits[0])[i] + for j in range(n_splits[1]): + asd = affective_space_dimensions[1] + y_coordinate = np.arange(asd[0] + 1 / n_splits[1], asd[1], (asd[1] - asd[0]) / n_splits[1])[j] + for k in range(n_splits[2]): + asd = affective_space_dimensions[2] + z_coordinate = np.arange(asd[0] + 1 / n_splits[2], asd[1], (asd[1] - asd[0]) / n_splits[2])[k] + cluster_centers.append([x_coordinate, y_coordinate, z_coordinate]) + cluster_centers = np.array(cluster_centers) + else: + raise NotImplemented + cluster_ids = find_cluster(affective_coordinates) + return cluster_ids, cluster_centers, find_cluster + + +def cocktail2affect(cocktail_reps, save=False): + if cocktail_reps.ndim == 1: + cocktail_reps = cocktail_reps.reshape(1, -1) + + assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful'] + all_weights = [] + + # valence + # + sweet - bitter - booze + colorful + weights = np.array([-1, 1, 0, 0, 0, -1, 0, 1]) + valence = (cocktail_reps * weights).sum(axis=1) + if save: + min_ = valence.min() + max_ = valence.max() + np.savetxt(min_max_path + 'min_valence.txt', np.array([min_])) + np.savetxt(min_max_path + 'max_valence.txt', np.array([max_])) + else: + min_ = min_val + max_ = max_val + valence = 2 * ((valence - min_) / (max_ - min_) - 0.5) + valence = sigmoid(valence, shift=0.1, beta=3.5) + valence = valence.reshape(-1, 1) + all_weights.append(weights.copy()) + + # arousal + # + fizzy + sour + complex - sweet + spicy + bitter + # weights = np.array([0, -1, 1, 1, 1, 1, 1, 0]) + weights = np.array([0.7, 0, 1.5, 1.5, 0.6, 0, 0.6, 0]) + arousal = (cocktail_reps * weights).sum(axis=1) + if save: + min_ = arousal.min() + max_ = arousal.max() + np.savetxt(min_max_path + 'min_arousal.txt', np.array([min_])) + np.savetxt(min_max_path + 'max_arousal.txt', np.array([max_])) + else: + min_, max_ = min_arousal, max_arousal + arousal = 2 * ((arousal - min_) / (max_ - min_) - 0.5) # normalize to -1, 1 + arousal = sigmoid(arousal, shift=0.3, beta=4) + arousal = arousal.reshape(-1, 1) + all_weights.append(weights.copy()) + + # dominance + # assert affective_keys == ['booze', 'sweet', 'sour', 'fizzy', 'complex', 'bitter', 'spicy', 'colorful'] + # + booze + fizzy - complex - bitter - sweet + weights = np.array([1.5, -0.8, 0, 0.7, -1, -1.5, 0, 0]) + dominance = (cocktail_reps * weights).sum(axis=1) + if save: + min_ = dominance.min() + max_ = dominance.max() + np.savetxt(min_max_path + 'min_dominance.txt', np.array([min_])) + np.savetxt(min_max_path + 'max_dominance.txt', np.array([max_])) + else: + min_, max_ = min_dom, max_dom + dominance = 2 * ((dominance - min_) / (max_ - min_) - 0.5) + dominance = sigmoid(dominance, shift=-0.05, beta=5) + dominance = dominance.reshape(-1, 1) + all_weights.append(weights.copy()) + + affective_coordinates = np.concatenate([valence, arousal, dominance], axis=1) + # if save: + # assert (affective_coordinates.min(axis=0) == np.array([ac[0] for ac in affective_space_dimensions])).all() + # assert (affective_coordinates.max(axis=0) == np.array([ac[1] for ac in affective_space_dimensions])).all() + return affective_coordinates, all_weights + +def save_reps(path, affective_cluster_ids): + cocktail_data = pd.read_csv(path) + rep_keys = get_bunch_of_rep_keys()['custom'] + cocktail_reps = np.array([cocktail_data[k] for k in rep_keys]).transpose() + np.savetxt(experiment_path + 'clustered_representations/' + f'min_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.min(axis=0)) + np.savetxt(experiment_path + 'clustered_representations/' + f'max_cocktail_reps_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps.max(axis=0)) + cocktail_reps = ((cocktail_reps - cocktail_reps.min(axis=0)) / (cocktail_reps.max(axis=0) - cocktail_reps.min(axis=0)) - 0.5) * 2 # normalize in -1, 1 + np.savetxt(experiment_path + 'clustered_representations/' + f'all_cocktail_reps_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', cocktail_reps) + np.savetxt(experiment_path + 'clustered_representations/' + 'affective_cluster_ids.txt', affective_cluster_ids) + for cluster_id in sorted(set(affective_cluster_ids)): + indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten() + reps = cocktail_reps[indexes, :] + np.savetxt(experiment_path + 'clustered_representations/' + f'rep_cluster{cluster_id}_norm-1_1_custom_keys_dim{cocktail_reps.shape[1]}.txt', reps) + +def study_affects(affective_coordinates, affective_cluster_ids): + plt.figure() + plt.hist(affective_cluster_ids, bins=total_n_clusters) + plt.xlabel('Affective cluster ids') + plt.xticks(np.arange(total_n_clusters)) + plt.savefig(experiment_path + 'affective_cluster_distrib.png') + fig = plt.gcf() + plt.close(fig) + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + ax.set_zlim([-1, 1]) + for cluster_id in sorted(set(affective_cluster_ids)): + indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten() + ax.scatter(affective_coordinates[indexes, 0], affective_coordinates[indexes, 1], affective_coordinates[indexes, 2], c=cluster_colors[cluster_id], s=150) + ax.set_xlabel('Valence') + ax.set_ylabel('Arousal') + ax.set_zlabel('Dominance') + stop = 1 + plt.savefig(experiment_path + 'scatters_affect/affective_mapping.png') + fig = plt.gcf() + plt.close(fig) + + affects = ['Valence', 'Arousal', 'Dominance'] + for i in range(3): + for j in range(i + 1, 3): + fig = plt.figure() + ax = fig.add_subplot() + for cluster_id in sorted(set(affective_cluster_ids)): + indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten() + ax.scatter(affective_coordinates[indexes, i], affective_coordinates[indexes, j], alpha=0.5, c=cluster_colors[cluster_id], s=150) + ax.set_xlabel(affects[i]) + ax.set_ylabel(affects[j]) + plt.savefig(experiment_path + f'scatters_affect/scatter_{affects[i]}_vs_{affects[j]}.png') + fig = plt.gcf() + plt.close(fig) + plt.figure() + plt.hist(affective_coordinates[:, i]) + plt.xlabel(affects[i]) + plt.savefig(experiment_path + f'hists_affect/hist_{affects[i]}.png') + fig = plt.gcf() + plt.close(fig) + plt.close('all') + stop = 1 + +def sample_clusters(path, cocktail_reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates, n_samples=4): + cocktail_data = pd.read_csv(path) + these_cocktail_reps = normalize_cocktail_reps_affective(np.array([cocktail_data[k] for k in original_affective_keys]).transpose()) + + names = cocktail_data['names'] + urls = cocktail_data['urls'] + ingr_str = cocktail_data['ingredients_str'] + for cluster_id in sorted(set(affective_cluster_ids)): + indexes = np.argwhere(affective_cluster_ids == cluster_id).flatten() + print('\n\n\n---------\n----------\n-----------\n') + cluster_str = '' + cluster_str += f'Affective cluster #{cluster_id}' + \ + f'\n\tSize: {len(indexes)}' + \ + f'\n\tCenter: ' + \ + f'\n\t\tVal: {affective_cluster_centers[cluster_id][0]:.2f}, ' + \ + f'\n\t\tArousal: {affective_cluster_centers[cluster_id][1]:.2f}, ' + \ + f'\n\t\tDominance: {affective_cluster_centers[cluster_id][2]:.2f}' + print(cluster_str) + if affective_cluster_centers[cluster_id][2] == np.max(affective_cluster_centers[:, 2]): + stop = 1 + sampled_idx = np.random.choice(indexes, size=min(len(indexes), n_samples), replace=False) + cocktail_str = '' + for i in sampled_idx: + assert np.sum(cocktail_reps[i] - these_cocktail_reps[i]) < 1e-9 + cocktail_str += f'\n\n-------------' + cocktail_str += print_recipe(ingr_str[i], name=names[i], to_print=False) + cocktail_str += f'\nUrl: {urls[i]}' + cocktail_str += '\n\nRepresentation: ' + ', '.join([f'{af}: {cr:.2f}' for af, cr in zip(affective_keys, cocktail_reps[i])]) + '\n' + cocktail_str += '\n' + generate_explanation(cocktail_reps[i], all_weights, affective_coordinates[i]) + print(cocktail_str) + stop = 1 + cluster_str += '\n' + cocktail_str + with open(f"/home/cedric/Documents/pianocktail/experiments/cocktails/representation_analysis/affective_mapping/clusters/cluster_{cluster_id}", 'w') as f: + f.write(cluster_str) + stop = 1 + +def explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord): + names = ['valence', 'arousal', 'dominance'] + weights = all_weights[i] + explanation_str = f'\n{names[i].capitalize()} explanation ({aff_coord[i]:.2f}):' + strengths = np.abs(weights * cocktail_rep) + strengths /= strengths.sum() + indexes = np.flip(np.argsort(strengths)) + for ind in indexes: + if strengths[ind] != 0: + if np.sign(weights[ind]) == np.sign(cocktail_rep[ind]): + keyword = 'high' if cocktail_rep[ind] > 0 else 'low' + explanation_str += f'\n\t{int(strengths[ind]*100)}%: higher {names[i]} because {keyword} {affective_keys[ind]}' + else: + keyword = 'high' if cocktail_rep[ind] > 0 else 'low' + explanation_str += f'\n\t{int(strengths[ind]*100)}%: low {names[i]} because {keyword} {affective_keys[ind]}' + return explanation_str + +def generate_explanation(cocktail_rep, all_weights, aff_coord): + explanation_str = '' + for i in range(3): + explanation_str += explanation_per_dimension(i, cocktail_rep, all_weights, aff_coord) + return explanation_str + +def cocktails2affect_clusters(cocktail_rep): + if cocktail_rep.ndim == 1: + cocktail_rep = cocktail_rep.reshape(1, -1) + affective_coordinates, _ = cocktail2affect(cocktail_rep) + affective_cluster_ids, _, _ = get_clusters(affective_coordinates) + return affective_cluster_ids + + +def setup_affective_space(path, save=False): + cocktail_data = pd.read_csv(path) + names = cocktail_data['names'] + recipes = cocktail_data['ingredients_str'] + urls = cocktail_data['urls'] + reps = get_cocktail_reps(path) + affective_coordinates, all_weights = cocktail2affect(reps) + affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates, save=save) + nn_model = NearestNeighbors(n_neighbors=1) + nn_model.fit(affective_coordinates) + def cocktail2affect_cluster(cocktail_rep): + affective_coordinates, _ = cocktail2affect(cocktail_rep) + return find_cluster(affective_coordinates) + + affective_clusters = dict(affective_coordinates=affective_coordinates, # coordinates of cocktail in affective space + affective_cluster_ids=affective_cluster_ids, # cluster id of cocktails + affective_cluster_centers=affective_cluster_centers, # cluster centers in affective space + affective_weights=all_weights, # weights to compute valence, arousal, dominance from cocktail representations + original_affective_keys=original_affective_keys, + cocktail_reps=reps, # cocktail representations from the dataset (normalized) + find_cluster=find_cluster, # function to retrieve a cluster from affective coordinates + nn_model=nn_model, # to predict the nearest neighbor affective space, + names=names, # names of cocktails in the dataset + urls=urls, # urls from the dataset + recipes=recipes, # recipes of the dataset + cocktail2affect=cocktail2affect, # function to compute affects from cocktail representations + cocktails2affect_clusters=cocktails2affect_clusters, + cocktail2affect_cluster=cocktail2affect_cluster + ) + + return affective_clusters + +if __name__ == '__main__': + reps = get_cocktail_reps(COCKTAILS_CSV_DATA, save=True) + # plot(reps) + affective_coordinates, all_weights = cocktail2affect(reps, save=True) + affective_cluster_ids, affective_cluster_centers, find_cluster = get_clusters(affective_coordinates) + save_reps(COCKTAILS_CSV_DATA, affective_cluster_ids) + study_affects(affective_coordinates, affective_cluster_ids) + sample_clusters(COCKTAILS_CSV_DATA, reps, all_weights, affective_cluster_ids, affective_cluster_centers, affective_coordinates) + setup_affective_space(COCKTAILS_CSV_DATA, save=True) diff --git a/src/cocktails/pipeline/cocktailrep2recipe.py b/src/cocktails/pipeline/cocktailrep2recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..17a4ae4b55205f0107f9cde64a4d459b655c3be6 --- /dev/null +++ b/src/cocktails/pipeline/cocktailrep2recipe.py @@ -0,0 +1,329 @@ +import matplotlib.pyplot as plt +import pickle +from src.cocktails.utilities.cocktail_generation_utilities.population import * +from src.cocktails.utilities.glass_and_volume_utilities import glass_volume +from src.cocktails.config import RECIPE2FEATURES_PATH + +def test_mutation_params(cocktail_reps): + indexes = np.arange(cocktail_reps.shape[0]) + np.random.shuffle(indexes) + perfs = [] + mutated_perfs = [] + pop_params = dict(mutation_params=dict(p_add_ing=0.7, + p_remove_ing=0.7, + p_switch_ing=0.5, + p_change_q=0.7, + delta_change_q=0.3, + asexual_rep=True, + crossover=True, + ingredient_addition=(0.1, 0.05)), + nb_generations=100, + pop_size=100, + nb_elites=10, + dist='mse', + n_neighbors=5) + + for i in indexes[:20]: + target = cocktail_reps[i] + for j in range(100): + parent = IndividualCocktail(pop_params=pop_params, + target_affective_cluster=None, + target=target.copy()) + perfs.append(parent.perf) + child = parent.get_child()[0] + # child.compute_cocktail_rep() + # child.compute_perf() + if perfs[-1] != child.perf: + mutated_perfs.append(child.perf) + else: + perfs.pop(-1) + filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten() + non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)!=-100, np.array(mutated_perfs)!=-100)).flatten() + print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%') + plt.figure() + plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5) + plt.xlabel('parent perf') + plt.ylabel('child perf') + print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1]) + plt.show() + stop = 1 + +def test_crossover(cocktail_reps): + indexes = np.arange(cocktail_reps.shape[0]) + np.random.shuffle(indexes) + perfs = [] + mutated_perfs = [] + pop_params = dict(mutation_params=dict(p_add_ing=0.7, + p_remove_ing=0.7, + p_switch_ing=0.5, + p_change_q=0.7, + delta_change_q=0.3, + asexual_rep=True, + crossover=True, + ingredient_addition=(0.1, 0.05)), + nb_generations=100, + pop_size=100, + nb_elites=10, + dist='mse', + n_neighbors=5) + for i in indexes[:20]: + for j in range(100): + target = cocktail_reps[i] + parent1 = IndividualCocktail(pop_params=pop_params, + target_affective_cluster=None, + target=target.copy()) + parent2 = IndividualCocktail(pop_params=pop_params, + target_affective_cluster=None, + target=target.copy()) + child = parent1.get_child_with(parent2)[0] + # child.compute_cocktail_rep() + # child.compute_perf() + perfs.append((parent1.perf + parent2.perf)/2) + if perfs[-1] != child.perf: + mutated_perfs.append(child.perf) + else: + perfs.pop(-1) + filtered_children = np.argwhere(np.array(mutated_perfs)==-100).flatten() + non_filtered_ids = np.argwhere(np.logical_and(np.array(perfs)>-45, np.array(mutated_perfs)!=-100)).flatten() + print(f'Proportion of filtered: {filtered_children.size} / {len(mutated_perfs)} = {int(filtered_children.size / len(mutated_perfs)*100)}%') + plt.figure() + plt.scatter(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids], s=100, alpha=0.5) + plt.xlabel('parent perf') + plt.ylabel('child perf') + print(np.corrcoef(np.array(perfs)[non_filtered_ids], np.array(mutated_perfs)[non_filtered_ids])[0, 1]) + plt.show() + stop = 1 + +def run_comparisons(): + np.random.seed(0) + indexes = np.arange(cocktail_reps.shape[0]) + np.random.shuffle(indexes) + for n_neighbors in [0, 5]: + id_str_neigh = '5neigh_' if n_neighbors == 5 else '0_neigh_' + for asexual_rep in [True, False]: + id_str_as = id_str_neigh + 'asexual_' if asexual_rep else id_str_neigh + for crossover in [True, False]: + id_str = id_str_as + 'crossover_' if crossover else id_str_as + if crossover or asexual_rep: + mutation_params = dict(p_add_ing = 0.5, + p_remove_ing = 0.5, + p_change_q = 0.5, + delta_change_q = 0.3, + asexual_rep=asexual_rep, + crossover=crossover, + ingredient_addition = (0.1, 0.05)) + nb_generations = 100 + pop_size=100 + nb_elites=10 + dist = 'mse' + results = dict() + print(id_str) + for i, ind in enumerate(indexes[:30]): + print(i+1) + target_ing_str = data['ingredients_str'][ind] + target = cocktail_reps[ind] + population = Population(nb_generations=nb_generations, pop_size=pop_size, nb_elite=nb_elites, + target=target, dist=dist, mutation_params=mutation_params, + n_neighbors=n_neighbors, target_ing_str=target_ing_str, true_prep_type=data['category'][ind]) + population.run_evolution(verbose=False) + best_scores, best_ind = population.get_best_score() + recipes = [ind.get_recipe()[3] for ind in best_ind[:5]] + results[str(ind)] = dict(best_scores=best_scores[:5], recipes=recipes, target=population.target_individual.get_recipe()[3]) + with open(f'/home/cedric/Desktop/ga_tests_{id_str}.pickle', 'wb') as f: + pickle.dump(results, f) + +def get_cocktail_distribution(cocktail_reps): + return (np.mean(cocktail_reps, axis=0), np.cov(cocktail_reps, rowvar=0)) + +def sample_cocktails(cocktail_reps, n=10, target_affective_cluster=None, to_print=True): + distrib = get_cocktail_distribution(cocktail_reps) + sampled_cocktail_reps = np.random.multivariate_normal(distrib[0], distrib[1], size=n) + recipes = [] + closest_recipes = [] + for i_c, cr in enumerate(sampled_cocktail_reps): + population = setup_recipe_generation(cr.copy(), target_affective_cluster=target_affective_cluster) + closest_recipes.append(population.nn_recipes[0]) + best_scores, best_individuals = population.run_evolution() + recipes.append(best_individuals[0].get_recipe()[3]) + if to_print: + print(f'Sample #{len(recipes)}:') + print(recipes[-1]) + print('Closest from dataset:') + print(closest_recipes[-1]) + stop = 1 + return recipes, closest_recipes + +def setup_recipe_generation(target, known_target_dict=None, target_affective_cluster=None): + # pop_params = dict(mutation_params=dict(p_add_ing=0.7, + # p_remove_ing=0.7, + # p_switch_ing=0.5, + # p_change_q=0.7, + # delta_change_q=0.3, + # asexual_rep=True, + # crossover=True, + # ingredient_addition=(0.1, 0.05)), + # nb_generations=2, #100 + # pop_size=5, #100 + # nb_elites=2, #10 + # dist='mse', + # n_neighbors=3) #5 + pop_params = dict(mutation_params=dict(p_add_ing=0.4, + p_remove_ing=1, + p_switch_ing=0.5, + p_change_q=1, + delta_change_q=0.3, + asexual_rep=True, + crossover=True, + ingredient_addition=(0.1, 0.05)), + nb_generations=100, # 100 + pop_size=100, # 100 + nb_elites=10, # 10 + dist='mse', + n_neighbors=5) # 5 + + population = Population(target=target, target_affective_cluster=target_affective_cluster, known_target_dict=known_target_dict, pop_params=pop_params) + return population + +def cocktailrep2recipe(cocktail_rep, unit='mL', target_affective_cluster=None, known_target_dict=None, n_output=1, return_ind=False, verbose=True, full_verbose=False, level=0): + init_time = time.time() + if verbose: print(' ' * level + 'Generating cocktail..') + if cocktail_rep.ndim > 1: + assert cocktail_rep.shape[0] == 1 + cocktail_rep = cocktail_rep.flatten() + # target_affective_cluster = target_affective_cluster[0] + population = setup_recipe_generation(cocktail_rep.copy(), known_target_dict=known_target_dict, target_affective_cluster=target_affective_cluster) + if full_verbose: + print(' ' * (level + 2) + '3 nearest neighbors:') + for i, recipe, score in zip(range(3), population.nn_recipes[:3], population.nn_scores[:3]): + print(' ' * (level + 4) + f'#{i+1}, score: {score:.2f}') + print(' ' * (level + 4) + recipe[1:].replace('None ()', '').replace('\t\t', ' ' * (level + 6))) + best_scores, best_individuals = population.run_evolution(verbose=full_verbose, level=level+2) + for i in range(n_output): + best_individuals[i].make_recipe_fit_the_glass() + instructions = [ind.get_instructions() for ind in best_individuals[:n_output]] + recipes = [ind.get_recipe(unit=unit)[3] for ind in best_individuals[:n_output]] + glasses = [ind.glass for ind in best_individuals[:n_output]] + prep_types = [ind.prep_type for ind in best_individuals[:n_output]] + for i, g, p, inst in zip(range(len(recipes)), glasses, prep_types, instructions): + recipes[i] = recipes[i].replace('Recipe', 'Ingredients') + f'Serve in:\n {g.capitalize()} glass.\n' + inst + if full_verbose: + print(f'\n--------------\n{n_output} best results:') + for i, recipe, score in zip(range(n_output), recipes, best_scores[:n_output]): + print(f'#{i+1}, score: {score:.2f}') + print(recipe) + if verbose: print(' ' * (level + 2) + f'Generated in {int(time.time() - init_time)} seconds.') + if return_ind: + return recipes, best_scores[:n_output], best_individuals[:n_output] + else: + return recipes, best_scores[:n_output] + + +def interpolate(cocktail_rep1, cocktail_rep2, alpha, verbose=False): + recipe, score = cocktailrep2recipe(alpha * cocktail_rep1 + (1 - alpha) * cocktail_rep2, verbose=verbose) + return recipe[0], score + +def interpolation_study(n_steps, cocktail_reps): + alphas = np.arange(0, 1 + 1e-6, 1/(n_steps + 1)) + indexes = np.random.choice(np.arange(cocktail_reps.shape[0]), size=2, replace=False) + target_ing_str1, target_ing_str2 = data['ingredients_str'][indexes[0]], data['ingredients_str'][indexes[1]] + cocktail_rep1, cocktail_rep2 = cocktail_reps[indexes[0]], cocktail_reps[indexes[1]] + recipes, scores = [], [] + for alpha in alphas: + recipe, score = interpolate(cocktail_rep1, cocktail_rep2, alpha) + recipes.append(recipe) + scores.append(score[0]) + print('Point A:') + print_recipe(ingredient_str=target_ing_str2) + for i, alpha in enumerate(alphas): + print(f'Alpha = {alpha}, score = {scores[i]}') + print(recipes[i]) + print('Point B:') + print_recipe(ingredient_str=target_ing_str1) + stop = 1 + +def test_robustness_affective_cluster(cocktail_reps): + indexes = np.arange(cocktail_reps.shape[0]) + np.random.shuffle(indexes) + matches = [] + for i in indexes: + target_ing_str = data['ingredients_str'][i] + true_prep_type = data['category'][i] + target = cocktail_reps[i] + # get affective cluster + recipes, best_scores, best_inds = cocktailrep2recipe(cocktail_rep=target, target_ing_str=target_ing_str, true_prep_type=true_prep_type, n_output=1, verbose=False, + return_ind=True) + + matches.append(best_inds[0].does_affective_cluster_match()) + print(np.mean(matches)) + +def test(cocktail_reps): + indexes = np.arange(these_cocktail_reps.shape[0]) + unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose() + + for i in indexes: + target_ing_str = data['ingredients_str'][i] + true_prep_type = data['category'][i] + target = these_cocktail_reps[i] + # print('preptype:', true_prep_type) + # print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i]) + # print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i])) + # print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i]) + # print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i]) + + population = setup_recipe_generation(target.copy(), target_ing_str=target_ing_str, target_affective_cluster=None, true_prep_type=true_prep_type) + target = population.target_individual + target.compute_perf() + if target.perf < -50: + print(i) + print_recipe(target_ing_str) + if not target.is_alcohol_present(): print('No alcohol') + if not target.is_total_volume_enough(): print('small volume') + if not target.does_fit_glass(): + print(target.end_volume) + print(glass_volume[target.get_glass_type()] * 0.81) + print('too much volume') + if not target.is_alcohol_reasonable(): + print(f'amount of alcohol too small or too large: {target.alcohol_precentage}') + stop = 1 + + +if __name__ == '__main__': + these_cocktail_reps = COCKTAIL_REPS.copy() + # test_crossover(these_cocktail_reps) + # test_mutation_params(these_cocktail_reps) + # test(these_cocktail_reps) + # recipes, closest_recipes = sample_cocktails(these_cocktail_reps, n=10) + # interpolation_study(n_steps=4, cocktail_reps=these_cocktail_reps) + # test_robustness_affective_cluster(these_cocktail_reps) + indexes = np.arange(these_cocktail_reps.shape[0]) + np.random.shuffle(indexes) + # test_crossover(mutation_params, dist) + # test_mutation_params(mutation_params, dist) + stop = 1 + unnormalized_cr = np.array([data[k] for k in rep_keys]).transpose() + for i in indexes: + print(i) + target_ing_str = data['ingredients_str'][i] + target_prep_type = data['category'][i] + target_glass = data['glass'][i] + + print('preptype:', target_prep_type) + print('cocktail unnormalized', np.sum(unnormalized_cr[i]), unnormalized_cr[i]) + print('cocktail hand normalized', np.sum(normalize_cocktail(unnormalized_cr[i])), normalize_cocktail(unnormalized_cr[i])) + print('cocktail rep normalized', np.sum(these_cocktail_reps[i]), these_cocktail_reps[i]) + print('cocktail rep normalized', np.sum(all_reps[i]), all_reps[i]) + print(i) + + print('___________Target') + nn_model = NearestNeighbors() + nn_model.fit(these_cocktail_reps) + dists, indexes = nn_model.kneighbors(these_cocktail_reps[i].reshape(1, -1)) + print(indexes) + print_recipe(target_ing_str) + target = these_cocktail_reps[i] + known_target_dict = dict(prep_type=target_prep_type, + ing_str=target_ing_str, + glass=target_glass) + recipes, best_scores = cocktailrep2recipe(cocktail_rep=target, known_target_dict=known_target_dict, n_output=1, verbose=True, full_verbose=True) + + stop = 1 \ No newline at end of file diff --git a/src/cocktails/pipeline/get_affect2affective_cluster.py b/src/cocktails/pipeline/get_affect2affective_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0cd8cc37195869643cd591b9cf4585d7ff3c4a --- /dev/null +++ b/src/cocktails/pipeline/get_affect2affective_cluster.py @@ -0,0 +1,23 @@ +from src.music.config import CHECKPOINTS_PATH +import pickle +import numpy as np + +# can be computed from cocktail2affect +cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle" + +def get_affect2affective_cluster(): + with open(cluster_model_path, 'rb') as f: + data = pickle.load(f) + model = data['cluster_model'] + dimensions_weights = data['dimensions_weights'] + def find_cluster(aff_coord): + if aff_coord.ndim == 1: + aff_coord = aff_coord.reshape(1, -1) + return model.predict(aff_coord * np.array(dimensions_weights)) + return find_cluster + +def get_affective_cluster_centers(): + with open(cluster_model_path, 'rb') as f: + data = pickle.load(f) + return data['cluster_centers'] + diff --git a/src/cocktails/pipeline/get_cocktail2affective_cluster.py b/src/cocktails/pipeline/get_cocktail2affective_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..8f43cb36ac6832ab35dd26e1d710874851b93d64 --- /dev/null +++ b/src/cocktails/pipeline/get_cocktail2affective_cluster.py @@ -0,0 +1,9 @@ +from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster +from src.cocktails.pipeline.cocktail2affect import cocktail2affect + +def get_cocktail2affective_cluster(): + find_cluster = get_affect2affective_cluster() + def cocktail2affect_cluster(cocktail_rep): + affective_coordinates, _ = cocktail2affect(cocktail_rep) + return find_cluster(affective_coordinates) + return cocktail2affect_cluster diff --git a/src/cocktails/representation_learning/__init__.py b/src/cocktails/representation_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/representation_learning/__pycache__/__init__.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aa59f4286765a4ef2394211799c40c2aaa7fe20 Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/dataset.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017542980131eaee88c6134094e06318c9e6067f Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/dataset.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/multihead_model.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/multihead_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd7c408f200cccbda4d1959fc0fffe434e32ae63 Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/multihead_model.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/run.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8c7297dd777d5debf324010d364024a1acfc8e Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/run.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/run_without_vae.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/run_without_vae.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..645d794a788c75e304985e58040e1d475430cd34 Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/run_without_vae.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/simple_model.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/simple_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fab1d3e669c274e01c2c88179acf7dab612d7b0 Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/simple_model.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/__pycache__/vae_model.cpython-39.pyc b/src/cocktails/representation_learning/__pycache__/vae_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..add1e6e034ae039eeaddfe8345a646c948895bf2 Binary files /dev/null and b/src/cocktails/representation_learning/__pycache__/vae_model.cpython-39.pyc differ diff --git a/src/cocktails/representation_learning/dataset.py b/src/cocktails/representation_learning/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e48837ee01943ba1f1d3ffab54eb81315d834308 --- /dev/null +++ b/src/cocktails/representation_learning/dataset.py @@ -0,0 +1,324 @@ +from torch.utils.data import Dataset +import pickle +from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles, ingredients_per_type +from src.cocktails.utilities.other_scrubbing_utilities import print_recipe +import numpy as np + +def get_representation_from_ingredient(ingredients, quantities, max_q_per_ing, index, params): + assert len(ingredients) == len(quantities) + ing, q = ingredients[index], quantities[index] + proportion = q / np.sum(quantities) + index_ing = ingredient_list.index(ing) + # add keys of profile + rep_ingredient = [] + rep_ingredient += [ingredient_profiles[k][index_ing] for k in params['ing_keys']] + # add category encoding + # rep_ingredient += list(params['category_encodings'][ingredient_profiles['type'][index_ing]]) + # add quantitiy and relative quantity + rep_ingredient += [q / max_q_per_ing[ing], proportion] + ing_one_hot = np.zeros(len(ingredient_list)) + ing_one_hot[index_ing] = 1 + rep_ingredient += list(ing_one_hot) + indexes_to_normalize = list(range(len(params['ing_keys']))) + #TODO: should we add ing one hot? Or make sure no 2 ing have same embedding + return np.array(rep_ingredient), indexes_to_normalize + +def get_max_n_ingredients(data): + max_count = 0 + ingredient_set = set() + alcohol_set = set() + liqueur_set = set() + ing_str = np.array(data['ingredients_str']) + for i in range(len(data['names'])): + ingredients, quantities = extract_ingredients(ing_str[i]) + max_count = max(max_count, len(ingredients)) + for ing in ingredients: + ingredient_set.add(ing) + if ing in ingredients_per_type['liquor']: + alcohol_set.add(ing) + if ing in ingredients_per_type['liqueur']: + liqueur_set.add(ing) + return max_count, ingredient_set, alcohol_set, liqueur_set + +# Add your custom dataset class here +class MyDataset(Dataset): + def __init__(self, split, params): + data = params['raw_data'] + self.dim_rep_ingredient = params['dim_rep_ingredient'] + n_data = len(data["names"]) + + preparation_list = sorted(set(data['category'])) + categories_list = sorted(set(data['subcategory'])) + glasses_list = sorted(set(data['glass'])) + + max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) + ingredient_set = sorted(ingredient_set) + self.ingredient_set = ingredient_set + + ingredient_quantities = [] # output of our network + ingr_strs = np.array(data['ingredients_str']) + for i in range(n_data): + ingredients, quantities = extract_ingredients(ingr_strs[i]) + # get ingredient presence and quantity + ingredient_q_rep = np.zeros([len(ingredient_set)]) + for ing, q in zip(ingredients, quantities): + ingredient_q_rep[ingredient_set.index(ing)] = q + ingredient_quantities.append(ingredient_q_rep) + + # take care of ingredient quantities (OUTPUTS) + ingredient_quantities = np.array(ingredient_quantities) + ingredients_presence = (ingredient_quantities>0).astype(np.int) + + min_ing_quantities = np.min(ingredient_quantities, axis=0) + max_ing_quantities = np.max(ingredient_quantities, axis=0) + def normalize_ing_quantities(ing_quantities): + return ((ing_quantities - min_ing_quantities) / (max_ing_quantities - min_ing_quantities)).copy() + + def denormalize_ing_quantities(normalized_ing_quantities): + return (normalized_ing_quantities * (max_ing_quantities - min_ing_quantities) + min_ing_quantities).copy() + ing_q_when_present = ingredient_quantities.copy() + for i in range(len(ing_q_when_present)): + ing_q_when_present[i, np.where(ing_q_when_present[i, :] == 0)] = np.nan + self.min_when_present_ing_quantities = np.nanmin(ing_q_when_present, axis=0) + + + def filter_decoder_output(output): + output_unnormalized = output * max_ing_quantities + if output.ndim == 1: + output_unnormalized[np.where(output_unnormalized 2: + taste_rep_valid.append(True) + taste_rep_ground_truth.append([float(tr.split('[')[1].split(',')[0]), float(tr.split(']')[0].split(',')[1][1:])]) + else: + taste_rep_valid.append(False) + taste_rep_ground_truth.append([np.nan, np.nan]) + taste_rep_ground_truth = np.array(taste_rep_ground_truth) + taste_rep_valid = np.array(taste_rep_valid) + taste_rep_ground_truth /= 10 + + auxiliary_data = dict(categories=categories, + glasses=glasses, + prep_type=prep_type, + cocktail_reps=computed_cocktail_reps, + ingredients_presence=ingredients_presence, + taste_reps=taste_rep_ground_truth, + volume=volumes, + ingredients_quantities=ingredient_quantities) + self.auxiliary_keys = sorted(params['auxiliaries_dict'].keys()) + assert self.auxiliary_keys == sorted(auxiliary_data.keys()) + + data_preprocessing = dict(min_max_ing_quantities=(min_ing_quantities, max_ing_quantities), + min_max_ing_reps=(min_ing_reps, max_ing_reps), + min_max_vol=(min_vol, max_vol)) + + if split == 'train': + with open(params['save_path'] + 'normalization_funcs.pickle', 'wb') as f: + pickle.dump(data_preprocessing, f) + + n_data = len(input_data) + assert len(ingredient_quantities) == n_data + for aux in self.auxiliary_keys: + assert len(auxiliary_data[aux]) == n_data + + if split == 'train': + indexes = np.arange(int(0.9 * n_data)) + elif split == 'test': + indexes = np.arange(int(0.9 * n_data), n_data) + elif split == 'all': + indexes = np.arange(n_data) + else: + raise ValueError + + # np.random.shuffle(indexes) + self.taste_rep_valid = taste_rep_valid[indexes] + self.input_ingredients = input_data[indexes] + self.ingredient_quantities = ingredient_quantities[indexes] + self.computed_cocktail_reps = computed_cocktail_reps[indexes] + self.auxiliaries = dict() + for aux in self.auxiliary_keys: + self.auxiliaries[aux] = auxiliary_data[aux][indexes] + self.nb_ingredients = nb_ingredients[indexes] + + def __len__(self): + return len(self.input_ingredients) + + def get_auxiliary_data(self, idx): + out = dict() + for aux in self.auxiliary_keys: + out[aux] = self.auxiliaries[aux][idx] + return out + + def __getitem__(self, idx): + assert self.nb_ingredients[idx] == np.argwhere(~np.isnan(self.input_ingredients[idx])).flatten().size / self.dim_rep_ingredient + return [self.nb_ingredients[idx], self.input_ingredients[idx], self.ingredient_quantities[idx], self.computed_cocktail_reps[idx], self.get_auxiliary_data(idx), + self.taste_rep_valid[idx]] \ No newline at end of file diff --git a/src/cocktails/representation_learning/multihead_model.py b/src/cocktails/representation_learning/multihead_model.py new file mode 100644 index 0000000000000000000000000000000000000000..346ad3dbb7c6561192c5f9563e19943ceca02a19 --- /dev/null +++ b/src/cocktails/representation_learning/multihead_model.py @@ -0,0 +1,148 @@ +import torch; torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.distributions +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 +from src.cocktails.representation_learning.simple_model import SimpleNet + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_activation(activation): + if activation == 'tanh': + activ = F.tanh + elif activation == 'relu': + activ = F.relu + elif activation == 'mish': + activ = F.mish + elif activation == 'sigmoid': + activ = F.sigmoid + elif activation == 'leakyrelu': + activ = F.leaky_relu + elif activation == 'exp': + activ = torch.exp + else: + raise ValueError + return activ + +class IngredientEncoder(nn.Module): + def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout): + super(IngredientEncoder, self).__init__() + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [input_dim] + hidden_dims + [deepset_latent_dim] + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + + def forward(self, x): + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = layer(x) + if i_layer != self.n_layers - 1: + x = self.activation(dropout(x)) + return x # do not use dropout on last layer? + +class DeepsetCocktailEncoder(nn.Module): + def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation, + hidden_dims_cocktail, latent_dim, aggregation, dropout): + super(DeepsetCocktailEncoder, self).__init__() + self.input_dim = input_dim # dimension of ingredient representation + quantity + self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately + self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation + self.aggregation = aggregation + self.latent_dim = latent_dim + # post aggregation network + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [deepset_latent_dim] + hidden_dims_cocktail + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim) + self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim) + self.softplus = nn.Softplus() + + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + + def forward(self, nb_ingredients, x): + + # reshape x in (batch size * nb ingredients, dim_ing_rep) + batch_size = x.shape[0] + all_ingredients = [] + for i in range(batch_size): + for j in range(nb_ingredients[i]): + all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1)) + x = torch.cat(all_ingredients, dim=0) + # encode ingredients in parallel + ingredients_encodings = self.ingredient_encoder(x) + assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim) + + # aggregate + x = [] + index_first = 0 + for i in range(batch_size): + index_last = index_first + nb_ingredients[i] + # aggregate + if self.aggregation == 'sum': + x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) + elif self.aggregation == 'mean': + x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) + else: + raise ValueError + index_first = index_last + x = torch.cat(x, dim=0) + assert x.shape[0] == batch_size + + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = self.activation(dropout(layer(x))) + mean = self.FC_mean(x) + logvar = self.FC_logvar(x) + return mean, logvar + + +class MultiHeadModel(nn.Module): + def __init__(self, encoder, auxiliaries_dict, activation, hidden_dims_decoder): + super(MultiHeadModel, self).__init__() + self.encoder = encoder + self.latent_dim = self.encoder.output_dim + self.auxiliaries_str = [] + self.auxiliaries = nn.ModuleList() + for aux_str in sorted(auxiliaries_dict.keys()): + if aux_str == 'taste_reps': + self.taste_reps_decoder = SimpleNet(input_dim=self.latent_dim, hidden_dims=[], output_dim=auxiliaries_dict[aux_str]['dim_output'], + activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ']) + else: + self.auxiliaries_str.append(aux_str) + if aux_str == 'ingredients_quantities': + hd = hidden_dims_decoder + else: + hd = [] + self.auxiliaries.append(SimpleNet(input_dim=self.latent_dim, hidden_dims=hd, output_dim=auxiliaries_dict[aux_str]['dim_output'], + activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ'])) + + def get_all_auxiliaries(self, x): + return [aux(x) for aux in self.auxiliaries] + + def get_auxiliary(self, z, aux_str): + if aux_str == 'taste_reps': + return self.taste_reps_decoder(z) + else: + index = self.auxiliaries_str.index(aux_str) + return self.auxiliaries[index](z) + + def forward(self, x, aux_str=None): + z = self.encoder(x) + if aux_str is not None: + return z, self.get_auxiliary(z, aux_str), [aux_str] + else: + return z, self.get_all_auxiliaries(z), self.auxiliaries_str + +def get_multihead_model(input_dim, activation, hidden_dims_cocktail, latent_dim, dropout, auxiliaries_dict, hidden_dims_decoder): + encoder = SimpleNet(input_dim, hidden_dims_cocktail, latent_dim, activation, dropout) + model = MultiHeadModel(encoder, auxiliaries_dict, activation, hidden_dims_decoder) + return model \ No newline at end of file diff --git a/src/cocktails/representation_learning/run.py b/src/cocktails/representation_learning/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a1278ac80039d25130b4c05bc5670bcfe197d13a --- /dev/null +++ b/src/cocktails/representation_learning/run.py @@ -0,0 +1,557 @@ +import torch; torch.manual_seed(0) +import torch.utils +from torch.utils.data import DataLoader +import torch.distributions +import torch.nn as nn +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 +from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients +import json +import pandas as pd +import numpy as np +import os +from src.cocktails.representation_learning.vae_model import get_vae_model +from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH +from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys +from src.cocktails.utilities.ingredients_utilities import ingredient_profiles +from resource import getrusage +from resource import RUSAGE_SELF +import gc +gc.collect(2) +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_params(): + data = pd.read_csv(COCKTAILS_CSV_DATA) + max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) + num_ingredients = len(ingredient_set) + rep_keys = get_bunch_of_rep_keys()['custom'] + ing_keys = [k.split(' ')[1] for k in rep_keys] + ing_keys.remove('volume') + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + + params = dict(trial_id='test', + save_path=EXPERIMENT_PATH + "/deepset_vae/", + nb_epochs=2000, + print_every=50, + plot_every=100, + batch_size=64, + lr=0.001, + dropout=0., + nb_epoch_switch_beta=600, + latent_dim=10, + beta_vae=0.2, + ing_keys=ing_keys, + nb_ingredients=len(ingredient_set), + hidden_dims_ingredients=[128], + hidden_dims_cocktail=[32], + hidden_dims_decoder=[32], + agg='mean', + activation='relu', + auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), + glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))), + prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))), + cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13), + volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1), + taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2), + ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)), + category_encodings=category_encodings + ) + # params = dict(trial_id='test', + # save_path=EXPERIMENT_PATH + "/deepset_vae/", + # nb_epochs=1000, + # print_every=50, + # plot_every=100, + # batch_size=64, + # lr=0.001, + # dropout=0., + # nb_epoch_switch_beta=500, + # latent_dim=64, + # beta_vae=0.3, + # ing_keys=ing_keys, + # nb_ingredients=len(ingredient_set), + # hidden_dims_ingredients=[128], + # hidden_dims_cocktail=[128, 128], + # hidden_dims_decoder=[128, 128], + # agg='mean', + # activation='mish', + # auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), + # glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))), + # prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))), + # cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13), + # volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1), + # taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2), + # ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)), + # category_encodings=category_encodings + # ) + water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1], + max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0, + params=params) + dim_rep_ingredient = water_rep.size + params['indexes_ing_to_normalize'] = indexes_to_normalize + params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients + params['input_dim'] = dim_rep_ingredient + params['dim_rep_ingredient'] = dim_rep_ingredient + params = compute_expe_name_and_save_path(params) + del params['category_encodings'] # to dump + with open(params['save_path'] + 'params.json', 'w') as f: + json.dump(params, f) + + params = complete_params(params) + return params + +def complete_params(params): + data = pd.read_csv(COCKTAILS_CSV_DATA) + cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + params['cocktail_reps'] = cocktail_reps + params['raw_data'] = data + params['category_encodings'] = category_encodings + return params + +def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data): + losses = dict() + accuracies = dict() + other_metrics = dict() + for i_k, k in enumerate(auxiliaries_str): + # get ground truth + # compute loss + if k == 'volume': + outputs[i_k] = outputs[i_k].flatten() + ground_truth = auxiliaries[k] + if ground_truth.dtype == torch.float64: + losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float() + elif ground_truth.dtype == torch.int64: + if str(loss_functions[k]) != "BCEWithLogitsLoss()": + losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float() + else: + losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float() + else: + losses[k] = loss_functions[k](outputs[i_k], ground_truth).float() + # compute accuracies + if str(loss_functions[k]) == 'CrossEntropyLoss()': + bs, n_options = outputs[i_k].shape + predicted = outputs[i_k].argmax(dim=1).detach().numpy() + true = ground_truth.int().detach().numpy() + confusion_matrix = np.zeros([n_options, n_options]) + for i in range(bs): + confusion_matrix[true[i], predicted[i]] += 1 + acc = confusion_matrix.diagonal().sum() / bs + for i in range(n_options): + if confusion_matrix[i].sum() != 0: + confusion_matrix[i] /= confusion_matrix[i].sum() + other_metrics[k + '_confusion'] = confusion_matrix + accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy()) + assert (acc - accuracies[k]) < 1e-5 + + elif str(loss_functions[k]) == 'BCEWithLogitsLoss()': + assert k == 'ingredients_presence' + outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + predicted_presence = (outputs_rescaled > 0).astype(bool) + presence = ground_truth.detach().numpy().astype(bool) + other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool))) + other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool))) + accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling + elif str(loss_functions[k]) == 'MSELoss()': + accuracies[k] = np.nan + else: + raise ValueError + return losses, accuracies, other_metrics + +def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat): + ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + ing_presence = (ing_q > 0) + x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + # abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities + abs_diff = np.abs(ing_q - x_hat) + ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], [] + for i in range(ingredient_quantities.shape[0]): + ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])])) + ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])])) + aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present) + aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent) + return aux_other_metrics + +def run_epoch(opt, train, model, data, loss_functions, weights, params): + if train: + model.train() + else: + model.eval() + + # prepare logging of losses + losses = dict(kld_loss=[], + mse_loss=[], + vae_loss=[], + volume_loss=[], + global_loss=[]) + accuracies = dict() + other_metrics = dict() + for aux in params['auxiliaries_dict'].keys(): + losses[aux] = [] + accuracies[aux] = [] + if train: opt.zero_grad() + + for d in data: + nb_ingredients = d[0] + batch_size = nb_ingredients.shape[0] + x_ingredients = d[1].float() + ingredient_quantities = d[2] + cocktail_reps = d[3] + auxiliaries = d[4] + for k in auxiliaries.keys(): + if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float() + taste_valid = d[-1] + x = x_ingredients.to(device) + x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float()) + # get auxiliary losses and accuracies + aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data) + + # compute vae loss + mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float() + kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float() + vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss + # compute total volume loss to train decoder + # volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float() + volume_loss = torch.FloatTensor([0]) + + aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat) + + indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten() + if indexes_taste_valid.size > 0: + outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps') + gt = auxiliaries['taste_reps'][indexes_taste_valid] + factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases + aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float() + else: + aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([]) + aux_accuracies['taste_reps'] = 0 + + # aggregate losses + global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()])) + # for k in params['auxiliaries_dict'].keys(): + # global_loss += aux_losses[k] * weights[k] + + if train: + global_loss.backward() + opt.step() + opt.zero_grad() + + # logging + losses['global_loss'].append(float(global_loss)) + losses['mse_loss'].append(float(mse_loss)) + losses['vae_loss'].append(float(vae_loss)) + losses['volume_loss'].append(float(volume_loss)) + losses['kld_loss'].append(float(kld_loss)) + for k in params['auxiliaries_dict'].keys(): + losses[k].append(float(aux_losses[k])) + accuracies[k].append(float(aux_accuracies[k])) + for k in aux_other_metrics.keys(): + if k not in other_metrics.keys(): + other_metrics[k] = [aux_other_metrics[k]] + else: + other_metrics[k].append(aux_other_metrics[k]) + + for k in losses.keys(): + losses[k] = np.mean(losses[k]) + for k in accuracies.keys(): + accuracies[k] = np.mean(accuracies[k]) + for k in other_metrics.keys(): + other_metrics[k] = np.mean(other_metrics[k], axis=0) + return model, losses, accuracies, other_metrics + +def prepare_data_and_loss(params): + train_data = MyDataset(split='train', params=params) + test_data = MyDataset(split='test', params=params) + + train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True) + test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) + + loss_functions = dict() + weights = dict() + for k in sorted(params['auxiliaries_dict'].keys()): + if params['auxiliaries_dict'][k]['type'] == 'classif': + if k == 'glasses': + classif_weights = train_data.glasses_weights + elif k == 'prep_type': + classif_weights = train_data.prep_types_weights + elif k == 'categories': + classif_weights = train_data.categories_weights + else: + raise ValueError + loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights)) + elif params['auxiliaries_dict'][k]['type'] == 'multiclassif': + loss_functions[k] = nn.BCEWithLogitsLoss() + elif params['auxiliaries_dict'][k]['type'] == 'regression': + loss_functions[k] = nn.MSELoss() + else: + raise ValueError + weights[k] = params['auxiliaries_dict'][k]['weight'] + + + return loss_functions, train_data_loader, test_data_loader, weights + +def print_losses(train, losses, accuracies, other_metrics): + keyword = 'Train' if train else 'Eval' + print(f'\t{keyword} logs:') + keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss'] + for k in keys: + print(f'\t\t{k} - Loss: {losses[k]:.2f}') + for k in sorted(accuracies.keys()): + print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}') + for k in sorted(other_metrics.keys()): + if 'confusion' not in k: + print(f'\t\t{k} - {other_metrics[k]:.2f}') + + +def run_experiment(params, verbose=True): + loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params) + params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output + + model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation", + "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict", + "filter_decoder_output"]] + model = get_vae_model(*model_params) + opt = torch.optim.AdamW(model.parameters(), lr=params['lr']) + + + all_train_losses = [] + all_eval_losses = [] + all_train_accuracies = [] + all_eval_accuracies = [] + all_eval_other_metrics = [] + all_train_other_metrics = [] + best_loss = np.inf + model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + all_eval_losses.append(eval_losses) + all_eval_accuracies.append(eval_accuracies) + all_eval_other_metrics.append(eval_other_metrics) + if verbose: print(f'\n--------\nEpoch #0') + if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) + for epoch in range(params['nb_epochs']): + if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}') + model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics) + model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) + if eval_losses['global_loss'] < best_loss: + best_loss = eval_losses['global_loss'] + if verbose: print(f'Saving new best model with loss {best_loss:.2f}') + torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save') + + # log + all_train_losses.append(train_losses) + all_train_accuracies.append(train_accuracies) + all_eval_losses.append(eval_losses) + all_eval_accuracies.append(eval_accuracies) + all_eval_other_metrics.append(eval_other_metrics) + all_train_other_metrics.append(train_other_metrics) + + # if epoch == params['nb_epoch_switch_beta']: + # params['beta_vae'] = 2.5 + # params['auxiliaries_dict']['prep_type']['weight'] /= 10 + # params['auxiliaries_dict']['glasses']['weight'] /= 10 + + if (epoch + 1) % params['plot_every'] == 0: + + plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, + all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights) + + return model + +def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, + all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights): + + steps = np.arange(len(all_eval_accuracies)) + + loss_keys = sorted(all_train_losses[0].keys()) + acc_keys = sorted(all_train_accuracies[0].keys()) + metrics_keys = sorted(all_train_other_metrics[0].keys()) + + plt.figure() + plt.title('Train losses') + for k in loss_keys: + factor = 1 if k == 'mse_loss' else 1 + if k not in weights.keys(): + plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) + else: + if weights[k] != 0: + plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) + + plt.legend() + plt.ylim([0, 4]) + plt.savefig(plot_path + 'train_losses.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train accuracies') + for k in acc_keys: + if weights[k] != 0: + plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'train_acc.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' in k: + plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' not in k: + plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) + plt.legend() + plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval losses') + for k in loss_keys: + factor = 1 if k == 'mse_loss' else 1 + if k not in weights.keys(): + plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) + else: + if weights[k] != 0: + plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) + plt.legend() + plt.ylim([0, 4]) + plt.savefig(plot_path + 'eval_losses.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval accuracies') + for k in acc_keys: + if weights[k] != 0: + plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'eval_acc.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' in k: + plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' not in k: + plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) + plt.legend() + plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + + for k in metrics_keys: + if 'confusion' in k: + plt.figure() + plt.title(k) + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'eval_{k}.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + for k in metrics_keys: + if 'confusion' in k: + plt.figure() + plt.title(k) + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'train_{k}.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.close('all') + + +def get_model(model_path): + + with open(model_path + 'params.json', 'r') as f: + params = json.load(f) + params['save_path'] = model_path + max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt') + mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt') + std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt') + min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt') + def filter_decoder_output(output): + output = output.detach().numpy() + output_unnormalized = output * std_ing_quantities + mean_ing_quantities + if output.ndim == 1: + output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0 + else: + for i in range(output.shape[0]): + output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0 + return output_unnormalized.copy() + params['filter_decoder_output'] = filter_decoder_output + model_chkpt = model_path + "checkpoint_best.save" + model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation", + "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict", + "filter_decoder_output"]] + model = get_vae_model(*model_params) + model.load_state_dict(torch.load(model_chkpt)) + model.eval() + return model, filter_decoder_output, params + + +def compute_expe_name_and_save_path(params): + weights_str = '[' + for aux in params['auxiliaries_dict'].keys(): + weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, ' + weights_str = weights_str[:-2] + ']' + save_path = params['save_path'] + params["trial_id"] + save_path += f'_lr{params["lr"]}' + save_path += f'_betavae{params["beta_vae"]}' + save_path += f'_bs{params["batch_size"]}' + save_path += f'_latentdim{params["latent_dim"]}' + save_path += f'_hding{params["hidden_dims_ingredients"]}' + save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}' + save_path += f'_hddecoder{params["hidden_dims_decoder"]}' + save_path += f'_agg{params["agg"]}' + save_path += f'_activ{params["activation"]}' + save_path += f'_w{weights_str}' + counter = 0 + while os.path.exists(save_path + f"_{counter}"): + counter += 1 + save_path = save_path + f"_{counter}" + '/' + params["save_path"] = save_path + os.makedirs(save_path) + os.makedirs(save_path + 'plots/') + params['plot_path'] = save_path + 'plots/' + print(f'logging to {save_path}') + return params + + + +if __name__ == '__main__': + params = get_params() + run_experiment(params) + diff --git a/src/cocktails/representation_learning/run_simple_net.py b/src/cocktails/representation_learning/run_simple_net.py new file mode 100644 index 0000000000000000000000000000000000000000..4409df378eaaab94612078ce16d8ebfab096c306 --- /dev/null +++ b/src/cocktails/representation_learning/run_simple_net.py @@ -0,0 +1,302 @@ +import torch; torch.manual_seed(0) +import torch.utils +from torch.utils.data import DataLoader +import torch.distributions +import torch.nn as nn +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 +from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients +import json +import pandas as pd +import numpy as np +import os +from src.cocktails.representation_learning.simple_model import SimpleNet +from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH +from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys +from src.cocktails.utilities.ingredients_utilities import ingredient_profiles +from resource import getrusage +from resource import RUSAGE_SELF +import gc +gc.collect(2) +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_params(): + data = pd.read_csv(COCKTAILS_CSV_DATA) + max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) + num_ingredients = len(ingredient_set) + rep_keys = get_bunch_of_rep_keys()['custom'] + ing_keys = [k.split(' ')[1] for k in rep_keys] + ing_keys.remove('volume') + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + + params = dict(trial_id='test', + save_path=EXPERIMENT_PATH + "/simple_net/", + nb_epochs=100, + print_every=50, + plot_every=50, + batch_size=128, + lr=0.001, + dropout=0.15, + output_keyword='glasses', + ing_keys=ing_keys, + nb_ingredients=len(ingredient_set), + hidden_dims=[16], + activation='sigmoid', + auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), + glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))), + prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))), + cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13), + volume=dict(weight=0, type='regression', final_activ='relu', dim_output=1), + taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2), + ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients), + ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)), + + category_encodings=category_encodings + ) + params['output_dim'] = params['auxiliaries_dict'][params['output_keyword']]['dim_output'] + water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1], + max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0, + params=params) + dim_rep_ingredient = water_rep.size + params['indexes_ing_to_normalize'] = indexes_to_normalize + params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients + params['dim_rep_ingredient'] = dim_rep_ingredient + params['input_dim'] = params['nb_ingredients'] + params = compute_expe_name_and_save_path(params) + del params['category_encodings'] # to dump + with open(params['save_path'] + 'params.json', 'w') as f: + json.dump(params, f) + + params = complete_params(params) + return params + +def complete_params(params): + data = pd.read_csv(COCKTAILS_CSV_DATA) + cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + params['cocktail_reps'] = cocktail_reps + params['raw_data'] = data + params['category_encodings'] = category_encodings + return params + +def compute_confusion_matrix_and_accuracy(predictions, ground_truth): + bs, n_options = predictions.shape + predicted = predictions.argmax(dim=1).detach().numpy() + true = ground_truth.int().detach().numpy() + confusion_matrix = np.zeros([n_options, n_options]) + for i in range(bs): + confusion_matrix[true[i], predicted[i]] += 1 + acc = confusion_matrix.diagonal().sum() / bs + for i in range(n_options): + if confusion_matrix[i].sum() != 0: + confusion_matrix[i] /= confusion_matrix[i].sum() + acc2 = np.mean(predicted == true) + assert (acc - acc2) < 1e-5 + return confusion_matrix, acc + + +def run_epoch(opt, train, model, data, loss_function, params): + if train: + model.train() + else: + model.eval() + + # prepare logging of losses + losses = [] + accuracies = [] + cf_matrices = [] + if train: opt.zero_grad() + + for d in data: + nb_ingredients = d[0] + batch_size = nb_ingredients.shape[0] + x_ingredients = d[1].float() + ingredient_quantities = d[2].float() + cocktail_reps = d[3].float() + auxiliaries = d[4] + for k in auxiliaries.keys(): + if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float() + taste_valid = d[-1] + predictions = model(ingredient_quantities) + loss = loss_function(predictions, auxiliaries[params['output_keyword']].long()).float() + cf_matrix, accuracy = compute_confusion_matrix_and_accuracy(predictions, auxiliaries[params['output_keyword']]) + if train: + loss.backward() + opt.step() + opt.zero_grad() + + losses.append(float(loss)) + cf_matrices.append(cf_matrix) + accuracies.append(accuracy) + + return model, np.mean(losses), np.mean(accuracies), np.mean(cf_matrices, axis=0) + +def prepare_data_and_loss(params): + train_data = MyDataset(split='train', params=params) + test_data = MyDataset(split='test', params=params) + + train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True) + test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) + + + if params['auxiliaries_dict'][params['output_keyword']]['type'] == 'classif': + if params['output_keyword'] == 'glasses': + classif_weights = train_data.glasses_weights + elif params['output_keyword'] == 'prep_type': + classif_weights = train_data.prep_types_weights + elif params['output_keyword'] == 'categories': + classif_weights = train_data.categories_weights + else: + raise ValueError + # classif_weights = (np.array(classif_weights) * 2 + np.ones(len(classif_weights))) / 3 + loss_function = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights)) + # loss_function = nn.CrossEntropyLoss() + + elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'multiclassif': + loss_function = nn.BCEWithLogitsLoss() + elif params['auxiliaries_dict'][params['output_keyword']]['type'] == 'regression': + loss_function = nn.MSELoss() + else: + raise ValueError + + return loss_function, train_data_loader, test_data_loader + +def print_losses(train, loss, accuracy): + keyword = 'Train' if train else 'Eval' + print(f'\t{keyword} logs:') + print(f'\t\t Loss: {loss:.2f}, Acc: {accuracy:.2f}') + + +def run_experiment(params, verbose=True): + loss_function, train_data_loader, test_data_loader = prepare_data_and_loss(params) + + model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout']) + opt = torch.optim.AdamW(model.parameters(), lr=params['lr']) + + all_train_losses = [] + all_eval_losses = [] + all_eval_cf_matrices = [] + all_train_accuracies = [] + all_eval_accuracies = [] + all_train_cf_matrices = [] + best_loss = np.inf + model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params) + all_eval_losses.append(eval_loss) + all_eval_accuracies.append(eval_accuracy) + if verbose: print(f'\n--------\nEpoch #0') + if verbose: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss) + for epoch in range(params['nb_epochs']): + if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}') + model, train_loss, train_accuracy, train_cf_matrix = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_function=loss_function, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracy=train_accuracy, loss=train_loss) + model, eval_loss, eval_accuracy, eval_cf_matrix = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_function=loss_function, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracy=eval_accuracy, loss=eval_loss) + if eval_loss < best_loss: + best_loss = eval_loss + if verbose: print(f'Saving new best model with loss {best_loss:.2f}') + torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save') + + # log + all_train_losses.append(train_loss) + all_train_accuracies.append(train_accuracy) + all_eval_losses.append(eval_loss) + all_eval_accuracies.append(eval_accuracy) + all_eval_cf_matrices.append(eval_cf_matrix) + all_train_cf_matrices.append(train_cf_matrix) + + if (epoch + 1) % params['plot_every'] == 0: + + plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices, + all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, params['plot_path']) + + return model + +def plot_results(all_train_losses, all_train_accuracies, all_train_cf_matrices, + all_eval_losses, all_eval_accuracies, all_eval_cf_matrices, plot_path): + + steps = np.arange(len(all_eval_accuracies)) + + plt.figure() + plt.title('Losses') + plt.plot(steps[1:], all_train_losses, label='train') + plt.plot(steps, all_eval_losses, label='eval') + plt.legend() + plt.ylim([0, 4]) + plt.savefig(plot_path + 'losses.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Accuracies') + plt.plot(steps[1:], all_train_accuracies, label='train') + plt.plot(steps, all_eval_accuracies, label='eval') + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'accs.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + + plt.figure() + plt.title('Train confusion matrix') + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_train_cf_matrices[-1], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'train_confusion_matrix.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval confusion matrix') + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_eval_cf_matrices[-1], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'eval_confusion_matrix.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.close('all') + + +def get_model(model_path): + with open(model_path + 'params.json', 'r') as f: + params = json.load(f) + params['save_path'] = model_path + model_chkpt = model_path + "checkpoint_best.save" + model = SimpleNet(params['input_dim'], params['hidden_dims'], params['output_dim'], params['activation'], params['dropout']) + model.load_state_dict(torch.load(model_chkpt)) + model.eval() + return model, params + + +def compute_expe_name_and_save_path(params): + weights_str = '[' + for aux in params['auxiliaries_dict'].keys(): + weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, ' + weights_str = weights_str[:-2] + ']' + save_path = params['save_path'] + params["trial_id"] + save_path += f'_lr{params["lr"]}' + save_path += f'_bs{params["batch_size"]}' + save_path += f'_hd{params["hidden_dims"]}' + save_path += f'_activ{params["activation"]}' + save_path += f'_w{weights_str}' + counter = 0 + while os.path.exists(save_path + f"_{counter}"): + counter += 1 + save_path = save_path + f"_{counter}" + '/' + params["save_path"] = save_path + os.makedirs(save_path) + os.makedirs(save_path + 'plots/') + params['plot_path'] = save_path + 'plots/' + print(f'logging to {save_path}') + return params + + + +if __name__ == '__main__': + params = get_params() + run_experiment(params) + diff --git a/src/cocktails/representation_learning/run_without_vae.py b/src/cocktails/representation_learning/run_without_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe25c76d4cfd8adf93c3886390e0f22853eddc1 --- /dev/null +++ b/src/cocktails/representation_learning/run_without_vae.py @@ -0,0 +1,514 @@ +import torch; torch.manual_seed(0) +import torch.utils +from torch.utils.data import DataLoader +import torch.distributions +import torch.nn as nn +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 +from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients +import json +import pandas as pd +import numpy as np +import os +from src.cocktails.representation_learning.multihead_model import get_multihead_model +from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH +from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys +from src.cocktails.utilities.ingredients_utilities import ingredient_profiles +from resource import getrusage +from resource import RUSAGE_SELF +import gc +gc.collect(2) +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_params(): + data = pd.read_csv(COCKTAILS_CSV_DATA) + max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data) + num_ingredients = len(ingredient_set) + rep_keys = get_bunch_of_rep_keys()['custom'] + ing_keys = [k.split(' ')[1] for k in rep_keys] + ing_keys.remove('volume') + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + + params = dict(trial_id='test', + save_path=EXPERIMENT_PATH + "/multihead_model/", + nb_epochs=500, + print_every=50, + plot_every=50, + batch_size=128, + lr=0.001, + dropout=0., + nb_epoch_switch_beta=600, + latent_dim=10, + beta_vae=0.2, + ing_keys=ing_keys, + nb_ingredients=len(ingredient_set), + hidden_dims_ingredients=[128], + hidden_dims_cocktail=[64], + hidden_dims_decoder=[32], + agg='mean', + activation='relu', + auxiliaries_dict=dict(categories=dict(weight=5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))), #0.5 + glasses=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['glass']))), #0.1 + prep_type=dict(weight=0.1, type='classif', final_activ=None, dim_output=len(set(data['category']))),#1 + cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),#1 + volume=dict(weight=1, type='regression', final_activ='relu', dim_output=1),#1 + taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),#1 + ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients),#10 + ingredients_quantities=dict(weight=0, type='regression', final_activ=None, dim_output=num_ingredients)), + category_encodings=category_encodings + ) + water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1], + max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0, + params=params) + dim_rep_ingredient = water_rep.size + params['indexes_ing_to_normalize'] = indexes_to_normalize + params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients + params['dim_rep_ingredient'] = dim_rep_ingredient + params['input_dim'] = params['nb_ingredients'] + params = compute_expe_name_and_save_path(params) + del params['category_encodings'] # to dump + with open(params['save_path'] + 'params.json', 'w') as f: + json.dump(params, f) + + params = complete_params(params) + return params + +def complete_params(params): + data = pd.read_csv(COCKTAILS_CSV_DATA) + cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) + nb_ing_categories = len(set(ingredient_profiles['type'])) + category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories))) + params['cocktail_reps'] = cocktail_reps + params['raw_data'] = data + params['category_encodings'] = category_encodings + return params + +def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data): + losses = dict() + accuracies = dict() + other_metrics = dict() + for i_k, k in enumerate(auxiliaries_str): + # get ground truth + # compute loss + if k == 'volume': + outputs[i_k] = outputs[i_k].flatten() + ground_truth = auxiliaries[k] + if ground_truth.dtype == torch.float64: + losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float() + elif ground_truth.dtype == torch.int64: + if str(loss_functions[k]) != "BCEWithLogitsLoss()": + losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float() + else: + losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float() + else: + losses[k] = loss_functions[k](outputs[i_k], ground_truth).float() + # compute accuracies + if str(loss_functions[k]) == 'CrossEntropyLoss()': + bs, n_options = outputs[i_k].shape + predicted = outputs[i_k].argmax(dim=1).detach().numpy() + true = ground_truth.int().detach().numpy() + confusion_matrix = np.zeros([n_options, n_options]) + for i in range(bs): + confusion_matrix[true[i], predicted[i]] += 1 + acc = confusion_matrix.diagonal().sum() / bs + for i in range(n_options): + if confusion_matrix[i].sum() != 0: + confusion_matrix[i] /= confusion_matrix[i].sum() + other_metrics[k + '_confusion'] = confusion_matrix + accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy()) + assert (acc - accuracies[k]) < 1e-5 + + elif str(loss_functions[k]) == 'BCEWithLogitsLoss()': + assert k == 'ingredients_presence' + outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + predicted_presence = (outputs_rescaled > 0).astype(bool) + presence = ground_truth.detach().numpy().astype(bool) + other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool))) + other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool))) + accuracies[k] = np.mean(predicted_presence == presence) # accuracy for multi class labeling + elif str(loss_functions[k]) == 'MSELoss()': + accuracies[k] = np.nan + else: + raise ValueError + return losses, accuracies, other_metrics + +def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat): + ing_q = ingredient_quantities.detach().numpy()# * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + ing_presence = (ing_q > 0) + x_hat = x_hat.detach().numpy() + # x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities + abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities + # abs_diff = np.abs(ing_q - x_hat) + ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], [] + for i in range(ingredient_quantities.shape[0]): + ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])])) + ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])])) + aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present) + aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent) + return aux_other_metrics + +def run_epoch(opt, train, model, data, loss_functions, weights, params): + if train: + model.train() + else: + model.eval() + + # prepare logging of losses + losses = dict(kld_loss=[], + mse_loss=[], + vae_loss=[], + volume_loss=[], + global_loss=[]) + accuracies = dict() + other_metrics = dict() + for aux in params['auxiliaries_dict'].keys(): + losses[aux] = [] + accuracies[aux] = [] + if train: opt.zero_grad() + + for d in data: + nb_ingredients = d[0] + batch_size = nb_ingredients.shape[0] + x_ingredients = d[1].float() + ingredient_quantities = d[2] + cocktail_reps = d[3] + auxiliaries = d[4] + for k in auxiliaries.keys(): + if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float() + taste_valid = d[-1] + z, outputs, auxiliaries_str = model.forward(ingredient_quantities.float()) + # get auxiliary losses and accuracies + aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data) + + # compute vae loss + aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, outputs[auxiliaries_str.index('ingredients_quantities')]) + + indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten() + if indexes_taste_valid.size > 0: + outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps') + gt = auxiliaries['taste_reps'][indexes_taste_valid] + factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases + aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float() + else: + aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([]) + aux_accuracies['taste_reps'] = 0 + + # aggregate losses + global_loss = torch.sum(torch.cat([torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()])) + # for k in params['auxiliaries_dict'].keys(): + # global_loss += aux_losses[k] * weights[k] + + if train: + global_loss.backward() + opt.step() + opt.zero_grad() + + # logging + losses['global_loss'].append(float(global_loss)) + for k in params['auxiliaries_dict'].keys(): + losses[k].append(float(aux_losses[k])) + accuracies[k].append(float(aux_accuracies[k])) + for k in aux_other_metrics.keys(): + if k not in other_metrics.keys(): + other_metrics[k] = [aux_other_metrics[k]] + else: + other_metrics[k].append(aux_other_metrics[k]) + + for k in losses.keys(): + losses[k] = np.mean(losses[k]) + for k in accuracies.keys(): + accuracies[k] = np.mean(accuracies[k]) + for k in other_metrics.keys(): + other_metrics[k] = np.mean(other_metrics[k], axis=0) + return model, losses, accuracies, other_metrics + +def prepare_data_and_loss(params): + train_data = MyDataset(split='train', params=params) + test_data = MyDataset(split='test', params=params) + + train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True) + test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True) + + loss_functions = dict() + weights = dict() + for k in sorted(params['auxiliaries_dict'].keys()): + if params['auxiliaries_dict'][k]['type'] == 'classif': + if k == 'glasses': + classif_weights = train_data.glasses_weights + elif k == 'prep_type': + classif_weights = train_data.prep_types_weights + elif k == 'categories': + classif_weights = train_data.categories_weights + else: + raise ValueError + loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights)) + elif params['auxiliaries_dict'][k]['type'] == 'multiclassif': + loss_functions[k] = nn.BCEWithLogitsLoss() + elif params['auxiliaries_dict'][k]['type'] == 'regression': + loss_functions[k] = nn.MSELoss() + else: + raise ValueError + weights[k] = params['auxiliaries_dict'][k]['weight'] + + + return loss_functions, train_data_loader, test_data_loader, weights + +def print_losses(train, losses, accuracies, other_metrics): + keyword = 'Train' if train else 'Eval' + print(f'\t{keyword} logs:') + keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss'] + for k in keys: + print(f'\t\t{k} - Loss: {losses[k]:.2f}') + for k in sorted(accuracies.keys()): + print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}') + for k in sorted(other_metrics.keys()): + if 'confusion' not in k: + print(f'\t\t{k} - {other_metrics[k]:.2f}') + + +def run_experiment(params, verbose=True): + loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params) + + model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] + model = get_multihead_model(*model_params) + opt = torch.optim.AdamW(model.parameters(), lr=params['lr']) + + + all_train_losses = [] + all_eval_losses = [] + all_train_accuracies = [] + all_eval_accuracies = [] + all_eval_other_metrics = [] + all_train_other_metrics = [] + best_loss = np.inf + model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + all_eval_losses.append(eval_losses) + all_eval_accuracies.append(eval_accuracies) + all_eval_other_metrics.append(eval_other_metrics) + if verbose: print(f'\n--------\nEpoch #0') + if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) + for epoch in range(params['nb_epochs']): + if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}') + model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics) + model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions, + weights=weights, params=params) + if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics) + if eval_losses['global_loss'] < best_loss: + best_loss = eval_losses['global_loss'] + if verbose: print(f'Saving new best model with loss {best_loss:.2f}') + torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save') + + # log + all_train_losses.append(train_losses) + all_train_accuracies.append(train_accuracies) + all_eval_losses.append(eval_losses) + all_eval_accuracies.append(eval_accuracies) + all_eval_other_metrics.append(eval_other_metrics) + all_train_other_metrics.append(train_other_metrics) + + # if epoch == params['nb_epoch_switch_beta']: + # params['beta_vae'] = 2.5 + # params['auxiliaries_dict']['prep_type']['weight'] /= 10 + # params['auxiliaries_dict']['glasses']['weight'] /= 10 + + if (epoch + 1) % params['plot_every'] == 0: + + plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, + all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights) + + return model + +def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics, + all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights): + + steps = np.arange(len(all_eval_accuracies)) + + loss_keys = sorted(all_train_losses[0].keys()) + acc_keys = sorted(all_train_accuracies[0].keys()) + metrics_keys = sorted(all_train_other_metrics[0].keys()) + + plt.figure() + plt.title('Train losses') + for k in loss_keys: + factor = 1 if k == 'mse_loss' else 1 + if k not in weights.keys(): + plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) + else: + if weights[k] != 0: + plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k) + + plt.legend() + plt.ylim([0, 4]) + plt.savefig(plot_path + 'train_losses.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train accuracies') + for k in acc_keys: + if weights[k] != 0: + plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'train_acc.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' in k: + plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Train other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' not in k: + plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k) + plt.legend() + plt.ylim([0, 15]) + plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval losses') + for k in loss_keys: + factor = 1 if k == 'mse_loss' else 1 + if k not in weights.keys(): + plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) + else: + if weights[k] != 0: + plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k) + plt.legend() + plt.ylim([0, 4]) + plt.savefig(plot_path + 'eval_losses.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval accuracies') + for k in acc_keys: + if weights[k] != 0: + plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'eval_acc.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' in k: + plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) + plt.legend() + plt.ylim([0, 1]) + plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.figure() + plt.title('Eval other metrics') + for k in metrics_keys: + if 'confusion' not in k and 'presence' not in k: + plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k) + plt.legend() + plt.ylim([0, 15]) + plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + + for k in metrics_keys: + if 'confusion' in k: + plt.figure() + plt.title(k) + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'eval_{k}.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + for k in metrics_keys: + if 'confusion' in k: + plt.figure() + plt.title(k) + plt.ylabel('True') + plt.xlabel('Predicted') + plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1) + plt.colorbar() + plt.savefig(plot_path + f'train_{k}.png', dpi=200) + fig = plt.gcf() + plt.close(fig) + + plt.close('all') + + +def get_model(model_path): + + with open(model_path + 'params.json', 'r') as f: + params = json.load(f) + params['save_path'] = model_path + model_chkpt = model_path + "checkpoint_best.save" + model_params = [params[k] for k in ["input_dim", "activation", "hidden_dims_cocktail", "latent_dim", "dropout", "auxiliaries_dict", "hidden_dims_decoder"]] + model = get_multihead_model(*model_params) + model.load_state_dict(torch.load(model_chkpt)) + model.eval() + max_ing_quantities = np.loadtxt(model_path + 'max_ing_quantities.txt') + def predict(ing_qs, aux_str): + ing_qs /= max_ing_quantities + input_model = torch.FloatTensor(ing_qs).reshape(1, -1) + _, outputs, auxiliaries_str = model.forward(input_model, ) + if isinstance(aux_str, str): + return outputs[auxiliaries_str.index(aux_str)].detach().numpy() + elif isinstance(aux_str, list): + return [outputs[auxiliaries_str.index(aux)].detach().numpy() for aux in aux_str] + else: + raise ValueError + return predict, params + + +def compute_expe_name_and_save_path(params): + weights_str = '[' + for aux in params['auxiliaries_dict'].keys(): + weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, ' + weights_str = weights_str[:-2] + ']' + save_path = params['save_path'] + params["trial_id"] + save_path += f'_lr{params["lr"]}' + save_path += f'_betavae{params["beta_vae"]}' + save_path += f'_bs{params["batch_size"]}' + save_path += f'_latentdim{params["latent_dim"]}' + save_path += f'_hding{params["hidden_dims_ingredients"]}' + save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}' + save_path += f'_hddecoder{params["hidden_dims_decoder"]}' + save_path += f'_agg{params["agg"]}' + save_path += f'_activ{params["activation"]}' + save_path += f'_w{weights_str}' + counter = 0 + while os.path.exists(save_path + f"_{counter}"): + counter += 1 + save_path = save_path + f"_{counter}" + '/' + params["save_path"] = save_path + os.makedirs(save_path) + os.makedirs(save_path + 'plots/') + params['plot_path'] = save_path + 'plots/' + print(f'logging to {save_path}') + return params + + + +if __name__ == '__main__': + params = get_params() + run_experiment(params) + diff --git a/src/cocktails/representation_learning/simple_model.py b/src/cocktails/representation_learning/simple_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9eeca3e546f0fde4bc5be6a8fadb24c171a88271 --- /dev/null +++ b/src/cocktails/representation_learning/simple_model.py @@ -0,0 +1,54 @@ +import torch; torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.distributions +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_activation(activation): + if activation == 'tanh': + activ = F.tanh + elif activation == 'relu': + activ = F.relu + elif activation == 'mish': + activ = F.mish + elif activation == 'sigmoid': + activ = torch.sigmoid + elif activation == 'leakyrelu': + activ = F.leaky_relu + elif activation == 'exp': + activ = torch.exp + else: + raise ValueError + return activ + + +class SimpleNet(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None): + super(SimpleNet, self).__init__() + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.output_dim = output_dim + dims = [input_dim] + hidden_dims + [output_dim] + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + if final_activ != None: + self.final_activ = get_activation(final_activ) + self.use_final_activ = True + else: + self.use_final_activ = False + + def forward(self, x): + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = layer(x) + if i_layer != self.n_layers - 1: + x = self.activation(dropout(x)) + if self.use_final_activ: x = self.final_activ(x) + return x + diff --git a/src/cocktails/representation_learning/vae_model.py b/src/cocktails/representation_learning/vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..45ea70aadaba8c5274bd82e87be7c9bba5d43b9e --- /dev/null +++ b/src/cocktails/representation_learning/vae_model.py @@ -0,0 +1,238 @@ +import torch; torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.distributions +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_activation(activation): + if activation == 'tanh': + activ = F.tanh + elif activation == 'relu': + activ = F.relu + elif activation == 'mish': + activ = F.mish + elif activation == 'sigmoid': + activ = F.sigmoid + elif activation == 'leakyrelu': + activ = F.leaky_relu + elif activation == 'exp': + activ = torch.exp + else: + raise ValueError + return activ + +class IngredientEncoder(nn.Module): + def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout): + super(IngredientEncoder, self).__init__() + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [input_dim] + hidden_dims + [deepset_latent_dim] + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + + def forward(self, x): + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = layer(x) + if i_layer != self.n_layers - 1: + x = self.activation(dropout(x)) + return x # do not use dropout on last layer? + +class DeepsetCocktailEncoder(nn.Module): + def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation, + hidden_dims_cocktail, latent_dim, aggregation, dropout): + super(DeepsetCocktailEncoder, self).__init__() + self.input_dim = input_dim # dimension of ingredient representation + quantity + self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately + self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation + self.aggregation = aggregation + self.latent_dim = latent_dim + # post aggregation network + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [deepset_latent_dim] + hidden_dims_cocktail + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim) + self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim) + self.softplus = nn.Softplus() + + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + + def forward(self, nb_ingredients, x): + + # reshape x in (batch size * nb ingredients, dim_ing_rep) + batch_size = x.shape[0] + all_ingredients = [] + for i in range(batch_size): + for j in range(nb_ingredients[i]): + all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1)) + x = torch.cat(all_ingredients, dim=0) + # encode ingredients in parallel + ingredients_encodings = self.ingredient_encoder(x) + assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim) + + # aggregate + x = [] + index_first = 0 + for i in range(batch_size): + index_last = index_first + nb_ingredients[i] + # aggregate + if self.aggregation == 'sum': + x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) + elif self.aggregation == 'mean': + x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) + else: + raise ValueError + index_first = index_last + x = torch.cat(x, dim=0) + assert x.shape[0] == batch_size + + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = self.activation(dropout(layer(x))) + mean = self.FC_mean(x) + logvar = self.FC_logvar(x) + return mean, logvar + +class Decoder(nn.Module): + def __init__(self, latent_dim, hidden_dims, num_ingredients, activation, dropout, filter_output=None): + super(Decoder, self).__init__() + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [latent_dim] + hidden_dims + [num_ingredients] + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + self.filter = filter_output + + def forward(self, x, to_filter=False): + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = layer(x) + if i_layer != self.n_layers - 1: + x = self.activation(dropout(x)) + if to_filter: + x = self.filter(x) + return x + +class PredictorHead(nn.Module): + def __init__(self, latent_dim, dim_output, final_activ): + super(PredictorHead, self).__init__() + self.linear = nn.Linear(latent_dim, dim_output) + if final_activ != None: + self.final_activ = get_activation(final_activ) + self.use_final_activ = True + else: + self.use_final_activ = False + + def forward(self, x): + x = self.linear(x) + if self.use_final_activ: x = self.final_activ(x) + return x + + +class VAEModel(nn.Module): + def __init__(self, encoder, decoder, auxiliaries_dict): + super(VAEModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.latent_dim = self.encoder.latent_dim + self.auxiliaries_str = [] + self.auxiliaries = nn.ModuleList() + for aux_str in sorted(auxiliaries_dict.keys()): + if aux_str == 'taste_reps': + self.taste_reps_decoder = PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ']) + else: + self.auxiliaries_str.append(aux_str) + self.auxiliaries.append(PredictorHead(self.latent_dim, auxiliaries_dict[aux_str]['dim_output'], auxiliaries_dict[aux_str]['final_activ'])) + + def reparameterization(self, mean, logvar): + std = torch.exp(0.5 * logvar) + epsilon = torch.randn_like(std).to(device) # sampling epsilon + z = mean + std * epsilon # reparameterization trick + return z + + + def sample(self, n=1): + z = torch.randn(size=(n, self.latent_dim)) + return self.decoder(z) + + def get_all_auxiliaries(self, x): + return [aux(x) for aux in self.auxiliaries] + + def get_auxiliary(self, z, aux_str): + if aux_str == 'taste_reps': + return self.taste_reps_decoder(z) + else: + index = self.auxiliaries_str.index(aux_str) + return self.auxiliaries[index](z) + + def forward_direct(self, x, aux_str=None, to_filter=False): + mean, logvar = self.encoder(x) + z = self.reparameterization(mean, logvar) # takes exponential function (log var -> std) + x_hat = self.decoder(mean, to_filter=to_filter) + if aux_str is not None: + return x_hat, z, mean, logvar, self.get_auxiliary(z, aux_str), [aux_str] + else: + return x_hat, z, mean, logvar, self.get_all_auxiliaries(z), self.auxiliaries_str + + def forward(self, nb_ingredients, x, aux_str=None, to_filter=False): + assert False + mean, std = self.encoder(nb_ingredients, x) + z = self.reparameterization(mean, std) # takes exponential function (log var -> std) + x_hat = self.decoder(mean, to_filter=to_filter) + if aux_str is not None: + return x_hat, z, mean, std, self.get_auxiliary(z, aux_str), [aux_str] + else: + return x_hat, z, mean, std, self.get_all_auxiliaries(z), self.auxiliaries_str + + + + +class SimpleEncoder(nn.Module): + + def __init__(self, input_dim, hidden_dims, latent_dim, activation, dropout): + super(SimpleEncoder, self).__init__() + self.latent_dim = latent_dim + # post aggregation network + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + dims = [input_dim] + hidden_dims + for d_in, d_out in zip(dims[:-1], dims[1:]): + self.linears.append(nn.Linear(d_in, d_out)) + self.dropouts.append(nn.Dropout(dropout)) + self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) + self.FC_logvar = nn.Linear(hidden_dims[-1], latent_dim) + # self.softplus = nn.Softplus() + + self.activation = get_activation(activation) + self.n_layers = len(self.linears) + self.layer_range = range(self.n_layers) + + def forward(self, x): + for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): + x = self.activation(dropout(layer(x))) + mean = self.FC_mean(x) + logvar = self.FC_logvar(x) + return mean, logvar + +def get_vae_model(input_dim, deepset_latent_dim, hidden_dims_ing, activation, + hidden_dims_cocktail, hidden_dims_decoder, num_ingredients, latent_dim, aggregation, dropout, auxiliaries_dict, + filter_decoder_output): + # encoder = DeepsetCocktailEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, + # hidden_dims_cocktail, latent_dim, aggregation, dropout) + encoder = SimpleEncoder(num_ingredients, hidden_dims_cocktail, latent_dim, activation, dropout) + decoder = Decoder(latent_dim, hidden_dims_decoder, num_ingredients, activation, dropout, filter_output=filter_decoder_output) + vae = VAEModel(encoder, decoder, auxiliaries_dict) + return vae \ No newline at end of file diff --git a/src/cocktails/utilities/__init__.py b/src/cocktails/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/utilities/__pycache__/__init__.cpython-39.pyc b/src/cocktails/utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92bce76c895553d4b6076b407ac2e109f50d28b2 Binary files /dev/null and b/src/cocktails/utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/cocktails/utilities/__pycache__/cocktail_category_detection_utilities.cpython-39.pyc b/src/cocktails/utilities/__pycache__/cocktail_category_detection_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60d3ad387b6f00ca9b357681707dd9882e97378d Binary files /dev/null and b/src/cocktails/utilities/__pycache__/cocktail_category_detection_utilities.cpython-39.pyc differ diff --git a/src/cocktails/utilities/__pycache__/cocktail_utilities.cpython-39.pyc b/src/cocktails/utilities/__pycache__/cocktail_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0398603bd15a36ce657cdffe53bd044ae489cb Binary files /dev/null and b/src/cocktails/utilities/__pycache__/cocktail_utilities.cpython-39.pyc differ diff --git a/src/cocktails/utilities/__pycache__/glass_and_volume_utilities.cpython-39.pyc b/src/cocktails/utilities/__pycache__/glass_and_volume_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd46db07d323b63faff5fcfe526260f19a0b5944 Binary files /dev/null and b/src/cocktails/utilities/__pycache__/glass_and_volume_utilities.cpython-39.pyc differ diff --git a/src/cocktails/utilities/__pycache__/ingredients_utilities.cpython-39.pyc b/src/cocktails/utilities/__pycache__/ingredients_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c56baf6a9b89cb7b24f4fed2e979f576a6340492 Binary files /dev/null and b/src/cocktails/utilities/__pycache__/ingredients_utilities.cpython-39.pyc differ diff --git a/src/cocktails/utilities/__pycache__/other_scrubbing_utilities.cpython-39.pyc b/src/cocktails/utilities/__pycache__/other_scrubbing_utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d1b496d57328ea556ed468b722848a007f8e811 Binary files /dev/null and b/src/cocktails/utilities/__pycache__/other_scrubbing_utilities.cpython-39.pyc differ diff --git a/src/cocktails/utilities/analysis_utilities.py b/src/cocktails/utilities/analysis_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..f364e9df88a5900bc258bba3e8196a7da576543b --- /dev/null +++ b/src/cocktails/utilities/analysis_utilities.py @@ -0,0 +1,189 @@ +import numpy as np +import matplotlib.pyplot as plt + +from src.cocktails.utilities.ingredients_utilities import ingredient_list, extract_ingredients, ingredients_per_type + +color_codes = dict(ancestral='#000000', + spirit_forward='#2320D2', + duo='#6E20D2', + champagne_cocktail='#25FFCA', + complex_highball='#068F25', + simple_highball='#25FF57', + collins='#77FF96', + julep='#25B8FF', + simple_sour='#FBD756', + complex_sour='#DCAD07', + simple_sour_with_juice='#FF5033', + complex_sour_with_juice='#D42306', + # simple_sour_with_egg='#FF9C54', + # complex_sour_with_egg='#CF5700', + # almost_simple_sor='#FF5033', + # almost_sor='#D42306', + # almost_sor_with_egg='#D42306', + other='#9B9B9B' + ) + +def get_subcategories(data): + subcategories = np.array(data['subcategory']) + sub_categories_list = sorted(set(subcategories)) + subcat_count = dict(zip(sub_categories_list, [0] * len(sub_categories_list))) + for sc in data['subcategory']: + subcat_count[sc] += 1 + return subcategories, sub_categories_list, subcat_count + +def get_ingredient_count(data): + ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list))) + for ing_str in data['ingredients_str']: + ingredients, _ = extract_ingredients(ing_str) + for ing in ingredients: + ingredient_counts[ing] += 1 + return ingredient_counts + +def compute_eucl_dist(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def recipe_contains(ingredients, stuff): + if stuff in ingredient_list: + return stuff in ingredients + elif stuff == 'juice': + return any(['juice' in ing and 'lemon' not in ing and 'lime' not in ing for ing in ingredients]) + elif stuff == 'bubbles': + return any([ing in ['soda', 'tonic', 'cola', 'sparkling wine', 'ginger beer'] for ing in ingredients]) + elif stuff == 'acid': + return any([ing in ['lemon juice', 'lime juice'] for ing in ingredients]) + elif stuff == 'vermouth': + return any([ing in ingredients_per_type['vermouth'] for ing in ingredients]) + elif stuff == 'plain sweet': + plain_sweet = ingredients_per_type['sweeteners'] + return any([ing in plain_sweet for ing in ingredients]) + elif stuff == 'sweet': + sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc'] + return any([ing in sweet for ing in ingredients]) + elif stuff == 'spirit': + return any([ing in ingredients_per_type['liquor'] for ing in ingredients]) + else: + raise ValueError + + + +def radar_factory(num_vars, frame='circle'): + # from stackoverflow's post? Or matplotlib's blog + """ + Create a radar chart with `num_vars` axes. + + This function creates a RadarAxes projection and registers it. + + Parameters + ---------- + num_vars : int + Number of variables for radar chart. + frame : {'circle', 'polygon'} + Shape of frame surrounding axes. + + """ + import numpy as np + + from matplotlib.patches import Circle, RegularPolygon + from matplotlib.path import Path + from matplotlib.projections.polar import PolarAxes + from matplotlib.projections import register_projection + from matplotlib.spines import Spine + from matplotlib.transforms import Affine2D + # calculate evenly-spaced axis angles + theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False) + + class RadarAxes(PolarAxes): + + name = 'radar' + # use 1 line segment to connect specified points + RESOLUTION = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # rotate plot such that the first axis is at the top + self.set_theta_zero_location('N') + + def fill(self, *args, closed=True, **kwargs): + """Override fill so that line is closed by default""" + return super().fill(closed=closed, *args, **kwargs) + + def plot(self, *args, **kwargs): + """Override plot so that line is closed by default""" + lines = super().plot(*args, **kwargs) + for line in lines: + self._close_line(line) + + def _close_line(self, line): + x, y = line.get_data() + # FIXME: markers at x[0], y[0] get doubled-up + if x[0] != x[-1]: + x = np.append(x, x[0]) + y = np.append(y, y[0]) + line.set_data(x, y) + + def set_varlabels(self, labels): + self.set_thetagrids(np.degrees(theta), labels) + + def _gen_axes_patch(self): + # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5 + # in axes coordinates. + if frame == 'circle': + return Circle((0.5, 0.5), 0.5) + elif frame == 'polygon': + return RegularPolygon((0.5, 0.5), num_vars, + radius=.5, edgecolor="k") + else: + raise ValueError("Unknown value for 'frame': %s" % frame) + + def _gen_axes_spines(self): + if frame == 'circle': + return super()._gen_axes_spines() + elif frame == 'polygon': + # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'. + spine = Spine(axes=self, + spine_type='circle', + path=Path.unit_regular_polygon(num_vars)) + # unit_regular_polygon gives a polygon of radius 1 centered at + # (0, 0) but we want a polygon of radius 0.5 centered at (0.5, + # 0.5) in axes coordinates. + spine.set_transform(Affine2D().scale(.5).translate(.5, .5) + + self.transAxes) + return {'polar': spine} + else: + raise ValueError("Unknown value for 'frame': %s" % frame) + + register_projection(RadarAxes) + return theta + +def plot_radar_cocktail(representation, labels_dim, labels_cocktails, save_path=None, to_show=False, to_save=False): + assert to_show or to_save, 'either show or save' + assert representation.ndim == 2 + n_data, dim_rep = representation.shape + assert len(labels_cocktails) == n_data + assert len(labels_dim) == dim_rep + assert n_data <= 5, 'max 5 representation_analysis please' + + theta = radar_factory(dim_rep, frame='circle') + + + fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar')) + fig.subplots_adjust(wspace=0.25, hspace=0.20, top=0.85, bottom=0.05) + + colors = ['b', 'r', 'g', 'm', 'y'] + # Plot the four cases from the example data on separate axes + ax.set_rgrids([0.2, 0.4, 0.6, 0.8]) + for d, color in zip(representation, colors): + ax.plot(theta, d, color=color) + for d, color in zip(representation, colors): + ax.fill(theta, d, facecolor=color, alpha=0.25) + ax.set_varlabels(labels_dim) + + # add legend relative to top-left plot + legend = ax.legend(labels_cocktails, loc=(0.9, .95), + labelspacing=0.1, fontsize='small') + + if to_save: + plt.savefig(save_path, bbox_artists=(legend,), dpi=200) + else: + plt.show() + diff --git a/src/cocktails/utilities/cocktail_category_detection_utilities.py b/src/cocktails/utilities/cocktail_category_detection_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..5846f86263a1b2a83554d9c2ea5039ccd83555cb --- /dev/null +++ b/src/cocktails/utilities/cocktail_category_detection_utilities.py @@ -0,0 +1,221 @@ +# The following functions check whether a cocktail belong to any of N categories +import numpy as np +from src.cocktails.utilities.ingredients_utilities import ingredient_profiles, ingredients_per_type, ingredient2ingredient_id, extract_ingredients + + +def is_ancestral(n, ingredient_indexes, ingredients, quantities): + # ancestrals have a strong spirit and some sweetness from sugar, syrup or liqueurs, no citrus. + # absinthe can be added up to 3 dashes. + # Liqueurs are there to bring sweetness, thus must stay below 15ml (if not it's a duo) + if n['spirit'] > 0 and n['citrus'] == 0 and n['plain_sweet'] + n['liqueur'] <= 2: + if n['spirit'] > 1 and 'absinthe' in ingredients: + if quantities[ingredients.index('absinthe')] < 3: + pass + else: + return False + if n['sugar'] < 2 and n['liqueur'] < 3: + if n['all'] - n['spirit'] - n['sugar'] -n['syrup']- n['liqueur']- n['inconsequentials'] == 0: + if n['liqueur'] == 0: + return True + else: + q_liqueur = np.sum([quantities[i_ing] + for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients))) + if ingredient_profiles['type'][i_ind].lower() == 'liqueur']) + if q_liqueur <= 15: + return True + else: + return False + return False + + +def is_simple_sour(n, ingredient_indexes, ingredients, quantities): + # simple sours contain a citrus, at least 1 spirit and non-alcoholic sweetness + if n['citrus'] + n['coffee']> 0 and n['spirit'] > 0 and n['plain_sweet'] > 0 and n['juice'] == 0: + if n['all'] - n['citrus'] - n['coffee'] - n['spirit'] - n['plain_sweet'] - n['juice'] -n['egg'] - n['inconsequentials'] == 0: + return True + return False + +def is_complex_sour(n, ingredient_indexes, ingredients, quantities): + # complex sours are simple sours that use alcoholic sweetness, at least in part + if n['citrus'] + n['coffee'] > 0 and n['all_sweet'] > 0 and n['juice'] == 0: + if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0: + if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0: + if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \ + - n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0: + return True + return False + +def is_spirit_forward(n, ingredient_indexes, ingredients, quantities): + # spirit forward contain at least a spirit and vermouth, no citrus. Can contain sweet (sugar, syrups, liqueurs) + if n['spirit'] > 0 and n['citrus'] == 0 and n['vermouth'] > 0: + if n['all'] - n['spirit'] - n['sugar'] - n['syrup'] - n['liqueur'] -n['egg'] - n['vermouth'] - n['inconsequentials']== 0: + return True + return False + +def is_duo(n, ingredient_indexes, ingredients, quantities): + # duos are made of one spirit and one liqueur (above 15ml), under it's an ancestral, no citrus. + if n['spirit'] >= 1 and n['citrus'] == 0 and n['sugar']==0 and n['liqueur'] > 0 and n['vermouth'] == 0: + if n['all'] - n['spirit'] - n['sugar'] - n['liqueur'] - n['vermouth'] - n['inconsequentials'] == 0: + q_liqueur = np.sum([quantities[i_ing] + for i_ind, i_ing in zip(ingredient_indexes, range(len(ingredients))) + if ingredient_profiles['type'][i_ind].lower() == 'liqueur']) + if q_liqueur > 15: + return True + else: + return False + return False + +def is_champagne_cocktail(n, ingredient_indexes, ingredients, quantities): + if n['sparkling'] > 0: + return True + else: + return False + +def is_simple_highball(n, ingredient_indexes, ingredients, quantities): + # simple highballs have one alcoholic ingredient and bubbles + if n['alcoholic'] == 1 and n['bubbles'] > 0: + if n['all'] - n['alcoholic'] - n['bubbles'] - n['inconsequentials']== 0: + return True + return False + +def is_complex_highball(n, ingredient_indexes, ingredients, quantities): + # complex highballs have at least one alcoholic ingredient and bubbles (possibly alcoholic). They also contain extra sugar under any form and juice + if n['alcoholic'] > 0 and (n['bubbles'] + n['sparkling']) == 1 and n['juice'] + n['all_sweet'] + n['sugar_bubbles']> 0: + if n['all'] - n['spirit'] - n['bubbles'] - n['sparkling'] - n['citrus'] - n['juice'] - n['liqueur'] \ + - n['syrup'] - n['sugar'] -n['vermouth'] -n['egg'] - n['inconsequentials'] == 0: + if not is_collins(n, ingredient_indexes, ingredients, quantities) and not is_simple_highball(n, ingredient_indexes, ingredients, quantities): + return True + return False + +def is_collins(n, ingredient_indexes, ingredients, quantities): + # collins are a particular kind of highball with sugar and citrus + if n['alcoholic'] == 1 and n['bubbles'] == 1 and n['citrus'] > 0 and n['plain_sweet'] + n['sugar_bubbles'] > 0: + if n['all'] - n['spirit'] - n['bubbles'] - n['citrus'] - n['sugar'] - n['inconsequentials'] == 0: + return True + return False + +def is_julep(n, ingredient_indexes, ingredients, quantities): + # juleps involve smashd mint, sugar and a spirit, no citrus. + if 'mint' in ingredients and n['sugar'] > 0 and n['spirit'] > 0 and n['vermouth'] == 0 and n['citrus'] == 0: + return True + return False + +def is_simple_sour_with_juice(n, ingredient_indexes, ingredients, quantities): + # almost sours are sours with juice + if n['juice'] > 0 and n['spirit'] > 0 and n['plain_sweet'] > 0: + if n['all'] - n['citrus'] - n['coffee'] - n['juice'] - n['spirit'] - n['sugar'] - n['syrup'] - n['egg'] - n['inconsequentials'] == 0: + return True + return False + + +def is_complex_sour_with_juice(n, ingredient_indexes, ingredients, quantities): + # almost sours are sours with juice + if n['juice'] > 0 and n['all_sweet'] > 0: + if (n['spirit'] == 0 and n['liqueur'] > 0) or n['spirit'] > 0: + if n['vermouth'] + n['liqueur'] <= 2 and n['vermouth'] + n['liqueur'] > 0: + if n['all'] -n['coffee'] - n['citrus'] - n['spirit'] - n['sugar'] - n['syrup'] \ + - n['liqueur'] - n['vermouth'] - n['egg'] - n['juice'] - n['inconsequentials'] == 0: + return True + return False + + +is_sub_category = [is_ancestral, is_complex_sour, is_simple_sour, is_duo, is_champagne_cocktail, + is_spirit_forward, is_simple_highball, is_complex_highball, is_collins, + is_julep, is_simple_sour_with_juice, is_complex_sour_with_juice] +sub_categories = ['ancestral', 'complex_sour', 'simple_sour', 'duo', 'champagne_cocktail', + 'spirit_forward', 'simple_highball', 'complex_highball', 'collins', + 'julep', 'simple_sour_with_juice', 'complex_sour_with_juice'] + + +# compute cocktail category as a function of ingredients and quantities, uses name to check match between name and cat (e.g. XXX Collins should be collins..) +# Categories definitions are based on https://www.seriouseats.com/cocktail-style-guide-categories-of-cocktails-glossary-families-of-drinks +def find_cocktail_sub_category(ingredients, quantities, name=None): + ingredient_indexes = [ingredient2ingredient_id[ing] for ing in ingredients] + n_spirit = np.sum([ingredient_profiles['type'][i].lower() == 'liquor' for i in ingredient_indexes ]) + n_citrus = np.sum([ingredient_profiles['type'][i].lower()== 'acid' for i in ingredient_indexes]) + n_sugar = np.sum([ingredient_profiles['ingredient'][i].lower() in ['double syrup', 'simple syrup', 'honey syrup'] for i in ingredient_indexes]) + plain_sweet = ingredients_per_type['sweeteners'] + all_sweet = ingredients_per_type['sweeteners'] + ingredients_per_type['liqueur'] + ['sweet vermouth', 'lillet blanc'] + n_plain_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in plain_sweet for i in ingredient_indexes]) + n_all_sweet = np.sum([ingredient_profiles['ingredient'][i].lower() in all_sweet for i in ingredient_indexes]) + n_sugar_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['cola', 'ginger beer', 'tonic'] for i in ingredient_indexes]) + n_juice = np.sum([ingredient_profiles['type'][i].lower() == 'juice' for i in ingredient_indexes]) + n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes]) + alcoholic = ingredients_per_type['liquor'] + ingredients_per_type['liqueur'] + ingredients_per_type['vermouth'] + n_alcoholic = np.sum([ingredient_profiles['ingredient'][i].lower() in alcoholic for i in ingredient_indexes]) + n_bitter = np.sum([ingredient_profiles['type'][i].lower() == 'bitters' for i in ingredient_indexes]) + n_egg = np.sum([ingredient_profiles['ingredient'][i].lower() == 'egg' for i in ingredient_indexes]) + n_vermouth = np.sum([ingredient_profiles['type'][i].lower() == 'vermouth' for i in ingredient_indexes]) + n_sparkling = np.sum([ingredient_profiles['ingredient'][i].lower() == 'sparkling wine' for i in ingredient_indexes]) + n_bubbles = np.sum([ingredient_profiles['ingredient'][i].lower() in ['soda', 'tonic', 'cola', 'ginger beer'] for i in ingredient_indexes]) + n_syrup = np.sum([ingredient_profiles['ingredient'][i].lower() in ['grenadine', 'raspberry syrup'] for i in ingredient_indexes]) + n_coffee = np.sum([ingredient_profiles['ingredient'][i].lower() == 'espresso' for i in ingredient_indexes]) + inconsequentials = ['water', 'salt', 'angostura', 'orange bitters', 'mint'] + n_inconsequentials = np.sum([ingredient_profiles['ingredient'][i].lower() in inconsequentials for i in ingredient_indexes]) + n = dict(all=len(ingredients), + inconsequentials=n_inconsequentials, + sugar_bubbles=n_sugar_bubbles, + bubbles=n_bubbles, + plain_sweet=n_plain_sweet, + all_sweet=n_all_sweet, + coffee=n_coffee, + alcoholic=n_alcoholic, + syrup=n_syrup, + sparkling=n_sparkling, + sugar=n_sugar, + spirit=n_spirit, + citrus=n_citrus, + juice=n_juice, + liqueur=n_liqueur, + bitter=n_bitter, + egg=n_egg, + vermouth=n_vermouth) + + sub_cats = [c for c, test_c in zip(sub_categories, is_sub_category) if test_c(n, ingredient_indexes, ingredients, quantities)] + if name != None: + name = name.lower() + keywords_to_test = ['julep', 'collins', 'highball', 'sour', 'champagne'] + for k in keywords_to_test: + if k in name and not any([k in cat for cat in sub_cats]): + print(k) + for ing, q in zip(ingredients, quantities): + print(f'{ing}: {q} ml') + print(n) + break + if sorted(sub_cats) == ['champagne_cocktail', 'complex_highball']: + sub_cats = ['champagne_cocktail'] + elif sorted(sub_cats) == ['collins', 'complex_highball']: + sub_cats = ['collins'] + elif sorted(sub_cats) == ['champagne_cocktail', 'complex_highball', 'julep']: + sub_cats = ['champagne_cocktail'] + elif sorted(sub_cats) == ['ancestral', 'julep']: + sub_cats = ['julep'] + elif sorted(sub_cats) == ['complex_highball', 'julep']: + sub_cats = ['complex_highball'] + elif sorted(sub_cats) == ['julep', 'simple_sour_with_juice']: + sub_cats = ['simple_sour_with_juice'] + elif sorted(sub_cats) == ['complex_sour_with_juice', 'julep']: + sub_cats = ['complex_sour_with_juice'] + if len(sub_cats) != 1: + # print(sub_cats) + # for ing, q in zip(ingredients, quantities): + # print(f'{ing}: {q} ml') + # print(n) + # if len(sub_cats) == 0: + sub_cats = ['other'] + assert len(sub_cats) == 1, sub_cats + return sub_cats[0], n + +def get_cocktails_attributes(ing_strs): + attributes = dict() + cats = [] + for ing_str in ing_strs: + ingredients, quantities = extract_ingredients(ing_str) + cat, atts = find_cocktail_sub_category(ingredients, quantities) + for k in atts.keys(): + if k not in attributes.keys(): + attributes[k] = [atts[k]] + else: + attributes[k].append(atts[k]) + cats.append(cat) + return cats, attributes diff --git a/src/cocktails/utilities/cocktail_generation_utilities/__init__.py b/src/cocktails/utilities/cocktail_generation_utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/__init__.cpython-39.pyc b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a7569b42a51e32cd95929356e8f9df68bcbafa Binary files /dev/null and b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/individual.cpython-39.pyc b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/individual.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b2d4eb53392a1548eb100e25f260d858bd932f4 Binary files /dev/null and b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/individual.cpython-39.pyc differ diff --git a/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/population.cpython-39.pyc b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/population.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f40a6a51ba08b01a0c7ca3c0f2051fdc2236a34 Binary files /dev/null and b/src/cocktails/utilities/cocktail_generation_utilities/__pycache__/population.cpython-39.pyc differ diff --git a/src/cocktails/utilities/cocktail_generation_utilities/individual.py b/src/cocktails/utilities/cocktail_generation_utilities/individual.py new file mode 100644 index 0000000000000000000000000000000000000000..a56c2f2844a34274d03eca9205b4302d6a8c7ab3 --- /dev/null +++ b/src/cocktails/utilities/cocktail_generation_utilities/individual.py @@ -0,0 +1,587 @@ +from src.cocktails.utilities.ingredients_utilities import get_ingredients_info, format_ingredients, extract_ingredients, ingredients_per_type, bubble_ingredients +import numpy as np +from src.cocktails.utilities.other_scrubbing_utilities import print_recipe +from src.cocktails.utilities.cocktail_utilities import get_cocktail_rep, get_profile, get_bunch_of_rep_keys +from src.cocktails.utilities.glass_and_volume_utilities import glass_volume +from src.cocktails.representation_learning.run import get_model +from src.cocktails.pipeline.get_cocktail2affective_cluster import get_cocktail2affective_cluster +from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, REPO_PATH, COCKTAIL_REP_CHKPT_PATH, RECIPE2FEATURES_PATH +from src.cocktails.representation_learning.run_without_vae import get_model +from src.cocktails.utilities.cocktail_category_detection_utilities import find_cocktail_sub_category + +import pandas as pd +import torch +import time +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +density_ingredients = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'density_ingredients.txt') +max_ingredients, ingredient_list, ind_alcohol = get_ingredients_info() +min_ingredients = 2 +factor_max = 1.2 # generated recipes can go up to 1.2 times the max quantity of the ingredient found in the dataset + +prep_model = get_model(RECIPE2FEATURES_PATH + 'multi_predictor/')[0] + +all_rep_path = FULL_COCKTAIL_REP_PATH +all_reps = np.loadtxt(all_rep_path) +experiment_dir = REPO_PATH + '/experiments/cocktails/' +rep_keys = get_bunch_of_rep_keys()['custom'] +dict_weights_mse_computation = {'end volume': .1, 'end sour': 2, 'end sweet': 2, 'end booze': 4, 'end bitter': 2, 'end fruit': 1, 'end herb': 1, + 'end complex': 1, 'end spicy': 5, 'end oaky': 1, 'end fizzy': 10, 'end colorful': 1, 'end eggy': 10} +assert sorted(dict_weights_mse_computation.keys()) == sorted(rep_keys) +weights_mse_computation = np.array([dict_weights_mse_computation[k] for k in rep_keys]) +weights_mse_computation /= weights_mse_computation.sum() +data = pd.read_csv(COCKTAILS_CSV_DATA) +preparation_list = sorted(set(data['category'])) +glasses_list = sorted(set(data['glass'])) + +weights_perf_n_ing = {2:0.71, 3:0.81, 4:0.93, 5:1., 6:1.03, 7:1.08, 8:1.05} + +# weights_perf_n_ing = {2:0.75, 3:0.8, 4:0.95, 5:1.05, 6:1.05, 7:1.05, 8:1.05} +min_ingredients_quantities_when_present = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities_when_present.txt') +min_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'ingredients_min_quantities.txt') +max_ingredients_quantities = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'ingredients_max_quantities.txt') +min_cocktail_rep, max_cocktail_rep = np.loadtxt(COCKTAIL_REP_CHKPT_PATH +'cocktail_minmax_dim13_customkeys.txt') +distrib_nb_ings_2_8 = np.loadtxt(COCKTAIL_REP_CHKPT_PATH + 'distrib_nb_ing.txt')[2:] +def normalize_cocktail(cocktail_rep): + return ((cocktail_rep - min_cocktail_rep) / (max_cocktail_rep - min_cocktail_rep) - 0.5) * 2 + +def denormalize_cocktail(cocktail_rep): + return (cocktail_rep / 2 + 0.5) * (max_cocktail_rep - min_cocktail_rep) + min_cocktail_rep + +def normalize_ingredient_q_rep(ingredients_q): + return (ingredients_q - min_ingredients_quantities_when_present) / (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present) + +COCKTAIL_REPS = normalize_cocktail(np.array([data[k] for k in rep_keys]).transpose()) +assert np.abs(COCKTAIL_REPS - all_reps).sum() < 1e-8 + +cocktail2affective_cluster = get_cocktail2affective_cluster() + +original_affective_keys = get_bunch_of_rep_keys()['affective'] +def sigmoid(x, shift, beta): + return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 + +def get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(cocktail_rep): + indexes = np.array([rep_keys.index(key) for key in original_affective_keys]) + cocktail_rep = cocktail_rep[indexes] + cocktail_rep[0] = sigmoid(cocktail_rep[0], shift=0.05, beta=4) + cocktail_rep[1] = sigmoid(cocktail_rep[1], shift=0.3, beta=5) + cocktail_rep[2] = sigmoid(cocktail_rep[2], shift=0.15, beta=3) + cocktail_rep[3] = sigmoid(cocktail_rep[3], shift=0.9, beta=20) + cocktail_rep[4] = sigmoid(cocktail_rep[4], shift=0, beta=4) + cocktail_rep[5] = sigmoid(cocktail_rep[5], shift=0.2, beta=3) + cocktail_rep[6] = sigmoid(cocktail_rep[6], shift=0.5, beta=5) + cocktail_rep[7] = sigmoid(cocktail_rep[7], shift=0.2, beta=6) + return cocktail_rep + +class IndividualCocktail(): + def __init__(self, pop_params, target, target_affective_cluster, genes_presence=None, genes_quantity=None, + compute_perf=True, known_target_dict=None, run_hard_check=False): + + self.pop_params = pop_params + self.n_genes = len(ingredient_list) + self.max_ingredients = max_ingredients + self.min_ingredients = min_ingredients + self.mutation_params = pop_params['mutation_params'] + self.dist = pop_params['dist'] + self.target = target + self.is_known = known_target_dict is not None + self.known_target_dict = known_target_dict + self.perf = None + self.cocktail_rep = None + self.affective_cluster = None + self.target_affective_cluster = target_affective_cluster + self.ing_list = np.array(ingredient_list) + self.ing_set = set(ingredient_list) + + self.ing_ids_per_cat = dict(bubbles=set(self.get_ingredients_ids_from_list(bubble_ingredients)), + liquor=set(self.get_ingredients_ids_from_list(ingredients_per_type['liquor'])), + liqueur=set(self.get_ingredients_ids_from_list(ingredients_per_type['liqueur'])), + citrus=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'] + ['orange juice'])), + alcohol=set(ind_alcohol), + sweeteners=set(self.get_ingredients_ids_from_list(ingredients_per_type['sweeteners'])), + vermouth=set(self.get_ingredients_ids_from_list(ingredients_per_type['vermouth'])), + bitters=set(self.get_ingredients_ids_from_list(ingredients_per_type['bitters'])), + juice=set(self.get_ingredients_ids_from_list(ingredients_per_type['juice'])), + acid=set(self.get_ingredients_ids_from_list(ingredients_per_type['acid'])), + egg=set(self.get_ingredients_ids_from_list(['egg'])) + ) + + if genes_presence is not None: + assert len(genes_presence) == self.n_genes + assert len(genes_quantity) == self.n_genes + self.genes_presence = genes_presence + self.genes_quantity = genes_quantity + if compute_perf: + self.compute_cocktail_rep() + self.compute_perf() + else: + self.sample_initial_genes() + self.compute_cocktail_rep() + # self.make_recipe_fit_the_glass() + self.compute_perf() + + + # # # # # # # # # # # # # # # # # # # # # # # # + # Sample initial genes with smart rules + # # # # # # # # # # # # # # # # # # # # # # # # + + def sample_initial_genes(self): + # rules: + # - between min_ingredients and max_ingredients + # - at most one type of bubbles + # - at least one alcohol + # - no egg without lime or lemon + # - at most two liqueurs + # - at most three liquors + # - at most two sweetener + self.genes_quantity = np.random.uniform(0, 1, size=self.n_genes) # holds quantities for each ingredient + n_ingredients = np.random.choice(np.arange(min_ingredients, max_ingredients + 1), p=distrib_nb_ings_2_8) + self.genes_presence = np.zeros(self.n_genes) + # add one alchohol + self.genes_presence[np.random.choice(ind_alcohol)] = 1 + while self.get_ing_count() < n_ingredients: + candidate_ids = self.get_candidate_ingredients_ids(self.genes_presence) + probas = density_ingredients[candidate_ids] / np.sum(density_ingredients[candidate_ids]) + self.genes_presence[np.random.choice(candidate_ids, p=probas)] = 1 + + def get_candidate_ingredients_ids(self, genes_presence): + candidates = set(np.argwhere(genes_presence==0).flatten()) + present_ids = set(np.argwhere(genes_presence==1).flatten()) + + if self.count_in_genes(present_ids, 'bubbles') >= 1: # at most one type of bubbles + candidates = candidates - self.ing_ids_per_cat['bubbles'] + if self.count_in_genes(present_ids, 'liquor') >= 3: # at most three liquors + candidates = candidates - self.ing_ids_per_cat['liquor'] + if self.count_in_genes(present_ids, 'liqueur') >= 2: # at most two liqueurs + candidates = candidates - self.ing_ids_per_cat['liqueur'] + if self.count_in_genes(present_ids, 'sweeteners') >= 2: # at most two sweetener + candidates = candidates - self.ing_ids_per_cat['sweeteners'] + if self.count_in_genes(present_ids, 'citrus') == 0: # no egg without lime or lemon + candidates = candidates - self.ing_ids_per_cat['egg'] + return np.array(sorted(candidates)) + + def count_in_genes(self, present_ids, keyword): + if keyword == 'citrus': return len(present_ids & self.ing_ids_per_cat['citrus']) + elif keyword == 'bubbles': return len(present_ids & self.ing_ids_per_cat['bubbles']) + elif keyword == 'liquor': return len(present_ids & self.ing_ids_per_cat['liquor']) + elif keyword == 'liqueur': return len(present_ids & self.ing_ids_per_cat['liqueur']) + elif keyword == 'alcohol': return len(present_ids & self.ing_ids_per_cat['alcohol']) + elif keyword == 'sweeteners': return len(present_ids & self.ing_ids_per_cat['sweeteners']) + else: raise ValueError + + def get_ingredients_ids_from_list(self, ing_list): + return [ingredient_list.index(ing) for ing in ing_list] + + def get_ing_count(self): + return np.sum(self.genes_presence) + + # # # # # # # # # # # # # # # # # # # # # # # # + # Compute cocktail representations + # # # # # # # # # # # # # # # # # # # # # # # # + + def get_absent_ing(self): + return np.argwhere(self.genes_presence==0).flatten() + + def get_present_ing(self): + return np.argwhere(self.genes_presence==1).flatten() + + def get_ingredient_quantities(self): + # unnormalize quantities to get real ones + return (self.genes_quantity * (max_ingredients_quantities * factor_max - min_ingredients_quantities_when_present) + min_ingredients_quantities_when_present) * self.genes_presence + + def get_ing_and_q_from_genes(self): + present_ings = self.get_present_ing() + ing_quantities = self.get_ingredient_quantities() + ingredients, quantities = [], [] + for i_ing in present_ings: + ingredients.append(ingredient_list[i_ing]) + quantities.append(ing_quantities[i_ing]) + return ingredients, quantities, ing_quantities + + def compute_cocktail_rep(self): + # only call when genes have changes + init_time = time.time() + ingredients, quantities, ing_quantities = self.get_ing_and_q_from_genes() + # compute cocktail category + self.category = find_cocktail_sub_category(ingredients, quantities)[0] + # print(f't1: {time.time() - init_time}') + init_time = time.time() + self.prep_type = self.get_prep_type(ing_quantities) + # print(f't2: {time.time() - init_time}') + init_time = time.time() + cocktail_rep, self.end_volume, self.end_alcohol = get_cocktail_rep(self.prep_type, ingredients, quantities, keys=rep_keys[1:]) # volume is added later + # print(f't3: {time.time() - init_time}') + init_time = time.time() + self.cocktail_rep = normalize_cocktail(cocktail_rep) + # print(f't4: {time.time() - init_time}') + init_time = time.time() + self.glass = self.get_glass_type(ing_quantities) + # print(f't5: {time.time() - init_time}') + init_time = time.time() + if self.is_known: + assert np.abs(self.cocktail_rep - self.target).sum() < 1e-6 + return self.cocktail_rep + + def get_prep_type(self, quantities=None): + if self.is_known: return self.known_target_dict['prep_type'] + else: + if quantities is None: + quantities = self.get_ingredient_quantities() + if quantities[ingredient_list.index('egg')] > 0: + prep_cat = 'egg_shaken' + elif self.category in ['spirit_forward', 'simple_sour_with_juice', 'julep', 'duo', 'ancestral', 'complex_sour_with_juice']: + # use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script + if self.category in ['ancestral', 'spirit_forward', 'duo']: + prep_cat = 'stirred' + elif self.category in ['complex_sour_with_juice', 'julep', 'simple_sour_with_juice']: + prep_cat = 'shaken' + else: + raise ValueError + else: + output = prep_model(quantities, aux_str='prep_type').flatten() + output[preparation_list.index('egg_shaken')] = -np.inf + prep_cat = preparation_list[np.argmax(output)] + return prep_cat + + def get_glass_type(self, quantities=None): + if self.is_known: return self.known_target_dict['glass'] + else: + if self.category in ['collins', 'complex_highball', 'simple_highball', 'champagne_cocktail', 'complex_sour']: + # use hard coded rules for most obvious cases determined with the correlations_glass_cat_prep_script + if self.category in ['collins', 'complex_highball', 'simple_highball']: + glass = 'collins' + elif self.category in ['champagne_cocktail', 'complex_sour']: + glass = 'coupe' + else: + if quantities is None: + quantities = self.get_ingredient_quantities() + output = prep_model(quantities, aux_str='glasses').flatten() + glass = glasses_list[np.argmax(output)] + return glass + + # # # # # # # # # # # # # # # # # # # # # # # # + # Adapt recipe to fit the glass + # # # # # # # # # # # # # # # # # # # # # # # # + + def is_too_large_for_glass(self): + return self.end_volume > glass_volume[self.glass] * 0.80 + + def is_too_small_for_glass(self): + return self.end_volume < glass_volume[self.glass] * 0.3 + + def scale_ing_quantities(self, present_ings, factor): + qs = self.get_ingredient_quantities().copy() + qs[present_ings] *= factor + self.set_genes_from_quantities(present_ings, qs) + + def set_genes_from_quantities(self, present_ings, quantities): + genes_quantity = np.clip((quantities - min_ingredients_quantities_when_present) / + (factor_max * max_ingredients_quantities - min_ingredients_quantities_when_present), 0, 1) + self.genes_quantity[present_ings] = genes_quantity[present_ings] + + def make_recipe_fit_the_glass(self): + # check if citrus, if not remove egg + present_ids = np.argwhere(self.genes_presence == 1).flatten() + ing_list = self.ing_list[present_ids] + present_ids = set(present_ids) + if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list: + if self.genes_presence.sum() > 2: + i_egg = ingredient_list.index('egg') + self.genes_presence[i_egg] = 0. + self.compute_cocktail_rep() + + + i_trial = 0 + present_ings = self.get_present_ing() + while self.is_too_large_for_glass(): + i_trial += 1 + end_volume = self.end_volume + desired_volume = glass_volume[self.glass] * 0.80 + ratio = desired_volume / end_volume + self.scale_ing_quantities(present_ings, factor=ratio) + self.compute_cocktail_rep() + if end_volume == self.end_volume: break + if i_trial == 10: break + while self.is_too_small_for_glass(): + i_trial += 1 + end_volume = self.end_volume + desired_volume = glass_volume[self.glass] * 0.80 + ratio = desired_volume / end_volume + self.scale_ing_quantities(present_ings, factor=ratio) + self.compute_cocktail_rep() + if end_volume == self.end_volume: break + if i_trial == 10: break + + # # # # # # # # # # # # # # # # # # # # # # # # + # Compute performance + # # # # # # # # # # # # # # # # # # # # # # # # + + def passes_checks(self): + present_ids = np.argwhere(self.genes_presence==1).flatten() + # ing_list = self.ing_list[present_ids] + present_ids = set(present_ids) + if len(present_ids) < 2 or len(present_ids) > 8: return False + # if self.is_too_large_for_glass(): return False + # if self.is_too_small_for_glass(): return False + if self.end_alcohol < 0.05 or self.end_alcohol > 0.31: return False + if self.count_in_genes(present_ids, 'sweeteners') > 2: return False + if self.count_in_genes(present_ids, 'liqueur') > 2: return False + if self.count_in_genes(present_ids, 'liquor') > 3: return False + # if self.count_in_genes(present_ids, 'citrus') == 0 and 'egg' in ing_list: return False + if self.count_in_genes(present_ids, 'bubbles') > 1: return False + else: return True + + def get_affective_cluster(self): + cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(self.cocktail_rep) + self.affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0] + return self.affective_cluster + + def does_affective_cluster_match(self): + return True#self.get_affective_cluster() == self.target_affective_cluster + + def compute_perf(self): + if not self.passes_checks(): self.perf = -100 + else: + if self.dist == 'mse': + # self.perf = - np.sqrt(((self.cocktail_rep - self.target)**2).mean()) + self.perf = - np.sqrt(np.dot((self.cocktail_rep - self.target)**2, weights_mse_computation)) + self.perf *= weights_perf_n_ing[int(self.genes_presence.sum())] + if not self.does_affective_cluster_match(): + self.perf *= 2 + else: raise NotImplemented + + + # # # # # # # # # # # # # # # # # # # # # # # # + # Mutations and crossover + # # # # # # # # # # # # # # # # # # # # # # # # + + def get_child(self): + time_dict = dict() + init_time = time.time() + child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster, + target=self.target, genes_presence=self.genes_presence.copy(), + genes_quantity=self.genes_quantity.copy(), compute_perf=False) + time_dict[' asexual child creation'] = [time.time() - init_time] + init_time = time.time() + this_time_dict = child.mutate() + time_dict = self.update_time_dict(time_dict, this_time_dict) + time_dict[' asexual child mutation'] = [time.time() - init_time] + return child, time_dict + + def get_child_with(self, other_parent): + time_dict = dict() + init_time = time.time() + new_genes_presence = np.zeros(self.n_genes) + present_ing = self.get_present_ing() + other_present_ing = other_parent.get_present_ing() + new_genes_quantity = np.random.uniform(0, 1, size=self.n_genes) + shared_ingredients = sorted(set(present_ing) & set(other_present_ing)) + unique_ingredients_one = sorted(set(present_ing) - set(other_present_ing)) + unique_ingredients_two = sorted(set(other_present_ing) - set(present_ing)) + for i in shared_ingredients: + new_genes_presence[i] = 1 + new_genes_quantity[i] = (self.genes_quantity[i] + other_parent.genes_quantity[i]) / 2 + time_dict[' crossover child creation'] = [time.time() - init_time] + init_time = time.time() + # add one alcohol if none present + if len(set(np.argwhere(new_genes_presence==1).flatten()).intersection(ind_alcohol)) == 0: + new_genes_presence[np.random.choice(ind_alcohol)] = 1 + # up to here, we respect the constraints (assuming both parents do). + candidate_genes = np.array(unique_ingredients_one + unique_ingredients_two) + candidate_quantities = np.array([self.genes_quantity[i] for i in unique_ingredients_one] + [other_parent.genes_quantity[i] for i in unique_ingredients_two]) + indexes = np.arange(len(candidate_genes)) + np.random.shuffle(indexes) + candidate_genes = candidate_genes[indexes] + candidate_quantities = candidate_quantities[indexes] + time_dict[' crossover prepare selection'] = [time.time() - init_time] + init_time = time.time() + # now let's try to add each of them while respecting the constraints + for i in range(len(indexes)): + if np.random.rand() < 0.5 or np.sum(new_genes_presence) < self.min_ingredients: # only try to add one every two ingredient + ing_id = candidate_genes[i] + q = candidate_quantities[i] + new_genes_presence[ing_id] = 1 + new_genes_quantity[ing_id] = q + if np.sum(new_genes_presence) == self.max_ingredients: + break + time_dict[' crossover do selection'] = [time.time() - init_time] + init_time = time.time() + # create new child + child = IndividualCocktail(pop_params=self.pop_params, target_affective_cluster=self.target_affective_cluster, target=self.target, + genes_presence=new_genes_presence.copy(), genes_quantity=new_genes_quantity.copy(), compute_perf=False) + time_dict[' crossover create child'] = [time.time() - init_time] + init_time = time.time() + this_time_dict = child.mutate() + time_dict = self.update_time_dict(time_dict, this_time_dict) + time_dict[' crossover child mutation'] = [time.time() - init_time] + init_time = time.time() + return child, time_dict + + def mutate(self): + # self.print_recipe() + time_dict = dict() + # remove an ingredient + init_time = time.time() + present_ids = set(np.argwhere(self.genes_presence==1).flatten()) + + if np.random.rand() < self.mutation_params['p_remove_ing']: + if self.get_ing_count() > self.min_ingredients: + candidate_ings = self.get_present_ing() + if self.count_in_genes(present_ids, 'alcohol') == 1: # make sure we keep at least one liquor + candidate_ings = np.array(sorted(set(candidate_ings) - set(ind_alcohol))) + index_to_remove = np.random.choice(candidate_ings) + self.genes_presence[index_to_remove] = 0 + time_dict[' mutation remove ing'] = [time.time() - init_time] + init_time = time.time() + # add an ingredient + if np.random.rand() < self.mutation_params['p_add_ing']: + if self.get_ing_count() < self.max_ingredients: + candidate_ings = self.get_candidate_ingredients_ids(self.genes_presence.copy()) + index_to_add = np.random.choice(candidate_ings, p=density_ingredients[candidate_ings] / np.sum(density_ingredients[candidate_ings])) + self.genes_presence[index_to_add] = 1 + time_dict[' mutation add ing'] = [time.time() - init_time] + + init_time = time.time() + # replace ings by others from the same family + if np.random.rand() < self.mutation_params['p_switch_ing']: + i = np.random.choice(self.get_present_ing()) + ing_str = ingredient_list[i] + if ing_str not in ['sparkling wine', 'orange juice']: + if ing_str in bubble_ingredients: + candidates_ids = np.array(sorted(self.ing_ids_per_cat['bubbles'] - set([i]))) + new_bubble = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids])) + self.genes_presence[i] = 0 + self.genes_presence[new_bubble] = 1 + self.genes_quantity[new_bubble] = self.genes_quantity[i] # copy quantity + categories = ['acid', 'bitters', 'juice', 'liqueur', 'liquor', 'sweeteners', 'vermouth'] + for cat in categories: + if ing_str in ingredients_per_type[cat]: + present_ings = self.get_present_ing() + candidates_ids = np.array(sorted(self.ing_ids_per_cat[cat] - set([i]) - set(present_ings))) + if len(candidates_ids) > 0: + replacing_ing = np.random.choice(candidates_ids, p=density_ingredients[candidates_ids] / np.sum(density_ingredients[candidates_ids])) + self.genes_presence[i] = 0 + self.genes_presence[replacing_ing] = 1 + self.genes_quantity[replacing_ing] = self.genes_quantity[i] # copy quantity + break + time_dict[' mutation switch ing'] = [time.time() - init_time] + init_time = time.time() + # add noise on ing quantity + for i in self.get_present_ing(): + if np.random.rand() < self.mutation_params['p_change_q']: + self.genes_quantity[i] += np.random.randn() * self.mutation_params['delta_change_q'] + self.genes_quantity = np.clip(self.genes_quantity, 0, 1) + time_dict[' mutation change quantity'] = [time.time() - init_time] + + init_time = time.time() + self.compute_cocktail_rep() + time_dict[' mutation compute cocktail rep'] = [time.time() - init_time] + init_time = time.time() + # self.make_recipe_fit_the_glass() + time_dict[' mutation check glass fit'] = [time.time() - init_time] + init_time = time.time() + self.compute_perf() + time_dict[' mutation compute perf'] = [time.time() - init_time] + init_time = time.time() + stop = 1 + return time_dict + + + def update_time_dict(self, main_dict, new_dict): + for k in new_dict.keys(): + if k in main_dict.keys(): + main_dict[k].append(np.sum(new_dict[k])) + else: + main_dict[k] = [np.sum(new_dict[k])] + return main_dict + + # # # # # # # # # # # # # # # # # # # # # # # # + # Get recipe and print + # # # # # # # # # # # # # # # # # # # # # # # # + + def get_recipe(self, unit='mL', name=None): + ing_quantities = self.get_ingredient_quantities() + ingredients, quantities = [], [] + for i_ing, q_ing in enumerate(ing_quantities): + if q_ing > 0.8: + ingredients.append(ingredient_list[i_ing]) + quantities.append(round(q_ing)) + recipe_str = format_ingredients(ingredients, quantities) + recipe_str_readable = print_recipe(unit=unit, ingredient_str=recipe_str, name=name, to_print=False) + return ingredients, quantities, recipe_str, recipe_str_readable + + def get_instructions(self): + ing_quantities = self.get_ingredient_quantities() + ingredients, quantities = [], [] + for i_ing, q_ing in enumerate(ing_quantities): + if q_ing > 0.8: + ingredients.append(ingredient_list[i_ing]) + quantities.append(round(q_ing)) + str_out = 'Instructions:\n ' + + if 'mint' in ingredients: + i_mint = ingredients.index('mint') + n_leaves = quantities[i_mint] + str_out += f'Add {n_leaves} mint leaves to a shaker, followed by an ice cube.\n Muddle the mint and ice together with a muddler.\n ' + bubbles = ['sparkling wine', 'tonic', 'soda', 'ginger beer'] + other_ings = [ing for ing in ingredients if ing not in ['egg', 'angostura', 'orange bitters'] + bubbles] + + if self.prep_type == 'built': + str_out += 'Add a large ice cube in the glass.\n ' + # add ingredients to pour + str_out += 'Pour' + for i, ing in enumerate(other_ings): + if i == len(other_ings) - 2: + str_out += f' {ing} and' + elif i == len(other_ings) - 1: + str_out += f' {ing}' + else: + str_out += f' {ing},' + + if self.prep_type in ['built'] and 'mint' not in ingredients: + str_out += ' into the glass.\n ' + else: + str_out += ' into the shaker.\n ' + + if self.prep_type == 'egg_shaken' and 'egg' in ingredients: + str_out += 'Add the egg white.\n Dry-shake for 15s (without ice), then fill with ice and shake for another 15s.\n Serve into the glass through a strainer.\n ' + elif 'shaken' in self.prep_type: + str_out += 'Fill with ice and shake for 15s.\n Serve into the glass through a strainer.\n ' + elif self.prep_type == 'stirred': + str_out += 'Add ice and stir the cocktail with a spoon for 15s.\n Serve into the glass through a strainer.\n ' + elif self.prep_type == 'built': + str_out += 'Stir two turns with a spoon.\n ' + + bubble_ing = [ing for ing in ingredients if ing in bubbles] + if len(bubble_ing) > 0: + str_out += f'Top up with ' + for ing in bubble_ing: + str_out += f'{ing}, ' + str_out = str_out[:-2] + '.\n ' + bitter_ing = [ing for ing in ingredients if ing in ['angostura', 'orange bitters']] + if len(bitter_ing) > 0: + if len(bitter_ing) == 1: + q = quantities[ingredients.index(bitter_ing[0])] + n_dashes = max(1, int(q / 0.6)) + str_out += f'Add {n_dashes} dash' + if n_dashes > 1: + str_out += 'es' + str_out += f' of {bitter_ing[0]}.\n ' + elif len(bitter_ing) == 2: + q = quantities[ingredients.index(bitter_ing[0])] + n_dashes = max(1, int(q / 0.6)) + str_out += f'Add {n_dashes} dash' + if n_dashes > 1: + str_out += 'es' + str_out += f' of {bitter_ing[0]} and ' + q = quantities[ingredients.index(bitter_ing[1])] + n_dashes = max(1, int(q / 0.6)) + str_out += f'{n_dashes} dash' + if n_dashes > 1: + str_out += 'es' + str_out += f' of {bitter_ing[1]}.\n ' + str_out += 'Enjoy!' + return str_out + + def print_recipe(self, name=None): + print(self.get_recipe(name)[3]) \ No newline at end of file diff --git a/src/cocktails/utilities/cocktail_generation_utilities/population.py b/src/cocktails/utilities/cocktail_generation_utilities/population.py new file mode 100644 index 0000000000000000000000000000000000000000..2a004f3146611efa8f9579c4e928c8dd335f7c9b --- /dev/null +++ b/src/cocktails/utilities/cocktail_generation_utilities/population.py @@ -0,0 +1,213 @@ +from src.cocktails.utilities.cocktail_generation_utilities.individual import * +from sklearn.neighbors import NearestNeighbors +import time +import pickle +from src.cocktails.config import COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA + +class Population: + def __init__(self, target, pop_params, target_affective_cluster=None, known_target_dict=None): + self.pop_params = pop_params + self.pop_size = pop_params['pop_size'] + self.nb_elite = pop_params['nb_elites'] + self.nb_generations = pop_params['nb_generations'] + self.target = target + self.mutation_params = pop_params['mutation_params'] + self.dist = pop_params['dist'] + self.n_neighbors = pop_params['n_neighbors'] + self.known_target_dict = known_target_dict + + + with open(COCKTAIL_NN_PATH, 'rb') as f: + data = pickle.load(f) + self.nn_model_cocktail = data['nn_model'] + self.dim_rep_cocktail = data['dim_rep_cocktail'] + self.n_cocktails = data['n_cocktails'] + self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA) + + if target_affective_cluster is None: + cocktail_rep_affective = get_normalized_affective_cocktail_rep_from_normalized_cocktail_rep(target) + self.target_affective_cluster = cocktail2affective_cluster(cocktail_rep_affective)[0] + else: + self.target_affective_cluster = target_affective_cluster + + self.pop_elite = [] + self.pop = [] + self.add_target_individual() # create a target individual (not in pop) + self.add_nearest_neighbors_in_pop() # add nearest neighbor from dataset into the population + + # fill population + while self.get_pop_size() < self.pop_size: + self.add_individual() + while len(self.pop_elite) < self.nb_elite: + self.pop_elite.append(IndividualCocktail(pop_params=self.pop_params, + target=self.target.copy(), + target_affective_cluster=self.target_affective_cluster)) + self.update_elite_and_get_next_pop() + + def add_target_individual(self): + if self.known_target_dict is not None: + genes_presence, genes_quantity = self.get_q_rep(*extract_ingredients(self.known_target_dict['ing_str'])) + self.target_individual = IndividualCocktail(pop_params=self.pop_params, + target=self.target.copy(), + known_target_dict=self.known_target_dict, + target_affective_cluster=self.target_affective_cluster, + genes_presence=genes_presence, + genes_quantity=genes_quantity + ) + else: + self.target_individual = None + + + def add_nearest_neighbors_in_pop(self): + # add nearest neighbor from dataset into the population + if self.n_neighbors > 0: + dists, indexes = self.nn_model_cocktail.kneighbors(self.target.reshape(1, -1)) + dists, indexes = dists.flatten(), indexes.flatten() + first = 1 if dists[0] == 0 else 0 # avoid taking the target when testing with known targets from the dataset + indexes = indexes[first:first + self.n_neighbors] + self.ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes] + recipes = [extract_ingredients(ing_str) for ing_str in self.ing_strs] + for r in recipes: + genes_presence, genes_quantity = self.get_q_rep(r[0], r[1]) + genes_presence[-1] = 0 # remove water ingredient + self.add_individual(genes_presence=genes_presence.copy(), genes_quantity=genes_quantity.copy()) + self.nn_recipes = [ind.get_recipe()[3] for ind in self.pop] + self.nn_scores = [ind.perf for ind in self.pop] + else: + self.ing_strs = None + + def add_individual(self, genes_presence=None, genes_quantity=None): + self.pop.append(IndividualCocktail(pop_params=self.pop_params, + target=self.target.copy(), + target_affective_cluster=self.target_affective_cluster, + genes_presence=genes_presence, + genes_quantity=genes_quantity)) + + def get_elite_perf(self): + return np.array([e.perf for e in self.pop_elite]) + + def get_pop_perf(self): + return np.array([ind.perf for ind in self.pop]) + + + def update_elite_and_get_next_pop(self): + time_dict = dict() + init_time = time.time() + elite_perfs = self.get_elite_perf() + pop_perfs = self.get_pop_perf() + all_perfs = np.concatenate([elite_perfs, pop_perfs]) + temp_list = self.pop_elite + self.pop + time_dict[' get pop perfs'] = [time.time() - init_time] + init_time = time.time() + # update elite population with new bests + indexes_sorted = np.flip(np.argsort(all_perfs)) + new_pop_elite = [IndividualCocktail(pop_params=self.pop_params, + target=self.target.copy(), + target_affective_cluster=self.target_affective_cluster, + genes_presence=temp_list[i_new_e].genes_presence.copy(), + genes_quantity=temp_list[i_new_e].genes_quantity.copy()) for i_new_e in indexes_sorted[:self.nb_elite]] + time_dict[' recreate elite individuals'] = [time.time() - init_time] + init_time = time.time() + # select parents + rank_perfs = np.flip(np.arange(len(temp_list))) + sampling_probs = rank_perfs / np.sum(rank_perfs) + if self.mutation_params['asexual_rep'] and not self.mutation_params['crossover']: + new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size) + self.pop = [temp_list[i].get_child() for i in new_pop_indexes] + elif self.mutation_params['crossover'] and not self.mutation_params['asexual_rep']: + self.pop = [] + while len(self.pop) < self.pop_size: + parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False) + self.pop.append(temp_list[parents[0]].get_child_with(temp_list[parents[1]])) + elif self.mutation_params['crossover'] and self.mutation_params['asexual_rep']: + new_pop_indexes = np.random.choice(indexes_sorted, p=sampling_probs, size=self.pop_size//2) + time_dict[' choose asexual parent indexes'] = [time.time() - init_time] + init_time = time.time() + self.pop = [] + for i in new_pop_indexes: + child, this_time_dict = temp_list[i].get_child() + self.pop.append(child) + time_dict = self.update_time_dict(time_dict, this_time_dict) + time_dict[' get asexual children'] = [time.time() - init_time] + init_time = time.time() + while len(self.pop) < self.pop_size: + parents = np.random.choice(indexes_sorted, p=sampling_probs, size=2, replace=False) + child, this_time_dict = temp_list[parents[0]].get_child_with(temp_list[parents[1]]) + self.pop.append(child) + time_dict = self.update_time_dict(time_dict, this_time_dict) + time_dict[' get sexual children'] = [time.time() - init_time] + self.pop_elite = new_pop_elite + return time_dict + + def get_pop_size(self): + return len(self.pop) + + def get_q_rep(self, ingredients, quantities): + ingredient_q_rep = np.zeros([len(ingredient_list)]) + genes_presence = np.zeros([len(ingredient_list)]) + for ing, q in zip(ingredients, quantities): + ingredient_q_rep[ingredient_list.index(ing)] = q + genes_presence[ingredient_list.index(ing)] = 1 + return genes_presence.copy(), normalize_ingredient_q_rep(ingredient_q_rep) + + def get_best_score(self, affective_cluster_check=False): + elite_perfs = self.get_elite_perf() + pop_perfs = self.get_pop_perf() + all_perfs = np.concatenate([elite_perfs, pop_perfs]) + temp_list = self.pop_elite + self.pop + if affective_cluster_check: + indexes = np.array([i for i in range(len(temp_list)) if temp_list[i].does_affective_cluster_match()]) + if indexes.size > 0: + temp_list = np.array(temp_list)[indexes] + all_perfs = all_perfs[indexes] + indexes_best = np.flip(np.argsort(all_perfs)) + return np.array(all_perfs)[indexes_best], np.array(temp_list)[indexes_best] + + def update_time_dict(self, main_dict, new_dict): + for k in new_dict.keys(): + if k in main_dict.keys(): + main_dict[k].append(np.sum(new_dict[k])) + else: + main_dict[k] = [np.sum(new_dict[k])] + return main_dict + + def run_one_generation(self, verbose=True, affective_cluster_check=False): + time_dict = dict() + init_time = time.time() + this_time_dict = self.update_elite_and_get_next_pop() + time_dict['update_elite_and_pop'] = [time.time() - init_time] + time_dict = self.update_time_dict(time_dict, this_time_dict) + init_time = time.time() + best_perfs, best_individuals = self.get_best_score(affective_cluster_check) + time_dict['get best scores'] = [time.time() - init_time] + return best_perfs[0], time_dict + + def run_evolution(self, verbose=False, print_every=10, affective_cluster_check=False, level=0): + best_score = -np.inf + time_dict = dict() + init_time = time.time() + for i in range(self.nb_generations): + best_score, this_time_dict = self.run_one_generation(verbose, affective_cluster_check=affective_cluster_check) + time_dict = self.update_time_dict(time_dict, this_time_dict) + if verbose and (i+1) % print_every == 0: + print(' ' * level + f'Gen #{i+1} - Current best perf: {best_score:.2f}, time: {time.time() - init_time:.4f}') + init_time = time.time() + # + # to_print = time_dict.copy() + # keys = sorted(to_print.keys()) + # values = [] + # for k in keys: + # to_print[k] = np.sum(to_print[k]) + # values.append(to_print[k]) + # sorted_inds = np.flip(np.argsort(values)) + # for i in sorted_inds: + # print(f'{keys[i]}: {values[i]:.4f}') + if verbose: print(' ' * level + f'Evolution over, best perf: {best_score:.2f}') + return self.get_best_score() + + def print_results(self, n=3): + best_scores, best_ind = self.get_best_score() + for i in range(n): + best_ind[i].print_recipe(f'Candidate #{i+1}, Score: {best_scores[i]:.2f}') + + diff --git a/src/cocktails/utilities/cocktail_utilities.py b/src/cocktails/utilities/cocktail_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2264fe67a3b7e75447c27817d42e7c135ac8b4 --- /dev/null +++ b/src/cocktails/utilities/cocktail_utilities.py @@ -0,0 +1,220 @@ +import numpy as np +from src.cocktails.utilities.ingredients_utilities import ingredient2ingredient_id, ingredient_profiles, ingredients_per_type, ingredient_list, find_ingredient_from_str +from src.cocktails.utilities.cocktail_category_detection_utilities import * +import time + +# representation_keys = ['pH', 'sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb', +# 'complex', 'spicy', 'strong', 'oaky', 'fizzy', 'colorful', 'eggy'] +representation_keys = ['sour', 'sweet', 'booze', 'bitter', 'fruit', 'herb', + 'complex', 'spicy', 'oaky', 'fizzy', 'colorful', 'eggy'] +representation_keys_linear = list(set(representation_keys) - set(['pH', 'complex'])) + +ing_reps = np.array([[ingredient_profiles[k][ing_id] for ing_id in ingredient2ingredient_id.values()] for k in representation_keys]).transpose() + + +def compute_cocktail_representation(profile, ingredients, quantities): + # computes representation of a cocktail from the recipe (ingredients, quantities) and volume + n = len(ingredients) + assert n == len(quantities) + quantities = np.array(quantities) + + weights = quantities / np.sum(quantities) + rep = dict() + + ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients]) + # compute features as linear combination of ingredient features + for k in representation_keys_linear: + k_ing = np.array([ingredient_profiles[k][ing_id] for ing_id in ing_ids]) + rep[k] = np.dot(weights, k_ing) + + # for ph + # ph = - log10 x + phs = np.array([ingredient_profiles['pH'][ing_id] for ing_id in ing_ids]) + concentrations = 10 ** (- phs) + mix_c = np.dot(weights, concentrations) + + rep['pH'] = - np.log10(mix_c) + + rep['complex'] = np.mean([ingredient_profiles['complex'][ing_id] for ing_id in ing_ids]) + len(ing_ids) + + # compute profile after dilution + volume_ratio = profile['mix volume'] / profile['end volume'] + for k in representation_keys: + rep['end ' + k] = rep[k] * volume_ratio + concentration = 10 ** (-rep['pH']) + end_concentration = concentration * volume_ratio + rep['end pH'] = - np.log10(end_concentration) + return rep + +def get_alcohol_profile(ingredients, quantities): + ingredients = ingredients.copy() + quantities = quantities.copy() + assert len(ingredients) == len(quantities) + if 'mint' in ingredients: + mint_ind = ingredients.index('mint') + ingredients.pop(mint_ind) + quantities.pop(mint_ind) + alcohol = [] + volume_mix = np.sum(quantities) + weights = quantities / volume_mix + assert np.abs(np.sum(weights) - 1) < 1e-4 + ingredients_list = [ing.lower() for ing in ingredient_list] + for ing, q in zip(ingredients, quantities): + id = ingredients_list.index(ing) + alcohol.append(ingredient_profiles['ethanol'][id]) + alcohol = np.dot(alcohol, weights) + return alcohol, volume_mix + +def get_mix_profile(ingredients, quantities): + ingredients = ingredients.copy() + quantities = quantities.copy() + assert len(ingredients) == len(quantities) + if 'mint' in ingredients: + mint_ind = ingredients.index('mint') + ingredients.pop(mint_ind) + quantities.pop(mint_ind) + alcohol, sugar, acid = [], [], [] + volume_mix = np.sum(quantities) + weights = quantities / volume_mix + assert np.abs(np.sum(weights) - 1) < 1e-4 + ingredients_list = [ing.lower() for ing in ingredient_list] + for ing, q in zip(ingredients, quantities): + id = ingredients_list.index(ing) + sugar.append(ingredient_profiles['sugar'][id]) + alcohol.append(ingredient_profiles['ethanol'][id]) + acid.append(ingredient_profiles['acid'][id]) + sugar = np.dot(sugar, weights) + acid = np.dot(acid, weights) + alcohol = np.dot(alcohol, weights) + return alcohol, sugar, acid + + +def extract_preparation_type(instructions, recipe): + flag = False + instructions = instructions.lower() + egg_in_recipe = any([find_ingredient_from_str(ing_str)[1]=='egg' for ing_str in recipe[1]]) + if 'shake' in instructions: + if egg_in_recipe: + prep_type = 'egg_shaken' + else: + prep_type = 'shaken' + elif 'stir' in instructions: + prep_type = 'stirred' + elif 'blend' in instructions: + prep_type = 'blended' + elif any([w in instructions for w in ['build', 'mix', 'pour', 'combine', 'place']]): + prep_type = 'built' + else: + prep_type = 'built' + if egg_in_recipe and 'shaken' not in prep_type: + stop = 1 + return flag, prep_type + +def get_dilution_ratio(category, alcohol): + # formulas from the Liquid Intelligence book + # The formula for built was invented + if category == 'stirred': + return -1.21 * alcohol**2 + 1.246 * alcohol + 0.145 + elif category in ['shaken', 'egg_shaken']: + return -1.567 * alcohol**2 + 1.742 * alcohol + 0.203 + elif category == 'built': + return (-1.21 * alcohol**2 + 1.246 * alcohol + 0.145) /2 + else: + return 1 + +def get_cocktail_rep(category, ingredients, quantities, keys): + ingredients = ingredients.copy() + quantities = quantities.copy() + assert len(ingredients) == len(quantities) + + volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint']) + + # compute alcohol content without mint ingredient + ingredients2 = [ing for ing in ingredients if ing != 'mint'] + quantities2 = [q for ing, q in zip(ingredients, quantities) if ing != 'mint'] + weights2 = quantities2 / np.sum(quantities2) + assert np.abs(np.sum(weights2) - 1) < 1e-4 + ing_ids2 = np.array([ingredient2ingredient_id[ing] for ing in ingredients2]) + alcohol = np.array([ingredient_profiles['ethanol'][ing_id] for ing_id in ing_ids2]) + alcohol = np.dot(alcohol, weights2) + dilution_ratio = get_dilution_ratio(category, alcohol) + end_volume = volume_mix + volume_mix * dilution_ratio + volume_ratio = volume_mix / end_volume + end_alcohol = alcohol * volume_ratio + + # computes representation of a cocktail from the recipe (ingredients, quantities) and volume + weights = quantities / np.sum(quantities) + assert np.abs(np.sum(weights) - 1) < 1e-4 + ing_ids = np.array([ingredient2ingredient_id[ing] for ing in ingredients]) + reps = ing_reps[ing_ids] + cocktail_rep = np.dot(weights, reps) + i_complex = keys.index('end complex') + cocktail_rep[i_complex] = np.mean(reps[:, i_complex]) + len(ing_ids) # complexity increases with number of ingredients + + # compute profile after dilution + cocktail_rep = cocktail_rep * volume_ratio + cocktail_rep = np.concatenate([[end_volume], cocktail_rep]) + return cocktail_rep, end_volume, end_alcohol + +def get_profile(category, ingredients, quantities): + + volume_mix = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint']) + alcohol, sugar, acid = get_mix_profile(ingredients, quantities) + dilution_ratio = get_dilution_ratio(category, alcohol) + end_volume = volume_mix + volume_mix * dilution_ratio + volume_ratio = volume_mix / end_volume + profile = {'mix volume': volume_mix, + 'mix alcohol': alcohol, + 'mix sugar': sugar, + 'mix acid': acid, + 'dilution ratio': dilution_ratio, + 'end volume': end_volume, + 'end alcohol': alcohol * volume_ratio, + 'end sugar': sugar * volume_ratio, + 'end acid': acid * volume_ratio} + cocktail_rep = compute_cocktail_representation(profile, ingredients, quantities) + profile.update(cocktail_rep) + return profile + +profile_keys = ['mix volume', 'end volume', + 'dilution ratio', + 'mix alcohol', 'end alcohol', + 'mix sugar', 'end sugar', + 'mix acid', 'end acid'] \ + + representation_keys \ + + ['end ' + k for k in representation_keys] + +def update_profile_in_datapoint(datapoint, category, ingredients, quantities): + profile = get_profile(category, ingredients, quantities) + for k in profile_keys: + datapoint[k] = profile[k] + return datapoint + +# define representation keys +def get_bunch_of_rep_keys(): + dict_rep_keys = dict() + # all + rep_keys = profile_keys + dict_rep_keys['all'] = rep_keys + # only_end + rep_keys = [k for k in profile_keys if 'end' in k ] + dict_rep_keys['only_end'] = rep_keys + # except_end + rep_keys = [k for k in profile_keys if 'end' not in k ] + dict_rep_keys['except_end'] = rep_keys + # custom + to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong'] + rep_keys = [k for k in profile_keys if 'end' in k ] + for k in to_remove: + if k in rep_keys: + rep_keys.remove(k) + dict_rep_keys['custom'] = rep_keys + # custom restricted + to_remove = ['end alcohol', 'end sugar', 'end acid', 'end pH', 'end strong', 'end spicy', 'end oaky'] + rep_keys = [k for k in profile_keys if 'end' in k ] + for k in to_remove: + if k in rep_keys: + rep_keys.remove(k) + dict_rep_keys['restricted'] = rep_keys + dict_rep_keys['affective'] = ['end booze', 'end sweet', 'end sour', 'end fizzy', 'end complex', 'end bitter', 'end spicy', 'end colorful'] + return dict_rep_keys \ No newline at end of file diff --git a/src/cocktails/utilities/glass_and_volume_utilities.py b/src/cocktails/utilities/glass_and_volume_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..f84927fe5cc74c8198752b858661afae2805a576 --- /dev/null +++ b/src/cocktails/utilities/glass_and_volume_utilities.py @@ -0,0 +1,42 @@ + + +glass_conversion = {'coupe':'coupe', + 'martini': 'martini', + 'collins': 'collins', + 'oldfashion': 'oldfashion', + 'Coupe glass': 'coupe', + 'Old-fashioned glass': 'oldfashion', + 'Martini glass': 'martini', + 'Nick & Nora glass': 'coupe', + 'Julep tin': 'oldfashion', + 'Collins or Pineapple shell glass': 'collins', + 'Collins glass': 'collins', + 'Rocks glass': 'oldfashion', + 'Highball (max 10oz/300ml)': 'collins', + 'Wine glass': 'coupe', + 'Flute glass': 'coupe', + 'Double old-fashioned': 'oldfashion', + 'Copa glass': 'coupe', + 'Toddy glass': 'oldfashion', + 'Sling glass': 'collins', + 'Goblet glass': 'oldfashion', + 'Fizz or Highball (8oz to 10oz)': 'collins', + 'Copper mug or Collins glass': 'collins', + 'Tiki mug or collins': 'collins', + 'Snifter glass': 'oldfashion', + 'Coconut shell or Collins glass': 'collins', + 'Martini (large 10oz) glass': 'martini', + 'Hurricane glass': 'collins', + 'Absinthe glass or old-fashioned glass': 'oldfashion' + } +glass_volume = dict(coupe = 200, + collins=350, + martini=200, + oldfashion=320) +assert set(glass_conversion.values()) == set(glass_volume.keys()) + +volume_ranges = dict(stirred=(90, 97), + built=(70, 75), + shaken=(98, 112), + egg_shaken=(130, 143), + carbonated=(150, 150)) \ No newline at end of file diff --git a/src/cocktails/utilities/ingredients_utilities.py b/src/cocktails/utilities/ingredients_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..a1142192a7c3eb1117e8145b75a18552cd3a152c --- /dev/null +++ b/src/cocktails/utilities/ingredients_utilities.py @@ -0,0 +1,209 @@ +# This script loads the list and profiles of our ingredients selection. +# It defines rules to recognize ingredients from the list in recipes and the function to extract that information from ingredient strings. + +import pandas as pd +from src.cocktails.config import INGREDIENTS_LIST_PATH, COCKTAILS_CSV_DATA +import numpy as np + +ingredient_profiles = pd.read_csv(INGREDIENTS_LIST_PATH) +ingredient_list = [ing.lower() for ing in ingredient_profiles['ingredient']] +n_ingredients = len(ingredient_list) +ingredient2ingredient_id = dict(zip(ingredient_list, range(n_ingredients))) + +ingredients_types = sorted(set(ingredient_profiles['type'])) +# for each type, get all ingredients +ing_per_type = [[ing for ing in ingredient_list if ingredient_profiles['type'][ingredient_list.index(ing)] == type] for type in ingredients_types] +ingredients_per_type = dict(zip(ingredients_types, ing_per_type)) + +bubble_ingredients = ['soda', 'ginger beer', 'tonic', 'sparkling wine'] +# rules to recognize ingredients in recipes. +# in [] are separate rules with an OR relation: only one needs to be satisfied +# within [], rules apply with and AND relation: all rules need to be satisfied. +# ~ indicates that the following expression must NOT appear +# simple expression indicate that the expression MUST appear. +ingredient_search = {#'salt': ['salt'], + 'lime juice': [['lime', '~soda', '~lemonade', '~cordial']], + 'lemon juice': [['lemon', '~soda', '~lemonade']], + 'angostura': [['angostura', '~orange'], + ['bitter', '~campari', '~orange', '~red', '~italian', '~fernet']], + 'orange bitters': [['orange', 'bitter', '~bittersweet']], + 'orange juice': [['orange', '~bitter', '~jam', '~marmalade', '~liqueur', '~water'], + ['orange', 'squeeze']], + 'pineapple juice': [['pineapple']], + # 'apple juice': [['apple', 'juice', '~pine']], + 'cranberry juice': [['cranberry', 'juice']], + 'cointreau': ['cointreau', 'triple sec', 'grand marnier', 'curaçao', 'curacao'], + 'luxardo maraschino': ['luxardo', 'maraschino', 'kirsch'], + 'amaretto': ['amaretto'], + 'benedictine': ['benedictine', 'bénédictine', 'bénedictine', 'benédictine'], + 'campari': ['campari', ['italian', 'red', 'bitter'], 'aperol', 'bittersweet', 'aperitivo', 'orange-red'], + # 'campari': ['campari', ['italian', 'red', 'bitter']], + # 'crème de violette': [['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']], + # 'aperol': ['aperol', 'bittersweet', 'aperitivo', 'orange-red'], + 'green chartreuse': ['chartreuse'], + 'black raspberry liqueur': [['cassis', 'liqueur'], + ['black raspberry', 'liqueur'], + ['raspberry', 'liqueur'], + ['strawberry', 'liqueur'], + ['blackberry', 'liqueur'], + ['violette', 'crème'], ['crême', 'violette'], ['liqueur', 'violette']], + # 'simple syrup': [], + # 'drambuie': ['drambuie'], + # 'fernet branca': ['fernet', 'branca'], + 'gin': [['gin', '~sloe', '~ginger']], + 'vodka': ['vodka'], + 'cuban rum': [['rum', 'puerto rican'], ['light', 'rum'], ['white', 'rum'], ['rum', 'havana', '~7'], ['rum', 'bacardi']], + 'cognac': [['cognac', '~grand marnier', '~cointreau', '~orange']], + # 'bourbon': [['bourbon', '~liqueur']], + # 'tequila': ['tequila', 'pisco'], + # 'tequila': ['tequila'], + 'scotch': ['scotch'], + 'dark rum': [['rum', 'age', '~bacardi', '~havana'], + ['rum', 'dark', '~bacardi', '~havana'], + ['rum', 'old', '~bacardi', '~havana'], + ['rum', 'old', '7'], + ['rum', 'havana', '7'], + ['havana', 'rum', 'especial']], + 'absinthe': ['absinthe'], + 'rye whiskey': ['rye', ['bourbon', '~liqueur']], + # 'rye whiskey': ['rye'], + 'apricot brandy': [['apricot', 'brandy']], + # 'pisco': ['pisco'], + # 'cachaça': ['cachaça', 'cachaca'], + 'egg': [['egg', 'white', '~yolk', '~whole']], + 'soda': [['soda', 'water', '~lemon', '~lime']], + 'mint': ['mint'], + 'sparkling wine': ['sparkling wine', 'prosecco', 'champagne'], + 'ginger beer': [['ginger', 'beer'], ['ginger', 'ale']], + 'tonic': [['tonic'], ['7up'], ['sprite']], + # 'espresso': ['espresso', 'expresso', ['café', '~liqueur', '~cream'], + # ['cafe', '~liqueur', '~cream'], + # ['coffee', '~liqueur', '~cream']], + # 'southern comfort': ['southern comfort'], + # 'cola': ['cola', 'coke', 'pepsi'], + 'double syrup': [['sugar','~raspberry'], ['simple', 'syrup'], ['double', 'syrup']], + # 'grenadine': ['grenadine', ['pomegranate', 'syrup']], + 'grenadine': ['grenadine', ['pomegranate', 'syrup'], ['raspberry', 'syrup', '~black']], + 'honey syrup': ['honey', ['maple', 'syrup']], + # 'raspberry syrup': [['raspberry', 'syrup', '~black']], + 'dry vermouth': [['vermouth', 'dry'], ['vermouth', 'white'], ['vermouth', 'french'], 'lillet'], + 'sweet vermouth': [['vermouth', 'sweet'], ['vermouth', 'red'], ['vermouth', 'italian']], + # 'lillet blanc': ['lillet'], + 'water': [['water', '~sugar', '~coconut', '~soda', '~tonic', '~honey', '~orange', '~melon']] + } +# check that there is a rule for all ingredients in the list +assert sorted(ingredient_list) == sorted(ingredient_search.keys()), 'ing search dict keys do not match ingredient list' + +def get_ingredients_info(): + data = pd.read_csv(COCKTAILS_CSV_DATA) + max_ingredients, ingredient_set, liquor_set, liqueur_set, vermouth_set = get_max_n_ingredients(data) + ingredient_list = sorted(ingredient_set) + alcohol = sorted(liquor_set.union(liqueur_set).union(vermouth_set).union(set(['sparkling wine']))) + ind_alcohol = [i for i in range(len(ingredient_list)) if ingredient_list[i] in alcohol] + return max_ingredients, ingredient_list, ind_alcohol + +def get_max_n_ingredients(data): + max_count = 0 + ingredient_set = set() + alcohol_set = set() + liqueur_set = set() + vermouth_set = set() + ing_str = np.array(data['ingredients_str']) + for i in range(len(data['names'])): + ingredients, quantities = extract_ingredients(ing_str[i]) + max_count = max(max_count, len(ingredients)) + for ing in ingredients: + ingredient_set.add(ing) + if ing in ingredients_per_type['liquor']: + alcohol_set.add(ing) + if ing in ingredients_per_type['liqueur']: + liqueur_set.add(ing) + if ing in ingredients_per_type['vermouth']: + vermouth_set.add(ing) + return max_count, ingredient_set, alcohol_set, liqueur_set, vermouth_set + +def find_ingredient_from_str(ing_str): + # function that assigns an ingredient string to one of the ingredient if possible, following the rules defined above. + # return a flag and the ingredient string. When flag is false, the ingredient has not been found and the cocktail is rejected. + ing_str = ing_str.lower() + flags = [] + for k in ingredient_list: + or_flags = [] # get flag for each of several conditions + for i_p, pattern in enumerate(ingredient_search[k]): + or_flags.append(True) + if isinstance(pattern, str): + if pattern[0] == '~' and pattern[1:] in ing_str: + or_flags[-1] = False + elif pattern[0] != '~' and pattern not in ing_str: + or_flags[-1] = False + elif isinstance(pattern, list): + for element in pattern: + if element[0] == '~': + or_flags[-1] = or_flags[-1] and not element[1:] in ing_str + else: + or_flags[-1] = or_flags[-1] and element in ing_str + else: + raise ValueError + flags.append(any(or_flags)) + if sum(flags) > 1: + print(ing_str) + for i_f, f in enumerate(flags): + if f: + print(ingredient_list[i_f]) + stop = 1 + return True, ingredient_list[flags.index(True)] + elif sum(flags) == 0: + # if 'grape' not in ing_str: + # print('\t\t Not found:', ing_str) + return True, None + else: + return False, ingredient_list[flags.index(True)] + +def get_cocktails_per_ingredient(ing_strs): + cocktails_per_ing = dict(zip(ingredient_list, [[] for _ in range(len(ingredient_list))])) + for i_ing, ing_str in enumerate(ing_strs): + ingredients, _ = extract_ingredients(ing_str) + for ing in ingredients: + cocktails_per_ing[ing].append(i_ing) + return cocktails_per_ing + +def extract_ingredients(ingredient_str): + # extract list of ingredients and quantities from an formatted ingredient string (reverse of format_ingredients) + ingredient_str = ingredient_str[1: -1] + words = ingredient_str.split(',') + ingredients = [] + quantities = [] + for i in range(len(words)//2): + ingredients.append(words[2 * i][1:]) + quantities.append(float(words[2 * i + 1][:-1])) + return ingredients, quantities + +def format_ingredients(ingredients, quantities): + # format an ingredient string from the lists of ingredients and quantities (reverse of extract_ingredients) + out = '[' + for ing, q in zip(ingredients, quantities): + if ing[-1] == ' ': + ingre = ing[:-1] + else: + ingre = ing + out += f'({ingre},{q}),' + out = out[:-1] + ']' + return out + + +def get_ingredient_count(data): + # get count of ingredients in the whole dataset + ingredient_counts = dict(zip(ingredient_list, [0] * len(ingredient_list))) + for i in range(len(data['names'])): + if data['to_keep'][i]: + ingredients, _ = extract_ingredients(data['ingredients_str'][i]) + for i in ingredients: + ingredient_counts[i] += 1 + return ingredient_counts + +def add_counts_to_ingredient_list(data): + # update the list of ingredients to add their count of occurence in dataset. + ingredient_counts = get_ingredient_count(data) + counts = [ingredient_counts[k] for k in ingredient_list] + ingredient_profiles['counts'] = counts + ingredient_profiles.to_csv(INGREDIENTS_LIST_PATH, index=False) \ No newline at end of file diff --git a/src/cocktails/utilities/other_scrubbing_utilities.py b/src/cocktails/utilities/other_scrubbing_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..eed580ca304ee3fbf1725e9b1166d432c5129068 --- /dev/null +++ b/src/cocktails/utilities/other_scrubbing_utilities.py @@ -0,0 +1,240 @@ +import numpy as np +import pickle +from src.cocktails.utilities.cocktail_utilities import get_profile, profile_keys +from src.cocktails.utilities.ingredients_utilities import extract_ingredients, ingredient_list, ingredient_profiles +from src.cocktails.utilities.glass_and_volume_utilities import glass_volume, volume_ranges + +one_dash = 1 +one_splash = 6 +one_tablespoon = 15 +one_barspoon = 5 +fill_rate = 0.8 +quantity_factors ={'ml':1, + 'cl':10, + 'splash':one_splash, + 'splashes':one_splash, + 'dash':one_dash, + 'dashes':one_dash, + 'spoon':one_barspoon, + 'spoons':one_barspoon, + 'tablespoon':one_tablespoon, + 'barspoons':one_barspoon, + 'barspoon':one_barspoon, + 'bar spoons': one_barspoon, + 'bar spoon': one_barspoon, + 'tablespoons':one_tablespoon, + 'teaspoon':5, + 'teaspoons':5, + 'drop':0.05, + 'drops':0.05} +quantitiy_keys = sorted(quantity_factors.keys()) +indexes_keys = np.flip(np.argsort([len(k) for k in quantitiy_keys])) +quantity_factors_keys = list(np.array(quantitiy_keys)[indexes_keys]) + +keys_to_track = ['names', 'urls', 'glass', 'garnish', 'recipe', 'how_to', 'review', 'taste_rep', 'valid'] +keys_to_add = ['category', 'subcategory', 'ingredients_str', 'ingredients', 'quantities', 'to_keep'] +keys_to_update = ['glass'] +keys_for_csv = ['names', 'category', 'subcategory', 'ingredients_str', 'urls', 'glass', 'garnish', 'how_to', 'review', 'taste_rep'] + profile_keys + +to_replace_q = {' fresh': ''} +to_replace_ing = {'maple syrup': 'honey syrup', + 'agave syrup': 'honey syrup', + 'basil': 'mint'} + +def print_recipe(unit='mL', ingredient_str=None, ingredients=None, quantities=None, name='', cat='', to_print=True): + str_out = '' + if ingredient_str is None: + assert len(ingredients) == len(quantities), 'provide either ingredient_str, or list ingredients and quantities' + else: + assert ingredients is None and quantities is None, 'provide either ingredient_str, or list ingredients and quantities' + ingredients, quantities = extract_ingredients(ingredient_str) + + str_out += f'\nRecipe:' + if name != '' and name is not None: str_out += f' {name}' + if cat != '': str_out += f' ({cat})' + str_out += '\n' + for i in range(len(ingredients)): + # get quantifier + if ingredients[i] == 'egg': + quantities[i] = 1 + ingredients[i] = 'egg white' + if unit == 'mL': + quantifier = ' (30 mL)' + elif unit == 'oz': + quantifier = ' (1 fl oz)' + else: + raise ValueError + elif ingredients[i] in ['angostura', 'orange bitters']: + quantities[i] = max(1, int(quantities[i] / 0.6)) + quantifier = ' dash' + if quantities[i] > 1: quantifier += 'es' + elif ingredients[i] == 'mint': + if quantities[i] > 1: quantifier = ' leaves' + else: quantifier = ' leaf' + else: + if unit == "oz": + quantities[i] = float(f"{quantities[i] * 0.033814:.3f}") # convert to fl oz + quantifier = ' fl oz' + else: + quantifier = ' mL' + str_out += f' {quantities[i]}{quantifier} - {ingredients[i]}\n' + + if to_print: + print(str_out) + return str_out + + +def test_datapoint(datapoint, category, ingredients, quantities): + # run checks + ingredient_indexes = [ingredient_list.index(ing) for ing in ingredients] + profile = get_profile(category, ingredients, quantities) + volume = profile['end volume'] + alcohol = profile['end alcohol'] + acid = profile['end acid'] + sugar = profile['end sugar'] + # check volume + if datapoint['glass'] != None: + if volume > glass_volume[datapoint['glass']] * fill_rate: + # recompute quantities for it to match + ratio = fill_rate * glass_volume[datapoint['glass']] / volume + for i_q in range(len(quantities)): + quantities[i_q] = float(f'{quantities[i_q] * ratio:.2f}') + # check alcohol + assert alcohol < 30, 'too boozy' + assert alcohol < 5, 'not boozy enough' + assert acid < 2, 'too much acid' + assert sugar < 20, 'too much sugar' + assert len(ingredients) > 1, 'only one ingredient' + if len(set(ingredients)) != len(ingredients): + i_doubles = [] + s_ing = set() + for i, ing in enumerate(ingredients): + if ing in s_ing: + i_doubles.append(i) + else: + s_ing.add(ing) + ingredient_double_ok = ['mint', 'cointreau', 'lemon juice', 'cuban rum', 'double syrup'] + if len(i_doubles) == 1 and ingredients[i_doubles[0]] in ingredient_double_ok: + ing_double = ingredients[i_doubles[0]] + double_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == ing_double]) + ingredients.pop(i_doubles[0]) + quantities.pop(i_doubles[0]) + quantities[ingredients.index(ing_double)] = double_q + else: + assert False, f'double ingredient, not {ingredient_double_ok}' + lemon_lime_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['lime juice', 'lemon juice']]) + assert lemon_lime_q <= 45, 'too much lemon and lime' + salt_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'salt']) + assert salt_q <= 8, 'too much salt' + bitter_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] in ['angostura', 'orange bitters']]) + assert bitter_q <= 5 * one_dash, 'too much bitter' + absinthe_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'absinthe']) + if absinthe_q > 4 * one_dash: + mix_volume = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] != 'mint']) + assert absinthe_q < 0.5 * mix_volume, 'filter absinthe glasses' + if any([w in datapoint['how_to'] or any([w in ing.lower() for ing in datapoint['recipe'][1]]) for w in ['warm', 'boil', 'hot']]) and 'shot' not in datapoint['how_to']: + assert False + water_q = np.sum([quantities[i] for i in range(len(ingredients)) if ingredients[i] == 'water']) + assert water_q < 40 + # n_liqueur = np.sum([ingredient_profiles['type'][i].lower() == 'liqueur' for i in ingredient_indexes]) + # assert n_liqueur <= 2 + n_liqueur_and_vermouth = np.sum([ingredient_profiles['type'][i].lower() in ['liqueur', 'vermouth'] for i in ingredient_indexes]) + assert n_liqueur_and_vermouth <= 3 + return ingredients, quantities + +def run_battery_checks_difford(datapoint, category, ingredients, quantities): + flag = False + try: + ingredients, quantities = test_datapoint(datapoint, category, ingredients, quantities) + except: + flag = True + print(datapoint["names"]) + print(datapoint["urls"]) + ingredients, quantities = None, None + + return flag, ingredients, quantities + +def tambouille(q, ingredients_scrubbed, quantities_scrubbed, cat): + # ugly + ing_scrubbed = ingredients_scrubbed[len(quantities_scrubbed)] + if q == '4 cube' and ing_scrubbed == 'pineapple juice': + q = '20 ml' + elif 'top up with' in q: + volume_so_far = np.sum([quantities_scrubbed[i] for i in range(len(quantities_scrubbed)) if ingredients_scrubbed[i] != 'mint']) + volume_mix = np.sum(volume_ranges[cat]) / 2 + if (volume_mix - volume_so_far) < 15: + q = '15 ml'# + else: + q = str(int(volume_mix - volume_so_far)) + ' ml' + elif q == '1 pinch' and ing_scrubbed == 'salt': + q = '2 drops' + elif 'cube' in q and ing_scrubbed == 'double syrup': + q = f'{float(q.split(" ")[0]) * 2 * 1.7:.2f} ml' #2g per cube, 1.7 is ratio solid / syrup + elif 'wedge' in q: + if ing_scrubbed == 'orange juice': + vol = 70 + elif ing_scrubbed == 'lime juice': + vol = 30 + elif ing_scrubbed == 'lemon juice': + vol = 45 + elif ing_scrubbed == 'pineapple juice': + vol = 140 + factor = float(q.split(' ')[0]) * 0.15 # consider a wedge to be 0.15*the fruit. + q = f'{factor * vol:.2f} ml' + elif 'slice' in q: + if ing_scrubbed == 'orange juice': + vol = 70 + elif ing_scrubbed == 'lime juice': + vol = 30 + elif ing_scrubbed == 'lemon juice': + vol = 45 + elif ing_scrubbed == 'pineapple juice': + vol = 140 + f = q.split(' ')[0] + if len(f.split('⁄')) > 1: + frac = f.split('⁄') + factor = float(frac[0]) / float(frac[1]) + else: + factor = float(f) + factor *= 0.1 # consider a slice to be 0.1*the fruit. + q = f'{factor * vol:.2f} ml' + elif q == '1 whole' and ing_scrubbed == 'luxardo maraschino': + q = '10 ml' + elif ing_scrubbed == 'egg' and 'ml' not in q: + q = f'{float(q) * 30:.2f} ml' # 30 ml per egg + return q + + +def compute_eucl_dist(a, b): + return np.sqrt(np.sum((a - b)**2)) + +def evaluate_with_quadruplets(representations, strategy='all'): + with open(QUADRUPLETS_PATH, 'rb') as f: + data = pickle.load(f) + data = list(data.values()) + quadruplets = [] + if strategy != 'all': + for d in data: + if d[0] == strategy: + quadruplets.append(d[1:]) + elif strategy == 'all': + for d in data: + quadruplets.append(d[1:]) + else: + raise ValueError + + scores = [] + for q in quadruplets: + close = q[0] + if len(close) == 2: + far = q[1] + distance_close = compute_eucl_dist(representations[close[0]], representations[close[1]]) + distances_far = [compute_eucl_dist(representations[far[i][0]], representations[far[i][1]]) for i in range(len(far))] + scores.append(distance_close < np.min(distances_far)) + if len(scores) == 0: + score = np.nan + else: + score = np.mean(scores) + return score + + diff --git a/src/debugger.py b/src/debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa54d4e6ae9fd7337535dd29b64cb040480ac84 --- /dev/null +++ b/src/debugger.py @@ -0,0 +1,180 @@ +import os.path + +# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob +from src.music.utils import load_audio +from src.music.config import FPS +import pretty_midi as pm +import numpy as np +from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH +from sklearn.neighbors import NearestNeighbors +from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA +# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers +from src.cocktails.utilities.other_scrubbing_utilities import print_recipe +from src.music.utils import get_all_subfiles_with_extension +import os +import pickle +import pandas as pd +import time + +keyword = 'b256_r128_represented' +def load_reps(rep_path, sample_size=None): + if sample_size: + with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f: + data = pickle.load(f) + else: + with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f: + data = pickle.load(f) + reps = data['reps'] + # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']] + playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']] + n_data, dim_data = reps.shape + return reps, data['paths'], playlists, n_data, dim_data + +class Debugger(): + def __init__(self, verbose=True): + + if verbose: print('Setting up debugger.') + if not os.path.exists(MUSIC_NN_PATH): + reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle' + if not os.path.exists(reps_path): + all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0) + all_data = [] + new_all_rep_path = [] + for i_r, r in enumerate(all_rep_path): + if 'mean_std' not in r: + all_data.append(np.loadtxt(r)) + assert len(all_data[-1]) == 128 + new_all_rep_path.append(r) + data = np.array(all_data) + to_save = dict(reps=data, + paths=new_all_rep_path) + with open(reps_path, 'wb') as f: + pickle.dump(to_save, f) + + reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH) + self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine') + self.nn_model_music.fit(reps) + to_save = dict(nn_model=self.nn_model_music, + rep_paths=self.rep_paths, + dim_rep_music=self.dim_rep_music) + with open(MUSIC_NN_PATH, 'wb') as f: + pickle.dump(to_save, f) + else: + with open(MUSIC_NN_PATH, 'rb') as f: + data = pickle.load(f) + self.nn_model_music = data['nn_model'] + self.rep_paths = data['rep_paths'] + self.dim_rep_music = data['dim_rep_music'] + if verbose: print(f' {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}') + self.rep_paths = np.array(self.rep_paths) + if not os.path.exists(COCKTAIL_NN_PATH): + cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH) + # cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0) + self.nn_model_cocktail = NearestNeighbors(n_neighbors=6) + self.nn_model_cocktail.fit(cocktail_reps) + self.dim_rep_cocktail = cocktail_reps.shape[1] + self.n_cocktails = cocktail_reps.shape[0] + to_save = dict(nn_model=self.nn_model_cocktail, + dim_rep_cocktail=self.dim_rep_cocktail, + n_cocktails=self.n_cocktails) + with open(COCKTAIL_NN_PATH, 'wb') as f: + pickle.dump(to_save, f) + else: + with open(COCKTAIL_NN_PATH, 'rb') as f: + data = pickle.load(f) + self.nn_model_cocktail = data['nn_model'] + self.dim_rep_cocktail = data['dim_rep_cocktail'] + self.n_cocktails = data['n_cocktails'] + if verbose: print(f' {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}') + + self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA) + # self.affective_cluster_centers = get_affective_cluster_centers() + self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls', + 'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep'] + # 'affect', 'affective_cluster_id', 'affective_cluster_center', + + + def get_nearest_songs(self, music_rep): + dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1)) + indexes = indexes.flatten()[:5] + rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]] + return rep_paths, dists.flatten().tolist() + + def get_nearest_cocktails(self, cocktail_rep): + dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1)) + indexes = indexes.flatten() + nn_names = np.array(self.cocktail_data['names'])[indexes].tolist() + nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist() + nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]] + nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist() + return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs + + def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0): + if verbose: print(' ' * level + 'Extracting debug info..') + init_time = time.time() + debug_dict = dict() + debug_dict['all_paths'] = all_paths + debug_dict['recipe_score'] = recipe_score + + if all_paths['audio_path'] != None: + # is it piano? + debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0]) + # how long is the audio + (audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True) + debug_dict['audio_len'] = int(len(audio) / FPS) + else: + debug_dict['piano_solo_prob'] = None + debug_dict['audio_len'] = None + + # how many notes? + midi = pm.PrettyMIDI(all_paths['processed_path']) + debug_dict['nb_notes'] = len(midi.instruments[0].notes) + + # dimension of music rep + representation = np.loadtxt(all_paths['representation_path']) + debug_dict['dim_rep'] = representation.shape[0] + + # closest songs in dataset + debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation) + + # get affective cluster info + # debug_dict['affective_cluster_id'] = affective_cluster_id[0] + # debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist() + # debug_dict['affect'] = affect.flatten().tolist() + indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep) + debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist() + debug_dict['nearest_cocktail_indexes'] = indexes.tolist() + debug_dict['nn_ing_strs'] = nn_ing_strs + debug_dict['nearest_cocktail_names'] = nn_names + debug_dict['nearest_cocktail_urls'] = nn_urls + debug_dict['nearest_cocktail_recipes'] = nn_recipes + + debug_dict['music_reconstruction'] = music_reconstruction.tolist() + debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean() + self.debug_dict = debug_dict + if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.') + + return self.debug_dict + + def print_debug(self, level=0): + print(' ' * level + '__DEBUGGING INFO__') + for k in self.keys_to_print: + to_print = self.debug_dict[k] + if k == 'nearest_cocktail_recipes': + to_print = self.debug_dict[k].copy() + for i in range(len(to_print)): + to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '') + if k == "nn_music": + to_print = self.debug_dict[k].copy() + for i in range(len(to_print)): + to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '') + to_print_str = f'{to_print}' + if isinstance(to_print, float): + to_print_str = f'{to_print:.2f}' + elif isinstance(to_print, list): + if isinstance(to_print[0], float): + to_print_str = '[' + for element in to_print: + to_print_str += f'{element:.2f}, ' + to_print_str = to_print_str[:-2] + ']' + print(' ' * (level + 2) + f'{k} : ' + to_print_str) \ No newline at end of file diff --git a/src/music/__init__.py b/src/music/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/__pycache__/__init__.cpython-39.pyc b/src/music/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d86d795ca62af805eb70af86a0bba39448e4cc9 Binary files /dev/null and b/src/music/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/__pycache__/config.cpython-39.pyc b/src/music/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20932799567a8ffe6fc330ae1e3e18609ce06523 Binary files /dev/null and b/src/music/__pycache__/config.cpython-39.pyc differ diff --git a/src/music/__pycache__/utils.cpython-39.pyc b/src/music/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb85a75c7ccd124bb8b177b3781118eb20074370 Binary files /dev/null and b/src/music/__pycache__/utils.cpython-39.pyc differ diff --git a/src/music/config.py b/src/music/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f185a2f68021337de669b06dc422e0f02deb5354 --- /dev/null +++ b/src/music/config.py @@ -0,0 +1,72 @@ +import numpy as np +import os + +REPO_PATH = '/'.join(os.path.abspath(__file__).split('/')[:-3]) + '/' +AUDIO_PATH = REPO_PATH + 'data/music/audio/' +MIDI_PATH = REPO_PATH + 'data/music/midi/' +MUSIC_PATH = REPO_PATH + 'data/music/' +PROCESSED_PATH = REPO_PATH + 'data/music/processed/' +ENCODED_PATH = REPO_PATH + 'data/music/encoded/' +HANDCODED_REP_PATH = MUSIC_PATH + 'handcoded_reps/' +DATASET_PATH = REPO_PATH + 'data/music/encoded_new_structured/diverse_piano/' +SYNTH_RECORDED_AUDIO_PATH = AUDIO_PATH + 'synth_audio_recorded/' +SYNTH_RECORDED_MIDI_PATH = MIDI_PATH + 'synth_midi_recorded/' +CHECKPOINTS_PATH = REPO_PATH + 'checkpoints/' +EXPERIMENT_PATH = REPO_PATH + 'experiments/' +SEED = 0 + +# params for data download +ALL_URL_PATH = REPO_PATH + 'data/music/audio/all_urls.pickle' +ALL_FAILED_URL_PATH = REPO_PATH + 'data/music/audio/all_failed_urls.pickle' +RATE_AUDIO_SAVE = 16000 +FROM_URL_PATH = AUDIO_PATH + 'from_url/' + +# params transcription +CHKPT_PATH_TRANSCRIPTION = REPO_PATH + 'checkpoints/piano_transcription/note_F1=0.9677_pedal_F1=0.9186.pth' # transcriptor chkpt path +FPS = 16000 +RANDOM_CROP = True # whether to use random crops in case of cropped audio +CROP_LEN = 26 * 60 + +# params midi scrubbing and processing +MAX_DEPTH = 5 # max depth when searching in folders for audio files +MAX_GAP_IN_SONG = 10 # in secs +MIN_LEN = 20 # actual min len could go down to MIN_LEN - 2 * (REMOVE_FIRST_AND_LAST / 5) +MAX_LEN = 25 * 60 # maximum audio len for playlist downloads, and maximum audio length for transcription (in sec) +MIN_NB_NOTES = 80 # min nb of notes per minute of recording +REMOVE_FIRST_AND_LAST = 10 # will be divided by 5 if cutting this makes the song fall below min len + +# parameters encoding +NOISE_INJECTED = True +AUGMENTATION = True +NB_AUG = 4 if AUGMENTATION else 0 +RANGE_NOTE_ON = 128 +RANGE_NOTE_OFF = 128 +RANGE_VEL = 32 +RANGE_TIME_SHIFT = 100 +MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON +MAX_TEST_SIZE = 1000 +CHECKSUM_PATH = REPO_PATH + 'data/music/midi/checksum.pickle' +CHUNK_SIZE = 512 + +ALL_AUGMENTATIONS = [] +for p in [-3, -2, -1, 1, 2, 3]: + ALL_AUGMENTATIONS.append((p)) +ALL_AUGMENTATIONS = np.array(ALL_AUGMENTATIONS) + +ALL_NOISE = [] +for s in [-5, -2.5, 0, 2.5, 5]: + for p in np.arange(-6, 7): + if not ((s == 0) and (p==0)): + ALL_NOISE.append((s, p)) +ALL_NOISE = np.array(ALL_NOISE) + +# music transformer params +REP_MODEL_NAME = REPO_PATH + "checkpoints/music_representation/sentence_embedding/smallbert_b256_r128_1/best_model" +MUSIC_REP_PATH = REPO_PATH + "checkpoints/b256_r128_represented/" +MUSIC_NN_PATH = REPO_PATH + "checkpoints/music_representation/b256_r128_represented/nn_model.pickle" + +TRANSLATION_VAE_CHKP_PATH = REPO_PATH + "checkpoints/music2cocktails/music2flavor/b256_r128_classif001_ld40_meanstd_regground2.5_egg_bubbles/" + +# piano solo evaluation +# META_DATA_PIANO_EVAL_PATH = REPO_PATH + 'data/music/audio/is_piano.csv' +# CHKPT_PATH_PIANO_EVAL = REPO_PATH + 'data/checkpoints/piano_detection/piano_solo_model_32k.pth' \ No newline at end of file diff --git a/src/music/pipeline/__init__.py b/src/music/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/pipeline/__pycache__/__init__.cpython-39.pyc b/src/music/pipeline/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d334faf2a6435408190b006096f34b57895ca2 Binary files /dev/null and b/src/music/pipeline/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/audio2midi.cpython-39.pyc b/src/music/pipeline/__pycache__/audio2midi.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de0ba4c0acceec3396869c54808d5aa5c55c59b9 Binary files /dev/null and b/src/music/pipeline/__pycache__/audio2midi.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/encoded2rep.cpython-39.pyc b/src/music/pipeline/__pycache__/encoded2rep.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303fd8d169ece6022298914168621eece8437cf9 Binary files /dev/null and b/src/music/pipeline/__pycache__/encoded2rep.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/midi2processed.cpython-39.pyc b/src/music/pipeline/__pycache__/midi2processed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef235b41551e8575b2f48d9c5d88305b83c62f7 Binary files /dev/null and b/src/music/pipeline/__pycache__/midi2processed.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/music_pipeline.cpython-39.pyc b/src/music/pipeline/__pycache__/music_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e52626a2a76c6ed5095a8e94aafc09a2dcefb771 Binary files /dev/null and b/src/music/pipeline/__pycache__/music_pipeline.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/processed2encoded.cpython-39.pyc b/src/music/pipeline/__pycache__/processed2encoded.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..452a8b704c4142fb77ee055ea9b643a272ff03f0 Binary files /dev/null and b/src/music/pipeline/__pycache__/processed2encoded.cpython-39.pyc differ diff --git a/src/music/pipeline/__pycache__/url2audio.cpython-39.pyc b/src/music/pipeline/__pycache__/url2audio.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3780aa8d634f6cb7f62704de29a7beced5940c5 Binary files /dev/null and b/src/music/pipeline/__pycache__/url2audio.cpython-39.pyc differ diff --git a/src/music/pipeline/audio2midi.py b/src/music/pipeline/audio2midi.py new file mode 100644 index 0000000000000000000000000000000000000000..135a615a1687d896f011c6bbbc7f2c48afd96707 --- /dev/null +++ b/src/music/pipeline/audio2midi.py @@ -0,0 +1,52 @@ +import torch +import piano_transcription_inference +import numpy as np +import os +import sys +sys.path.append('../../') +from src.music.utils import get_out_path, load_audio +from src.music.config import CHKPT_PATH_TRANSCRIPTION, FPS, MIN_LEN, CROP_LEN +# import librosa +device = 'cuda' if torch.cuda.is_available() else 'cpu' +TRANSCRIPTOR = piano_transcription_inference.PianoTranscription(device=device, + checkpoint_path=CHKPT_PATH_TRANSCRIPTION) + +def audio2midi(audio_path, midi_path=None, crop=CROP_LEN, random_crop=True, verbose=False, level=0): + if verbose and crop < MIN_LEN + 2: + print('crop is inferior to the minimal length of a tune') + assert '.mp3' == audio_path[-4:] + if midi_path is None: + midi_path, _, _ = get_out_path(in_path=audio_path, in_word='audio', out_word='midi', out_extension='.mid') + + if verbose: print(' ' * level + f'Transcribing {audio_path}.') + if os.path.exists(midi_path): + if verbose: print(' ' * (level + 2) + 'Midi file already exists.') + return midi_path, '' + + error_msg = 'Error in transcription. ' + try: + error_msg += 'Maybe in audio loading?' + (audio, _) = load_audio(audio_path, + sr=FPS, + mono=True) + error_msg += ' Nope. Cropping?' + if isinstance(crop, int) and len(audio) > FPS * crop: + rc_str = ' (random crop)' if random_crop else ' (start crop)' + if verbose: print(' ' * (level + 2) + f'Cropping the song to {crop}s before transcription{rc_str}. ') + size_crop = FPS * crop + if random_crop: + index_begining = np.random.randint(len(audio) - size_crop - 1) + else: + index_begining = 0 + audio = audio[index_begining: index_begining + size_crop] + error_msg += ' Nope. Transcription?' + TRANSCRIPTOR.transcribe(audio, midi_path) + error_msg += ' Nope.' + extra = f' Saved to {midi_path}' if midi_path else '' + if verbose: print(' ' * (level + 2) + f'Success! {extra}') + return midi_path, '' + except: + if verbose: print(' ' * (level + 2) + 'Transcription failed.') + if os.path.exists(midi_path): + os.remove(midi_path) + return None, error_msg + ' Yes.' diff --git a/src/music/pipeline/audio2piano_solo_prob.py b/src/music/pipeline/audio2piano_solo_prob.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd9f2854d83baa8648c0ace88e65d66bd2f0f98 --- /dev/null +++ b/src/music/pipeline/audio2piano_solo_prob.py @@ -0,0 +1,47 @@ +import numpy as np +import librosa +import sys +sys.path.append('../../../data/') +from src.music.utilities.processing_models import piano_detection_model +from src.music.config import CHKPT_PATH_PIANO_EVAL + +PIANO_SOLO_DETECTOR = piano_detection_model.PianoSoloDetector(CHKPT_PATH_PIANO_EVAL) +exclude_playlist_folders = ['synth_audio_recorded', 'from_url'] + +def clean_start_and_end_blanks(probs): + if len(probs) > 20: + # clean up to 10s in each direction + n_zeros_start = 0 + for i in range(10): + if probs[i] <= 0.001: + n_zeros_start += 1 + else: + break + n_zeros_end = 0 + for i in range(10): + if probs[-(i + 1)] <= 0.001: + n_zeros_end += 1 + else: + break + if n_zeros_end == 0: + return probs[n_zeros_start:] + else: + return probs[n_zeros_start:-n_zeros_end] + else: + return probs + +def calculate_piano_solo_prob(audio_path, verbose=False): + """Calculate the piano solo probability of all downloaded mp3s, and append + the probability to the meta csv file. Code from https://github.com/bytedance/GiantMIDI-Piano + """ + try: + error_msg = 'Error in audio loading?' + (audio, _) = librosa.core.load(audio_path, sr=piano_detection_model.SR, mono=True) + error_msg += ' Nope. Error in solo prediction?' + probs = PIANO_SOLO_DETECTOR.predict(audio) + # probs = clean_start_and_end_blanks(probs) # remove blanks at start and end (<=10s each way). If not piano, the rest of the song will be enough to tell. + piano_solo_prob = np.mean(probs) + error_msg += ' Nope. ' + return piano_solo_prob, '' + except: + return None, error_msg + 'Yes.' diff --git a/src/music/pipeline/encoded2rep.py b/src/music/pipeline/encoded2rep.py new file mode 100644 index 0000000000000000000000000000000000000000..600631488f6e93178e6a7e25ea59182b22c14f56 --- /dev/null +++ b/src/music/pipeline/encoded2rep.py @@ -0,0 +1,89 @@ +from src.music.utilities.representation_learning_utilities.constants import * +from src.music.config import REP_MODEL_NAME +from src.music.utils import get_out_path +import pickle +import numpy as np +# from transformers import AutoModel, AutoTokenizer +from torch import nn +from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +class Argument(object): + def __init__(self, adict): + self.__dict__.update(adict) + +class RepModel(nn.Module): + def __init__(self, model, model_name): + super().__init__() + if 't5' in model_name: + self.model = model.get_encoder() + else: + self.model = model + self.model.eval() + + def forward(self, inputs): + with torch.no_grad(): + out = self.model(inputs, output_hidden_states=True) + embeddings = out.hidden_states[-1] + return torch.mean(embeddings[0], dim=0) + +# def get_trained_music_LM(model_name): +# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) +# model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name) +# +# return model, tokenizer + +def get_trained_sentence_embedder(model_name): + model = SentenceTransformer(model_name, device=device) + return model + +MODEL = get_trained_sentence_embedder(REP_MODEL_NAME) + +def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0): + if not rep_path: + rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt') + + error_msg = 'Error in music transformer mapping.' + if verbose: print(' ' * level + 'Mapping to final music representations') + try: + error_msg += ' Error in encoded file loading?' + with open(encoded_path, 'rb') as f: + data = pickle.load(f) + performance = [str(w) for w in data['main'] if w != 1] + assert len(performance) % 5 == 0 + if(len(performance) == 0): + error_msg += " Error: No midi messages in primer file" + assert False + error_msg += ' Nope, error in tokenization?' + perf = ' '.join(performance) + # tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0) + error_msg += ' Nope. Maybe in performance encoding?' + # reps = [] + # for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)): + # chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2] + # rep = MODEL(chunk_tokenized) + # reps.append(rep.detach().numpy()) + # representation = np.mean(reps, axis=0) + p = [int(p) for p in perf.split(' ')] + # print('PERF:', np.sum(p), perf) + representation = MODEL.encode(perf) + # print('model weights sum: ', np.sum([param.detach().data.numpy().sum() for param in list(MODEL.parameters())])) + # print('reprep', representation) + error_msg += ' Nope. Saving performance?' + np.savetxt(rep_path, representation) + error_msg += ' Nope.' + if verbose: print(' ' * (level + 2) + 'Success.') + if return_rep: + return rep_path, representation, '' + else: + return rep_path, '' + except: + if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}') + if return_rep: + return None, None, error_msg + else: + return None, error_msg + +if __name__ == "__main__": + representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle") + stop = 1 diff --git a/src/music/pipeline/midi2processed.py b/src/music/pipeline/midi2processed.py new file mode 100644 index 0000000000000000000000000000000000000000..d1457c4ae7beca867f6338aaf2b629ba3db92b7b --- /dev/null +++ b/src/music/pipeline/midi2processed.py @@ -0,0 +1,152 @@ +import time +import os +import sys +sys.path.append('../../') +import pretty_midi as pm +import numpy as np + +from src.music.utils import get_out_path +from src.music.config import MIN_LEN, MIN_NB_NOTES, MAX_GAP_IN_SONG, REMOVE_FIRST_AND_LAST + + +def sort_notes(notes): + starts = np.array([n.start for n in notes]) + index_sorted = np.argsort(starts) + return [notes[i] for i in index_sorted].copy() + + +def delete_notes_end_after_start(notes): + indexes_to_keep = [i for i, n in enumerate(notes) if n.start < n.end] + return [notes[i] for i in indexes_to_keep].copy() + +def compute_largest_gap(notes): + gaps = [] + latest_note_end_so_far = notes[0].end + for i in range(len(notes) - 1): + note_start = notes[i + 1].start + if latest_note_end_so_far < note_start: + gaps.append(note_start - latest_note_end_so_far) + latest_note_end_so_far = max(latest_note_end_so_far, notes[i+1].end) + if len(gaps) > 0: + largest_gap = np.max(gaps) + else: + largest_gap = 0 + return largest_gap + +def analyze_instrument(inst): + # test that piano plays throughout + init = time.time() + notes = inst.notes.copy() + nb_notes = len(notes) + start = notes[0].start + end = inst.get_end_time() + duration = end - start + largest_gap = compute_largest_gap(notes) + return nb_notes, start, end, duration, largest_gap + +def remove_beginning_and_end(midi, end_time): + notes = midi.instruments[0].notes.copy() + new_notes = [n for n in notes if n.start > REMOVE_FIRST_AND_LAST and n.end < end_time - REMOVE_FIRST_AND_LAST] + midi.instruments[0].notes = new_notes + return midi + +def remove_blanks_beginning_and_end(midi): + # remove blanks and the beginning and the end + shift = midi.instruments[0].notes[0].start + for n in midi.instruments[0].notes: + n.start = max(0, n.start - shift) + n.end = max(0, n.end - shift) + for ksc in midi.key_signature_changes: + ksc.time = max(0, ksc.time - shift) + for tsc in midi.time_signature_changes: + tsc.time = max(0, tsc.time - shift) + for pb in midi.instruments[0].pitch_bends: + pb.time = max(0, pb.time - shift) + for cc in midi.instruments[0].control_changes: + cc.time = max(0, cc.time - shift) + return midi + +def is_valid_inst(largest_gap, duration, nb_notes, gap_counts=True): + error_msg = '' + valid = True + if largest_gap > MAX_GAP_IN_SONG and gap_counts: + valid = False + error_msg += f'wide gap ({largest_gap:.2f} secs), ' + if duration < (MIN_LEN + 2 * REMOVE_FIRST_AND_LAST): + valid = False + error_msg += f'too short ({duration:.2f} secs), ' + if nb_notes < MIN_NB_NOTES * duration / 60: # nb of notes needs to be superior to the minimum number / min * the duration in minute + valid = False + error_msg += f'too few notes ({nb_notes}), ' + return valid, error_msg + +def midi2processed(midi_path, processed_path=None, apply_filtering=True, verbose=False, level=0): + assert midi_path.split('.')[-1] in ['mid', 'midi'] + if not processed_path: + processed_path, _, _ = get_out_path(in_path=midi_path, in_word='midi', out_word='processed', out_extension='.mid') + + if verbose: print(' ' * level + f'Processing {midi_path}.') + + if os.path.exists(processed_path): + if verbose: print(' ' * (level + 2) + 'Processed midi file already exists.') + return processed_path, '' + error_msg = 'Error in scrubbing. ' + try: + inst_error_msg = '' + # load mid file + error_msg += 'Error in midi loading?' + midi = pm.PrettyMIDI(midi_path) + error_msg += ' Nope. Removing invalid notes?' + midi.remove_invalid_notes() # filter invalid notes + error_msg += ' Nope. Filtering instruments?' + # filter instruments + instruments = midi.instruments.copy() + new_instru = [] + instruments_data = [] + for i_inst, inst in enumerate(instruments): + if inst.program <= 7 and not inst.is_drum and len(inst.notes) > 5: + # inst is a piano + # check data + inst.notes = sort_notes(inst.notes) # sort notes + inst.notes = delete_notes_end_after_start(inst.notes) # delete invalid notes + nb_notes, start, end, duration, largest_gap = analyze_instrument(inst) + is_valid, err_msg = is_valid_inst(largest_gap=largest_gap, duration=duration, nb_notes=nb_notes, gap_counts='maestro' not in midi_path) + if is_valid or not apply_filtering: + new_instru.append(inst) + instruments_data.append([nb_notes, start, end, duration, largest_gap]) + else: + inst_error_msg += 'inst1: ' + err_msg + '\n' + instruments_data = np.array(instruments_data) + error_msg += ' Nope. Taking one instrument?' + + if len(new_instru) == 0: + error_msg = f'No piano instrument. {inst_error_msg}' + assert False + elif len(new_instru) > 1: + # take instrument playing the most notes + instrument = new_instru[np.argmax(instruments_data[:, 0])] + else: + instrument = new_instru[0] + instrument.program = 0 # set the instrument to Grand Piano. + midi.instruments = [instrument] # put instrument in midi file + error_msg += ' Nope. Removing blanks?' + # remove first and last REMOVE_FIRST_AND_LAST seconds (avoid clapping and jingles) + end_time = midi.get_end_time() + if apply_filtering: midi = remove_beginning_and_end(midi, end_time) + + # remove beginning and end + midi = remove_blanks_beginning_and_end(midi) + error_msg += ' Nope. Saving?' + + # save midi file + midi.write(processed_path) + error_msg += ' Nope.' + if verbose: + extra = f' Saved to {processed_path}' if midi_path else '' + print(' ' * (level + 2) + f'Success! {extra}') + return processed_path, '' + except: + if verbose: print(' ' * (level + 2) + 'Scrubbing failed.') + if os.path.exists(processed_path): + os.remove(processed_path) + return None, error_msg + ' Yes.' diff --git a/src/music/pipeline/music_pipeline.py b/src/music/pipeline/music_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b423e5c0274f82e0c8f2927188dabad6839e5bfa --- /dev/null +++ b/src/music/pipeline/music_pipeline.py @@ -0,0 +1,86 @@ +from src.music.pipeline.url2audio import url2audio +from src.music.pipeline.audio2midi import audio2midi +from src.music.pipeline.midi2processed import midi2processed +from src.music.pipeline.processed2encoded import processed2encoded +from src.music.pipeline.encoded2rep import encoded2rep +from src.music.config import RANDOM_CROP, NB_AUG, FROM_URL_PATH +# from src.music.pipeline.synth2audio import AudioRecorder +# from src.music.pipeline.processed2handcodedrep import processed2handcodedrep +import time +import hashlib + +VERBOSE = True +AUGMENTATION, NOISE_INJECTED = False, False +CROP = 10# crop 30s before transcription + +# AUDIO_RECORDER = AudioRecorder(place='home') + +def encode_music(url=None, + audio_path=None, + midi_path=None, + processed_path=None, + record=False, + crop=CROP, + random_crop=RANDOM_CROP, + augmentation=AUGMENTATION, + noise_injection=NOISE_INJECTED, + apply_filtering=True, + nb_aug=NB_AUG, + level=0, + verbose=VERBOSE): + if not record: assert url is not None or audio_path is not None or midi_path is not None or processed_path is not None + init_time = time.time() + error = '' + try: + if record: + assert audio_path is None and midi_path is None + if verbose: print(' ' * level + 'Processing music, recorded from mic.') + audio_path = AUDIO_RECORDER.record_one() + error = '' + if processed_path is None: + if midi_path is None: + if audio_path is None: + if verbose and not record: print(' ' * level + 'Processing music, from audio source.') + init_t = time.time() + audio_path, _, error = url2audio(playlist_path=FROM_URL_PATH, video_url=url, verbose=verbose, level=level+2) + if verbose: print(' ' * (level + 4) + f'Audio downloaded in {int(time.time() - init_t)} seconds.') + else: + if verbose and not record: print(' ' * level + 'Processing music, from midi source.') + init_t = time.time() + midi_path, error = audio2midi(audio_path, crop=crop, random_crop=random_crop, verbose=verbose, level=level+2) + if verbose: print(' ' * (level + 4) + f'Audio transcribed to midi in {int(time.time() - init_t)} seconds.') + init_t = time.time() + processed_path, error = midi2processed(midi_path, apply_filtering=apply_filtering, verbose=verbose, level=level+2) + if verbose: print(' ' * (level + 4) + f'Midi preprocessed in {int(time.time() - init_t)} seconds.') + init_t = time.time() + encoded_path, error = processed2encoded(processed_path, augmentation=augmentation, nb_aug=nb_aug, noise_injection=noise_injection, verbose=verbose, level=level+2) + if verbose: print(' ' * (level + 4) + f'Midi encoded in {int(time.time() - init_t)} seconds.') + init_t = time.time() + representation_path, representation, error = encoded2rep(encoded_path, return_rep=True, level=level+2, verbose=verbose) + if verbose: print(' ' * (level + 4) + f'Music representation computed in {int(time.time() - init_t)} seconds.') + init_t = time.time() + handcoded_rep_path, handcoded_rep, error = None, None, '' + # handcoded_rep_path, handcoded_rep, error = processed2handcodedrep(processed_path, return_rep=True, level=level+2, verbose=verbose) + if verbose: print(' ' * (level + 4) + f'Music handcoded representation computed in {int(time.time() - init_t)} seconds.') + # assert handcoded_rep_path is not None and representation_path is not None + all_paths = dict(url=url, audio_path=audio_path, midi_path=midi_path, processed_path=processed_path, encoded_path=encoded_path, + representation_path=representation_path, handcoded_rep_path=handcoded_rep_path) + print('audio hash: ', hashlib.md5(open(audio_path, 'rb').read()).hexdigest()) + print('midi hash: ', hashlib.md5(open(midi_path, 'rb').read()).hexdigest()) + print('processed hash: ', hashlib.md5(open(processed_path, 'rb').read()).hexdigest()) + print('encoded hash: ', hashlib.md5(open(encoded_path, 'rb').read()).hexdigest()) + print('rep hash: ', hashlib.md5(open(representation_path, 'rb').read()).hexdigest()) + print("rep:", representation[:10]) + if verbose: print(' ' * (level + 2) + f'Music processed in {int(time.time() - init_time)} seconds.') + except Exception as err: + print(err, error) + if verbose: print(' ' * (level + 2) + f'Music FAILED to process in {int(time.time() - init_time)} seconds.') + representation = None + handcoded_rep = None + all_paths = dict() + + return representation, handcoded_rep, all_paths, error + +if __name__ == '__main__': + representation = encode_music(url="https://www.youtube.com/watch?v=a2LFVWBmoiw")[0] + # representation = encode_music(record=True)[0] \ No newline at end of file diff --git a/src/music/pipeline/processed2encoded.py b/src/music/pipeline/processed2encoded.py new file mode 100644 index 0000000000000000000000000000000000000000..19aafec899c6a3648ab5e01ef2eafa015a5df7e3 --- /dev/null +++ b/src/music/pipeline/processed2encoded.py @@ -0,0 +1,52 @@ +import os +import sys +import numpy as np +import pickle +sys.path.append('../../') + +from src.music.utils import get_out_path +from src.music.config import ALL_NOISE, ALL_AUGMENTATIONS, NB_AUG, NOISE_INJECTED +from src.music.utilities.midi_processor import encode_midi_structured, encode_midi_chunks_structured + +nb_noise = ALL_NOISE.shape[0] +nb_aug = ALL_AUGMENTATIONS.shape[0] + +def sample_augmentations(n): + return ALL_AUGMENTATIONS[np.random.choice(np.arange(nb_aug), size=n, replace=False)] + +def sample_noise(): + return ALL_NOISE[np.random.choice(np.arange(nb_noise))] + +def processed2encoded(processed_path, encoded_path=None, augmentation=False, nb_aug=None, noise_injection=False, verbose=False, level=0): + assert processed_path.split('.')[-1] in ['mid', 'midi'] + if not encoded_path: + encoded_path, _, _ = get_out_path(in_path=processed_path, in_word='processed', out_word='encoded', out_extension='.pickle') + + if verbose: print(' ' * level + f'Encoding {processed_path}') + if os.path.exists(encoded_path): + if verbose: print(' ' * (level + 2) + 'Midi file is already encoded.') + return encoded_path, '' + + if augmentation: + assert isinstance(nb_aug, int) + error_msg = 'Error in encoding. ' + try: + error_msg = 'Error in encoding midi?' + nb_noise = 1 if noise_injection else 0 + encoded_main, encoded_aug, encoded_noisy = encode_midi_structured(processed_path, nb_aug, nb_noise) + + # make sure augmentations are not out of bounds + error_msg = ' Nope. Error in saving encoding?' + with open(encoded_path, 'wb') as f: + pickle.dump(dict(main=encoded_main, aug=encoded_aug, noisy=encoded_noisy), f) + error_msg = ' Nope.' + if verbose: + extra = f' Saved to {encoded_path}' if encoded_path else '' + print(' ' * (level + 2) + f'Success! {extra}') + return encoded_path, '' + except: + if verbose: print(' ' * (level + 2) + 'Transcription failed.') + if os.path.exists(encoded_path): + os.remove(encoded_path) + return None, error_msg + ' Yes.' + diff --git a/src/music/pipeline/processed2handcodedrep.py b/src/music/pipeline/processed2handcodedrep.py new file mode 100644 index 0000000000000000000000000000000000000000..0f752d4d36988d3734127fdd8519cd24a039e4ad --- /dev/null +++ b/src/music/pipeline/processed2handcodedrep.py @@ -0,0 +1,343 @@ +import numpy as np +from music21 import * +from music21.features import native, jSymbolic, DataSet +import pretty_midi as pm +from src.music.utils import get_out_path +from src.music.utilities.handcoded_rep_utilities.tht import tactus_hypothesis_tracker, tracker_analysis +from src.music.utilities.handcoded_rep_utilities.loudness import get_loudness, compute_total_loudness, amplitude2db, velocity2amplitude, get_db_of_equivalent_loudness_at_440hz, pitch2freq +import json +import os +environment.set('musicxmlPath', '/home/cedric/Desktop/test/') +midi_path = "/home/cedric/Documents/pianocktail/data/music/processed/doug_mckenzie_processed/allthethings_reharmonized_processed.mid" + +FEATURES_DICT_SCORE = dict( + # strongest pulse: measures how fast the melody is + # stronger_pulse=jSymbolic.StrongestRhythmicPulseFeature, + # weights of the two strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature + pulse_strength_two=jSymbolic.CombinedStrengthOfTwoStrongestRhythmicPulsesFeature, + # weights of the strongest pulse, measures rhythmic consistency: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#combinedstrengthoftwostrongestrhythmicpulsesfeature + pulse_strength = jSymbolic.StrengthOfStrongestRhythmicPulseFeature, + # variability of attacks: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#variabilityoftimebetweenattacksfeature + +) +FEATURES_DICT = dict( + # bass register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature + # bass_register=jSymbolic.ImportanceOfBassRegisterFeature, + # high register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature + # high_register=jSymbolic.ImportanceOfHighRegisterFeature, + # medium register importance: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#importanceofbassregisterfeature + # medium_register=jSymbolic.ImportanceOfMiddleRegisterFeature, + # number of common pitches (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature + # common_pitches=jSymbolic.NumberOfCommonPitchesFeature, + # pitch class variety (used at least once): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#pitchvarietyfeature + # pitch_variety=jSymbolic.PitchVarietyFeature, + # attack_variability = jSymbolic.VariabilityOfTimeBetweenAttacksFeature, + # staccato fraction: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#staccatoincidencefeature + # staccato_score = jSymbolic.StaccatoIncidenceFeature, + # mode analysis: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesNative.html + av_melodic_interval = jSymbolic.AverageMelodicIntervalFeature, + # chromatic motion: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#chromaticmotionfeature + chromatic_motion = jSymbolic.ChromaticMotionFeature, + # direction of motion (fraction of rising intervals: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#directionofmotionfeature + motion_direction = jSymbolic.DirectionOfMotionFeature, + # duration of melodic arcs: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#durationofmelodicarcsfeature + melodic_arcs_duration = jSymbolic.DurationOfMelodicArcsFeature, + # melodic arcs size: https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#sizeofmelodicarcsfeature + melodic_arcs_size = jSymbolic.SizeOfMelodicArcsFeature, + # number of common melodic interval (at least 9% of all): https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#numberofcommonmelodicintervalsfeature + # common_melodic_intervals = jSymbolic.NumberOfCommonMelodicIntervalsFeature, + # https://web.mit.edu/music21/doc/moduleReference/moduleFeaturesJSymbolic.html#amountofarpeggiationfeature + # arpeggiato=jSymbolic.AmountOfArpeggiationFeature, +) + + + + + + +def compute_beat_info(onsets): + onsets_in_ms = np.array(onsets) * 1000 + + tht = tactus_hypothesis_tracker.default_tht() + trackers = tht(onsets_in_ms) + top_hts = tracker_analysis.top_hypothesis(trackers, len(onsets_in_ms)) + beats = tracker_analysis.produce_beats_information(onsets_in_ms, top_hts, adapt_period=250 is not None, + adapt_phase=tht.eval_f, max_delta_bpm=250, avoid_quickturns=None) + tempo = 1 / (np.mean(np.diff(beats)) / 1000) * 60 # in bpm + conf_values = tracker_analysis.tht_tracking_confs(trackers, len(onsets_in_ms)) + pulse_clarity = np.mean(np.array(conf_values), axis=0)[1] + return tempo, pulse_clarity + +def dissonance_score(A): + """ + Given a piano-roll indicator matrix representation of a musical work (128 pitches x beats), + return the dissonance as a function of beats. + Input: + A - 128 x beats indicator matrix of MIDI pitch number + + """ + freq_rats = np.arange(1, 7) # Harmonic series ratios + amps = np.exp(-.5 * freq_rats) # Partial amplitudes + F0 = 8.1757989156 # base frequency for MIDI (note 0) + diss = [] # List for dissonance values + thresh = 1e-3 + for beat in A.T: + idx = np.where(beat>thresh)[0] + if len(idx): + freqs, mags = [], [] # lists for frequencies, mags + for i in idx: + freqs.extend(F0*2**(i/12.0)*freq_rats) + mags.extend(amps) + freqs = np.array(freqs) + mags = np.array(mags) + sortIdx = freqs.argsort() + d = compute_dissonance(freqs[sortIdx],mags[sortIdx]) + diss.extend([d]) + else: + diss.extend([-1]) # Null value + diss = np.array(diss) + return diss[np.where(diss != -1)] + +def compute_dissonance(freqs, amps): + """ + From https://notebook.community/soundspotter/consonance/week1_consonance + Compute dissonance between partials with center frequencies in freqs, uses a model of critical bandwidth. + and amplitudes in amps. Based on Sethares "Tuning, Timbre, Spectrum, Scale" (1998) after Plomp and Levelt (1965) + + inputs: + freqs - list of partial frequencies + amps - list of corresponding amplitudes [default, uniformly 1] + """ + b1, b2, s1, s2, c1, c2, Dstar = (-3.51, -5.75, 0.0207, 19.96, 5, -5, 0.24) + f = np.array(freqs) + a = np.array(amps) + idx = np.argsort(f) + f = f[idx] + a = a[idx] + N = f.size + D = 0 + for i in range(1, N): + Fmin = f[ 0 : N - i ] + S = Dstar / ( s1 * Fmin + s2) + Fdif = f[ i : N ] - f[ 0 : N - i ] + am = a[ i : N ] * a[ 0 : N - i ] + Dnew = am * (c1 * np.exp (b1 * S * Fdif) + c2 * np.exp(b2 * S * Fdif)) + D += Dnew.sum() + return D + + + + +def store_new_midi(notes, out_path): + midi = pm.PrettyMIDI() + midi.instruments.append(pm.Instrument(program=0, is_drum=False)) + midi.instruments[0].notes = notes + midi.write(out_path) + return midi + + +def processed2handcodedrep(midi_path, handcoded_rep_path=None, crop=30, verbose=False, save=True, return_rep=False, level=0): + try: + if not handcoded_rep_path: + handcoded_rep_path, _, _ = get_out_path(in_path=midi_path, in_word='processed', out_word='handcoded_reps', out_extension='.mid') + features = dict() + if verbose: print(' ' * level + 'Computing handcoded representations') + if os.path.exists(handcoded_rep_path): + with open(handcoded_rep_path.replace('.mid', '.json'), 'r') as f: + features = json.load(f) + rep = np.array([features[k] for k in sorted(features.keys())]) + if rep.size == 49: + os.remove(handcoded_rep_path) + else: + if verbose: print(' ' * (level + 2) + 'Already computed.') + if return_rep: + return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), '' + else: + return handcoded_rep_path, '' + midi = pm.PrettyMIDI(midi_path) # load midi with pretty midi + notes = midi.instruments[0].notes # get notes + notes.sort(key=lambda x: (x.start, x.pitch)) # sort notes per start and pitch + onsets, offsets, pitches, durations, velocities = [], [], [], [], [] + n_notes_cropped = len(notes) + for i_n, n in enumerate(notes): + onsets.append(n.start) + offsets.append(n.end) + durations.append(n.end-n.start) + pitches.append(n.pitch) + velocities.append(n.velocity) + if crop is not None: # find how many notes to keep + if n.start > crop and n_notes_cropped == len(notes): + n_notes_cropped = i_n + break + notes = notes[:n_notes_cropped] + midi = store_new_midi(notes, handcoded_rep_path) + # pianoroll = midi.get_piano_roll() # extract piano roll representation + + # compute loudness + amplitudes = velocity2amplitude(np.array(velocities)) + power_dbs = amplitude2db(amplitudes) + frequencies = pitch2freq(np.array(pitches)) + loudness_values = get_loudness(power_dbs, frequencies) + # compute average perceived loudness + # for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent. + # equivalent_powers_dbs = get_db_of_equivalent_loudness_at_440hz(frequencies, power_dbs) + # then get the corresponding amplitudes + # equivalent_amplitudes = 10 ** (equivalent_powers_dbs / 20) + # not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz + # av_total_loudness, std_total_loudness = compute_total_loudness(equivalent_amplitudes, onsets, offsets) + + end_time = np.max(offsets) + start_time = notes[0].start + + + score = converter.parse(handcoded_rep_path) + score.chordify() + notes_without_chords = stream.Stream(score.flatten().getElementsByClass('Note')) + + velocities_wo_chords, pitches_wo_chords, amplitudes_wo_chords, dbs_wo_chords = [], [], [], [] + frequencies_wo_chords, loudness_values_wo_chords, onsets_wo_chords, offsets_wo_chords, durations_wo_chords = [], [], [], [], [] + for i_n in range(len(notes_without_chords)): + n = notes_without_chords[i_n] + velocities_wo_chords.append(n.volume.velocity) + pitches_wo_chords.append(n.pitch.midi) + onsets_wo_chords.append(n.offset) + offsets_wo_chords.append(onsets_wo_chords[-1] + n.seconds) + durations_wo_chords.append(n.seconds) + + amplitudes_wo_chords = velocity2amplitude(np.array(velocities_wo_chords)) + power_dbs_wo_chords = amplitude2db(amplitudes_wo_chords) + frequencies_wo_chords = pitch2freq(np.array(pitches_wo_chords)) + loudness_values_wo_chords = get_loudness(power_dbs_wo_chords, frequencies_wo_chords) + # compute average perceived loudness + # for each power, compute loudness, then compute power such that the loudness at 440 Hz would be equivalent. + # equivalent_powers_dbs_wo_chords = get_db_of_equivalent_loudness_at_440hz(frequencies_wo_chords, power_dbs_wo_chords) + # then get the corresponding amplitudes + # equivalent_amplitudes_wo_chords = 10 ** (equivalent_powers_dbs_wo_chords / 20) + # not use a amplitude model across the sample to compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz + # av_total_loudness_wo_chords, std_total_loudness_wo_chords = compute_total_loudness(equivalent_amplitudes_wo_chords, onsets_wo_chords, offsets_wo_chords) + + ds = DataSet(classLabel='test') + f = list(FEATURES_DICT.values()) + ds.addFeatureExtractors(f) + ds.addData(notes_without_chords) + ds.process() + for k, f in zip(FEATURES_DICT.keys(), ds.getFeaturesAsList()[0][1:-1]): + features[k] = f + + ds = DataSet(classLabel='test') + f = list(FEATURES_DICT_SCORE.values()) + ds.addFeatureExtractors(f) + ds.addData(score) + ds.process() + for k, f in zip(FEATURES_DICT_SCORE.keys(), ds.getFeaturesAsList()[0][1:-1]): + features[k] = f + + # # # # # + # Register features + # # # # # + + # features['av_pitch'] = np.mean(pitches) + # features['std_pitch'] = np.std(pitches) + # features['range_pitch'] = np.max(pitches) - np.min(pitches) # aka ambitus + + # # # # # + # Rhythmic features + # # # # # + + # tempo, pulse_clarity = compute_beat_info(onsets[:n_notes_cropped]) + # features['pulse_clarity'] = pulse_clarity + # features['tempo'] = tempo + features['tempo_pm'] = midi.estimate_tempo() + + # # # # # + # Temporal features + # # # # # + + features['av_duration'] = np.mean(durations) + # features['std_duration'] = np.std(durations) + features['note_density'] = len(notes) / (end_time - start_time) + # intervals_wo_chords = np.diff(onsets_wo_chords) + # articulations = [max((i-d)/i, 0) for d, i in zip(durations_wo_chords, intervals_wo_chords) if i != 0] + # features['articulation'] = np.mean(articulations) + # features['av_duration_wo_chords'] = np.mean(durations_wo_chords) + # features['std_duration_wo_chords'] = np.std(durations_wo_chords) + + # # # # # + # Dynamics features + # # # # # + features['av_velocity'] = np.mean(velocities) + features['std_velocity'] = np.std(velocities) + features['av_loudness'] = np.mean(loudness_values) + # features['std_loudness'] = np.std(loudness_values) + features['range_loudness'] = np.max(loudness_values) - np.min(loudness_values) + # features['av_integrated_loudness'] = av_total_loudness + # features['std_integrated_loudness'] = std_total_loudness + # features['av_velocity_wo_chords'] = np.mean(velocities_wo_chords) + # features['std_velocity_wo_chords'] = np.std(velocities_wo_chords) + # features['av_loudness_wo_chords'] = np.mean(loudness_values_wo_chords) + # features['std_loudness_wo_chords'] = np.std(loudness_values_wo_chords) + features['range_loudness_wo_chords'] = np.max(loudness_values_wo_chords) - np.min(loudness_values_wo_chords) + # features['av_integrated_loudness'] = av_total_loudness_wo_chords + # features['std_integrated_loudness'] = std_total_loudness_wo_chords + # indices_with_intervals = np.where(intervals_wo_chords > 0.01) + # features['av_loudness_change'] = np.mean(np.abs(np.diff(np.array(loudness_values_wo_chords)[indices_with_intervals]))) # accentuation + # features['av_velocity_change'] = np.mean(np.abs(np.diff(np.array(velocities_wo_chords)[indices_with_intervals]))) # accentuation + + # # # # # + # Harmony features + # # # # # + + # get major_minor score: https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisDiscrete.html + music_analysis = score.analyze('key') + major_score = None + minor_score = None + for a in [music_analysis] + music_analysis.alternateInterpretations: + if 'major' in a.__str__() and a.correlationCoefficient > 0: + major_score = a.correlationCoefficient + elif 'minor' in a.__str__() and a.correlationCoefficient > 0: + minor_score = a.correlationCoefficient + if major_score is not None and minor_score is not None: + break + features['major_minor'] = major_score / (major_score + minor_score) + features['tonal_certainty'] = music_analysis.tonalCertainty() + # features['av_sensory_dissonance'] = np.mean(dissonance_score(pianoroll)) + #TODO only works for chords, do something with melodic intervals: like proportion that is not third, fifth or sevenths? + + # # # # # + # Interval features + # # # # # + #https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisPatel.html + # features['melodic_interval_variability'] = analysis.patel.melodicIntervalVariability(notes_without_chords) + + # # # # # + # Suprize features + # # # # # + # https://web.mit.edu/music21/doc/moduleReference/moduleAnalysisMetrical.html + # analysis.metrical.thomassenMelodicAccent(notes_without_chords) + # melodic_accents = [n.melodicAccent for n in notes_without_chords] + # features['melodic_accent'] = np.mean(melodic_accents) + + if save: + for k, v in features.items(): + features[k] = float(features[k]) + with open(handcoded_rep_path.replace('.mid', '.json'), 'w') as f: + json.dump(features, f) + else: + print(features) + if os.path.exists(handcoded_rep_path): + os.remove(handcoded_rep_path) + if verbose: print(' ' * (level + 2) + 'Success.') + if return_rep: + return handcoded_rep_path, np.array([features[k] for k in sorted(features.keys())]), '' + else: + return handcoded_rep_path, '' + except: + if verbose: print(' ' * (level + 2) + 'Failed.') + if return_rep: + return None, None, 'error' + else: + return None, 'error' + + +if __name__ == '__main__': + processed2handcodedrep(midi_path, '/home/cedric/Desktop/test.mid', save=False) \ No newline at end of file diff --git a/src/music/pipeline/synth2audio.py b/src/music/pipeline/synth2audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8eee792b6a7b50837acdce34017f907cc9adda --- /dev/null +++ b/src/music/pipeline/synth2audio.py @@ -0,0 +1,170 @@ +import pynput +import sys +sys.path.append('../../') +from src.music.config import SYNTH_RECORDED_AUDIO_PATH, RATE_AUDIO_SAVE +from datetime import datetime +import numpy as np +import os +import wave + +from ctypes import * +from contextlib import contextmanager +import pyaudio + +ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p) + +def py_error_handler(filename, line, function, err, fmt): + pass +c_error_handler = ERROR_HANDLER_FUNC(py_error_handler) + +@contextmanager +def noalsaerr(): + asound = cdll.LoadLibrary('libasound.so') + asound.snd_lib_error_set_handler(c_error_handler) + yield + asound.snd_lib_error_set_handler(None) + +global KEY_PRESSED +KEY_PRESSED = None + +def on_press(key): + global KEY_PRESSED + try: + KEY_PRESSED = key.name + except: + pass + +def on_release(key): + global KEY_PRESSED + KEY_PRESSED = None + + +def is_pressed(key): + global KEY_PRESSED + return KEY_PRESSED == key + +# keyboard listener +listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release) +listener.start() + +LEN_RECORDINGS = 40 +class AudioRecorder: + def __init__(self, chunk=2**10, rate=44100, place='', len_recording=LEN_RECORDINGS, drop_beginning=0.5): + self.chunk = chunk + self.rate = rate + with noalsaerr(): + self.audio = pyaudio.PyAudio() + self.channels = 1 + self.format = pyaudio.paInt16 + self.stream = self.audio.open(format=self.format, + channels=self.channels, + rate=rate, + input=True, + frames_per_buffer=chunk) + self.stream.stop_stream() + self.drop_beginning_chunks = int(drop_beginning * self.rate / self.chunk) + self.place = place + self.len_recordings = len_recording + + def get_filename(self): + now = datetime.now() + return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mp3' + + def read_last_chunk(self): + return self.stream.read(self.chunk) + + def live_read(self): + if self.stream.is_stopped(): + self.stream.start_stream() + i = 0 + while not is_pressed('esc'): + data = np.frombuffer(self.stream.read(self.chunk), dtype=np.int16) + peak = np.average(np.abs(data)) * 2 + bars = "#"*int(50 * peak / 2 ** 16) + i += 1 + print("%04d %05d %s"%(i,peak,bars)) + self.stream.stop_stream() + + def record_next_N_seconds(self, n=None, saving_path=None): + if saving_path is None: + saving_path = SYNTH_RECORDED_AUDIO_PATH + self.get_filename() + if n is None: + n = self.len_recordings + + print(f'Recoding the next {n} secs.' + # f'\n\tRecording starts when the first key is pressed;' + f'\n\tPress Enter to end the recording;' + f'\n\tPress BackSpace (<--) to cancel the recording;' + f'\n\tSaving to {saving_path}') + try: + self.stream.start_stream() + backspace_pressed = False + self.recording = [] + i_chunk = 0 + while not is_pressed('enter') and self.chunk / self.rate * i_chunk < n: + self.recording.append(self.read_last_chunk()) + i_chunk += 1 + if is_pressed('backspace'): + backspace_pressed = True + print('\n \t--> Recording cancelled! (you pressed BackSpace)') + break + self.stream.stop_stream() + + # save the file + if not backspace_pressed: + self.recording = self.recording[self.drop_beginning_chunks:] # drop first chunks to remove keyboard sound + with wave.open(saving_path[:-4] + '.wav', 'wb') as waveFile: + waveFile.setnchannels(self.channels) + waveFile.setsampwidth(self.audio.get_sample_size(self.format)) + waveFile.setframerate(self.rate) + waveFile.writeframes(b''.join(self.recording)) + os.system(f'ffmpeg -i "{saving_path[:-4] + ".wav"}" -vn -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} -b:a 320k "{saving_path}" ') + os.remove(saving_path[:-4] + '.wav') + print(f'\n--> Recording saved, duration: {self.chunk / self.rate * i_chunk:.2f} secs.') + return saving_path + except: + print('\n --> The recording failed.') + return None + + def record_one(self): + ready_msg = False + print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)') + while True: + if not ready_msg: + print('-------\nReady to record!') + print('Press space to start a recording\n') + ready_msg = True + + if is_pressed('space'): + saving_path = self.record_next_N_seconds() + break + return saving_path + + def run(self): + # with pynput.Listener( + # on_press=self.on_press) as listener: + # listener.join() + ready_msg = False + print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)') + while True: + if not ready_msg: + print('-------\nReady to record!') + print('Press space to start a recording\n') + ready_msg = True + + if is_pressed('space'): + self.record_next_N_seconds() + ready_msg = False + if is_pressed('esc'): + print('End of the recording session. See you soon!') + self.close() + break + + def close(self): + self.stream.close() + self.audio.terminate() + +if __name__ == '__main__': + audio_recorder = AudioRecorder(place='home') + audio_recorder.record_one() + diff --git a/src/music/pipeline/synth2midi.py b/src/music/pipeline/synth2midi.py new file mode 100644 index 0000000000000000000000000000000000000000..b1257a3fa3ec84e51f2ef4bd9861b7d9ede68219 --- /dev/null +++ b/src/music/pipeline/synth2midi.py @@ -0,0 +1,146 @@ +import mido +mido.set_backend('mido.backends.pygame') +from mido import Message, MidiFile, MidiTrack +import time +import pynput +import sys +sys.path.append('../../') +from src.music.config import SYNTH_RECORDED_MIDI_PATH +from datetime import datetime + +#TODO: debug this with other cable, keyboard and sound card +global KEY_PRESSED +KEY_PRESSED = None + +def on_press(key): + global KEY_PRESSED + try: + KEY_PRESSED = key.name + except: + pass + +def on_release(key): + global KEY_PRESSED + KEY_PRESSED = None + + +def is_pressed(key): + global KEY_PRESSED + return KEY_PRESSED == key + +# keyboard listener +listener = pynput.keyboard.Listener(on_press=on_press, on_release=on_release) +listener.start() + +LEN_MIDI_RECORDINGS = 30 +class MidiRecorder: + def __init__(self, place='', len_midi_recordings=LEN_MIDI_RECORDINGS): + self.place = place + self.len_midi_recordings = len_midi_recordings + self.port = mido.open_input(mido.get_input_names()[0]) + + def get_filename(self): + now = datetime.now() + return self.place + '_' + now.strftime("%b_%d_%Y_%Hh%Mm%Ss") + '.mid' + + def read_last_midi_msgs(self): + return list(self.port.iter_pending()) + + def live_read(self): + while not is_pressed('esc'): + for msg in self.read_last_midi_msgs(): + print(msg) + + def check_if_recording_started(self, msgs, t_init): + started = False + if len(msgs) > 0: + for m in msgs: + if m.type == 'note_on': + started = True + t_init = time.time() + return started, t_init + + def create_empty_midi(self): + mid = MidiFile() + track = MidiTrack() + mid.tracks.append(track) + track.append(Message('program_change', program=0, time=0)) + return mid, track + + def record_next_N_seconds(self, n=None, saving_path=None): + if saving_path is None: + saving_path = SYNTH_RECORDED_PATH + self.get_filename() + if n is None: + n = self.len_midi_recordings + + print(f'Recoding the next {n} secs.' + f'\n\tRecording starts when the first key is pressed;' + f'\n\tPress Enter to end the recording;' + f'\n\tPress BackSpace (<--) to cancel the recording;' + f'\n\tSaving to {saving_path}') + try: + mid, track = self.create_empty_midi() + started = False + backspace_pressed = False + t_init = time.time() + while not is_pressed('enter') and (time.time() - t_init) < n: + msgs = self.read_last_midi_msgs() + if not started: + started, t_init = self.check_if_recording_started(msgs, t_init) + if started: + print("\n\t--> First note pressed, it's on!") + for m in msgs: + print(m) + if m.type == 'note_on' and m.velocity == 0: + m_off = Message(type='note_off', velocity=127, note=m.note, channel=m.channel, time=m.time) + track.append(m_off) + track.append(m) + if is_pressed('backspace'): + backspace_pressed = True + print('\n \t--> Recording cancelled! (you pressed BackSpace)') + break + # save the file + if not backspace_pressed and len(mid.tracks[0]) > 0: + mid.save(saving_path) + print(f'\n--> Recording saved, duration: {mid.length:.2f} secs, {len(mid.tracks[0])} events.') + except: + print('\n --> The recording failed.') + + + def run(self): + # with pynput.Listener( + # on_press=self.on_press) as listener: + # listener.join() + ready_msg = False + print('Starting the recording loop!\n\tPress BackSpace to cancel the current recording;\n\tPress Esc to quit the loop (only works while not recording)') + while True: + if not ready_msg: + print('-------\nReady to record!') + print('Press space to start a recording\n') + ready_msg = True + + if is_pressed('space'): + self.record_next_N_seconds() + ready_msg = False + if is_pressed('esc'): + print('End of the recording session. See you soon!') + break + + +midi_recorder = MidiRecorder(place='home') +midi_recorder.live_read() +# midi_recorder.run() + + +# try: +# controls[msg.control] = msg.value +# except: +# notes.append(msg.note) +# port = mido.open_input() +# while True: +# for msg in port.iter_pending(): +# print(msg) +# +# print('start pause') +# time.sleep(5) +# print('stop pause') \ No newline at end of file diff --git a/src/music/pipeline/url2audio.py b/src/music/pipeline/url2audio.py new file mode 100644 index 0000000000000000000000000000000000000000..6d34f6fa92f4651a08418cf89910cbd9514248b6 --- /dev/null +++ b/src/music/pipeline/url2audio.py @@ -0,0 +1,119 @@ +import os +from pytube import YouTube +from src.music.utils import RATE_AUDIO_SAVE, slugify +from src.music.config import MAX_LEN + +# define filtering keyworfds +start_keywords = [' ', '(', ',', ':'] +end_keywords = [')', ' ', '.', ',', '!', ':'] +def get_all_keywords(k): + all_keywords = [] + for s in start_keywords: + for e in end_keywords: + all_keywords.append(s + k + e) + return all_keywords +filtered_keywords = ['duet', 'duo', 'quartet', 'orchestre', 'orchestra', + 'quintet', 'sixtet', 'septet', 'octet', 'backing track', 'accompaniment', 'string', + 'contrebrasse', 'drums', 'guitar'] + get_all_keywords('live') + get_all_keywords('trio') + +# list of playlist for which no filtering should occur on keywords (they were prefiltered already, it's supposed to be only piano) +playlist_and_channel_not_to_filter = ["https://www.youtube.com/c/MySheetMusicTranscriptions", + "https://www.youtube.com/c/PianoNotion", + "https://www.youtube.com/c/PianoNotion", + "https://www.youtube.com/watch?v=3F5glYefwio&list=PLFv3ZQw-ZPxi2DH3Bau7lBC5K6zfPJZxc", + "https://www.youtube.com/user/Mercuziopianist", + "https://www.youtube.com/channel/UCy6NPK6-xeX7MZLaMARa5qg", + "https://www.youtube.com/channel/UCKMRNFV2dWTWIJnymtA9_Iw", + "https://www.youtube.com/c/pianomaedaful", + "https://www.youtube.com/c/FrancescoParrinoMusic", + "https://www.youtube.com/c/itsremco"] +playlist_ok = "https://www.youtube.com/watch?v=sYv_vk6bJtk&list=PLO9E3V4rGLD9-0BEd3t-AvvMcVF1zOJPj" + + +def should_be_filtered(title, length, url, playlist_url, max_length): + to_filter = False + reason = '' + lower_title = title.lower() + if length > max_length: + reason += f'it is too long (>{max_length/60:.1f} min), ' + to_filter = True + if any([f in lower_title for f in filtered_keywords]) \ + and playlist_url not in playlist_and_channel_not_to_filter \ + and 'to live' not in lower_title and 'alive' not in lower_title \ + and url not in playlist_ok: + reason += 'it contains a filtered keyword, ' + to_filter = True + return to_filter, reason + +def convert_mp4_to_mp3(path, verbose=True): + if verbose: print(f"Converting mp4 to mp3, in {path}\n") + assert '.mp4' == path[-4:] + os.system(f'ffmpeg -i "{path}" -loglevel panic -y -ac 1 -ar {int(RATE_AUDIO_SAVE)} "{path[:-4] + ".mp3"}" ') + os.remove(path) + if verbose: print('\tDone.') + +def pipeline_video(video, playlist_path, filename): + # extract best stream for this video + stream, kbps = extract_best_stream(video.streams) + stream.download(output_path=playlist_path, filename=filename + '.mp4') + # convert to mp3 + convert_mp4_to_mp3(playlist_path + filename + '.mp4', verbose=False) + return kbps + +def extract_best_stream(streams): + # extract best audio stream + stream_out = streams.get_audio_only() + kbps = int(stream_out.abr[:-4]) + return stream_out, kbps + +def get_title_and_length(video): + title = video.title + filename = slugify(title) + length = video.length + return title, filename, length, video.metadata + + +def url2audio(playlist_path, video_url=None, video=None, playlist_url='', apply_filters=False, verbose=False, level=0): + assert video_url is not None or video is not None, 'needs either video or url' + error_msg = 'Error in loading video?' + try: + if not video: + video = YouTube(video_url) + error_msg += ' Nope. In extracting title and length?' + title, filename, length, video_meta_data = get_title_and_length(video) + if apply_filters: + to_filter, reason = should_be_filtered(title, length, video_url, playlist_url, MAX_LEN) + else: + to_filter = False + if not to_filter: + audio_path = playlist_path + filename + ".mp3" + if verbose: print(' ' * level + f'Downloading {title}, Url: {video_url}') + if not os.path.exists(audio_path): + if length > MAX_LEN and verbose: print(' ' * (level + 2) + f'Long video ({int(length/60)} min), will be cut after {int(MAX_LEN/60)} min.') + error_msg += ' Nope. In pipeline video?' + kbps = None + for _ in range(5): + try: + kbps = pipeline_video(video, playlist_path, filename) + break + except: + pass + assert kbps is not None + error_msg += ' Nope. In dict filling?' + data = dict(title=title, filename=filename, length=length, kbps=kbps, url=video_url, meta=video_meta_data) + error_msg += ' Nope. ' + else: + if verbose: print(' ' * (level + 2) + 'Song already downloaded') + data = None + return audio_path, data, '' + else: + return None, None, f'Filtered because {reason}' + except: + if verbose: print(' ' * (level + 2) + f'Download failed with error {error_msg}') + if os.path.exists(audio_path): + os.remove(audio_path) + return None, None, error_msg + ' Yes.' + + + + diff --git a/src/music/representation_analysis/__init__.py b/src/music/representation_analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/representation_analysis/analyze_rep.py b/src/music/representation_analysis/analyze_rep.py new file mode 100644 index 0000000000000000000000000000000000000000..7b066ab8413b91b4a39dc83e2b06b240456cd7b9 --- /dev/null +++ b/src/music/representation_analysis/analyze_rep.py @@ -0,0 +1,146 @@ +import numpy as np +from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors +from sklearn.manifold import TSNE +from src.music.utils import get_all_subfiles_with_extension +import matplotlib.pyplot as plt +import pickle +import random +# import umap +import os +from shutil import copy +# install numba =numba==0.51.2 +# keyword = '32_represented' +# rep_path = f"/home/cedric/Documents/pianocktail/data/music/{keyword}/" +# plot_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/plots/' +# neighbors_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/neighbors/' +interpolation_path = '/home/cedric/Documents/pianocktail/data/music/representation_analysis/interpolation/' +keyword = 'b256_r128_represented' +rep_path = f"/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/{keyword}/" +plot_path = '/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/plots/' +neighbors_path = f'/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/analysis/neighbors_{keyword}/' +os.makedirs(neighbors_path, exist_ok=True) +def extract_all_reps(rep_path): + all_rep_path = get_all_subfiles_with_extension(rep_path, max_depth=3, extension='.txt', current_depth=0) + all_data = [] + new_all_rep_path = [] + for i_r, r in enumerate(all_rep_path): + if 'mean_std' not in r: + all_data.append(np.loadtxt(r)) + assert len(all_data[-1]) == 128 + new_all_rep_path.append(r) + data = np.array(all_data) + to_save = dict(reps=data, + paths=new_all_rep_path) + with open(rep_path + 'music_reps_unnormalized.pickle', 'wb') as f: + pickle.dump(to_save, f) + for sample_size in [100, 200, 500, 1000, 2000, 5000]: + if sample_size < len(data): + inds = np.arange(len(data)) + np.random.shuffle(inds) + to_save = dict(reps=data[inds[:sample_size]], + paths=np.array(all_rep_path)[inds[:sample_size]]) + with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'wb') as f: + pickle.dump(to_save, f) + +def load_reps(rep_path, sample_size=None): + if sample_size: + with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f: + data = pickle.load(f) + else: + with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f: + data = pickle.load(f) + reps = data['reps'] + # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']] + playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']] + n_data, dim_data = reps.shape + return reps, data['paths'], playlists, n_data, dim_data + + +def plot_tsne(reps, playlist_indexes, playlist_colors): + tsne_reps = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(reps) + plt.figure() + keys_to_print = ['spot_piano_solo_blues', 'itsremco', 'piano_solo_classical', + 'piano_solo_pop', 'piano_jazz_unspecified','spot_piano_solo_jazz_1', 'piano_solo_jazz_latin'] + keys_to_print = playlist_indexes.keys() + for k in sorted(keys_to_print): + if k in playlist_indexes.keys(): + # plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, label=k, alpha=0.5) + plt.scatter(tsne_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5) + plt.legend() + plt.savefig(plot_path + f'tsne_{keyword}.png') + fig = plt.gcf() + plt.close(fig) + # umap_reps = umap.UMAP().fit_transform(reps) + # plt.figure() + # for k in sorted(keys_to_print): + # if k in playlist_indexes.keys(): + # plt.scatter(umap_reps[playlist_indexes[k], 0], tsne_reps[playlist_indexes[k], 1], s=100, c=playlist_colors[k], label=k, alpha=0.5) + # plt.legend() + # plt.savefig(plot_path + f'umap_{keyword}.png') + # fig = plt.gcf() + # plt.close(fig) + return tsne_reps#, umap_reps + +def get_playlist_indexes(playlists): + playlist_indexes = dict() + for i in range(n_data): + if playlists[i] not in playlist_indexes.keys(): + playlist_indexes[playlists[i]] = [i] + else: + playlist_indexes[playlists[i]].append(i) + for k in playlist_indexes.keys(): + playlist_indexes[k] = np.array(playlist_indexes[k]) + set_playlists = sorted(set(playlists)) + playlist_colors = dict(zip(set_playlists, ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(len(set_playlists))])) + return set_playlists, playlist_indexes, playlist_colors + +def convert_rep_path_midi_path(rep_path): + # playlist = rep_path.split(f'_{keyword}/')[0].split('/')[-1] + playlist = rep_path.split(f'{keyword}')[1].split('/')[1].replace('_represented', '') + midi_path = "/home/cedric/Documents/pianocktail/data/music/dataset_exploration/dataset_representation/processed/" + playlist + '_processed/' + filename = rep_path.split(f'{keyword}')[1].split(f'/')[2].split('_represented.txt')[0] + '_processed.mid' + # filename = rep_path.split(f'_{keyword}/')[-1].split(f'_{keyword}')[0] + '_processed.mid' + midi_path = midi_path + filename + assert os.path.exists(midi_path), midi_path + return midi_path + +def sample_nn(reps, rep_paths, playlists, n_samples=30): + nn_model = NearestNeighbors(n_neighbors=6, metric='cosine') + nn_model.fit(reps) + indexes = np.arange(len(reps)) + np.random.shuffle(indexes) + for i, ind in enumerate(indexes[:n_samples]): + out = nn_model.kneighbors(reps[ind].reshape(1, -1))[1][0][1:] + midi_path = convert_rep_path_midi_path(rep_paths[ind]) + copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[ind]}_target.mid') + for i_n, neighbor in enumerate(out): + midi_path = convert_rep_path_midi_path(rep_paths[neighbor]) + copy(midi_path, neighbors_path + f'sample_{i}_playlist_{playlists[neighbor]}_neighbor_{i_n}.mid') + +def interpolate(reps, rep_paths, path): + files = os.listdir(path) + bounds = [f for f in files if 'interpolation' not in f] + b_reps = [np.loadtxt(path + f) for f in bounds] + nn_model = NearestNeighbors(n_neighbors=6) + nn_model.fit(reps) + reps = [alpha * b_reps[0] + (1 - alpha) * b_reps[1] for alpha in np.linspace(0, 1., 5)] + copy(convert_rep_path_midi_path(path + bounds[1]), path + 'interpolation_0.mid') + copy(convert_rep_path_midi_path(path + bounds[0]), path + 'interpolation_1.mid') + for alpha, rep in zip(np.linspace(0, 1, 5)[1:-1], reps[1: -1]): + dists, indexes = nn_model.kneighbors(rep.reshape(1, -1)) + if dists.flatten()[0] == 0: + nn = indexes.flatten()[1] + else: + nn = indexes.flatten()[0] + midi_path = convert_rep_path_midi_path(rep_paths[nn]) + copy(midi_path, path + f'interpolation_{alpha}.mid') + +if __name__ == '__main__': + extract_all_reps(rep_path) + reps, rep_paths, playlists, n_data, dim_data = load_reps(rep_path) + set_playlists, playlist_indexes, playlist_colors = get_playlist_indexes(playlists) + # interpolate(reps, rep_paths, interpolation_path + 'trial_1/') + sample_nn(reps, rep_paths, playlists) + tsne_reps, umap_reps = plot_tsne(reps, playlist_indexes, playlist_colors) + diff --git a/src/music/representation_learning/__init__.py b/src/music/representation_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/representation_learning/__pycache__/__init__.cpython-39.pyc b/src/music/representation_learning/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13016d789e2e90b85980923ba6b5fedbee8b26f8 Binary files /dev/null and b/src/music/representation_learning/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/representation_learning/mlm_pretrain/__init__.py b/src/music/representation_learning/mlm_pretrain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/representation_learning/mlm_pretrain/data_collators.py b/src/music/representation_learning/mlm_pretrain/data_collators.py new file mode 100644 index 0000000000000000000000000000000000000000..13971a9af4cc6373e1a926710c556a98a5e4c391 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/data_collators.py @@ -0,0 +1,180 @@ +from typing import Any, Dict, List, Optional, Tuple, Union +from transformers.data.data_collator import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, BatchEncoding, DataCollatorForPermutationLanguageModeling +from dataclasses import dataclass + + +def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): + """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" + import numpy as np + import torch + + # Tensorize if necessary. + if isinstance(examples[0], (list, tuple, np.ndarray)): + examples = [torch.tensor(e, dtype=torch.long) for e in examples] + + length_of_first = examples[0].size(0) + + # Check if padding is necessary. + + are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) + if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): + return torch.stack(examples, dim=0) + + # If yes, check if we have a `pad_token`. + if tokenizer._pad_token is None: + raise ValueError( + "You are attempting to pad samples but the tokenizer you are using" + f" ({tokenizer.__class__.__name__}) does not have a pad token." + ) + + # Creating the full tensor and filling it with our data. + max_length = max(x.size(0) for x in examples) + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) + for i, example in enumerate(examples): + if tokenizer.padding_side == "right": + result[i, : example.shape[0]] = example + else: + result[i, -example.shape[0] :] = example + return result + + +@dataclass +class DataCollatorForMusicModeling(DataCollatorForLanguageModeling): + """ + Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + mlm (`bool`, *optional*, defaults to `True`): + Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs + with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked + tokens and the value to predict for the masked token. + mlm_probability (`float`, *optional*, defaults to 0.15): + The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + + For best performance, this data collator should be used with a dataset having items that are dictionaries or + BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a + [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`. + """ + + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + # Handle dict or lists with proper padding and conversion to tensor. + if isinstance(examples[0], (dict, BatchEncoding)): + batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of) + else: + batch = { + "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) + } + + # If special token mask has been preprocessed, pop it from the dict. + special_tokens_mask = batch.pop("special_tokens_mask", None) + if self.mlm: + batch["input_ids"], batch["labels"] = self.torch_mask_tokens( + batch["input_ids"], special_tokens_mask=special_tokens_mask + ) + else: + labels = batch["input_ids"].clone() + if self.tokenizer.pad_token_id is not None: + labels[labels == self.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + return batch + + def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + import torch + + labels = inputs.clone() + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) + notes_shape = (labels.shape[0], labels.shape[1] // 5) + probability_matrix = torch.full(notes_shape, self.mlm_probability) + # if special_tokens_mask is None: + # special_tokens_mask = [ + # self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + # ] + # special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + # else: + # special_tokens_mask = special_tokens_mask.bool() + + # probability_matrix.masked_fill_(special_tokens_mask, value=0.0) + masked_notes_indices = torch.bernoulli(probability_matrix).bool() + masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1) + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_notes_replaced = torch.bernoulli(torch.full(notes_shape, 0.8)).bool() & masked_notes_indices + indices_replaced = torch.repeat_interleave(indices_notes_replaced, repeats=5, dim=1) + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_notes_random = torch.bernoulli(torch.full(notes_shape, 0.5)).bool() & masked_notes_indices & ~indices_notes_replaced + indices_random = torch.repeat_interleave(indices_notes_random, repeats=5, dim=1) + random_words = torch.randint(3, len(self.tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + + +@dataclass +class DataCollatorForSpanMusicModeling(DataCollatorForLanguageModeling): + """ + Data collator used for permutation language modeling. + - collates batches of tensors, honoring their tokenizer's pad_token + - preprocesses batches for permutation language modeling with procedures specific to XLNet + """ + + + def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]: + """ + The masked tokens to be predicted for a particular sequence are determined by the following algorithm: + 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far). + 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked) + 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be + masked + 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - + span_length]` and mask tokens `start_index:start_index + span_length` + 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the + sequence to be processed), repeat from Step 1. + """ + + import torch + + labels = inputs.clone() + # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) + notes_shape = (labels.shape[0], labels.shape[1] // 5) + masked_notes_indices = torch.full(notes_shape, 0, dtype=torch.bool) + + for i in range(labels.size(0)): + # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far). + cur_len = 0 + max_len = notes_shape[1] + + while cur_len < max_len: + # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked) + span_length = torch.randint(1, 5 + 1, (1,)).item() + # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked + context_length = int(span_length / self.mlm_probability) + # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length` + start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item() + masked_notes_indices[i, start_index: start_index + span_length] = 1 + # Set `cur_len = cur_len + context_length` + cur_len += context_length + + masked_indices = torch.repeat_interleave(masked_notes_indices, repeats=5, dim=1) + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + return inputs, labels + diff --git a/src/music/representation_learning/mlm_pretrain/models/music-bert/config.json b/src/music/representation_learning/mlm_pretrain/models/music-bert/config.json new file mode 100644 index 0000000000000000000000000000000000000000..068ac44ae7b050f4bbc49766f9d8685a3c97a2c5 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-bert/config.json @@ -0,0 +1,20 @@ +{ + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "relative_key_query", + "transformers_version": "4.8.2", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json b/src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..89a4d0d2360f7afcb38606dc684db7b76a438db2 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-bert/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}} \ No newline at end of file diff --git a/src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json b/src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json new file mode 100644 index 0000000000000000000000000000000000000000..068ac44ae7b050f4bbc49766f9d8685a3c97a2c5 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-spanbert/config.json @@ -0,0 +1,20 @@ +{ + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "relative_key_query", + "transformers_version": "4.8.2", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json b/src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..89a4d0d2360f7afcb38606dc684db7b76a438db2 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-spanbert/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"WordLevel","vocab":{"[PAD]":0,"[MASK]":1,"[UNK]":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":"[UNK]"}} \ No newline at end of file diff --git a/src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json b/src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json new file mode 100644 index 0000000000000000000000000000000000000000..9a0c221cca474df8753ba45b5544f019c8b59ab6 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-t5-small/config.json @@ -0,0 +1,56 @@ +{ + "architectures": [ + "T5WithLMHeadModel" + ], + "d_ff": 2048, + "d_kv": 64, + "d_model": 512, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 6, + "num_heads": 8, + "num_layers": 6, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "transformers_version": "4.8.2", + "use_cache": true, + "vocab_size": 32128 +} diff --git a/src/music/representation_learning/mlm_pretrain/models/music-t5-small/tokenizer.json b/src/music/representation_learning/mlm_pretrain/models/music-t5-small/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..d17b4ea158dd4a89e127a79eb0a6f736964a8ecb --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/models/music-t5-small/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":{"type":"Lowercase"},"pre_tokenizer":{"type":"Whitespace"},"post_processor":{"type":"TemplateProcessing","single":[{"Sequence":{"id":"A","type_id":0}},{"SpecialToken":{"id":"","type_id":0}}],"pair":[{"Sequence":{"id":"A","type_id":0}},{"Sequence":{"id":"B","type_id":1}}],"special_tokens":{"":{"id":"","ids":[1],"tokens":[""]}}},"decoder":null,"model":{"type":"WordLevel","vocab":{"":0,"":1,"":2,"2":3,"3":4,"4":5,"5":6,"6":7,"7":8,"8":9,"9":10,"10":11,"11":12,"12":13,"13":14,"14":15,"15":16,"16":17,"17":18,"18":19,"19":20,"20":21,"21":22,"22":23,"23":24,"24":25,"25":26,"26":27,"27":28,"28":29,"29":30,"30":31,"31":32,"32":33,"33":34,"34":35,"35":36,"36":37,"37":38,"38":39,"39":40,"40":41,"41":42,"42":43,"43":44,"44":45,"45":46,"46":47,"47":48,"48":49,"49":50,"50":51,"51":52,"52":53,"53":54,"54":55,"55":56,"56":57,"57":58,"58":59,"59":60,"60":61,"61":62,"62":63,"63":64,"64":65,"65":66,"66":67,"67":68,"68":69,"69":70,"70":71,"71":72,"72":73,"73":74,"74":75,"75":76,"76":77,"77":78,"78":79,"79":80,"80":81,"81":82,"82":83,"83":84,"84":85,"85":86,"86":87,"87":88,"88":89,"89":90,"90":91,"91":92,"92":93,"93":94,"94":95,"95":96,"96":97,"97":98,"98":99,"99":100,"100":101,"101":102,"102":103,"103":104,"104":105,"105":106,"106":107,"107":108,"108":109,"109":110,"110":111,"111":112,"112":113,"113":114,"114":115,"115":116,"116":117,"117":118,"118":119,"119":120,"120":121,"121":122,"122":123,"123":124,"124":125,"125":126,"126":127,"127":128,"128":129,"129":130,"130":131,"131":132,"132":133,"133":134,"134":135,"135":136,"136":137,"137":138,"138":139,"139":140,"140":141,"141":142,"142":143,"143":144,"144":145,"145":146,"146":147,"147":148,"148":149,"149":150,"150":151,"151":152,"152":153,"153":154,"154":155,"155":156,"156":157,"157":158,"158":159,"159":160,"160":161,"161":162,"162":163,"163":164,"164":165,"165":166,"166":167,"167":168,"168":169,"169":170,"170":171,"171":172,"172":173,"173":174,"174":175,"175":176,"176":177,"177":178,"178":179,"179":180,"180":181,"181":182,"182":183},"unk_token":""}} \ No newline at end of file diff --git a/src/music/representation_learning/mlm_pretrain/my_tokenizer.py b/src/music/representation_learning/mlm_pretrain/my_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..410f6860b378d4f1bd3d7d346663314f75b25354 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/my_tokenizer.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +from typing import Union + +from tokenizers import AddedToken, Tokenizer, normalizers, pre_tokenizers +from tokenizers.implementations.base_tokenizer import BaseTokenizer +from tokenizers.models import WordLevel +from tokenizers.processors import TemplateProcessing + + +class MyT5Tokenizer(BaseTokenizer): + """ + This class is a copy of `DeDLOC's tokenizer implementation `__ . + Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization + Represents the Unigram algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: dict, + replacement: str = "▁", + add_prefix_space: bool = True, + unk_token: Union[str, AddedToken] = "", + eos_token: Union[str, AddedToken] = "", + pad_token: Union[str, AddedToken] = "", + ): + self.special_tokens = { + "pad": {"id": 0, "token": pad_token}, + "eos": {"id": 1, "token": eos_token}, + "unk": {"id": 2, "token": unk_token}, + } + + self.special_tokens_list = [None] * len(self.special_tokens) + for token_dict in self.special_tokens.values(): + self.special_tokens_list[token_dict["id"]] = token_dict["token"] + + tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token))#WordLevel(vocab=vocab, unk_token=unk_token)) + + # tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() + + + tokenizer.normalizer = normalizers.Lowercase() + + # tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + + tokenizer.post_processor = TemplateProcessing( + single=f"$A {self.special_tokens['eos']['token']}", + special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])], + ) + + parameters = { + "model": "SentencePieceUnigram", + "replacement": replacement, + "add_prefix_space": add_prefix_space, + } + + super().__init__(tokenizer, parameters) + +class MyBERTTokenizer(BaseTokenizer): + """ + This class is a copy of `DeDLOC's tokenizer implementation `__ . + Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization + Represents the Unigram algorithm, with the pretokenization used by SentencePiece + """ + + def __init__( + self, + vocab: dict, + replacement: str = "▁", + add_prefix_space: bool = True, + unk_token: Union[str, AddedToken] = "", + ): + + + tokenizer = Tokenizer(WordLevel(vocab, unk_token=unk_token))#WordLevel(vocab=vocab, unk_token=unk_token)) + + # tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + tokenizer.pre_tokenizer = pre_tokenizers.Whitespace() + + + tokenizer.normalizer = normalizers.Lowercase() + + # tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + + parameters = { + "model": "SentencePieceUnigram", + "replacement": replacement, + "add_prefix_space": add_prefix_space, + } + + super().__init__(tokenizer, parameters) \ No newline at end of file diff --git a/src/music/representation_learning/mlm_pretrain/pretrain_mlm.py b/src/music/representation_learning/mlm_pretrain/pretrain_mlm.py new file mode 100644 index 0000000000000000000000000000000000000000..805a14bdfdda30657e35f212487820bc21013f0d --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/pretrain_mlm.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=fill-mask +""" +# You can also adapt this script on your own mlm task. Pointers for this are left as comments. + +import argparse +import logging +import math +import os +from pathlib import Path +from accelerate import DistributedDataParallelKwargs +import numpy as np +import torch +import datasets +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from data_collators import DataCollatorForMusicModeling, DataCollatorForSpanMusicModeling + +import transformers +from accelerate import Accelerator, DistributedType +from huggingface_hub import Repository +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + SchedulerType, + get_scheduler, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils.versions import require_version +from src.music.config import DATASET_PATH, EXPERIMENT_PATH + + +logger = logging.getLogger(__name__) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=DATASET_PATH + "/small/train.txt", + help="A csv or a json file containing the training data." + ) + parser.add_argument("--validation_file", + type=str, + default=DATASET_PATH + "/small/test.txt", + help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default="./models/music-bert", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default="./models/music-bert", + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default="./models/music-bert", + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--eval_steps", + type=int, + default=1000, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--nb_eval_steps", + type=int, + default=100, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=EXPERIMENT_PATH + './music/representation_learning/saved_models/pretraining_mlm/music-bert/', help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default='bert', + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=512, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument( + "--nb_tokens_per_note", + type=int, + default=5, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument("--cache_dir", type=str, default=None, help="The token to use to push to the Model Hub.") + + parser.add_argument( + "--line_by_line", + type=bool, + default=False, + help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=1, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, default="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", help="The token to use to push to the Model Hub.") + args = parser.parse_args() + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + args.max_seq_length = (args.max_seq_length // args.nb_tokens_per_note) * args.nb_tokens_per_note + + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + logger.info(accelerator.device) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if accelerator.is_main_process: + save_dir = args.output_dir + 'run' + candidate_save_dir = save_dir + trial_id = 0 + while os.path.exists(candidate_save_dir): + trial_id += 1 + candidate_save_dir = save_dir + f'_{trial_id}' + save_dir = candidate_save_dir + os.makedirs(save_dir) + args.output_dir = save_dir + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if accelerator.is_main_process: + print("\nLoading Dataset") + data_files = {} + data_files["train"] = args.train_file + data_files["validation"] = args.validation_file + dataset = load_dataset("text", data_files=data_files, cache_dir=args.cache_dir) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir, use_fast=not args.use_slow_tokenizer) + + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # Since we make sure that all sequences are of the same length, no attention_mask is needed. + def tokenize_function(examples): + return tokenizer(examples['text'], return_attention_mask=False) + with accelerator.main_process_first(): + tokenized_datasets = dataset.map(tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_cache, + ) + + def group_texts(examples): + results = dict() + for k in examples.keys(): + results[k] = [] + ex = examples[k] + for e in ex: + results[k] += [e[i: i + args.max_seq_length] for i in range(0, len(e) - args.max_seq_length, args.max_seq_length)] + return results + + # chunk into groups of size args.max_sequence_len + with accelerator.main_process_first(): + tokenized_datasets = tokenized_datasets.map(group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + ) + if accelerator.is_main_process: + print(" len of train_loader", len(tokenized_datasets['train'])) + print(" len of valid_loader", len(tokenized_datasets['validation'])) + + if accelerator.is_main_process: + if torch.cuda.is_available(): + print("Use %d GPUS" % torch.cuda.device_count()) + else: + print('Use cpu.') + + + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + + + logger.info("Training new model from scratch") + model = AutoModelForMaskedLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + train_dataset = tokenized_datasets["train"] + eval_dataset = tokenized_datasets["validation"] + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = DataCollatorForMusicModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability) + + # DataLoaders creation: + train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + print('eval dataloader len', len(eval_dataloader)) + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be + # shorter in multiprocess) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler(name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + + n_steps = 0 + best_eval_loss = float("inf") + best_eval_acc = -float("inf") + + for epoch in range(args.num_train_epochs): + model.train() + train_losses = [] + for step, batch in enumerate(train_dataloader): + n_steps += 1 + outputs = model(**batch) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + train_losses.append(accelerator.gather(loss.repeat(args.per_device_train_batch_size))) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + if n_steps % args.eval_steps == 0 or step == len(train_dataloader) - 1: + model.eval() + eval_losses = [] + eval_accuracies = [] + for step, batch in enumerate(eval_dataloader): + if step > args.nb_eval_steps: + break + with torch.no_grad(): + outputs = model(**batch) + eval_accuracies.append(accelerator.gather(compute_accuracy(outputs, batch))) + loss = outputs.loss + eval_losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) + + eval_losses = torch.cat(eval_losses) + eval_accuracies = torch.cat(eval_accuracies) + try: + perplexity = math.exp(torch.mean(eval_losses)) + except OverflowError: + perplexity = float("inf") + + if torch.mean(eval_accuracies) > best_eval_acc: + best_eval_acc = torch.mean(eval_accuracies) + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + if accelerator.is_main_process: + logger.info("New best score, saving checkpoint") + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + if accelerator.is_main_process: + logger.info(f"\nEval: epoch {epoch}, step: {step} (total step: {n_steps}), --- perplexity: {perplexity:.3f}, loss: {torch.mean(eval_losses):.3f}, " + f"acc: {torch.mean(eval_accuracies):.3f}") + + if n_steps % args.logging_steps == 0 or step == len(train_dataloader) - 1: + this_train_losses = torch.cat(train_losses) + this_train_losses = this_train_losses[-args.logging_steps:] + try: + perplexity = math.exp(torch.mean(this_train_losses)) + except OverflowError: + perplexity = float("inf") + if accelerator.is_main_process: + logger.info(f"\nTrain: epoch {epoch}, step: {step} (total step: {n_steps}) --- perplexity: {perplexity:.3f}, loss: {torch.mean(this_train_losses):.3f}") + +def compute_accuracy(outputs, batch): + predictions = torch.argmax(outputs['logits'], dim=2) + labels = batch['labels'] + inds = torch.where(labels != -100) + accuracy = torch.mean((predictions[inds] == labels[inds]).float()) + return torch.atleast_1d(accuracy) + + + +if __name__ == "__main__": + main() diff --git a/src/music/representation_learning/mlm_pretrain/pretrain_span.py b/src/music/representation_learning/mlm_pretrain/pretrain_span.py new file mode 100644 index 0000000000000000000000000000000000000000..513e8e2c6e6373eae9f80d5a21b90f8c37d9b0e0 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/pretrain_span.py @@ -0,0 +1,527 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=fill-mask +""" +# You can also adapt this script on your own mlm task. Pointers for this are left as comments. + +import argparse +import logging +import math +import os +from pathlib import Path +from accelerate import DistributedDataParallelKwargs +import numpy as np +import torch +import datasets +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from data_collators import DataCollatorForMusicModeling, DataCollatorForSpanMusicModeling + +import transformers +from accelerate import Accelerator, DistributedType +from huggingface_hub import Repository +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + SchedulerType, + get_scheduler, + set_seed, +) +from transformers.file_utils import get_full_repo_name +from transformers.utils.versions import require_version +from src.music.config import DATASET_PATH, EXPERIMENT_PATH + + +logger = logging.getLogger(__name__) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument("--train_file", + type=str, + default=DATASET_PATH + "/small/train.txt", + help="A csv or a json file containing the training data." + ) + parser.add_argument("--validation_file", + type=str, + default=DATASET_PATH + "/small/test.txt", + help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default="./models/music-spanbert", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default="./models/music-spanbert", + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default="./models/music-spanbert", + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--eval_steps", + type=int, + default=1000, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--nb_eval_steps", + type=int, + default=100, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=EXPERIMENT_PATH + './music/representation_learning/saved_models/pretraining_mlm/music-spanbert/', help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default='bert', + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--max_seq_length", + type=int, + default=512, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument( + "--nb_tokens_per_note", + type=int, + default=5, + help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated.", + ) + parser.add_argument("--cache_dir", type=str, default=None, help="The token to use to push to the Model Hub.") + + parser.add_argument( + "--line_by_line", + type=bool, + default=False, + help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=1, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, default="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", help="The token to use to push to the Model Hub.") + args = parser.parse_args() + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + args.max_seq_length = (args.max_seq_length // args.nb_tokens_per_note) * args.nb_tokens_per_note + + + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + logger.info(accelerator.device) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if accelerator.is_main_process: + save_dir = args.output_dir + 'run' + candidate_save_dir = save_dir + trial_id = 0 + while os.path.exists(candidate_save_dir): + trial_id += 1 + candidate_save_dir = save_dir + f'_{trial_id}' + save_dir = candidate_save_dir + os.makedirs(save_dir) + args.output_dir = save_dir + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if accelerator.is_main_process: + print("\nLoading Dataset") + data_files = {} + data_files["train"] = args.train_file + data_files["validation"] = args.validation_file + dataset = load_dataset("text", data_files=data_files, cache_dir=args.cache_dir) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir, use_fast=not args.use_slow_tokenizer) + + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # Since we make sure that all sequences are of the same length, no attention_mask is needed. + def tokenize_function(examples): + return tokenizer(examples['text'], return_attention_mask=False) + with accelerator.main_process_first(): + tokenized_datasets = dataset.map(tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_cache, + ) + + def group_texts(examples): + results = dict() + for k in examples.keys(): + results[k] = [] + ex = examples[k] + for e in ex: + results[k] += [e[i: i + args.max_seq_length] for i in range(0, len(e) - args.max_seq_length, args.max_seq_length)] + return results + + # chunk into groups of size args.max_sequence_len + with accelerator.main_process_first(): + tokenized_datasets = tokenized_datasets.map(group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + ) + if accelerator.is_main_process: + print(" len of train_loader", len(tokenized_datasets['train'])) + print(" len of valid_loader", len(tokenized_datasets['validation'])) + + if accelerator.is_main_process: + if torch.cuda.is_available(): + print("Use %d GPUS" % torch.cuda.device_count()) + else: + print('Use cpu.') + + + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + + + logger.info("Training new model from scratch") + model = AutoModelForMaskedLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + train_dataset = tokenized_datasets["train"] + eval_dataset = tokenized_datasets["validation"] + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = DataCollatorForSpanMusicModeling(tokenizer=tokenizer, mlm_probability=args.mlm_probability) + + # DataLoaders creation: + train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size) + eval_dataloader = DataLoader(eval_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + print('eval dataloader len', len(eval_dataloader)) + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be + # shorter in multiprocess) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler(name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + + n_steps = 0 + best_eval_loss = float("inf") + best_eval_acc = -float("inf") + for epoch in range(args.num_train_epochs): + model.train() + train_losses = [] + for step, batch in enumerate(train_dataloader): + n_steps += 1 + outputs = model(**batch) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + train_losses.append(accelerator.gather(loss.repeat(args.per_device_train_batch_size))) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + if n_steps % args.eval_steps == 0 or step == len(train_dataloader) - 1: + model.eval() + eval_losses = [] + eval_accuracies = [] + for step, batch in enumerate(eval_dataloader): + if step > args.nb_eval_steps: + break + with torch.no_grad(): + outputs = model(**batch) + eval_accuracies.append(accelerator.gather(compute_accuracy(outputs, batch))) + loss = outputs.loss + eval_losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) + + eval_losses = torch.cat(eval_losses) + eval_accuracies = torch.cat(eval_accuracies) + try: + perplexity = math.exp(torch.mean(eval_losses)) + except OverflowError: + perplexity = float("inf") + + if torch.mean(eval_accuracies) > best_eval_acc: + best_eval_acc = torch.mean(eval_accuracies) + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + if accelerator.is_main_process: + logger.info("New best score, saving checkpoint") + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + if accelerator.is_main_process: + logger.info(f"\nEval: epoch {epoch}, step: {step} (total step: {n_steps}), --- perplexity: {perplexity:.3f}, loss: {torch.mean(eval_losses):.3f}, " + f"acc: {torch.mean(eval_accuracies):.3f}") + + if n_steps % args.logging_steps == 0 or step == len(train_dataloader) - 1: + this_train_losses = torch.cat(train_losses) + this_train_losses = this_train_losses[-args.logging_steps:] + try: + perplexity = math.exp(torch.mean(this_train_losses)) + except OverflowError: + perplexity = float("inf") + if accelerator.is_main_process: + logger.info(f"\nTrain: epoch {epoch}, step: {step} (total step: {n_steps}) --- perplexity: {perplexity:.3f}, loss: {torch.mean(this_train_losses):.3f}") + +def compute_accuracy(outputs, batch): + predictions = torch.argmax(outputs['logits'], dim=2) + labels = batch['labels'] + inds = torch.where(labels != -100) + accuracy = torch.mean((predictions[inds] == labels[inds]).float()) + return torch.atleast_1d(accuracy) + + +if __name__ == "__main__": + main() diff --git a/src/music/representation_learning/mlm_pretrain/pretrain_t5.py b/src/music/representation_learning/mlm_pretrain/pretrain_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..056b261ad2549beddd848779bdd0271311f225ba --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/pretrain_t5.py @@ -0,0 +1,975 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset. +Here is the full list of checkpoints on the hub that can be pretrained by this script: +Here is the full list of checkpoints on the hub that can be pretrained by this script: +https://huggingface.co/models?filter=t5 +""" +import json +import logging +import os +import sys +import time +from dataclasses import asdict, dataclass, field + +# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. +from enum import Enum +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from datasets import load_dataset +from tqdm import tqdm + +import flax +import jax +import jax.numpy as jnp +import optax +from flax import jax_utils, traverse_util +from flax.training import train_state +from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository, HfApi +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_MASKED_LM_MAPPING, + AutoTokenizer, + BatchEncoding, + FlaxT5ForConditionalGeneration, + HfArgumentParser, + PreTrainedTokenizerBase, + T5Config, + is_tensorboard_available, + set_seed, +) +from transformers.models.t5.modeling_flax_t5 import shift_tokens_right +from src.music.config import DATASET_PATH, EXPERIMENT_PATH + + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class TrainingArguments: + output_dir: str = field(default=EXPERIMENT_PATH + "./music/representation_learning/saved_models/pretraining_mlm/music-t5-small", metadata={"help": "The output directory where the model predictions and checkpoints will be written."},) + overwrite_output_dir: bool = field( + default=True, + metadata={ + "help": ( + "Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, + ) + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + per_device_train_batch_size: int = field( + default=32, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} + ) + learning_rate: float = field(default=0.005, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.001, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + adafactor: bool = field(default=True, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) + num_train_epochs: float = field(default=100.0, metadata={"help": "Total number of training epochs to perform."}) + warmup_steps: int = field(default=2000, metadata={"help": "Linear warmup over warmup_steps."}) + logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + eval_steps: int = field(default=500, metadata={"help": "Run an evaluation every X steps."}) + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) + hub_model_id: str = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_token: str = field(default="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", metadata={"help": "The token to use to push to the Model Hub."}) + + def __post_init__(self): + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + d = asdict(self) + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default='t5', + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default="./music-t5-small", metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default="./music-t5-small", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=DATASET_PATH + "/small/train.txt", + metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=DATASET_PATH + "/small/test.txt", + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + train_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, + ) + validation_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + max_seq_length: Optional[int] = field( + default=512, # 101 notes * 5 tokens = 505 + 6 spans token + metadata={ + "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model." + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=3, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"} + ) + mean_noise_span_length: float = field( + default=3.0, + metadata={"help": "Mean span length of masked notes"}, + ) + nb_tokens_per_note: int = field( + default=5, + metadata={"help": "Nb of tokens representing a single note"}, + ) + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length): + """This function is copy of `random_spans_helper `__ . + Training parameters to avoid padding with random_spans_noise_mask. + When training a model with random_spans_noise_mask, we would like to set the other + training hyperparmeters in a way that avoids padding. + This function helps us compute these hyperparameters. + We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, + and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. + This function tells us the required number of tokens in the raw example (for split_tokens()) + as well as the length of the encoded targets. Note that this function assumes + the inputs and targets will have EOS appended and includes that in the reported length. + Args: + inputs_length: an integer - desired length of the tokenized inputs sequence + noise_density: a float + mean_noise_span_length: a float + Returns: + tokens_length: length of original text in tokens + targets_length: an integer - length in tokens of encoded targets sequence + """ + + def _tokens_length_to_inputs_length_targets_length(tokens_length): + num_noise_tokens = int(round(tokens_length * noise_density)) + num_nonnoise_tokens = tokens_length - num_noise_tokens + num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) + # inputs contain all nonnoise tokens, sentinels for all noise spans + # and one EOS token. + _input_length = num_nonnoise_tokens + num_noise_spans + 1 + _output_length = num_noise_tokens + num_noise_spans + 2 # add one to match the mess create by having 5 tokens per note + return _input_length, _output_length + + tokens_length = inputs_length + + while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: + tokens_length += 1 + + inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) + + # minor hack to get the targets length to be equal to inputs length + # which is more likely to have been set to a nice round number. + if noise_density == 0.5 and targets_length > inputs_length: + tokens_length -= 1 + targets_length -= 1 + return tokens_length, targets_length + + +@flax.struct.dataclass +class FlaxDataCollatorForT5MLM: + """ + Data collator used for T5 span-masked language modeling. + It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length. + For more information on how T5 span-masked language modeling works, one can take a look + at the `official paper `__ + or the `official code for preprocessing `__ . + Args: + tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): + The tokenizer used for encoding the data. + noise_density (:obj:`float`): + The probability with which to (randomly) mask tokens in the input. + mean_noise_span_length (:obj:`float`): + The average span length of the masked tokens. + input_length (:obj:`int`): + The expected input length after masking. + target_length (:obj:`int`): + The expected target length after masking. + pad_token_id: (:obj:`int`): + The pad token id of the model + decoder_start_token_id: (:obj:`int): + The decoder start token id of the model + """ + + tokenizer: PreTrainedTokenizerBase + noise_density: float + mean_noise_span_length: float + nb_tokens_per_note: int + input_length: int + target_length: int + pad_token_id: int + decoder_start_token_id: int + + def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: + + # convert list to dict and tensorize input + batch = BatchEncoding( + {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} + ) + + input_ids = batch["input_ids"] + batch_size, expandend_input_length = input_ids.shape + + mask_indices = np.asarray([self.music_random_spans_noise_mask(expandend_input_length) for i in range(batch_size)]) + labels_mask = ~mask_indices + + input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) + labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) + + batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) + batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel) + + if batch["input_ids"].shape[-1] != self.input_length: + raise ValueError( + f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.input_length}." + ) + + if batch["labels"].shape[-1] != self.target_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}." + ) + + # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here... + batch["decoder_input_ids"] = shift_tokens_right( + batch["labels"], self.pad_token_id, self.decoder_start_token_id + ) + + return batch + + def create_sentinel_ids(self, mask_indices): + """ + Sentinel ids creation given the indices that should be masked. + The start indices of each mask are replaced by the sentinel ids in increasing + order. Consecutive mask indices to be deleted are replaced with `-1`. + """ + start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices + start_indices[:, 0] = mask_indices[:, 0] + + sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) + sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) + sentinel_ids -= mask_indices - start_indices + + return sentinel_ids + + def filter_input_ids(self, input_ids, sentinel_ids): + """ + Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. + This will reduce the sequence length from `expanded_inputs_length` to `input_length`. + """ + batch_size = input_ids.shape[0] + + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) + input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1)) + input_ids = np.concatenate( + [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 + ) + return input_ids + + def random_spans_noise_mask(self, length): + + """This function is copy of `random_spans_helper `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add + up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) + nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + def music_random_spans_noise_mask(self, length): + + """This function is copy of `random_spans_helper `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + length_notes = length // self.nb_tokens_per_note + num_noise_notes = int(np.round(length_notes * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_notes = min(max(num_noise_notes, 1), length - 1) + num_noise_spans = int(np.round(num_noise_notes / self.mean_noise_span_length)) + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_notes = length_notes - num_noise_notes + + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add + up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_notes, num_noise_spans) + nonnoise_span_lengths = _random_segmentation(num_nonnoise_notes, num_noise_spans) + + interleaved_span_lengths = np.reshape(np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length_notes,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + is_noise = np.repeat(is_noise, self.nb_tokens_per_note) + return is_noise[:orig_length] + +def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: + num_samples = len(samples_idx) + samples_to_remove = num_samples % batch_size + + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + batch_idx = np.split(samples_idx, sections_split) + return batch_idx + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.add_scalar("train_time", train_time, step) + + train_metrics = get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.add_scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step): + for metric_name, value in eval_metrics.items(): + summary_writer.add_scalar(f"eval_{metric_name}", value, step) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + level=logging.INFO, + datefmt="[%X]", + ) + + # Log on each process the small summary: + logger = logging.getLogger(__name__) + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + api = HfApi() + repo_name = api.get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if model_args.config_name: + config = T5Config.from_pretrained( + model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer) + ) + elif model_args.model_name_or_path: + config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # Since we make sure that all sequences are of the same length, no attention_mask is needed. + def tokenize_function(examples): + return tokenizer(examples[text_column_name], return_attention_mask=False) + + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. + # To ensure that the input length is `max_seq_length`, we need to increase the maximum length + # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. + expanded_inputs_length, targets_length = compute_input_and_target_lengths( + inputs_length=513, + noise_density=data_args.mlm_probability, + mean_noise_span_length=data_args.mean_noise_span_length * data_args.nb_tokens_per_note, + ) + + nb_tokens_per_note = data_args.nb_tokens_per_note + nb_notes_in_expanded_inputs = expanded_inputs_length // nb_tokens_per_note + expanded_inputs_length = nb_notes_in_expanded_inputs * nb_tokens_per_note + # custom function to get input sequences of same length without concatenating different songs + def group_texts(examples): + results = dict() + for k in examples.keys(): + results[k] = [] + ex = examples[k] + for e in ex: + nb_chunks = len(e) // expanded_inputs_length + results[k] += [e[i: i + expanded_inputs_length] for i in range(0, len(e) - expanded_inputs_length, expanded_inputs_length)] + return results + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. + # def group_texts(examples): + # # Concatenate all texts. + # concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + # total_length = len(concatenated_examples[list(examples.keys())[0]]) + # # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # # customize this part to your needs. + # if total_length >= expanded_inputs_length: + # total_length = (total_length // expanded_inputs_length) * expanded_inputs_length + # # Split by chunks of max_len. + # result = { + # k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)] + # for k, t in concatenated_examples.items() + # } + # return result + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a + # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value + # might be slower to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + print(f'{len(tokenized_datasets["train"])} training examples') + print(f'{len(tokenized_datasets["validation"])} validation examples') + print(f'{jax.device_count()} devices: {jax.devices()}') + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from tensorboardX import SummaryWriter + # from flax.metrics.tensorboard import SummaryWriter + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + + if model_args.model_name_or_path: + model = FlaxT5ForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + config.vocab_size = len(tokenizer) + model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + + # Data collator + # This one will take care of randomly masking the tokens. + data_collator = FlaxDataCollatorForT5MLM( + tokenizer=tokenizer, + noise_density=data_args.mlm_probability, + mean_noise_span_length=data_args.mean_noise_span_length, + nb_tokens_per_note=nb_tokens_per_note, + input_length=max_seq_length, + target_length=targets_length, + pad_token_id=model.config.pad_token_id, + decoder_start_token_id=model.config.decoder_start_token_id, + ) + + # Store some constant + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() + eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + + num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs + + # Create learning rate schedule + warmup_fn = optax.linear_schedule( + init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps + ) + decay_fn = optax.linear_schedule( + init_value=training_args.learning_rate, + end_value=0, + transition_steps=num_train_steps - training_args.warmup_steps, + ) + linear_decay_lr_schedule_fn = optax.join_schedules( + schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps] + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = { + path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) + for path in flat_params + } + return traverse_util.unflatten_dict(flat_mask) + + # create adam optimizer + if training_args.adafactor: + # We use the default parameters here to initialize adafactor, + # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) + + # Setup train state + state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) + + # Define gradient update step fn + def train_step(state, batch, dropout_rng): + dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) + + def loss_fn(params): + labels = batch.pop("labels") + + logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] + + # compute loss + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean() + + return loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grad = grad_fn(state.params) + grad = jax.lax.pmean(grad, "batch") + new_state = state.apply_gradients(grads=grad) + + metrics = jax.lax.pmean( + {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" + ) + + return new_state, metrics, new_dropout_rng + + # Create parallel version of the train step + p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + + logits = model(**batch, params=params, train=False)[0] + + # compute loss + loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) + + # compute accuracy + accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) + + # summarize metrics + metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} + metrics = jax.lax.pmean(metrics, axis_name="batch") + + return metrics + + + p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,)) + + # Replicate the train state on each device + state = jax_utils.replicate(state) + + train_time = 0 + epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) + for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + train_metrics = [] + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by shuffling sampling indices from the train dataset + num_train_samples = len(tokenized_datasets["train"]) + train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) + + # Gather the indexes for creating the batch and do a training step + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): + samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + + # Model forward + model_inputs = shard(model_inputs.data) + state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) + train_metrics.append(train_metric) + + cur_step = epoch * (num_train_samples // train_batch_size) + step + + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_metric = jax_utils.unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" + ) + + train_metrics = [] + + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) + + # get eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + + # Update progress bar + epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})") + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + write_eval_metric(summary_writer, eval_metrics, cur_step) + + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir, legacy_format=False) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + + # Eval after training + if training_args.do_eval: + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) + + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) + + # get eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) + + if jax.process_index() == 0: + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/music/representation_learning/mlm_pretrain/saving_tokenizer_and_config.py b/src/music/representation_learning/mlm_pretrain/saving_tokenizer_and_config.py new file mode 100644 index 0000000000000000000000000000000000000000..593cef776ad2b39db4301aa247e1c56dc7cb413b --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/saving_tokenizer_and_config.py @@ -0,0 +1,41 @@ +from transformers import T5Config, BertConfig +from my_tokenizer import MyT5Tokenizer, MyBERTTokenizer + +models = ['t5', 'bert', 'spanbert'] # 't5' 'bert' 'spanbert' + +for model in models: + if model == 't5': + tokens = ['', '', ''] + [str(i) for i in range(2, 183)] + tokens_ids = list(range(len(tokens))) + vocab = dict(zip(tokens, tokens_ids)) + + tokenizer = MyT5Tokenizer(vocab=vocab, unk_token=tokens[2], eos_token=tokens[1], pad_token=tokens[0]) + assert tokenizer.decode(tokenizer.encode('0 2 3 182 183').ids) == ' 2 3 182 ' + tokenizer.save("./models/music-t5-small/tokenizer.json") + + # config = T5Config.from_pretrained("t5-small", vocab_size=tokenizer.get_vocab_size()) + config = T5Config.from_json_file(json_file="/home/cedric/Downloads/config.json") + config.save_pretrained("./models/music-t5-small") + elif model == 'bert': + + tokens = ['[PAD]', '[MASK]', '[UNK]'] + [str(i) for i in range(2, 183)] + tokens_ids = list(range(len(tokens))) + vocab = dict(zip(tokens, tokens_ids)) + + tokenizer = MyBERTTokenizer(vocab=vocab, unk_token=tokens[2]) + assert tokenizer.decode(tokenizer.encode('0 2 3 182 183').ids) == '[UNK] 2 3 182 [UNK]' + tokenizer.save("./models/music-bert/tokenizer.json") + + config = BertConfig(position_embedding_type='relative_key_query') + config.save_pretrained("./models/music-bert") + elif model == 'spanbert': + tokens = ['[PAD]', '[MASK]', '[UNK]'] + [str(i) for i in range(2, 183)] + tokens_ids = list(range(len(tokens))) + vocab = dict(zip(tokens, tokens_ids)) + + tokenizer = MyBERTTokenizer(vocab=vocab, unk_token=tokens[2]) + assert tokenizer.decode(tokenizer.encode('0 2 3 182 183').ids) == '[UNK] 2 3 182 [UNK]' + tokenizer.save("./models/music-spanbert/tokenizer.json") + + config = BertConfig(position_embedding_type='relative_key_query') + config.save_pretrained("./models/music-spanbert") diff --git a/src/music/representation_learning/mlm_pretrain/uploading_trained_model.py b/src/music/representation_learning/mlm_pretrain/uploading_trained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb89c4d5e893f8d1b58bd7c3595cb4ebf4ce963 --- /dev/null +++ b/src/music/representation_learning/mlm_pretrain/uploading_trained_model.py @@ -0,0 +1,24 @@ +from transformers import AutoModel, AutoTokenizer + +model_paths = ['music-bert/run', 'music-bert/run_1', 'music-spanbert/run', 'music-spanbert/run_1', 't5/t5_small_dataset', 't5/t5_huge_dataset'] +hf_paths = ["music-bert-base-small-data", "music-bert-base-large-data", "music-spanbert-base-small-data", "music-spanbert-base-large-data", "music-t5-small-small-data", + "music-t5-small-large-data"] + +model_paths = ['t5/t5_small_dataset', 't5/t5_huge_dataset'] +hf_paths = ["music-t5-small-small-data", "music-t5-small-large-data"] + +model_base_path = './experiments/' +hf_path = "ccolas/" + +for m, hf in zip(model_paths, hf_paths): + print(m) + hf_p = hf_path + hf + m_p = model_base_path + m + tokenizer = AutoTokenizer.from_pretrained(m_p) + if 't5' in hf_p: + model = AutoModel.from_pretrained(m_p, from_flax=True) + else: + model = AutoModel.from_pretrained(m_p) + model.push_to_hub(hf_p) + tokenizer.push_to_hub(hf_p) +stop = 1 \ No newline at end of file diff --git a/src/music/representation_learning/sentence_bert_finetuning/__init__.py b/src/music/representation_learning/sentence_bert_finetuning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/representation_learning/sentence_bert_finetuning/song_embedding.py b/src/music/representation_learning/sentence_bert_finetuning/song_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..de413cb0c1d54af7f9f8adc8559fe455d3b2688c --- /dev/null +++ b/src/music/representation_learning/sentence_bert_finetuning/song_embedding.py @@ -0,0 +1,239 @@ +import argparse +import os +from datasets import load_dataset +from ..sentence_transfo.sentence_transformers import SentenceTransformer +from ..sentence_transfo.sentence_transformers import models, losses +from torch.utils.data import DataLoader +import numpy as np +import logging +from accelerate import Accelerator, DistributedDataParallelKwargs +import datasets +import transformers +import torch +from torch import nn +import json +from src.music.config import DATASET_PATH, EXPERIMENT_PATH + +logger = logging.getLogger(__name__) + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a Masked Language Modeling task") + parser.add_argument("--train_file", + type=str, + default=DATASET_PATH + "/small/train_stacked_aug.txt", + help="A csv or a json file containing the training data." + ) + parser.add_argument("--expe_name", + type=str, + default="", + help="A csv or a json file containing the training data." + ) + parser.add_argument("--validation_file", + type=str, + default=DATASET_PATH + "/small/test_stacked_aug.txt", + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--sentence_embedding_model", + type=str, + default="", + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--model_name", + type=str, + default='ccolas/music-bert-base-small-data', + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--pooling", + type=str, + default='mean', + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--overwrite_cache", + type=bool, + default=True, + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--preprocessing_num_workers", + type=int, + default=1, + help="A csv or a json file containing the validation data." + ) + parser.add_argument("--cache_dir", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--output_dir", type=str, default=EXPERIMENT_PATH + '/music/representation_learning/saved_models/sentence_embedding/local/', + help="Where to store the final model.") + parser.add_argument("--max_seq_length", type=int, default=512) + parser.add_argument("--nb_tokens_per_note", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=3) + parser.add_argument("--pair_per_song", type=int, default=10) + parser.add_argument("--rep_size", type=int, default=0) + + args = parser.parse_args() + return args + +def setup_sentence_transfo_model(args): + # Define your sentence transformer model using CLS pooling + word_embedding_model = models.Transformer(args.model_name, max_seq_length=args.max_seq_length) + if 't5' in args.model_name: + word_embedding_model.auto_model = word_embedding_model.auto_model.encoder + pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=args.pooling) + if args.rep_size > 0: + dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=args.rep_size, activation_function=nn.Tanh()) + model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) + else: + model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) + return model + +class Argument(object): + def __init__(self, adict): + self.__dict__.update(adict) + +def setup_dataset(args, accelerator): + data_files = {} + data_files["train"] = args.train_file + data_files["validation"] = args.validation_file + dataset = load_dataset("text", data_files=data_files, cache_dir=args.cache_dir) + + def group_texts(examples): + results = dict() + results['texts'] = [] + results['label'] = [] + ex = examples['text'] + for e in ex: + pairs = [] + augs = e.split('&') + aug_chunks = [] + nb_aug = len(augs) + nb_chunks = [] + for aug in augs: + aug = aug.split(' ') + aug_chunk = [' '.join(aug[i: i + args.max_seq_length]) for i in range(0, len(aug) - args.max_seq_length, args.max_seq_length)] + nb_chunks.append(len(aug_chunk)) + aug_chunks.append(aug_chunk) + nb_chunks = np.min(nb_chunks) + if nb_chunks != 0: + if nb_chunks >= 2: + while len(pairs) < min(nb_aug * nb_chunks, args.pair_per_song): + chunk_ids = np.arange(nb_chunks) + np.random.shuffle(chunk_ids) + for index in range(0, nb_chunks - 1, 2): + aug_ids = np.random.choice(np.arange(nb_aug), size=2, replace=False) + chk_id = chunk_ids[index:index+2] + pairs.append([aug_chunks[aug_ids[0]][chk_id[0]], aug_chunks[aug_ids[1]][chk_id[1]]]) + if len(pairs) == min(nb_aug * nb_chunks, args.pair_per_song): + break + else: + # use same chunk (chunk 0) + for i in range(3): + aug_ids = np.random.choice(np.arange(nb_aug), size=2, replace=False) + pairs.append([aug_chunks[aug_ids[0]][0], aug_chunks[aug_ids[1]][0]]) + results['texts'] += pairs + results['label'] += [0 for _ in range(len(pairs))] + return results + + with accelerator.main_process_first(): + dataset = dataset.map(group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + # writer_batch_size=3_000, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_cache, + ) + + train_dataset = dataset['train'] + validation_dataset = dataset['validation'] + + # DataLoader to batch your data + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) + validation_dataloader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=True) + + return train_dataloader, validation_dataloader + +def get_output_dir(args): + if args.expe_name == '': + args.expe_name = 'run' + save_dir = args.output_dir + args.expe_name + candidate_save_dir = save_dir + trial_id = 0 + while os.path.exists(candidate_save_dir): + trial_id += 1 + candidate_save_dir = save_dir + f'_{trial_id}' + save_dir = candidate_save_dir + '/' + os.makedirs(save_dir) + return save_dir + +def train(): + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + args = parse_args() + args.max_seq_length = (args.max_seq_length // args.nb_tokens_per_note) * args.nb_tokens_per_note + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) + logger.info(accelerator.state) + logger.info(accelerator.device) + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + + if accelerator.is_main_process: print('Setting up the model') + if args.sentence_embedding_model != '': + if accelerator.is_main_process: print(f' Loading pretrained model from {args.sentence_embedding_model}') + model = SentenceTransformer(args.sentence_embedding_model) + else: + model = setup_sentence_transfo_model(args) + print(model) + if accelerator.is_main_process: print('Building dataset') + train_dataloader, validation_dataloader = setup_dataset(args, accelerator) + if accelerator.is_main_process: + print(" len of train_loader", len(train_dataloader)) + print(" len of valid_loader", len(validation_dataloader)) + print(" total train data", len(train_dataloader.dataset)) + print(" total valid data", len(validation_dataloader.dataset)) + if accelerator.is_main_process: + args.output_dir = get_output_dir(args) + print(f'Saving results to {args.output_dir}') + + if accelerator.is_main_process: + if torch.cuda.is_available(): + print("Use %d GPUS" % torch.cuda.device_count()) + else: + print('Use cpu.') + params = vars(args) + with open(args.output_dir + 'params.json', 'w') as f: + json.dump(params, f) + + # Use the denoising auto-encoder loss + train_loss = losses.MultipleNegativesRankingLoss(model) + + accelerator.wait_for_everyone() + + # Call the fit method + model.fit(train_objectives=[(train_dataloader, train_loss)], + validation_dataloader=validation_dataloader, + epochs=100, + save_best_model=True, + gradient_accumulation=1, + output_path=args.output_dir, + evaluate_every_steps=1000, + log_every_steps=500, + nb_eval_steps=100, + show_progress_bar=True, + accelerator=accelerator) + + +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/src/music/representation_learning/sentence_transfo/LICENSE b/src/music/representation_learning/sentence_transfo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..782dfcd04cd9ec8f8a6b02b6ab55127720e16d53 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/src/music/representation_learning/sentence_transfo/README.md b/src/music/representation_learning/sentence_transfo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..09d13a0a31a2e87249b37e39facc5858d9360c34 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/README.md @@ -0,0 +1,168 @@ +# Sentence Transformers: Multilingual Sentence, Paragraph, and Image Embeddings using BERT & Co. + +This framework provides an easy method to compute dense vector representations for **sentences**, **paragraphs**, and **images**. The models are based on transformer networks like BERT / RoBERTa / XLM-RoBERTa etc. and achieve state-of-the-art performance in various task. Text is embedding in vector space such that similar text is close and can efficiently be found using cosine similarity. + + +We provide an increasing number of **[state-of-the-art pretrained models](https://www.sbert.net/docs/pretrained_models.html)** for more than 100 languages, fine-tuned for various use-cases. + +Further, this framework allows an easy **[fine-tuning of custom embeddings models](https://www.sbert.net/docs/training/overview.html)**, to achieve maximal performance on your specific task. + + +For the **full documentation**, see **[www.SBERT.net](https://www.sbert.net)**. + +The following publications are integrated in this framework: +- [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019) +- [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020) +- [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (NAACL 2021) +- [The Curse of Dense Low-Dimensional Information Retrieval for Large Index Sizes](https://arxiv.org/abs/2012.14210) (arXiv 2020) +- [TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning](https://arxiv.org/abs/2104.06979) (arXiv 2021) +- [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663) (arXiv 2021) + + +## Installation +We recommend **Python 3.6** or higher, **[PyTorch 1.6.0](https://pytorch.org/get-started/locally/)** or higher and **[transformers v4.6.0](https://github.com/huggingface/transformers)** or higher. The code does **not** work with Python 2.7. + + + + +**Install with pip** + +Install the *sentence-transformers* with `pip`: +``` +pip install -U sentence-transformers +``` + +**Install from sources** + +Alternatively, you can also clone the latest version from the [repository](https://github.com/UKPLab/sentence-transformers) and install it directly from the source code: +```` +pip install -e . +```` + +**PyTorch with CUDA** +If you want to use a GPU / CUDA, you must install PyTorch with the matching CUDA Version. Follow +[PyTorch - Get Started](https://pytorch.org/get-started/locally/) for further details how to install PyTorch. + + + +## Getting Started + +See [Quickstart](https://www.sbert.net/docs/quickstart.html) in our documenation. + + +[This example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/computing-embeddings/computing_embeddings.py) shows you how to use an already trained Sentence Transformer model to embed sentences for another task. + +First download a pretrained model. +````python +from sentence_transformers import SentenceTransformer +model = SentenceTransformer('all-MiniLM-L6-v2') +```` +Then provide some sentences to the model. +````python +sentences = ['This framework generates embeddings for each input sentence', + 'Sentences are passed as a list of string.', + 'The quick brown fox jumps over the lazy dog.'] +sentence_embeddings = model.encode(sentences) +```` +And that's it already. We now have a list of numpy arrays with the embeddings. +````python +for sentence, embedding in zip(sentences, sentence_embeddings): + print("Sentence:", sentence) + print("Embedding:", embedding) + print("") +```` + +## Pre-Trained Models + +We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`. + +[» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html) + + + +## Training +This framework allows you to fine-tune your own sentence embedding methods, so that you get task-specific sentence embeddings. You have various options to choose from in order to get perfect sentence embeddings for your specific task. + +See [Training Overview](https://www.sbert.net/docs/training/overview.html) for an introduction how to train your own embedding models. We provide [various examples](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) how to train models on various datasets. + + +Some highlights are: +- Support of various transformer networks including BERT, RoBERTa, XLM-R, DistilBERT, Electra, BART, ... +- Multi-Lingual and multi-task learning +- Evaluation during training to find optimal model +- [10+ loss-functions](https://www.sbert.net/docs/package_reference/losses.html) allowing to tune models specifically for semantic search, paraphrase mining, semantic similarity comparison, clustering, triplet loss, contrastive loss. + + + +## Performance + +Our models are evaluated extensively on 15+ datasets including challening domains like Tweets, Reddit, emails. They achieve by far the **best performance** from all available sentence embedding methods. Further, we provide several **smaller models** that are **optimized for speed**. + +[» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html) + + + + +## Application Examples +You can use this framework for: +- [Computing Sentence Embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html) +- [Semantic Textual Similarity](https://www.sbert.net/docs/usage/semantic_textual_similarity.html) +- [Clustering](https://www.sbert.net/examples/applications/clustering/README.html) +- [Paraphrase Mining](https://www.sbert.net/examples/applications/paraphrase-mining/README.html) + - [Translated Sentence Mining](https://www.sbert.net/examples/applications/parallel-sentence-mining/README.html) + - [Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html) + - [Retrieve & Re-Rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) + - [Text Summarization](https://www.sbert.net/examples/applications/text-summarization/README.html) +- [Multilingual Image Search, Clustering & Duplicate Detection](https://www.sbert.net/examples/applications/image-search/README.html) + +and many more use-cases. + + +For all examples, see [examples/applications](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications). + +## Citing & Authors +If you find this repository helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084): +```bibtex +@inproceedings{reimers-2019-sentence-bert, + title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks", + author = "Reimers, Nils and Gurevych, Iryna", + booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", + month = "11", + year = "2019", + publisher = "Association for Computational Linguistics", + url = "https://arxiv.org/abs/1908.10084", +} +``` + + +If you use one of the multilingual models, feel free to cite our publication [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813): +```bibtex +@inproceedings{reimers-2020-multilingual-sentence-bert, + title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation", + author = "Reimers, Nils and Gurevych, Iryna", + booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing", + month = "11", + year = "2020", + publisher = "Association for Computational Linguistics", + url = "https://arxiv.org/abs/2004.09813", +} +``` + +Please have a look at [Publications](https://www.sbert.net/docs/publications.html) for our different publications that are integrated into SentenceTransformers. + + +Contact person: [Nils Reimers](https://www.nils-reimers.de), [info@nils-reimers.de](mailto:info@nils-reimers.de) + +https://www.ukp.tu-darmstadt.de/ + + +Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions. + +> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. + + + + + + + diff --git a/src/music/representation_learning/sentence_transfo/__init__.py b/src/music/representation_learning/sentence_transfo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/representation_learning/sentence_transfo/__pycache__/__init__.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab36b06d494b6d2a948a160f6db9bb0b21c1d137 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/LoggingHandler.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/LoggingHandler.py new file mode 100644 index 0000000000000000000000000000000000000000..8f73660b57a3dfbf7dae0173f8d7af3a7e752112 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/LoggingHandler.py @@ -0,0 +1,56 @@ +import logging +import tqdm + +class LoggingHandler(logging.Handler): + def __init__(self, level=logging.NOTSET): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.handleError(record) + + +def install_logger( + given_logger, level = logging.WARNING, fmt="%(levelname)s:%(name)s:%(message)s" +): + """ Configures the given logger; format, logging level, style, etc """ + import coloredlogs + + def add_notice_log_level(): + """ Creates a new 'notice' logging level """ + # inspired by: + # https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility + NOTICE_LEVEL_NUM = 25 + logging.addLevelName(NOTICE_LEVEL_NUM, "NOTICE") + + def notice(self, message, *args, **kws): + if self.isEnabledFor(NOTICE_LEVEL_NUM): + self._log(NOTICE_LEVEL_NUM, message, args, **kws) + + logging.Logger.notice = notice + + # Add an extra logging level above INFO and below WARNING + add_notice_log_level() + + # More style info at: + # https://coloredlogs.readthedocs.io/en/latest/api.html + field_styles = coloredlogs.DEFAULT_FIELD_STYLES.copy() + field_styles["asctime"] = {} + level_styles = coloredlogs.DEFAULT_LEVEL_STYLES.copy() + level_styles["debug"] = {"color": "white", "faint": True} + level_styles["notice"] = {"color": "cyan", "bold": True} + + coloredlogs.install( + logger=given_logger, + level=level, + use_chroot=False, + fmt=fmt, + level_styles=level_styles, + field_styles=field_styles, + ) diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/SentenceTransformer.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/SentenceTransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd3b1780b6d6bec30766eebff5439bd08eb09bf --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/SentenceTransformer.py @@ -0,0 +1,1076 @@ +import json +import logging +import os +import shutil +import stat +import warnings +from collections import OrderedDict +from functools import partial +from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional +import requests +import numpy as np +from numpy import ndarray +import transformers +from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_url, cached_download +import torch +from torch import nn, Tensor, device +from torch.optim import Optimizer +from torch.utils.data import DataLoader +import torch.multiprocessing as mp +from tqdm.autonotebook import trange +import math +import queue +import tempfile +from distutils.dir_util import copy_tree +from accelerate import Accelerator + +from . import __MODEL_HUB_ORGANIZATION__ +from .evaluation import SentenceEvaluator +from .util import import_from_string, batch_to_device, fullname, snapshot_download, mismatched_sizes_all_gather +from .models import Transformer, Pooling, Dense +from .model_card_templates import ModelCardTemplate +from . import __version__ +import pandas as pd + +import sys +sys.path.append('../../') +import src.music.representation_learning.sentence_transfo as sentence_transfo +logger = logging.getLogger(__name__) + + +class SentenceTransformer(nn.Sequential): + """ + Loads or create a SentenceTransformer model, that can be used to map sentences / text to embeddings. + + :param model_name_or_path: If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from Huggingface models repository with that name. + :param modules: This parameter can be used to create custom SentenceTransformer models from scratch. + :param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used. + :param cache_folder: Path to store models + """ + + def __init__(self, model_name_or_path: Optional[str] = None, modules: Optional[Iterable[nn.Module]] = None, device: Optional[str] = None, cache_folder: Optional[str] = None, + **auto_model_kwargs): + self._model_card_vars = {} + self._model_card_text = None + self._model_config = {} + + if cache_folder is None: + cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME') + if cache_folder is None: + try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() + except ImportError: + torch_cache_home = os.path.expanduser(os.getenv('TORCH_HOME', os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) + + cache_folder = os.path.join(torch_cache_home, 'sentence_transformers') + + if model_name_or_path is not None and model_name_or_path != "": + logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) + + # Old models that don't belong to any organization + basic_transformer_models = ['albert-base-v1', 'albert-base-v2', 'albert-large-v1', 'albert-large-v2', 'albert-xlarge-v1', 'albert-xlarge-v2', 'albert-xxlarge-v1', + 'albert-xxlarge-v2', 'bert-base-cased-finetuned-mrpc', 'bert-base-cased', 'bert-base-chinese', 'bert-base-german-cased', + 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'bert-base-multilingual-cased', 'bert-base-multilingual-uncased', + 'bert-base-uncased', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking', 'bert-large-cased', + 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking', 'bert-large-uncased', 'camembert-base', + 'ctrl', 'distilbert-base-cased-distilled-squad', 'distilbert-base-cased', 'distilbert-base-german-cased', + 'distilbert-base-multilingual-cased', 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-finetuned-sst-2-english', + 'distilbert-base-uncased', 'distilgpt2', 'distilroberta-base', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'gpt2', 'openai-gpt', + 'roberta-base-openai-detector', 'roberta-base', 'roberta-large-mnli', 'roberta-large-openai-detector', 'roberta-large', 't5-11b', 't5-3b', + 't5-base', 't5-large', 't5-small', 'transfo-xl-wt103', 'xlm-clm-ende-1024', 'xlm-clm-enfr-1024', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', + 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', + 'xlm-roberta-base', 'xlm-roberta-large-finetuned-conll02-dutch', 'xlm-roberta-large-finetuned-conll02-spanish', + 'xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large-finetuned-conll03-german', 'xlm-roberta-large', 'xlnet-base-cased', + 'xlnet-large-cased'] + + if os.path.exists(model_name_or_path): + # Load from path + model_path = model_name_or_path + else: + # Not a path, load from hub + if '\\' in model_name_or_path or model_name_or_path.count('/') > 1: + raise ValueError("Path {} not found".format(model_name_or_path)) + + if '/' not in model_name_or_path and model_name_or_path.lower() not in basic_transformer_models: + # A model from sentence_transfo + model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path + + model_path = os.path.join(cache_folder, model_name_or_path.replace("/", "_")) + + # Download from hub with caching + snapshot_download(model_name_or_path, + cache_dir=cache_folder, + library_name='sentence_transfo', + library_version=__version__, + ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + + if os.path.exists(os.path.join(model_path, 'modules.json')): # Load as SentenceTransformer model + modules = self._load_sbert_model(model_path) + else: # Load with AutoModel + modules = self._load_auto_model(model_path, **auto_model_kwargs) + + if modules is not None and not isinstance(modules, OrderedDict): + modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)]) + + super().__init__(modules) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Use pytorch device: {}".format(device)) + + self._target_device = torch.device(device) + + def encode(self, sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + output_value: str = 'sentence_embedding', + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + num_proc=None) -> Union[List[Tensor], ndarray, Tensor]: + """ + Computes sentence embeddings + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Output a progress bar when encode sentences + :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values + :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param device: Which torch.device to use for the computation. By default, with + :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. + :param num_proc: How many processes to distribute the computation through. With `device=None`, will distribute the computation through all available GPUs. + + :return: + By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. + """ + self.eval() + if show_progress_bar is None: + show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG) + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != 'sentence_embedding': + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr(sentences, '__len__'): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + all_embeddings = [] + + # For distributed training + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + if device is None: + device = self._target_device + # defining the 0-dim sizes of the batches + sizes = [len(sentences_sorted) // world_size + (1 if rank < len(sentences_sorted) % world_size else 0) + for rank in range(world_size)] + # dividing the list of sentences into batches + limits = np.cumsum([0] + sizes) + local_sentences = sentences_sorted[limits[rank]:limits[rank + 1]] + # embedding + local_embeddings = [] + for start_index in trange(0, len(local_sentences), batch_size, desc="Batches", disable=not show_progress_bar): + sentences_batch = local_sentences[start_index:start_index + batch_size] + batch_embeddings = self._encode(sentences_batch, device=device, output_value=output_value, + convert_to_numpy=False, normalize_embeddings=normalize_embeddings, + multiprocessing=False) + local_embeddings.extend(batch_embeddings) + local_embeddings = torch.stack(local_embeddings) + # gathering everything thanks to the size information from earlier + all_embeddings = mismatched_sizes_all_gather(local_embeddings) + all_embeddings = torch.cat(all_embeddings) + if convert_to_numpy: + all_embeddings = all_embeddings.cpu() + + # Otherwise + else: + # Single-GPU/single-process + if num_proc is None or num_proc == 1: + if device is None: + device = self._target_device + self.to(device) + for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): + sentences_batch = sentences_sorted[start_index:start_index + batch_size] + embeddings = self._encode(sentences_batch, device=device, output_value=output_value, + convert_to_numpy=convert_to_numpy, normalize_embeddings=normalize_embeddings, + multiprocessing=False) + all_embeddings.extend(embeddings) + # Multi-GPU/multi-process + else: + # Allows for several CUDA processes + cuda_compatible_multiprocess = mp.get_context("spawn") + with cuda_compatible_multiprocess.Pool(num_proc) as p: + sentences_batches = [sentences_sorted[start_index:start_index + batch_size] + for start_index in trange(0, len(sentences), batch_size)] + for result in p.map(partial(self._encode, + device=device, + output_value=output_value, + convert_to_numpy=convert_to_numpy, + normalize_embeddings=normalize_embeddings), + sentences_batches): + all_embeddings.extend(result) + + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if convert_to_tensor: + all_embeddings = torch.stack(all_embeddings) + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + def _encode(self, sentences_batch, device, output_value: str = 'sentence_embedding', convert_to_numpy: bool = False, + normalize_embeddings: bool = False, multiprocessing=False): + + if multiprocessing: + rank = mp.current_process()._identity[0] + if device is None and torch.cuda.is_available(): + device = f"cuda:{rank % torch.cuda.device_count()}" + + self.to(device) + features = self.tokenize(sentences_batch) + features = batch_to_device(features, device) + + with torch.no_grad(): + out_features = self.forward(features) + + if output_value == 'token_embeddings': + embeddings = [] + for token_emb, attention in zip(out_features[output_value], out_features['attention_mask']): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + embeddings.append(token_emb[0:last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features['sentence_embedding'])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + # fixes for #522 and #487 to avoid oom problems on gpu with large datasets + if convert_to_numpy: + embeddings = embeddings.cpu() + + return embeddings + + def start_multi_process_pool(self, target_devices: List[str] = None): + """ + Starts multi process to process the encoding with several, independent processes. + This method is recommended if you want to encode on multiple GPUs. It is advised + to start only one process per GPU. This method works together with encode_multi_process + + :param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used + :return: Returns a dict with the target processes, an input queue and and output queue. + """ + if target_devices is None: + if torch.cuda.is_available(): + target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())] + else: + logger.info("CUDA is not available. Start 4 CPU worker") + target_devices = ['cpu'] * 4 + + logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices)))) + + ctx = mp.get_context('spawn') + input_queue = ctx.Queue() + output_queue = ctx.Queue() + processes = [] + + for cuda_id in target_devices: + p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(cuda_id, self, input_queue, output_queue), daemon=True) + p.start() + processes.append(p) + + return {'input': input_queue, 'output': output_queue, 'processes': processes} + + @staticmethod + def stop_multi_process_pool(pool): + """ + Stops all processes started with start_multi_process_pool + """ + for p in pool['processes']: + p.terminate() + + for p in pool['processes']: + p.join() + p.close() + + pool['input'].close() + pool['output'].close() + + def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None): + """ + This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages + and sent to individual processes, which encode these on the different GPUs. This method is only suitable + for encoding large sets of sentences + + :param sentences: List of sentences + :param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool + :param batch_size: Encode sentences with batch size + :param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size. + :return: Numpy matrix with all embeddings + """ + + if chunk_size is None: + chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000) + + logger.info("Chunk data into packages of size {}".format(chunk_size)) + + input_queue = pool['input'] + last_chunk_id = 0 + chunk = [] + + for sentence in sentences: + chunk.append(sentence) + if len(chunk) >= chunk_size: + input_queue.put([last_chunk_id, batch_size, chunk]) + last_chunk_id += 1 + chunk = [] + + if len(chunk) > 0: + input_queue.put([last_chunk_id, batch_size, chunk]) + last_chunk_id += 1 + + output_queue = pool['output'] + results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0]) + embeddings = np.concatenate([result[1] for result in results_list]) + return embeddings + + @staticmethod + def _encode_multi_process_worker(target_device: str, model, input_queue, results_queue): + """ + Internal working process to encode sentences in multi-process setup + """ + while True: + try: + id, batch_size, sentences = input_queue.get() + embeddings = model.encode(sentences, device=target_device, show_progress_bar=False, convert_to_numpy=True, batch_size=batch_size) + results_queue.put([id, embeddings]) + except queue.Empty: + break + + def get_max_seq_length(self): + """ + Returns the maximal sequence length for input the model accepts. Longer inputs will be truncated + """ + if hasattr(self._first_module(), 'max_seq_length'): + return self._first_module().max_seq_length + + return None + + def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): + """ + Tokenizes the texts + """ + return self._first_module().tokenize(texts) + + def get_sentence_features(self, *features): + return self._first_module().get_sentence_features(*features) + + def get_sentence_embedding_dimension(self): + for mod in reversed(self._modules.values()): + sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None) + if callable(sent_embedding_dim_method): + return sent_embedding_dim_method() + return None + + def _first_module(self): + """Returns the first module of this sequential embedder""" + return self._modules[next(iter(self._modules))] + + def _last_module(self): + """Returns the last module of this sequential embedder""" + return self._modules[next(reversed(self._modules))] + + def save(self, path: str, model_name: Optional[str] = None, create_model_card: bool = True): + """ + Saves all elements for this seq. sentence embedder into different sub-folders + :param path: Path on disc + :param model_name: Optional model name + :param create_model_card: If True, create a README.md with basic information about this model + """ + if path is None: + return + + os.makedirs(path, exist_ok=True) + + logger.info("Save model to {}".format(path)) + modules_config = [] + + # Save some model info + if '__version__' not in self._model_config: + self._model_config['__version__'] = { + 'sentence_transformers': __version__, + 'transformers': transformers.__version__, + 'pytorch': torch.__version__, + } + + with open(os.path.join(path, 'config_sentence_transformers.json'), 'w') as fOut: + json.dump(self._model_config, fOut, indent=2) + + # Save modules + for idx, name in enumerate(self._modules): + module = self._modules[name] + if idx == 0 and isinstance(module, Transformer): # Save transformer model in the main folder + model_path = path + "/" + else: + model_path = os.path.join(path, str(idx) + "_" + type(module).__name__) + + os.makedirs(model_path, exist_ok=True) + module.save(model_path) + modules_config.append({'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__}) + + with open(os.path.join(path, 'modules.json'), 'w') as fOut: + json.dump(modules_config, fOut, indent=2) + + # Create model card + if create_model_card: + self._create_model_card(path, model_name) + + def _create_model_card(self, path: str, model_name: Optional[str] = None): + """ + Create an automatic model and stores it in path + """ + if self._model_card_text is not None and len(self._model_card_text) > 0: + model_card = self._model_card_text + else: + tags = ModelCardTemplate.__TAGS__.copy() + model_card = ModelCardTemplate.__MODEL_CARD__ + + if len(self._modules) == 2 and isinstance(self._first_module(), Transformer) and isinstance(self._last_module(), + Pooling) and self._last_module().get_pooling_mode_str() in ['cls', 'max', + 'mean']: + pooling_module = self._last_module() + pooling_mode = pooling_module.get_pooling_mode_str() + model_card = model_card.replace("{USAGE_TRANSFORMERS_SECTION}", ModelCardTemplate.__USAGE_TRANSFORMERS__) + pooling_fct_name, pooling_fct = ModelCardTemplate.model_card_get_pooling_function(pooling_mode) + model_card = model_card.replace("{POOLING_FUNCTION}", pooling_fct).replace("{POOLING_FUNCTION_NAME}", pooling_fct_name).replace("{POOLING_MODE}", pooling_mode) + tags.append('transformers') + + # Print full model + model_card = model_card.replace("{FULL_MODEL_STR}", str(self)) + + # Add tags + model_card = model_card.replace("{TAGS}", "\n".join(["- " + t for t in tags])) + + # Add dim info + self._model_card_vars["{NUM_DIMENSIONS}"] = self.get_sentence_embedding_dimension() + + # Replace vars we created while using the model + for name, value in self._model_card_vars.items(): + model_card = model_card.replace(name, str(value)) + + # Replace remaining vars with default values + for name, value in ModelCardTemplate.__DEFAULT_VARS__.items(): + model_card = model_card.replace(name, str(value)) + + if model_name is not None: + model_card = model_card.replace("{MODEL_NAME}", model_name.strip()) + + with open(os.path.join(path, "README.md"), "w", encoding='utf8') as fOut: + fOut.write(model_card.strip()) + + def save_to_hub(self, + repo_name: str, + organization: Optional[str] = None, + private: Optional[bool] = None, + commit_message: str = "Add new SentenceTransformer model.", + local_model_path: Optional[str] = None, + exist_ok: bool = False, + replace_model_card: bool = False): + """ + Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository. + + :param repo_name: Repository name for your model in the Hub. + :param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization). + :param private: Set to true, for hosting a prive model + :param commit_message: Message to commit while pushing. + :param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded + :param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible + :param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card + :return: The url of the commit of your model in the given repository. + """ + token = HfFolder.get_token() + if token is None: + raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.") + + if '/' in repo_name: + splits = repo_name.split('/', maxsplit=1) + if organization is None or organization == splits[0]: + organization = splits[0] + repo_name = splits[1] + else: + raise ValueError("You passed and invalid repository name: {}.".format(repo_name)) + + endpoint = "https://huggingface.co" + repo_url = HfApi(endpoint=endpoint).create_repo( + token, + repo_name, + organization=organization, + private=private, + repo_type=None, + exist_ok=exist_ok, + ) + full_model_name = repo_url[len(endpoint) + 1:].strip("/") + + with tempfile.TemporaryDirectory() as tmp_dir: + # First create the repo (and clone its content if it's nonempty). + logging.info("Create repository and clone it if it exists") + repo = Repository(tmp_dir, clone_from=repo_url) + + # If user provides local files, copy them. + if local_model_path: + copy_tree(local_model_path, tmp_dir) + else: # Else, save model directly into local repo. + create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md')) + self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card) + + # Find files larger 5M and track with git-lfs + large_files = [] + for root, dirs, files in os.walk(tmp_dir): + for filename in files: + file_path = os.path.join(root, filename) + rel_path = os.path.relpath(file_path, tmp_dir) + + if os.path.getsize(file_path) > (5 * 1024 * 1024): + large_files.append(rel_path) + + if len(large_files) > 0: + logging.info("Track files with git lfs: {}".format(", ".join(large_files))) + repo.lfs_track(large_files) + + logging.info("Push model to the hub. This might take a while") + push_return = repo.push_to_hub(commit_message=commit_message) + + def on_rm_error(func, path, exc_info): + # path contains the path of the file that couldn't be removed + # let's just assume that it's read-only and unlink it. + try: + os.chmod(path, stat.S_IWRITE) + os.unlink(path) + except: + pass + + # Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted + # Hence, try to set write permissions on error + try: + for f in os.listdir(tmp_dir): + shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error) + except Exception as e: + logging.warning("Error when deleting temp folder: {}".format(str(e))) + pass + + return push_return + + def smart_batching_collate(self, batch): + """ + Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model + Here, batch is a list of tuples: [(tokens, label), ...] + + :param batch: + a batch from a SmartBatchingDataset + :return: + a batch of tensors for the model + """ + num_texts = len(batch[0]['texts']) + texts = [[] for _ in range(num_texts)] + labels = [] + + for example in batch: + for idx, text in enumerate(example['texts']): + texts[idx].append(text) + + labels.append(example['label']) + + labels = torch.tensor(labels).to(self._target_device) + + sentence_features = [] + for idx in range(num_texts): + tokenized = self.tokenize(texts[idx]) + batch_to_device(tokenized, self._target_device) + sentence_features.append(tokenized) + + return sentence_features, labels + + def _text_length(self, text: Union[List[int], List[List[int]]]): + """ + Help function to get the length for the input text. Text can be either a string (which means a single text) + a list of ints (which means a single tokenized text), or a tuple of list of ints + (representing several text inputs to the model). + """ + if isinstance(text, str) or isinstance(text[0], int) or len(text) == 0: # Single text, list of ints, or empty + return len(text) + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, '__len__'): # Object has no len() method + return 1 + else: + return sum([len(t) for t in text]) # Sum of length of individual strings + + def fit(self, + train_objectives: Iterable[Tuple[DataLoader, nn.Module]], + evaluator: SentenceEvaluator = None, + epochs: int = 1, + steps_per_epoch: int = None, + scheduler: str = 'WarmupLinear', + warmup_steps: int = 10000, + gradient_accumulation: int = 1, + optimizer_class: Type[Optimizer] = transformers.AdamW, + optimizer_params: Dict[str, object] = None, + weight_decay: float = 0.01, + evaluate_every_steps: int = 0, + nb_eval_steps: int = 0, + log_every_steps: int = 0, + output_path: str = None, + save_best_model: bool = True, + max_grad_norm: float = 1, + use_amp: bool = False, + callback: Callable[[float, int, int], None] = None, + show_progress_bar: bool = True, + checkpoint_path: str = None, + checkpoint_save_steps: int = 500, + checkpoint_save_total_limit: int = 0, + validation_dataloader=None, + accelerator: Accelerator = None + ): + """ + Train the model with the given training objective + Each training objective is sampled in turn for one batch. + We sample only as many batches from each objective as there are in the smallest one + to make sure of equal training with each dataset. + + :param train_objectives: Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning + :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc. + :param epochs: Number of epochs for training + :param steps_per_epoch: Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives. + :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts + :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero. + :param gradient_accumulation: number of steps to take before gradient updates + :param optimizer_class: Optimizer + :param optimizer_params: Optimizer parameters + :param weight_decay: Weight decay for model parameters#!/bin/bash + :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps + :param output_path: Storage path for the model and evaluation files + :param save_best_model: If true, the best model (according to evaluator) is stored at output_path + :param max_grad_norm: Used for gradient normalization. + :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0 + :param callback: Callback function that is invoked after each evaluation. + It must accept the following three parameters in this order: + `score`, `epoch`, `steps` + :param show_progress_bar: If True, output a tqdm progress bar + :param checkpoint_path: Folder to save checkpoints during training + :param checkpoint_save_steps: Will save a checkpoint after so many steps + :param checkpoint_save_total_limit: Total number of checkpoints to store + :param accelerator: Allows you to pass your own accelerator object defined beforehand. + """ + + # replacing mutable arguments + if optimizer_params is None: + optimizer_params = {'lr': 2e-5} + + ##Add info to model card + # info_loss_functions = "\n".join(["- {} with {} training examples".format(str(loss), len(dataloader)) for dataloader, loss in train_objectives]) + info_loss_functions = [] + for dataloader, loss in train_objectives: + info_loss_functions.extend(ModelCardTemplate.get_train_objective_info(dataloader, loss)) + info_loss_functions = "\n\n".join([text for text in info_loss_functions]) + info_fit_parameters = json.dumps({"evaluator": fullname(evaluator), "epochs": epochs, "steps_per_epoch": steps_per_epoch, + "scheduler": scheduler, "warmup_steps": warmup_steps, "optimizer_class": str(optimizer_class), + "optimizer_params": optimizer_params, "weight_decay": weight_decay, + "evaluate_every_steps": evaluate_every_steps, "nb_eval_steps": nb_eval_steps, "max_grad_norm": max_grad_norm}, indent=4, sort_keys=True) + self._model_card_text = None + self._model_card_vars['{TRAINING_SECTION}'] = ModelCardTemplate.__TRAINING_SECTION__.replace("{LOSS_FUNCTIONS}", info_loss_functions).replace("{FIT_PARAMETERS}", + info_fit_parameters) + + if evaluate_every_steps > 0: + assert nb_eval_steps > 0 + assert validation_dataloader is not None + + best_model_path = output_path + '/best_model/' + + # accelerate setup + if accelerator is None: + accelerator = Accelerator() + + if use_amp: + from torch.cuda.amp import autocast + scaler = torch.cuda.amp.GradScaler() + validation_dataloader.collate_fn = self.smart_batching_collate + dataloaders = [dataloader for dataloader, _ in train_objectives] + # Use smart batching + for dataloader in dataloaders: + dataloader.collate_fn = self.smart_batching_collate + # Calculate number of steps + if steps_per_epoch is None or steps_per_epoch == 0: + steps_per_epoch = min([len(dataloader) for dataloader in dataloaders]) / gradient_accumulation + if torch.distributed.is_initialized(): + steps_per_epoch = steps_per_epoch / torch.distributed.get_world_size() + steps_per_epoch = math.ceil(steps_per_epoch) + num_train_steps = int(steps_per_epoch * epochs) + + loss_models = [loss for _, loss in train_objectives] + # Prepare optimizers + optimizers = [] + schedulers = [] + for loss_model in loss_models: + param_optimizer = list(loss_model.named_parameters()) + + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params) + scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps) + + optimizers.append(optimizer) + schedulers.append(scheduler_obj) + + n_dataloaders, n_loss_models, n_optimizers, n_valid_dataloader = len(dataloaders), len(loss_models), len(optimizers), 1 + prepared = accelerator.prepare(*dataloaders, *loss_models, *optimizers, validation_dataloader) + dataloaders = prepared[0:n_dataloaders] + loss_models = prepared[n_dataloaders:n_dataloaders + n_loss_models] + optimizers = prepared[n_dataloaders + n_loss_models:len(prepared) - 1] + validation_dataloader = prepared[-1] + + self.best_score = -9999999 + + global_step = 0 + data_iterators = [iter(dataloader) for dataloader in dataloaders] + validation_iterator = iter(validation_dataloader) + num_train_objectives = len(train_objectives) + training_losses = [] + training_accuracies = [] + index_last_train_loss = 0 + index_last_eval_loss = 0 + + results = dict(eval_loss=[], eval_acc=[], train_loss=[], train_acc=[], step=[]) + skip_scheduler = False + + # run initial eval + eval_loss, eval_acc = self.my_eval(loss_models[0], validation_iterator, validation_dataloader, nb_eval_steps, save_best_model, best_model_path, accelerator) + if accelerator.is_main_process: + # get new loss and acc since last eval step + train_loss = np.mean(training_losses[index_last_eval_loss:]) + train_acc = np.mean(training_accuracies[index_last_eval_loss:]) + logger.warning(f'\nEvaluation step {global_step}: eval loss: {eval_loss:.3f}, eval accuracy: {eval_acc:.3f}') + results['eval_acc'].append(float(eval_acc)) + results['eval_loss'].append(float(eval_loss)) + results['train_acc'].append(float(train_acc)) + results['train_loss'].append(float(train_loss)) + results['step'].append(global_step) + frame = pd.DataFrame.from_dict(results) + frame.to_csv(output_path + '/results.csv', index=False) + + for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar): + training_steps = 0 + + for loss_model in loss_models: + loss_model.zero_grad() + loss_model.train() + + for _ in trange(steps_per_epoch * gradient_accumulation, desc="Iteration", smoothing=0.05, disable=not show_progress_bar): + for train_idx in range(num_train_objectives): + loss_model = loss_models[train_idx] + optimizer = optimizers[train_idx] + scheduler = schedulers[train_idx] + data_iterator = data_iterators[train_idx] + + try: + data = next(data_iterator) + except StopIteration: + data_iterator = iter(dataloaders[train_idx]) + data_iterators[train_idx] = data_iterator + data = next(data_iterator) + + features, labels = data + # logger.warning(f'Rank: {accelerator.process_index}, features shape: {features.shape}') + if use_amp: + with autocast(): + loss_value = loss_model(features, labels) + + scale_before_step = scaler.get_scale() + accelerator.backward(scaler.scale(loss_value)) + training_steps += 1 + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm) + + if training_steps % gradient_accumulation == 0: + scaler.step(optimizer) + scaler.update() + skip_scheduler = scaler.get_scale() != scale_before_step + optimizer.zero_grad() + if not skip_scheduler: + scheduler.step() + global_step += 1 + else: + loss_value, accuracy = loss_model(features, labels) + accelerator.backward(loss_value) + training_losses.append(accelerator.gather(torch.atleast_1d(loss_value))) + training_accuracies.append(accelerator.gather(torch.atleast_1d(accuracy))) + torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm) + if training_steps % gradient_accumulation == 0: + optimizer.step() + optimizer.zero_grad() + if not skip_scheduler: + scheduler.step() + global_step += 1 + training_steps += 1 + + if log_every_steps > 0 and global_step % log_every_steps == 0: + if accelerator.is_main_process: + # get latest loss and acc + train_loss = torch.mean(torch.cat(training_losses[index_last_train_loss:])) + train_acc = torch.mean(torch.cat(training_accuracies[index_last_train_loss:])) + logger.warning(f'\nTraining step {global_step}: loss: {train_loss:.3f}, accuracy: {train_acc:.3f}') + # keep track of number of loss stored + index_last_train_loss = len(training_losses) + assert index_last_train_loss == len(training_accuracies) + + if evaluate_every_steps > 0 and global_step % evaluate_every_steps == 0: + eval_loss, eval_acc = self.my_eval(loss_models[0], validation_iterator, validation_dataloader, nb_eval_steps, save_best_model, best_model_path, accelerator) + if accelerator.is_main_process: + # get new loss and acc since last eval step + train_loss = torch.mean(torch.cat(training_losses[index_last_eval_loss:])) + train_acc = torch.mean(torch.cat(training_accuracies[index_last_eval_loss:])) + logger.warning(f'\nEvaluation step {global_step}: eval loss: {eval_loss:.3f}, eval accuracy: {eval_acc:.3f}') + results['eval_acc'].append(float(eval_acc)) + results['eval_loss'].append(float(eval_loss)) + results['train_acc'].append(float(train_acc)) + results['train_loss'].append(float(train_loss)) + results['step'].append(global_step) + frame = pd.DataFrame.from_dict(results) + frame.to_csv(output_path + '/results.csv', index=False) + + # keep track of number of loss stored + index_last_eval_loss = len(training_losses) + assert index_last_eval_loss == len(training_accuracies) + + # self._eval_during_training(evaluator, output_path, save_best_model, epoch, global_step, callback, + # main_process=accelerator.is_main_process) + + for loss_model in loss_models: + loss_model.zero_grad() + loss_model.train() + + if checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 \ + and global_step % checkpoint_save_steps == 0 and accelerator.is_main_process: + self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) + + # self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback, + # main_process=accelerator.is_main_process) + + # if accelerator.is_main_process: + # if evaluator is None and output_path is not None: #No evaluator, but output path: save final model version + # self.save(output_path) + # + # if checkpoint_path is not None: + # self._save_checkpoint(checkpoint_path, checkpoint_save_total_limit, global_step) + + def my_eval(self, loss_model, validation_iterator, validation_dataloader, nb_eval_steps, save_best_model, output_path, accelerator): + loss_model.eval() + eval_losses = [] + eval_accuracies = [] + + for i in range(nb_eval_steps): + try: + data = next(validation_iterator) + except StopIteration: + validation_iterator = iter(validation_dataloader) + data = next(validation_iterator) + features, labels = data + with torch.no_grad(): + loss_value, accuracy = loss_model(features, labels) + eval_losses.append(accelerator.gather(torch.atleast_1d(loss_value))) + eval_accuracies.append(accelerator.gather(torch.atleast_1d(accuracy))) + eval_loss = torch.mean(torch.cat(eval_losses)) + eval_acc = torch.mean(torch.cat(eval_accuracies)) + if accelerator.is_main_process: + score = eval_acc + if score > self.best_score: + self.best_score = score + if save_best_model: + self.save(output_path) + return eval_loss, eval_acc + + # accelerator.backward(loss_value) + + def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None): + """ + Evaluate the model + + :param evaluator: + the evaluator + :param output_path: + the evaluator can write the results to this path + """ + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + return evaluator(self, output_path) + + def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback, main_process=True): + """Runs evaluation during the training""" + eval_path = output_path + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + eval_path = os.path.join(output_path, "eval") + os.makedirs(eval_path, exist_ok=True) + + if evaluator is not None: + score = evaluator(self, output_path=eval_path, epoch=epoch, steps=steps) + if callback is not None and main_process: + callback(score, epoch, steps) + if score > self.best_score and main_process: + self.best_score = score + if save_best_model: + self.save(output_path) + + def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step): + # Store new checkpoint + self.save(os.path.join(checkpoint_path, str(step))) + + # Delete old checkpoints + if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0: + old_checkpoints = [] + for subdir in os.listdir(checkpoint_path): + if subdir.isdigit(): + old_checkpoints.append({'step': int(subdir), 'path': os.path.join(checkpoint_path, subdir)}) + + if len(old_checkpoints) > checkpoint_save_total_limit: + old_checkpoints = sorted(old_checkpoints, key=lambda x: x['step']) + shutil.rmtree(old_checkpoints[0]['path']) + + def _load_auto_model(self, model_name_or_path, **auto_model_kwargs): + """ + Creates a simple Transformer + Mean Pooling model and returns the modules + """ + logging.warning("No sentence_transfo model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path)) + transformer_model = Transformer(model_name_or_path, **auto_model_kwargs) + pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean') + return [transformer_model, pooling_model] + + def _load_sbert_model(self, model_path): + """ + Loads a full sentence_transfo model + """ + # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) + config_sentence_transformers_json_path = os.path.join(model_path, 'config_sentence_transformers.json') + if os.path.exists(config_sentence_transformers_json_path): + with open(config_sentence_transformers_json_path) as fIn: + self._model_config = json.load(fIn) + + if '__version__' in self._model_config and 'sentence_transformers' in self._model_config['__version__'] and self._model_config['__version__'][ + 'sentence_transformers'] > __version__: + logger.warning( + "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format( + self._model_config['__version__']['sentence_transformers'], __version__)) + + # Check if a readme exists + model_card_path = os.path.join(model_path, 'README.md') + if os.path.exists(model_card_path): + try: + with open(model_card_path, encoding='utf8') as fIn: + self._model_card_text = fIn.read() + except: + pass + + # Load the modules of sentence transformer + modules_json_path = os.path.join(model_path, 'modules.json') + with open(modules_json_path) as fIn: + modules_config = json.load(fIn) + + modules = OrderedDict() + for module_config in modules_config: + if module_config['type'][:16] == 'sentence_transfo': + module_config['type'] = 'src.music.representation_learning.' + module_config['type'] + module_class = import_from_string(module_config['type']) + module = module_class.load(os.path.join(model_path, module_config['path'])) + modules[module_config['name']] = module + + return modules + + @staticmethod + def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int): + """ + Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts + """ + scheduler = scheduler.lower() + if scheduler == 'constantlr': + return transformers.get_constant_schedule(optimizer) + elif scheduler == 'warmupconstant': + return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps) + elif scheduler == 'warmuplinear': + return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) + elif scheduler == 'warmupcosine': + return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) + elif scheduler == 'warmupcosinewithhardrestarts': + return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) + else: + raise ValueError("Unknown scheduler {}".format(scheduler)) + + @property + def device(self) -> device: + """ + Get torch.device from module, assuming that the whole module has one device. + """ + try: + return next(self.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + @property + def tokenizer(self): + """ + Property to get the tokenizer that is used by this model + """ + return self._first_module().tokenizer + + @tokenizer.setter + def tokenizer(self, value): + """ + Property to set the tokenizer that is should used by this model + """ + self._first_module().tokenizer = value + + @property + def max_seq_length(self): + """ + Property to get the maximal input sequence length for the model. Longer inputs will be truncated. + """ + return self._first_module().max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value): + """ + Property to set the maximal input sequence length for the model. Longer inputs will be truncated. + """ + self._first_module().max_seq_length = value diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__init__.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..519696b5abe889c7b3989a041f2d5cc6eab72d90 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/__init__.py @@ -0,0 +1,4 @@ +__version__ = "2.1.0" +__MODEL_HUB_ORGANIZATION__ = 'sentence_transfo' +from .LoggingHandler import LoggingHandler +from .SentenceTransformer import SentenceTransformer diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/LoggingHandler.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/LoggingHandler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34a493a713778595723be852bb69046486f75db2 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/LoggingHandler.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/SentenceTransformer.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/SentenceTransformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fa563149617dc32a2ce392e5db14145dce2cd17 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/SentenceTransformer.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/__init__.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dfc9de3c5e9069bdd84c769fc7f7b8f1919ac2d Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/model_card_templates.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/model_card_templates.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17a9640eecc289bdc4524b6035184e39b3916b11 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/model_card_templates.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/util.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..832b70db3fe8f1fc399a88facf3adf21d70bb046 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/__pycache__/util.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/SentenceEvaluator.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/SentenceEvaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f4043f715e88ea19acc0b7e6edcd7c420a14f9 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/SentenceEvaluator.py @@ -0,0 +1,30 @@ +class SentenceEvaluator: + """ + Base class for all evaluators + + Extend this class and implement __call__ for custom evaluators. + """ + + def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, num_proc: int = None) -> float: + """ + This is called during training to evaluate the model. + It returns a score for the evaluation with a higher score indicating a better result. + + :param model: + the model to evaluate + :param output_path: + path where predictions and metrics are written to + :param epoch + the epoch where the evaluation takes place. + This is used for the file prefixes. + If this is -1, then we assume evaluation on test data. + :param steps + the steps in the current epoch at time of the evaluation. + This is used for the file prefixes. + If this is -1, then we assume evaluation at the end of the epoch. + :return: a score for the evaluation with a higher score indicating a better result + :param num_proc + the number of processes to use for evaluation. Allows for multi-GPU evaluation + :return: a score for the evaluation with a higher score indicating a better result + """ + pass diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/__pycache__/SentenceEvaluator.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/__pycache__/SentenceEvaluator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f26c37839fb697254301e6733c7a8aa0e39b6ec2 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/evaluation/__pycache__/SentenceEvaluator.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/losses/MultipleNegativesRankingLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..666bb3c26bda35d2475be77497e23b6665343304 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -0,0 +1,86 @@ +import torch +from torch import nn, Tensor +from typing import Iterable, Dict +from ..SentenceTransformer import SentenceTransformer +from .. import util +from ..util import mismatched_sizes_all_gather + + +class MultipleNegativesRankingLoss(nn.Module): + """ + This loss expects as input a batch consisting of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_n, p_n) + where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair. + + For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and + n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores. + + This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc)) + as it will sample in each batch n-1 negative docs randomly. + + The performance usually increases with increasing batch sizes. + + For more information, see: https://arxiv.org/pdf/1705.00652.pdf + (Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4) + + You can also provide one or multiple hard negatives per anchor-positive pair by structering the data like this: + (a_1, p_1, n_1), (a_2, p_2, n_2) + + Here, n_1 is a hard negative for (a_1, p_1). The loss will use for the pair (a_i, p_i) all p_j (j!=i) and all n_j as negatives. + + Example:: + + from sentence_transformers import SentenceTransformer, losses, InputExample + from torch.utils.data import DataLoader + + model = SentenceTransformer('distilbert-base-uncased') + train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']), + InputExample(texts=['Anchor 2', 'Positive 2'])] + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32) + train_loss = losses.MultipleNegativesRankingLoss(model=model) + """ + def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct = util.cos_sim): + """ + :param model: SentenceTransformer model + :param scale: Output of similarity function is multiplied by scale value + :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) + """ + super(MultipleNegativesRankingLoss, self).__init__() + self.model = model + self.scale = scale + self.similarity_fct = similarity_fct + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): + # reps = [torch.randn_like(torch.zeros([3, 768])), torch.randn_like(torch.zeros([3, 768]))] + reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] + embeddings_a = reps[0] + + if torch.distributed.is_initialized(): + + embeddings_b = reps[1] + if len(reps) > 2: + embeddings_n = torch.cat(reps[2:]) + else: + embeddings_n = embeddings_b[:0, :] + full_embeddings_b = mismatched_sizes_all_gather(embeddings_b) + full_embeddings_b = torch.cat(full_embeddings_b) + full_embeddings_n = mismatched_sizes_all_gather(embeddings_n) + full_embeddings_n = torch.cat(full_embeddings_n) + candidates = torch.cat([full_embeddings_b, full_embeddings_n]) + + scores = self.similarity_fct(embeddings_a, candidates) * self.scale + labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)\ + + len(scores) * torch.distributed.get_rank() + acc = torch.mean((torch.argmax(scores, dim=1) == labels).float()) + return self.cross_entropy_loss(scores, labels), acc + + else: + candidates = torch.cat(reps[1:]) + scores = self.similarity_fct(embeddings_a, candidates) * self.scale + labels = torch.tensor(range(len(scores)), dtype=torch.long, + device=scores.device) # Example a[i] should match with b[i] + acc = torch.mean((torch.argmax(scores, dim=1) == labels).float()) + return self.cross_entropy_loss(scores, labels), acc + + def get_config_dict(self): + return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__} diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/model_card_templates.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/model_card_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..e67a904a3efc84cd807c94ec1d93d9fa63dcbb42 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/model_card_templates.py @@ -0,0 +1,173 @@ +import logging + +from .util import fullname + +class ModelCardTemplate: + __TAGS__ = ["sentence_transfo", "feature-extraction", "sentence-similarity"] + __DEFAULT_VARS__ = { + "{PIPELINE_TAG}": "sentence-similarity", + "{MODEL_DESCRIPTION}": "", + "{TRAINING_SECTION}": "", + "{USAGE_TRANSFORMERS_SECTION}": "", + "{EVALUATION}": "", + "{CITING}": "" + } + + __MODEL_CARD__ = """ +--- +pipeline_tag: {PIPELINE_TAG} +tags: +{TAGS} +--- + +# {MODEL_NAME} + +This is a [sentence_transfo](https://www.SBERT.net) model: It maps sentences & paragraphs to a {NUM_DIMENSIONS} dimensional dense vector space and can be used for tasks like clustering or semantic search. + +{MODEL_DESCRIPTION} + +## Usage (Sentence-Transformers) + +Using this model becomes easy when you have [sentence_transfo](https://www.SBERT.net) installed: + +``` +pip install -U sentence_transfo +``` + +Then you can use the model like this: + +```python +from sentence_transformers import SentenceTransformer +sentences = ["This is an example sentence", "Each sentence is converted"] + +model = SentenceTransformer('{MODEL_NAME}') +embeddings = model.encode(sentences) +print(embeddings) +``` + +{USAGE_TRANSFORMERS_SECTION} + +## Evaluation Results + +{EVALUATION} + +For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name={MODEL_NAME}) + +{TRAINING_SECTION} + +## Full Model Architecture +``` +{FULL_MODEL_STR} +``` + +## Citing & Authors + +{CITING} + +""" + + + + __TRAINING_SECTION__ = """ +## Training +The model was trained with the parameters: + +{LOSS_FUNCTIONS} + +Parameters of the fit()-Method: +``` +{FIT_PARAMETERS} +``` +""" + + + __USAGE_TRANSFORMERS__ = """\n +## Usage (HuggingFace Transformers) +Without [sentence_transfo](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings. + +```python +from transformers import AutoTokenizer, AutoModel +import torch + +{POOLING_FUNCTION} + +# Sentences we want sentence embeddings for +sentences = ['This is an example sentence', 'Each sentence is converted'] + +# Load model from HuggingFace Hub +tokenizer = AutoTokenizer.from_pretrained('{MODEL_NAME}') +model = AutoModel.from_pretrained('{MODEL_NAME}') + +# Tokenize sentences +encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') + +# Compute token embeddings +with torch.no_grad(): + model_output = model(**encoded_input) + +# Perform pooling. In this case, {POOLING_MODE} pooling. +sentence_embeddings = {POOLING_FUNCTION_NAME}(model_output, encoded_input['attention_mask']) + +print("Sentence embeddings:") +print(sentence_embeddings) +``` + +""" + + + + @staticmethod + def model_card_get_pooling_function(pooling_mode): + if pooling_mode == 'max': + return "max_pooling", """ +# Max Pooling - Take the max value over time for every dimension. +def max_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value + return torch.max(token_embeddings, 1)[0] +""" + elif pooling_mode == 'mean': + return "mean_pooling", """ +#Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) +""" + + elif pooling_mode == 'cls': + return "cls_pooling", """ +def cls_pooling(model_output, attention_mask): + return model_output[0][:,0] +""" + + @staticmethod + def get_train_objective_info(dataloader, loss): + try: + if hasattr(dataloader, 'get_config_dict'): + train_loader = dataloader.get_config_dict() + else: + loader_params = {} + loader_params['batch_size'] = dataloader.batch_size if hasattr(dataloader, 'batch_size') else 'unknown' + if hasattr(dataloader, 'sampler'): + loader_params['sampler'] = fullname(dataloader.sampler) + if hasattr(dataloader, 'batch_sampler'): + loader_params['batch_sampler'] = fullname(dataloader.batch_sampler) + + dataloader_str = """**DataLoader**:\n\n`{}` of length {} with parameters: +``` +{} +```""".format(fullname(dataloader), len(dataloader), loader_params) + + loss_str = "**Loss**:\n\n`{}` {}".format(fullname(loss), + """with parameters: + ``` + {} + ```""".format(loss.get_config_dict()) if hasattr(loss, 'get_config_dict') else "") + + return [dataloader_str, loss_str] + + except Exception as e: + logging.WARN("Exception when creating get_train_objective_info: {}".format(str(e))) + return "" \ No newline at end of file diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Dense.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Dense.py new file mode 100644 index 0000000000000000000000000000000000000000..e25a7d5d7c7499d482019822b5732f8a08494eae --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Dense.py @@ -0,0 +1,63 @@ +import torch +from torch import Tensor +from torch import nn +from torch import functional as F +from typing import Union, Tuple, List, Iterable, Dict +import os +import json +from ..util import fullname, import_from_string + + +class Dense(nn.Module): + """Feed-forward function with activiation function. + + This layer takes a fixed-sized sentence embedding and passes it through a feed-forward layer. Can be used to generate deep averaging networs (DAN). + + :param in_features: Size of the input dimension + :param out_features: Output size + :param bias: Add a bias vector + :param activation_function: Pytorch activation function applied on output + :param init_weight: Initial value for the matrix of the linear layer + :param init_bias: Initial value for the bias of the linear layer + """ + def __init__(self, in_features: int, out_features: int, bias: bool = True, activation_function=nn.Tanh(), init_weight: Tensor = None, init_bias: Tensor = None): + super(Dense, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.activation_function = activation_function + self.linear = nn.Linear(in_features, out_features, bias=bias) + + if init_weight is not None: + self.linear.weight = nn.Parameter(init_weight) + + if init_bias is not None: + self.linear.bias = nn.Parameter(init_bias) + + def forward(self, features: Dict[str, Tensor]): + features.update({'sentence_embedding': self.activation_function(self.linear(features['sentence_embedding']))}) + return features + + def get_sentence_embedding_dimension(self) -> int: + return self.out_features + + def get_config_dict(self): + return {'in_features': self.in_features, 'out_features': self.out_features, 'bias': self.bias, 'activation_function': fullname(self.activation_function)} + + def save(self, output_path): + with open(os.path.join(output_path, 'config.json'), 'w') as fOut: + json.dump(self.get_config_dict(), fOut) + + torch.save(self.state_dict(), os.path.join(output_path, 'pytorch_model.bin')) + + def __repr__(self): + return "Dense({})".format(self.get_config_dict()) + @staticmethod + def load(input_path): + with open(os.path.join(input_path, 'config.json')) as fIn: + config = json.load(fIn) + + config['activation_function'] = import_from_string(config['activation_function'])() + model = Dense(**config) + model.load_state_dict(torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu'))) + return model diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Pooling.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..f75b16509d40d697de5a349cf261b0eceb4a4e87 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Pooling.py @@ -0,0 +1,120 @@ +import torch +from torch import Tensor +from torch import nn +from typing import Union, Tuple, List, Iterable, Dict +import os +import json + + +class Pooling(nn.Module): + """Performs pooling (max or mean) on the token embeddings. + + Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. + You can concatenate multiple poolings together. + + :param word_embedding_dimension: Dimensions for the word embeddings + :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings + :param pooling_mode_cls_token: Use the first token (CLS token) as text representations + :param pooling_mode_max_tokens: Use max in each dimension over all tokens. + :param pooling_mode_mean_tokens: Perform mean-pooling + :param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but devide by sqrt(input_length). + """ + def __init__(self, + word_embedding_dimension: int, + pooling_mode: str = None, + pooling_mode_cls_token: bool = False, + pooling_mode_max_tokens: bool = False, + pooling_mode_mean_tokens: bool = True, + pooling_mode_mean_sqrt_len_tokens: bool = False, + ): + super(Pooling, self).__init__() + + self.config_keys = ['word_embedding_dimension', 'pooling_mode_cls_token', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens', 'pooling_mode_mean_sqrt_len_tokens'] + + if pooling_mode is not None: #Set pooling mode by string + pooling_mode = pooling_mode.lower() + assert pooling_mode in ['mean', 'max', 'cls'] + pooling_mode_cls_token = (pooling_mode == 'cls') + pooling_mode_max_tokens = (pooling_mode == 'max') + pooling_mode_mean_tokens = (pooling_mode == 'mean') + + self.word_embedding_dimension = word_embedding_dimension + self.pooling_mode_cls_token = pooling_mode_cls_token + self.pooling_mode_mean_tokens = pooling_mode_mean_tokens + self.pooling_mode_max_tokens = pooling_mode_max_tokens + self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens + + pooling_mode_multiplier = sum([pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, pooling_mode_mean_sqrt_len_tokens]) + self.pooling_output_dimension = (pooling_mode_multiplier * word_embedding_dimension) + + + def __repr__(self): + return "Pooling({})".format(self.get_config_dict()) + + def get_pooling_mode_str(self) -> str: + """ + Returns the pooling mode as string + """ + modes = [] + if self.pooling_mode_cls_token: + modes.append('cls') + if self.pooling_mode_mean_tokens: + modes.append('mean') + if self.pooling_mode_max_tokens: + modes.append('max') + if self.pooling_mode_mean_sqrt_len_tokens: + modes.append('mean_sqrt_len_tokens') + + return "+".join(modes) + + def forward(self, features: Dict[str, Tensor]): + token_embeddings = features['token_embeddings'] + attention_mask = features['attention_mask'] + + ## Pooling strategy + output_vectors = [] + if self.pooling_mode_cls_token: + cls_token = features.get('cls_token_embeddings', token_embeddings[:, 0]) # Take first token by default + output_vectors.append(cls_token) + if self.pooling_mode_max_tokens: + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value + max_over_time = torch.max(token_embeddings, 1)[0] + output_vectors.append(max_over_time) + if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + + #If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present + if 'token_weights_sum' in features: + sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size()) + else: + sum_mask = input_mask_expanded.sum(1) + + sum_mask = torch.clamp(sum_mask, min=1e-9) + + if self.pooling_mode_mean_tokens: + output_vectors.append(sum_embeddings / sum_mask) + if self.pooling_mode_mean_sqrt_len_tokens: + output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) + + output_vector = torch.cat(output_vectors, 1) + features.update({'sentence_embedding': output_vector}) + return features + + def get_sentence_embedding_dimension(self): + return self.pooling_output_dimension + + def get_config_dict(self): + return {key: self.__dict__[key] for key in self.config_keys} + + def save(self, output_path): + with open(os.path.join(output_path, 'config.json'), 'w') as fOut: + json.dump(self.get_config_dict(), fOut, indent=2) + + @staticmethod + def load(input_path): + with open(os.path.join(input_path, 'config.json')) as fIn: + config = json.load(fIn) + + return Pooling(**config) diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Transformer.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d0675dfe6f2fafbdd665c38d9619eacdab4fa48b --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/Transformer.py @@ -0,0 +1,129 @@ +from torch import nn +from transformers import AutoModel, AutoTokenizer, AutoConfig +import json +from typing import List, Dict, Optional, Union, Tuple +import os + + +class Transformer(nn.Module): + """Huggingface AutoModel to generate token embeddings. + Loads the correct class, e.g. BERT / RoBERTa etc. + + :param model_name_or_path: Huggingface models name (https://huggingface.co/models) + :param max_seq_length: Truncate any inputs longer than max_seq_length + :param model_args: Arguments (key, value pairs) passed to the Huggingface Transformers model + :param cache_dir: Cache dir for Huggingface Transformers to store/load models + :param tokenizer_args: Arguments (key, value pairs) passed to the Huggingface Tokenizer model + :param do_lower_case: If true, lowercases the input (independent if the model is cased or not) + :param tokenizer_name_or_path: Name or path of the tokenizer. When None, then model_name_or_path is used + """ + def __init__(self, model_name_or_path: str, max_seq_length: Optional[int] = None, + model_args: Dict = {}, cache_dir: Optional[str] = None, + tokenizer_args: Dict = {}, do_lower_case: bool = False, + tokenizer_name_or_path : str = None, **auto_model_kwargs): + super(Transformer, self).__init__() + self.config_keys = ['max_seq_length', 'do_lower_case'] + self.do_lower_case = do_lower_case + + config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir) + self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **auto_model_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **tokenizer_args) + + #No max_seq_length set. Try to infer from model + if max_seq_length is None: + if hasattr(self.auto_model, "config") and hasattr(self.auto_model.config, "max_position_embeddings") and hasattr(self.tokenizer, "model_max_length"): + max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length) + + self.max_seq_length = max_seq_length + + if tokenizer_name_or_path is not None: + self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ + + def __repr__(self): + return "Transformer({}) with Transformer model: {} ".format(self.get_config_dict(), self.auto_model.__class__.__name__) + + def forward(self, features): + """Returns token_embeddings, cls_token""" + trans_features = {'input_ids': features['input_ids'], 'attention_mask': features['attention_mask']} + if 'token_type_ids' in features: + trans_features['token_type_ids'] = features['token_type_ids'] + + output_states = self.auto_model(**trans_features, return_dict=False) + output_tokens = output_states[0] + + features.update({'token_embeddings': output_tokens, 'attention_mask': features['attention_mask']}) + + if self.auto_model.config.output_hidden_states: + all_layer_idx = 2 + if len(output_states) < 3: #Some models only output last_hidden_states and all_hidden_states + all_layer_idx = 1 + + hidden_states = output_states[all_layer_idx] + features.update({'all_layer_embeddings': hidden_states}) + + return features + + def get_word_embedding_dimension(self) -> int: + return self.auto_model.config.hidden_size + + def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): + """ + Tokenizes a text and maps tokens to token-ids + """ + output = {} + if isinstance(texts[0], str): + to_tokenize = [texts] + elif isinstance(texts[0], dict): + to_tokenize = [] + output['text_keys'] = [] + for lookup in texts: + text_key, text = next(iter(lookup.items())) + to_tokenize.append(text) + output['text_keys'].append(text_key) + to_tokenize = [to_tokenize] + else: + batch1, batch2 = [], [] + for text_tuple in texts: + batch1.append(text_tuple[0]) + batch2.append(text_tuple[1]) + to_tokenize = [batch1, batch2] + + #strip + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] + + #Lowercase + if self.do_lower_case: + to_tokenize = [[s.lower() for s in col] for col in to_tokenize] + + + output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length)) + return output + + + def get_config_dict(self): + return {key: self.__dict__[key] for key in self.config_keys} + + def save(self, output_path: str): + self.auto_model.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + with open(os.path.join(output_path, 'sentence_bert_config.json'), 'w') as fOut: + json.dump(self.get_config_dict(), fOut, indent=2) + + @staticmethod + def load(input_path: str): + #Old classes used other config names than 'sentence_bert_config.json' + for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']: + sbert_config_path = os.path.join(input_path, config_name) + if os.path.exists(sbert_config_path): + break + + with open(sbert_config_path) as fIn: + config = json.load(fIn) + return Transformer(model_name_or_path=input_path, **config) + + + + + + diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Dense.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa7c59ef55f15d9527fc57c410a89fd5a726928 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Dense.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Pooling.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Pooling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbfe368bfc05f486b7939db7d6eda79cf811e4a Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Pooling.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Transformer.cpython-39.pyc b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa83e067638a804171b8a554c7d3cfc790d87100 Binary files /dev/null and b/src/music/representation_learning/sentence_transfo/sentence_transformers/models/__pycache__/Transformer.cpython-39.pyc differ diff --git a/src/music/representation_learning/sentence_transfo/sentence_transformers/util.py b/src/music/representation_learning/sentence_transfo/sentence_transformers/util.py new file mode 100644 index 0000000000000000000000000000000000000000..392e64fc86adbd6cd6e1d53af8ac6681d699e1f8 --- /dev/null +++ b/src/music/representation_learning/sentence_transfo/sentence_transformers/util.py @@ -0,0 +1,525 @@ +import requests +import torch +from torch import Tensor, device +from typing import List, Callable +from tqdm.autonotebook import tqdm +import sys +import importlib +import os +import torch +import numpy as np +import queue +import logging + + +logger = logging.getLogger(__name__) + +def pytorch_cos_sim(a: Tensor, b: Tensor): + """ + Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + return cos_sim(a, b) + +def cos_sim(a: Tensor, b: Tensor): + """ + Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + + +def dot_score(a: Tensor, b: Tensor): + """ + Computes the dot-product dot_prod(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = dot_prod(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + return torch.mm(a, b.transpose(0, 1)) + + +def pairwise_dot_score(a: Tensor, b: Tensor): + """ + Computes the pairwise dot-product dot_prod(a[i], b[i]) + :return: Vector with res[i] = dot_prod(a[i], b[i]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + return (a * b).sum(dim=-1) + + +def pairwise_cos_sim(a: Tensor, b: Tensor): + """ + Computes the pairwise cossim cos_sim(a[i], b[i]) + :return: Vector with res[i] = cos_sim(a[i], b[i]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b)) + + +def normalize_embeddings(embeddings: Tensor): + """ + Normalizes the embeddings matrix, so that each sentence embedding has unit length + """ + return torch.nn.functional.normalize(embeddings, p=2, dim=1) + + +def paraphrase_mining(model, + sentences: List[str], + show_progress_bar: bool = False, + batch_size:int = 32, + *args, + **kwargs): + """ + Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all + other sentences and returns a list with the pairs that have the highest cosine similarity score. + + :param model: SentenceTransformer model for embedding computation + :param sentences: A list of strings (texts or sentences) + :param show_progress_bar: Plotting of a progress bar + :param batch_size: Number of texts that are encoded simultaneously by the model + :param query_chunk_size: Search for most similar pairs for #query_chunk_size at the same time. Decrease, to lower memory footprint (increases run-time). + :param corpus_chunk_size: Compare a sentence simultaneously against #corpus_chunk_size other sentences. Decrease, to lower memory footprint (increases run-time). + :param max_pairs: Maximal number of text pairs returned. + :param top_k: For each sentence, we retrieve up to top_k other sentences + :param score_function: Function for computing scores. By default, cosine similarity. + :return: Returns a list of triplets with the format [score, id1, id2] + """ + + # Compute embedding for the sentences + embeddings = model.encode(sentences, show_progress_bar=show_progress_bar, batch_size=batch_size, convert_to_tensor=True) + + return paraphrase_mining_embeddings(embeddings, *args, **kwargs) + + +def paraphrase_mining_embeddings(embeddings: Tensor, + query_chunk_size: int = 5000, + corpus_chunk_size: int = 100000, + max_pairs: int = 500000, + top_k: int = 100, + score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim): + """ + Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all + other sentences and returns a list with the pairs that have the highest cosine similarity score. + + :param embeddings: A tensor with the embeddings + :param query_chunk_size: Search for most similar pairs for #query_chunk_size at the same time. Decrease, to lower memory footprint (increases run-time). + :param corpus_chunk_size: Compare a sentence simultaneously against #corpus_chunk_size other sentences. Decrease, to lower memory footprint (increases run-time). + :param max_pairs: Maximal number of text pairs returned. + :param top_k: For each sentence, we retrieve up to top_k other sentences + :param score_function: Function for computing scores. By default, cosine similarity. + :return: Returns a list of triplets with the format [score, id1, id2] + """ + + top_k += 1 # A sentence has the highest similarity to itself. Increase +1 as we are interest in distinct pairs + + # Mine for duplicates + pairs = queue.PriorityQueue() + min_score = -1 + num_added = 0 + + for corpus_start_idx in range(0, len(embeddings), corpus_chunk_size): + for query_start_idx in range(0, len(embeddings), query_chunk_size): + scores = score_function(embeddings[query_start_idx:query_start_idx+query_chunk_size], embeddings[corpus_start_idx:corpus_start_idx+corpus_chunk_size]) + + scores_top_k_values, scores_top_k_idx = torch.topk(scores, min(top_k, len(scores[0])), dim=1, largest=True, sorted=False) + scores_top_k_values = scores_top_k_values.cpu().tolist() + scores_top_k_idx = scores_top_k_idx.cpu().tolist() + + for query_itr in range(len(scores)): + for top_k_idx, corpus_itr in enumerate(scores_top_k_idx[query_itr]): + i = query_start_idx + query_itr + j = corpus_start_idx + corpus_itr + + if i != j and scores_top_k_values[query_itr][top_k_idx] > min_score: + pairs.put((scores_top_k_values[query_itr][top_k_idx], i, j)) + num_added += 1 + + if num_added >= max_pairs: + entry = pairs.get() + min_score = entry[0] + + # Get the pairs + added_pairs = set() # Used for duplicate detection + pairs_list = [] + while not pairs.empty(): + score, i, j = pairs.get() + sorted_i, sorted_j = sorted([i, j]) + + if sorted_i != sorted_j and (sorted_i, sorted_j) not in added_pairs: + added_pairs.add((sorted_i, sorted_j)) + pairs_list.append([score, i, j]) + + # Highest scores first + pairs_list = sorted(pairs_list, key=lambda x: x[0], reverse=True) + return pairs_list + + +def information_retrieval(*args, **kwargs): + """This function is deprecated. Use semantic_search instead""" + return semantic_search(*args, **kwargs) + + +def semantic_search(query_embeddings: Tensor, + corpus_embeddings: Tensor, + query_chunk_size: int = 100, + corpus_chunk_size: int = 500000, + top_k: int = 10, + score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim): + """ + This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings. + It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries. + + :param query_embeddings: A 2 dimensional tensor with the query embeddings. + :param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings. + :param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory. + :param corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory. + :param top_k: Retrieve top k matching entries. + :param score_function: Function for computing scores. By default, cosine similarity. + :return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores. + """ + + if isinstance(query_embeddings, (np.ndarray, np.generic)): + query_embeddings = torch.from_numpy(query_embeddings) + elif isinstance(query_embeddings, list): + query_embeddings = torch.stack(query_embeddings) + + if len(query_embeddings.shape) == 1: + query_embeddings = query_embeddings.unsqueeze(0) + + if isinstance(corpus_embeddings, (np.ndarray, np.generic)): + corpus_embeddings = torch.from_numpy(corpus_embeddings) + elif isinstance(corpus_embeddings, list): + corpus_embeddings = torch.stack(corpus_embeddings) + + + #Check that corpus and queries are on the same device + if corpus_embeddings.device != query_embeddings.device: + query_embeddings = query_embeddings.to(corpus_embeddings.device) + + queries_result_list = [[] for _ in range(len(query_embeddings))] + + for query_start_idx in range(0, len(query_embeddings), query_chunk_size): + # Iterate over chunks of the corpus + for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size): + # Compute cosine similarities + cos_scores = score_function(query_embeddings[query_start_idx:query_start_idx+query_chunk_size], corpus_embeddings[corpus_start_idx:corpus_start_idx+corpus_chunk_size]) + + # Get top-k scores + cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k, len(cos_scores[0])), dim=1, largest=True, sorted=False) + cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() + cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() + + for query_itr in range(len(cos_scores)): + for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]): + corpus_id = corpus_start_idx + sub_corpus_id + query_id = query_start_idx + query_itr + queries_result_list[query_id].append({'corpus_id': corpus_id, 'score': score}) + + #Sort and strip to top_k results + for idx in range(len(queries_result_list)): + queries_result_list[idx] = sorted(queries_result_list[idx], key=lambda x: x['score'], reverse=True) + queries_result_list[idx] = queries_result_list[idx][0:top_k] + + return queries_result_list + + +def http_get(url, path): + """ + Downloads a URL to a given path on disc + """ + if os.path.dirname(path) != '': + os.makedirs(os.path.dirname(path), exist_ok=True) + + req = requests.get(url, stream=True) + if req.status_code != 200: + print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr) + req.raise_for_status() + return + + download_filepath = path+"_part" + with open(download_filepath, "wb") as file_binary: + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total, unit_scale=True) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + file_binary.write(chunk) + + os.rename(download_filepath, path) + progress.close() + + +def batch_to_device(batch, target_device: device): + """ + send a pytorch batch to a device (CPU/GPU) + """ + for key in batch: + if isinstance(batch[key], Tensor): + batch[key] = batch[key].to(target_device) + return batch + + +# from https://github.com/vlkit/vlkit/blob/master/vlkit/ops/distributed.py +class AllGather(torch.autograd.Function): + """ + all_gather with gradient back-propagation + """ + @staticmethod + def forward(ctx, tensor_list, tensor, group, async_op): + torch.distributed.all_gather(tensor_list, tensor, group=group, async_op=async_op) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_list): + grad_list = list(grad_list) + rank = torch.distributed.get_rank() + + dist_ops = [ + torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) + ] + + for op in dist_ops: + op.wait() + + return None, grad_list[rank], None, None + + +all_gather_with_grad = AllGather.apply + + +def mismatched_sizes_all_gather(tensor: Tensor, group=None, async_op=False, mismatched_axis=0): + # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. + assert torch.distributed.is_initialized(), "torch.distributed not initialized" + world_size = torch.distributed.get_world_size() + # let's get the sizes for everyone + mismatched_sizes = torch.tensor([tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda") + sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] + torch.distributed.all_gather(sizes, mismatched_sizes, group=group, async_op=async_op) + sizes = torch.cat(sizes).cpu().tolist() + # now pad to the max dim-0 size + max_size = max(sizes) + padded = torch.zeros((*tensor.shape[:mismatched_axis], max_size, *tensor.shape[mismatched_axis+1:]), + device=tensor.device, dtype=tensor.dtype) + # selects the place where we're adding information + padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) + padded_to_fill[...] = tensor + # gather the padded tensors + tensor_list = [torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) for _ in range(world_size)] + all_gather_with_grad(tensor_list, padded, group, async_op) + # trim off the padding + for rank in range(world_size): + # checks that the rest is 0 + assert not tensor_list[rank].narrow(mismatched_axis, sizes[rank], padded.shape[mismatched_axis]-sizes[rank]).count_nonzero().is_nonzero(), \ + "This would remove non-padding information" + tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) + return tensor_list + + +def fullname(o): + """ + Gives a full name (package_name.class_name) for a class / object in Python. Will + be used to load the correct classes from JSON files + """ + + module = o.__class__.__module__ + if module is None or module == str.__class__.__module__: + return o.__class__.__name__ # Avoid reporting __builtin__ + else: + return module + '.' + o.__class__.__name__ + +def import_from_string(dotted_path): + """ + Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import failed. + """ + try: + module_path, class_name = dotted_path.rsplit('.', 1) + except ValueError: + msg = "%s doesn't look like a module path" % dotted_path + raise ImportError(msg) + + try: + module = importlib.import_module(dotted_path) + except: + module = importlib.import_module(module_path) + + try: + return getattr(module, class_name) + except AttributeError: + msg = 'Module "%s" does not define a "%s" attribute/class' % (module_path, class_name) + raise ImportError(msg) + + +def community_detection(embeddings, threshold=0.75, min_community_size=10, init_max_size=1000): + """ + Function for Fast Community Detection + + Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold). + + Returns only communities that are larger than min_community_size. The communities are returned + in decreasing order. The first element in each list is the central point in the community. + """ + + # Maximum size for community + init_max_size = min(init_max_size, len(embeddings)) + + # Compute cosine similarity scores + cos_scores = cos_sim(embeddings, embeddings) + + # Minimum size for a community + top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True) + + # Filter for rows >= min_threshold + extracted_communities = [] + for i in range(len(top_k_values)): + if top_k_values[i][-1] >= threshold: + new_cluster = [] + + # Only check top k most similar entries + top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True) + top_idx_large = top_idx_large.tolist() + top_val_large = top_val_large.tolist() + + if top_val_large[-1] < threshold: + for idx, val in zip(top_idx_large, top_val_large): + if val < threshold: + break + + new_cluster.append(idx) + else: + # Iterate over all entries (slow) + for idx, val in enumerate(cos_scores[i].tolist()): + if val >= threshold: + new_cluster.append(idx) + + extracted_communities.append(new_cluster) + + # Largest cluster first + extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True) + + # Step 2) Remove overlapping communities + unique_communities = [] + extracted_ids = set() + + for community in extracted_communities: + add_cluster = True + for idx in community: + if idx in extracted_ids: + add_cluster = False + break + + if add_cluster: + unique_communities.append(community) + for idx in community: + extracted_ids.add(idx) + + return unique_communities + + +################## +# +###################### + +from typing import Dict, Optional, Union +from pathlib import Path +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub import HfApi, hf_hub_url, cached_download +# from huggingface_hub.snapshot_download import REPO_ID_SEPARATOR +import fnmatch + +def snapshot_download( + repo_id: str, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + ignore_files: Optional[List[str]] = None +) -> str: + """ + Method derived from huggingface_hub. + Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns + """ + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + _api = HfApi() + model_info = _api.model_info(repo_id=repo_id, revision=revision) + + storage_folder = os.path.join( + cache_dir, repo_id.replace("/", "_") + ) + + for model_file in model_info.siblings: + if ignore_files is not None: + skip_download = False + for pattern in ignore_files: + if fnmatch.fnmatch(model_file.rfilename, pattern): + skip_download = True + break + + if skip_download: + continue + + url = hf_hub_url( + repo_id, filename=model_file.rfilename, revision=model_info.sha + ) + relative_filepath = os.path.join(*model_file.rfilename.split("/")) + + # Create potential nested dir + nested_dirname = os.path.dirname( + os.path.join(storage_folder, relative_filepath) + ) + os.makedirs(nested_dirname, exist_ok=True) + + path = cached_download( + url, + cache_dir=storage_folder, + force_filename=relative_filepath, + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + + if os.path.exists(path + ".lock"): + os.remove(path + ".lock") + + return storage_folder diff --git a/src/music/utilities/__init__.py b/src/music/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/__pycache__/__init__.cpython-39.pyc b/src/music/utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0347ae1dde951b47bbb8211c7366278f129e2054 Binary files /dev/null and b/src/music/utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/utilities/__pycache__/chord_structured.cpython-39.pyc b/src/music/utilities/__pycache__/chord_structured.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8f6a6285d40e740e501ea5d9b1a21d7a8aadc5b Binary files /dev/null and b/src/music/utilities/__pycache__/chord_structured.cpython-39.pyc differ diff --git a/src/music/utilities/__pycache__/midi_processor.cpython-39.pyc b/src/music/utilities/__pycache__/midi_processor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b885099c5a4ff75f3d4dd343eac88b4432be401f Binary files /dev/null and b/src/music/utilities/__pycache__/midi_processor.cpython-39.pyc differ diff --git a/src/music/utilities/chord_structured.py b/src/music/utilities/chord_structured.py new file mode 100644 index 0000000000000000000000000000000000000000..95687f50f5ee4b97cf136b9e24a57bd52ccc2e05 --- /dev/null +++ b/src/music/utilities/chord_structured.py @@ -0,0 +1,523 @@ +""" Structured MIDI encoding method as using in the Piano Inpainting Application +https://arxiv.org/abs/2107.05944 + +""" + +from typing import List, Tuple, Dict, Optional + +import numpy as np +from miditoolkit import Instrument, Note, TempoChange +from miditok import Structured +from miditok.midi_tokenizer_base import MIDITokenizer, Vocabulary, Event +from miditok.constants import * +from itertools import combinations +Cs = np.array([60 + oct for oct in range(-12*4, 12*5, 12)]) + +def get_chord_map(): + my_chord_map = {#'octave': (0, 12), + #'power': (0, 7), + #'power_inv_1': (0, 5), + 'min': (0, 3, 7), + 'maj': (0, 4, 7), + 'dim': (0, 3, 6), + 'aug': (0, 4, 8), + 'sus2': (0, 2, 7), + 'sus4': (0, 5, 7), + '7dom': (0, 4, 7, 10), + '7min': (0, 3, 7, 10), + '7maj': (0, 4, 7, 11), + '7halfdim': (0, 3, 6, 10), + '7dim': (0, 3, 6, 9), + '7aug': (0, 4, 8, 11), + '9maj': (0, 4, 7, 10, 14), + '9min': (0, 4, 7, 10, 13)} + + # + for k in list(my_chord_map.keys()).copy(): + n_notes = len(my_chord_map[k]) + if n_notes > 2: + if k not in ['7dim', 'aug', 'sus2', 'sus4']: + if '9' in k: + nb_invs = 3 + else: + nb_invs = n_notes + for i_inv in range(1, nb_invs): + shift = np.array([my_chord_map[k][(i + i_inv) % n_notes] for i in range(n_notes)]) + shift[-i_inv:] += 12 + pattern = [0] + for i in range(1, len(shift)): + pattern.append(shift[i] - shift[0]) + my_chord_map[k + f'_inv_{i_inv}'] = tuple(pattern) + known = set() + for k in my_chord_map.keys(): + assert my_chord_map[k] not in known + inverted_chord_map = dict() + for k, v in my_chord_map.items(): + inverted_chord_map[v] = k + return my_chord_map, inverted_chord_map + +def find_sub_pattern(pattern, candidate_patterns): + for i in np.arange(len(pattern) - 1, 0, -1): + patt_indexes = [(0,) + c for c in combinations(range(1, len(pattern)), i)] + for p_ind in patt_indexes: + sorted_pattern = np.sort(np.array(pattern)[np.array(p_ind)]) + sorted_pattern = tuple(sorted_pattern - sorted_pattern[0]) + if sorted_pattern in candidate_patterns: + return True, sorted_pattern, np.array(p_ind) + return False, None, None + +# def find_sub_pattern(pattern, candidate_patterns, indexes, n_asserted=1): +# if len(candidate_patterns) == 0 or len(pattern) < 3: +# return False, None, None +# else: +# sorted_pattern = np.sort(pattern) +# sorted_pattern = tuple(sorted_pattern - sorted_pattern[0]) +# if sorted_pattern in candidate_patterns: +# return True, sorted_pattern, indexes +# else: +# if n_asserted + 1 == len(pattern): +# return False, None, None +# else: +# # hypothesis that pattern is good up to n_asserted + 1 +# asserted_pattern = pattern[:n_asserted + 1] +# len_asserted = len(asserted_pattern) +# # find candidate patterns matching that beginning +# sorted_asserted_pattern = np.sort(asserted_pattern) +# sorted_asserted_pattern = tuple(sorted_asserted_pattern - sorted_asserted_pattern[0]) +# c_p = [cp for cp in candidate_patterns if cp[:len_asserted] == sorted_asserted_pattern] +# found, found_pattern, found_indexes = find_sub_pattern(pattern, c_p, indexes, n_asserted=n_asserted+1) +# if found: +# return True, found_pattern, found_indexes +# # if the pattern was not found, then we need to remove that note +# else: +# pattern2 = pattern[: n_asserted] + pattern[n_asserted + 1:] +# if pattern2 == pattern: +# stop = 1 +# new_indexes = indexes.copy() +# new_indexes.pop(n_asserted) +# return find_sub_pattern(pattern2, candidate_patterns, new_indexes, n_asserted=n_asserted) + + +def filter_notes_find_chord_and_root(chord, inverted_chord_map): + known_chords = list(inverted_chord_map.keys()) + found, chord_pattern, chord_indexes = find_sub_pattern(tuple(chord), known_chords) + if found: + chord_id = inverted_chord_map[chord_pattern].split('_')[0] + else: + return False, None, None, None + + # find root now :) + if 'inv' not in inverted_chord_map[chord_pattern]: + root_id = 0 + else: + inv_id = int(inverted_chord_map[chord_pattern].split('_')[-1]) + n_notes = len(chord_pattern) + root_id = n_notes - inv_id + + return True, chord_id, root_id, chord_indexes + + +class ChordStructured(MIDITokenizer): + """ Structured MIDI encoding method as using in the Piano Inpainting Application + https://arxiv.org/abs/2107.05944 + The token types follows the specific pattern: + Pitch -> Velocity -> Duration -> Time Shift -> back to Pitch ... + NOTE: this encoding uses only "Time Shifts" events to move in the time, and only + from one note to another. Hence it is suitable to encode continuous sequences of + notes without long periods of silence. If your dataset contains music with long + pauses, you might handle them with an appropriate "time shift" dictionary + (which values are made from the beat_res dict) or with a different encoding. + + :param pitch_range: range of used MIDI pitches + :param beat_res: beat resolutions, with the form: + {(beat_x1, beat_x2): beat_res_1, (beat_x2, beat_x3): beat_res_2, ...} + The keys of the dict are tuples indicating a range of beats, ex 0 to 3 for the first bar + The values are the resolution, in samples per beat, of the given range, ex 8 + :param nb_velocities: number of velocity bins + :param program_tokens: will add entries for MIDI programs in the dictionary, to use + in the case of multitrack generation for instance + :param sos_eos_tokens: Adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary + :param params: can be a path to the parameter (json encoded) file or a dictionary + """ + def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES, + nb_velocities: int = NB_VELOCITIES, program_tokens: bool = ADDITIONAL_TOKENS['Program'], + sos_eos_tokens: bool = False, params=None): + # No additional tokens + additional_tokens = {'Chord': False, 'Rest': False, 'Tempo': False, 'TimeSignature': False, 'Program': program_tokens} + self.pitch2octave_relative = dict() + self.octave_relative2pitch = dict() + for p in pitch_range: + self.pitch2octave_relative[p] = self.get_octave_and_relative(p) + self.octave_relative2pitch[self.pitch2octave_relative[p]] = p + self.chord_maps, self.inverted_chord_map = get_chord_map() + super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos_tokens, params) + + def get_octave_and_relative(self, pitch): + octave = np.argwhere(pitch - Cs >=0).flatten()[-1] + relative = pitch - Cs[octave] + return octave, relative + + def get_note_events(self, note, dur_bins, next_note_start): + events = [] + if isinstance(note.pitch, str): # it's a chord + chord_id = '_'.join(note.pitch.split('_')[:-1]) + pitch = int(note.pitch.split('_')[-1]) + else: # it's a note + chord_id = 'note' + pitch = note.pitch + # get octave and relative position of the pitch (root pitch for a chord) + octave, relative = self.pitch2octave_relative[pitch] + # Add chord/note event. A note is defined as Chord_note + events.append(Event(type_='Chord', time=note.start, value=chord_id, desc=note.pitch)) + # Add octave of the root + events.append(Event(type_='OctavePitch', time=note.start, value=octave, desc=note.pitch)) + # Add octave relative pitch of the root + events.append(Event(type_='RelativePitch', time=note.start, value=relative, desc=note.pitch)) + # Velocity + events.append(Event(type_='Velocity', time=note.start, value=note.velocity, desc=f'{note.velocity}')) + # Duration + duration = note.end - note.start + index = np.argmin(np.abs(dur_bins - duration)) + events.append(Event(type_='Duration', time=note.start, value='.'.join(map(str, self.durations[index])), desc=f'{duration} ticks')) + # Time-Shift + time_shift = next_note_start - note.start + assert time_shift >= 0 # this asserts that events are sorted + index = np.argmin(np.abs(dur_bins - time_shift)) + events.append(Event(type_='Time-Shift', time=note.start, desc=f'{time_shift} ticks', + value='.'.join(map(str, self.durations[index])) if time_shift != 0 else '0.0.1')) + return events, time_shift + + def track_to_tokens(self, track: Instrument) -> List[int]: + """ Converts a track (miditoolkit.Instrument object) into a sequence of tokens + + :param track: MIDI track to convert + :return: sequence of corresponding tokens + """ + # Make sure the notes are sorted first by their onset (start) times, second by pitch + # notes.sort(key=lambda x: (x.start, x.pitch)) # done in midi_to_tokens + events = [] + + dur_bins = self.durations_ticks[self.current_midi_metadata['time_division']] + + # assume first note is the beginning of the song, no time shift at first. + + # Track chords. For each chord, insert a fake note that contains its info so that it can be converted to the proper event + if self.additional_tokens['Chord'] and not track.is_drum: + notes_and_chords = self.detect_chords(track.notes, self.current_midi_metadata['time_division'], self._first_beat_res) + else: + notes_and_chords = track.notes + + sum_shifts = 0 + # Creates the Pitch, Velocity, Duration and Time Shift events + for n, note in enumerate(notes_and_chords): + if n == len(notes_and_chords) - 1: + next_note_start = note.start # add zero time shift at the end + else: + next_note_start = notes_and_chords[n + 1].start + new_events, time_shift = self.get_note_events(note, dur_bins, next_note_start=next_note_start) + events += new_events + sum_shifts += time_shift + assert len(events) // 6 == len(notes_and_chords) + + return self.events_to_tokens(events) + + def tokens_to_track(self, tokens: List[int], time_division: Optional[int] = TIME_DIVISION, + program: Optional[Tuple[int, bool]] = (0, False)) -> Tuple[Instrument, List[TempoChange]]: + """ Converts a sequence of tokens into a track object + + :param tokens: sequence of tokens to convert + :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create) + :param program: the MIDI program of the produced track and if it drum, (default (0, False), piano) + :return: the miditoolkit instrument object and a "Dummy" tempo change + """ + events = self.tokens_to_events(tokens) + instrument = Instrument(program[0], is_drum=False, name=MIDI_INSTRUMENTS[program[0]]['name']) + current_tick = 0 + count = 0 + # start at first chord event + while count < len(events) and events[count].type != 'Chord': + count += 1 + + while count < len(events): + if events[count].type == 'Chord': + note_chord_events = [events[c] for c in range(count, count + 6)] + events_types = [c.type for c in note_chord_events] + if events_types[1:] == ['OctavePitch', 'RelativePitch', 'Velocity', 'Duration', 'Time-Shift']: + octave, relative = int(note_chord_events[1].value), int(note_chord_events[2].value) + duration = self._token_duration_to_ticks(note_chord_events[4].value, time_division) + vel = int(note_chord_events[3].value) + root_pitch = self.octave_relative2pitch[(octave, relative)] + if note_chord_events[0].value == "note": + # pass + instrument.notes.append(Note(vel, root_pitch, current_tick, current_tick + duration)) + else: + pitches = self.find_chord_pitches(root_pitch, note_chord_events[0].value) + for p in pitches: + instrument.notes.append(Note(vel, p, current_tick, current_tick + duration)) + + beat, pos, res = map(int, note_chord_events[5].value.split('.')) + current_tick += (beat * res + pos) * time_division // res # time shift + count += 6 + else: + count += 1 + else: + count += 1 + + return instrument, [TempoChange(TEMPO, 0)] + + def find_chord_pitches(self, root_pitch, chord_name): + chord_map = self.chord_maps[chord_name] + if 'inv' not in chord_map: + root_position = 0 + else: + inv_id = int(chord_name.split('_')[-1]) + n_notes = len(chord_map) + root_position = n_notes - inv_id + deltas = np.array(chord_map) - chord_map[root_position] + pitches = [root_pitch + d for d in deltas] + return pitches + + def _create_vocabulary(self, sos_eos_tokens: bool = False) -> Vocabulary: + """ Creates the Vocabulary object of the tokenizer. + See the docstring of the Vocabulary class for more details about how to use it. + NOTE: token index 0 is often used as a padding index during training + + :param sos_eos_tokens: will include Start Of Sequence (SOS) and End Of Sequence (tokens) + :return: the vocabulary object + """ + vocab = Vocabulary({'PAD_None': 0}) + + if self.additional_tokens['Chord']: + vocab.add_event(f'Chord_{chord_quality}' for chord_quality in CHORD_MAPS) + + # PITCH + vocab.add_event('Chord_note') + vocab.add_event(f'OctavePitch_{i}' for i in range(8)) + vocab.add_event(f'RelativePitch_{i}' for i in range(12)) + # vocab.add_event(f'Pitch_{i}' for i in self.pitch_range) + + # VELOCITY + vocab.add_event(f'Velocity_{i}' for i in self.velocities) + + # DURATION + vocab.add_event(f'Duration_{".".join(map(str, duration))}' for duration in self.durations) + + # TIME SHIFT (same as durations) + vocab.add_event('Time-Shift_0.0.1') # for a time shift of 0 + vocab.add_event(f'Time-Shift_{".".join(map(str, duration))}' for duration in self.durations) + + # PROGRAM + if self.additional_tokens['Program']: + vocab.add_event(f'Program_{program}' for program in range(-1, 128)) + + # SOS & EOS + if sos_eos_tokens: + vocab.add_sos_eos_to_vocab() + + return vocab + + def _create_token_types_graph(self) -> Dict[str, List[str]]: + """ Returns a graph (as a dictionary) of the possible token + types successions. + NOTE: Program type is not referenced here, you can add it manually by + modifying the tokens_types_graph class attribute following your strategy. + + :return: the token types transitions dictionary + """ + dic = {'Pitch': ['Velocity'], 'Velocity': ['Duration'], 'Duration': ['Time-Shift'], 'Time-Shift': ['Pitch']} + self._add_pad_type_to_graph(dic) + return dic + + def token_types_errors(self, tokens: List[int], consider_pad: bool = False) -> float: + """ Checks if a sequence of tokens is constituted of good token types + successions and returns the error ratio (lower is better). + The Pitch values are also analyzed: + - a pitch token should not be present if the same pitch is already played at the time + + :param tokens: sequence of tokens to check + :param consider_pad: if True will continue the error detection after the first PAD token (default: False) + :return: the error ratio (lower is better) + """ + err = 0 + previous_type = self.vocab.token_type(tokens[0]) + current_pitches = [] + + def check(tok: int): + nonlocal err + nonlocal previous_type + nonlocal current_pitches + token_type, token_value = self.vocab.token_to_event[tok].split('_') + + # Good token type + if token_type in self.tokens_types_graph[previous_type]: + if token_type == 'Pitch': + if int(token_value) in current_pitches: + err += 1 # pitch already played at current position + else: + current_pitches.append(int(token_value)) + elif token_type == 'Time-Shift': + if self._token_duration_to_ticks(token_value, 48) > 0: + current_pitches = [] # moving in time, list reset + # Bad token type + else: + err += 1 + previous_type = token_type + + if consider_pad: + for token in tokens[1:]: + check(token) + else: + for token in tokens[1:]: + if previous_type == 'PAD': + break + check(token) + return err / len(tokens) + + def detect_chords(self, list_notes: List[Note], time_division: int, beat_res: int = 4, onset_offset: int = 1, + only_known_chord: bool = False, simul_notes_limit: int = 20, verbose=False) -> List[Event]: + """ Chord detection method. + NOTE: make sure to sort notes by start time then pitch before: notes.sort(key=lambda x: (x.start, x.pitch)) + NOTE2: on very large tracks with high note density this method can be very slow ! + If you plan to use it with the Maestro or GiantMIDI datasets, it can take up to + hundreds of seconds per MIDI depending on your cpu. + One time step at a time, it will analyse the notes played together + and detect possible chords. + + :param notes: notes to analyse (sorted by starting time, them pitch) + :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed) + :param beat_res: beat resolution, i.e. nb of samples per beat (default 4) + :param onset_offset: maximum offset (in samples) ∈ N separating notes starts to consider them + starting at the same time / onset (default is 1) + :param only_known_chord: will select only known chords. If set to False, non recognized chords of + n notes will give a chord_n event (default False) + :param simul_notes_limit: nb of simultaneous notes being processed when looking for a chord + this parameter allows to speed up the chord detection (default 20) + :return: the detected chords as Event objects + """ + assert simul_notes_limit >= 5, 'simul_notes_limit must be higher than 5, chords can be made up to 5 notes' + tuples = [] + for note in list_notes: + tuples.append((note.pitch, int(note.start), int(note.end), int(note.velocity))) + notes = np.asarray(tuples) + + time_div_half = time_division // 2 + onset_offset = time_division * onset_offset / beat_res + + count = 0 + previous_tick = -1 + detected_chords = [] + note_belong_to_chord_id = dict() + while count < len(notes): + # Checks we moved in time after last step, otherwise discard this tick + if notes[count, 1] == previous_tick: + count += 1 + continue + + # Gathers the notes around the same time step + # Reduce the scope of the search + notes_to_consider = notes[count:count + simul_notes_limit].copy() + old_true_notes_indexes = np.arange(count, count + simul_notes_limit) # keep track of true note indexes + # Take notes withing onset_offset samples of the first note + indexes_valid = np.where(notes_to_consider[:, 1] <= notes_to_consider[0, 1] + onset_offset) + true_notes_indexes = old_true_notes_indexes[indexes_valid] + onset_notes = notes_to_consider[indexes_valid] + # Take notes that end close to the first note's end + indexes_valid = np.where(np.abs(onset_notes[:, 2] - onset_notes[0, 2]) < time_div_half) + true_notes_indexes = true_notes_indexes[indexes_valid] + onset_notes = onset_notes[indexes_valid] + + # if there are at least 3 notes, try to find the chord + if len(onset_notes) >= 3: + found, chord_name, root_id, chord_notes_indexes = filter_notes_find_chord_and_root(onset_notes[:, 0], self.inverted_chord_map) + # if found: + # found, chord_name, root_id, chord_notes_indexes = filter_notes_find_chord_and_root(notes_to_consider[:, 0], self.inverted_chord_map) + + if found: + detected_chord_id = len(detected_chords) + # get the indexes of the notes in the chord wrt the onset_notes array + relative_indexes_chord_notes_in_onset_notes = np.array(chord_notes_indexes) + # get true indexes of the notes in the chord (indexes of the note stream) + true_indexes = true_notes_indexes[relative_indexes_chord_notes_in_onset_notes] + # for each note, track the chords it belongs to in note_belong_to_chord_id + for i in true_indexes: + if i not in note_belong_to_chord_id.keys(): + note_belong_to_chord_id[i] = [detected_chord_id] + else: + note_belong_to_chord_id[i].append(detected_chord_id) + # save the info of the detected chord + root_position_in_sorted_onset = chord_notes_indexes[root_id] + root_pitch = onset_notes[root_position_in_sorted_onset, 0] + onset = np.min([notes[i, 1] for i in true_indexes]) + offset = int(np.mean([notes[i, 2] for i in true_indexes])) + velocity = self.velocities[int(np.argmin(np.abs(self.velocities - int(np.mean([notes[i, 3] for i in true_indexes])))))] # quantize velocity + detected_chords.append((chord_name, true_indexes, root_pitch, onset, offset, velocity)) + if verbose: print(f'New chord detected: {chord_name}, root {root_pitch} with notes: {true_indexes}, onset: {onset}, offset: {offset}, velocity: {velocity}') + + count += 1 + + # now we need to delete some the redundant detected chords to have just one chord per note + indexes_chords_to_remove = [] + + for note, chord_ids in note_belong_to_chord_id.copy().items(): + # remove chords that were already filtered + chord_ids = sorted(set(chord_ids) - set(indexes_chords_to_remove)) + if len(chord_ids) == 0: # if not remaining chords, then the note should be removed + del note_belong_to_chord_id[note] + else: + note_belong_to_chord_id[note] = chord_ids # update the chord_ids + if len(chord_ids) > 1: # if several, we need to filter by the number of notes in the chords + chords = [detected_chords[i] for i in chord_ids] + selected_chord = np.argmax([len(c[1]) for c in chords]) + note_belong_to_chord_id[note] = [chord_ids[selected_chord]] + for i_c, c in enumerate(chord_ids): + if i_c != selected_chord: + indexes_chords_to_remove.append(c) + for note, chord_ids in note_belong_to_chord_id.copy().items(): + chord_ids = sorted(set(chord_ids) - set(indexes_chords_to_remove)) + if len(chord_ids) == 0: # if not remaining chords, then the note should be removed + del note_belong_to_chord_id[note] + else: + note_belong_to_chord_id[note] = chord_ids # update the chord_ids + selected_chords = [detected_chords[i] for i in range(len(detected_chords)) if i not in indexes_chords_to_remove] + selected_chords_ids = [i for i in range(len(detected_chords)) if i not in indexes_chords_to_remove] + # check that all notes are used just once + all_chord_notes = [] + for c in selected_chords: + all_chord_notes += list(c[1]) + assert len(all_chord_notes) == len(set(all_chord_notes)) + + # format new stream of notes, removing chord notes from them, and inserting "chord" to be able to track timeshifts + new_list_notes = [] + note_dict_keys = list(note_belong_to_chord_id.keys()) + inserted_chords = [] + count_added = 0 + for i in range(len(list_notes)): + if i not in note_dict_keys: + new_list_notes.append(list_notes[i]) + else: + assert len(note_belong_to_chord_id[i]) == 1 + chord_id = note_belong_to_chord_id[i][0] + if chord_id not in inserted_chords: + inserted_chords.append(chord_id) + count_added += 1 + chord_id, _, root_pitch, onset, offset, velocity = detected_chords[chord_id] + new_list_notes.append(Note(velocity=velocity, start=onset, end=offset, pitch=chord_id + '_' + str(root_pitch))) + # check the new count of notes (all previous notes - the number of notes in the chords + the number of chords) + assert len(new_list_notes) == (len(list_notes) - len(all_chord_notes) + len(selected_chords)) + return new_list_notes + + +if __name__ == '__main__': + from miditoolkit import MidiFile + + pitch_range = range(21, 109) + beat_res = {(0, 4): 8, (4, 12): 4} + nb_velocities = 32 + tokenizer_structured = ChordStructured(pitch_range, beat_res, nb_velocities) + # tokenizer_structured = Structured(pitch_range, beat_res, nb_velocities) + + path = '/home/cedric/Documents/pianocktail/data/music/processed/vkgoeswild_processed/ac_dc_hells_bells_vkgoeswild_piano_cover_processed.mid' + midi = MidiFile(path) + tokens = tokenizer_structured.midi_to_tokens(midi) + midi = tokenizer_structured.tokens_to_midi(tokens) + midi.dump("/home/cedric/Desktop/tes/transcribed.mid") \ No newline at end of file diff --git a/src/music/utilities/clean_folder_and_file_names.py b/src/music/utilities/clean_folder_and_file_names.py new file mode 100644 index 0000000000000000000000000000000000000000..4808fe31819eda8c8a87d0590171f4f8b4e9ebbb --- /dev/null +++ b/src/music/utilities/clean_folder_and_file_names.py @@ -0,0 +1,34 @@ +import os +from shutil import copy +from src.music.utils import slugify + +def rename_inside_files_with_folder_name(this_path, folder_path): + if os.path.isfile(this_path): + filename = this_path.split('/')[-1] # remove path + filename = '.'.join(filename.split('.')[:-1]) # remove extension + new_filename = slugify(folder_path.split('/')[-2] + '_' + filename) + '_midi.mid' + copy(this_path, folder_path + new_filename) + os.remove(this_path) + else: + for file_or_fold in os.listdir(this_path): + file_or_folder_path = this_path + '/' + file_or_fold + rename_inside_files_with_folder_name(file_or_folder_path, folder_path) + os.rmdir(this_path) + +path = '/home/cedric/Documents/pianocktail/data/music/test/midi/64k_midi/' + +n_dirs = len(os.listdir(path)) +for i_fold, folder in enumerate(os.listdir(path)): + print(i_fold + 1, '/', n_dirs) + folder_path = path + folder +'/' + if os.path.isdir(folder_path): + # rename folder + new_folder = slugify(folder) + new_folder_path = path + new_folder +'/' + os.rename(folder_path, new_folder_path) + # for file or folder inside, rename them with this name, recursively + for file_or_fold in os.listdir(new_folder_path): + file_or_folder_path = new_folder_path + file_or_fold + rename_inside_files_with_folder_name(file_or_folder_path, new_folder_path) + else: + stop = 1 \ No newline at end of file diff --git a/src/music/utilities/handcoded_rep_utilities/__init__.py b/src/music/utilities/handcoded_rep_utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/handcoded_rep_utilities/analysis_handcoded.py b/src/music/utilities/handcoded_rep_utilities/analysis_handcoded.py new file mode 100644 index 0000000000000000000000000000000000000000..718ea340565f020a3cf774c492232a48191eb741 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/analysis_handcoded.py @@ -0,0 +1,177 @@ +import json +import os +import numpy as np +import matplotlib.pyplot as plt +from sklearn.neighbors import NearestNeighbors +from shutil import copy +import pretty_midi as pm +import pydub +import pandas as pd +from scipy.io.wavfile import write + +path = "/home/cedric/Documents/pianocktail/data/music/input_jazz/handcoded_reps/" + +keys_to_investigate = sorted(['av_velocity', 'major_minor', 'articulation', 'av_pitch', 'av_sensory_dissonance', 'motion_direction', 'note_density', + 'pulse_strength', 'std_pitch', 'tonal_certainty']) +plot_dir = path + 'plots/' +os.makedirs(plot_dir, exist_ok=True) +sample_dir = path + 'samples/' +os.makedirs(sample_dir, exist_ok=True) +test_dir = sample_dir + 'test/' +os.makedirs(test_dir, exist_ok=True) + +FS = 44100 +FADE_IN_LEN = 2 # second +FADE_IN = np.arange(0, 1, 1/(FADE_IN_LEN * FS)) +SAMPLE_LEN = 15 # in seconds + +def collect_filepaths_and_reps(path): + file_paths = [] + reps = [] + keys = None + for folder in os.listdir(path): + folder_path = path + folder + '/' + if os.path.isdir(folder_path) and 'plot' not in folder_path and 'samples' not in folder_path: + for file in os.listdir(folder_path): + file_path = folder_path + file + #load json + if '.json' in file_path: + with open(file_path, 'r') as f: + data = json.load(f) + if keys == None: + keys = sorted(data.keys()) + rep = [data[k] for k in keys] + reps.append(rep) + file_paths.append(file_path) + + assert all([k in keys for k in keys_to_investigate]) + reps = np.array(reps) + #normalize reps + + valid_indexes = np.argwhere(reps.std(axis=0) != 0).flatten() + print(f'{[keys[k] for k in np.argwhere(reps.std(axis=0) == 0).flatten()]} are invalid keys') + reps[:, valid_indexes] = (reps[:, valid_indexes] - np.nanmin(reps[:, valid_indexes], axis=0)) / (np.nanmax(reps[:, valid_indexes], axis=0) - np.nanmin(reps[:, valid_indexes], axis=0)) + assert all(np.nanmin(reps[:, valid_indexes], axis=0) == 0) + assert all(np.nanmax(reps[:, valid_indexes], axis=0) == 1) + return file_paths, reps, keys + +def plot(keys, reps): + for i in range(len(keys)): + plt.figure() + plt.hist(reps[:, i], bins=30) + plt.title(keys[i]) + plt.savefig(plot_dir + f'hist_{keys[i]}.png') + plt.close('all') + stop = 1 + plt.close('all') + + print(f'Stats of {len(reps)} points:') + for i, k in enumerate(keys): + print(f'Key {k} - mean: {np.nanmean(reps[:, i]):.2f}, std: {np.nanstd(reps[:, i]):.2f}') + + correlation_matrix = np.zeros((len(keys), len(keys))) + for j in range(len(keys)): + for k in range(len(keys)): + correlation_matrix[j, k] = np.corrcoef(reps[:, j], reps[:, k])[0, 1] + if j > k: + # if correlation_matrix[j, k] > 0.5: + # print(f'Pos corr: {keys[j]} - {keys[k]}, score: {correlation_matrix[j, k]:.2f}') + if correlation_matrix[j, k] < -0.5: + print(f'Neg corr: {keys[j]} - {keys[k]}, score: {correlation_matrix[j, k]:.2f}') + av_correlations = np.nanmean(np.abs(correlation_matrix), axis=0) + av_correlations[np.where(np.isnan(av_correlations))] = 1 + indexes_low_correlations = np.argsort(av_correlations) + for i in indexes_low_correlations[:10]: + print(keys[i], f'score: {av_correlations[i]}') + plt.figure() + plt.imshow(correlation_matrix, cmap='seismic') + plt.clim(-1, 1) + plt.colorbar() + plt.savefig(plot_dir + 'correlation_matrix.png') + + +def investigate_reps(reps, keys, file_paths): + i = -1 + path = "/home/cedric/Documents/pianocktail/data/music/handcoded_reps/samples/test/" + + i += 1 + plt.close('all') + for f in os.listdir(path): + os.remove(path + f) + key = keys[i] + print(f'\n#{i+1}', key) + indexes = np.argsort(reps[:, i]) + lowest = indexes[:10] + highest = np.flip(indexes)[:10] + for i_ind, ind in enumerate(lowest): + fp = file_paths[ind] + copy(fp.replace('.json', '.mid'), path + f'lowest_top{i_ind+1}_score{reps[ind, i]:.2f}_' + fp.split('/')[-1].replace('.json', '.mid')) + for i_ind, ind in enumerate(highest): + fp = file_paths[ind] + copy(fp.replace('.json', '.mid'), path + f'highest_top{i_ind+1}_score{reps[ind, i]:.2f}_' + fp.split('/')[-1].replace('.json', '.mid')) + plt.figure() + plt.hist(reps[:, i], bins=30) + plt.title(keys[i]) + plt.show() + +def sample(reps, keys, file_paths): + + sub_reps = reps[:, np.array([keys.index(k) for k in keys_to_investigate])] + + all_indexes_to_sample = [] + reasons = [] + for i in range(len(keys_to_investigate)): + features = sub_reps[:, i] + ps = np.arange(2.5, 97.6, 95 / 9) + for p in ps: + reason = f'{keys_to_investigate[i]}_percentile_{p:.2f}_' + percentile = np.percentile(features, q=p) + index = np.argmin(np.abs(features - percentile)) + if index in all_indexes_to_sample: + reasons[all_indexes_to_sample.index(index)] += reason + else: + all_indexes_to_sample.append(index) + reasons.append(reason) + + # shuffle indexes + index_shuffling = np.arange(len(all_indexes_to_sample)) + np.random.shuffle(index_shuffling) + all_indexes_to_sample = np.array(all_indexes_to_sample)[index_shuffling][:90] + reasons = np.array(reasons)[index_shuffling][:90] + + # generate names + names = [f'form_{f}_sample_{i}' for f in ['a', 'b', 'c'] for i in np.arange(1, 31)] + meta_data = pd.DataFrame.from_dict(dict(name=names, reason=reasons, filename=[file_paths[i] for i in all_indexes_to_sample])) + meta_data.to_csv(sample_dir + 'data.csv', index=False) + # fade-in factors + for i, ind, r in zip(range(len(reasons)), all_indexes_to_sample, reasons): + fp = file_paths[ind] + new_filepath = '/'.join(fp.split('/')[:-2] + ['samples/test/', r + fp.split('/')[-1]]).replace('.json', '') + # copy(fp.replace('.json', '.mid'), new_filepath + '.mid') + midi = pm.PrettyMIDI(fp.replace('.json', '.mid')) + audio = midi.fluidsynth(fs=FS) + audio = audio[-SAMPLE_LEN * FS:] * (2 ** 31 - 1) + audio[:len(FADE_IN)] *= FADE_IN + audio[-len(FADE_IN):] *= np.flip(FADE_IN) + sound2 = pydub.AudioSegment(data=audio.astype("int32").tobytes(), channels=1, frame_rate=FS, sample_width=4) + sound2.export(new_filepath + '.mp3', format="mp3") + sound2.export(sample_dir + names[i] + '.mp3', format="mp3") + copy(fp, sample_dir + names[i] + '.json') + stop = 1 + + +if __name__ == '__main__': + file_paths, reps, keys = collect_filepaths_and_reps(path) + # plot(keys, reps) + # investigate_reps(reps[:, np.array([keys.index(k) for k in to_test])], to_test, file_paths) + sample(reps, keys, file_paths) + + + + + + + + + + diff --git a/src/music/utilities/handcoded_rep_utilities/loudness.py b/src/music/utilities/handcoded_rep_utilities/loudness.py new file mode 100644 index 0000000000000000000000000000000000000000..f58a65c1800e6f7aaba8c15683e2d42bf92ab9f3 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/loudness.py @@ -0,0 +1,256 @@ +"""Contains functions for generating and using equal-loudness contours for +side-presented steady pure tones according to the ISO/IEC 226 standard. +Code from https://gist.github.com/sammosummo/777debf946d0356acada and Cédric Colas +""" + + +__author__ = 'Sam Mathias' +__version__ = 1.0 + + +import numpy as np +from scipy.interpolate import interp1d +a_freq = 440 +a_pitch = 69 +upright_piano_dynamic_range = 60 # in db +a_440_db_at_max_velocity = 75 +a_440_amplitude_at_max_velocity = 10 ** (a_440_db_at_max_velocity / 20) + +def iso266(phon, return_freq=False): + """Returns an equal-loudness contour evaluated at 29 frequencies between + 20 Hz and 12.5 kHz according to the ISO/IEC 226 standard [1]_. + Parameters + ---------- + phon : int or float + The phon value represented by the equal-loudness contour, where a value + of :math:`x` phon is the loudness of 1-KHz steady pure tone presented + at :math:`x` dB SPL. Must be between 0 and 90. + return_freq : bool, optional + If True, the function returns the frequency values as well as the SPL + values of the contour. Default is False. + Returns + ------- + array_like + Either a 1-D or a 2-D numpy array, depending on `return_freq`. + Reference + --------- + .. [1] ISO/IEC (2003). ISO/IEC 226:2003 Acoustics -- Normal equal-loudness- + level contours. + http://www.iso.org/iso/catalogue_detail.htm?csnumber=34222. + Example + ------- + elc = iso266(60, return_freq=True) + print elc + [[ 20. 25. 31.5 40. 50. + 63. 80. 100. 125. 160. + 200. 250. 315. 400. 500. + 630. 800. 1000. 1250. 1600. + 2000. 2500. 3150. 4000. 5000. + 6300. 8000. 10000. 12500. ] + [ 109.51132227 104.22789784 99.07786826 94.17731862 + 89.96345731 85.94342131 82.05340072 78.65461863 + 75.56345314 72.4743448 69.86431929 67.53483532 + 65.39173983 63.45099627 62.0511792 60.81495942 + 59.88668375 60.011588 62.1549143 63.18935604 + 59.96161453 57.25515019 56.42385863 57.56993838 + 60.8882125 66.36125056 71.66396598 73.15510401 + 68.63077045]] + """ + if not 0 <= phon <= 90: + raise ValueError('Cannot calculate for this value.') + + f = np.array([ + 20, 25, 31.5, 40, 50, 63, 80, 100, 125, 160, 200, 250, 315, 400, 500, + 630, 800, 1000, 1250, 1600, 2000, 2500, 3150, 4000, 5000, 6300, 8000, + 10000, 12500 + ]) + + af = np.array([ + 0.532, 0.506, 0.480, 0.455, 0.432, 0.409, 0.387, 0.367, 0.349, 0.330, + 0.315, 0.301, 0.288, 0.276, 0.267, 0.259, 0.253, 0.250, 0.246, 0.244, + 0.243, 0.243, 0.243, 0.242, 0.242, 0.245, 0.254, 0.271, 0.301 + ]) + + Lu = np.array([ + -31.6, -27.2, -23.0, -19.1, -15.9, -13.0, -10.3, -8.1, -6.2, -4.5, + -3.1, -2.0, -1.1, -0.4, 0.0, 0.3, 0.5, 0.0, -2.7, -4.1, -1.0, 1.7, + 2.5, 1.2, -2.1, -7.1, -11.2, -10.7, -3.1 + ]) + + Tf = np.array([ + 78.5, 68.7, 59.5, 51.1, 44.0, 37.5, 31.5, 26.5, 22.1, 17.9, 14.4, 11.4, + 8.6, 6.2, 4.4, 3.0, 2.2, 2.4, 3.5, 1.7, -1.3, -4.2, -6.0, -5.4, -1.5, + 6.0, 12.6, 13.9, 12.3 + ]) + + Ln = phon + + Af = 4.47e-3 * (10 ** (.025 * Ln) - 1.15) \ + + (.4 * 10 ** (((Tf + Lu) / 10.) - 9)) ** af + Lp = ((10 / af) * np.log10(Af)) - Lu + 94 + + spl = Lp + freq = f + + if return_freq is True: + return np.array([freq, spl]) + + else: + return spl + + +def equal_loudness(phon, freqs, return_freq=False): + """Returns equal-loudness levels for any frequencies between 20 Hz and + 12.5 kHz according to the ISO/IEC 226 standard [1]_. + Parameters + ---------- + phon : number + The phon value represented by the equal-loudness contour, where a value + of :math:`x` phon is the loudness of 1-KHz steady pure tone presented + at :math:`x` dB SPL. Must be between 0 and 90. + freqs : number or array_like + Frequency or frequencies in Hz to be evaluated. Must be between 20 and + 12500. + return_freq : bool, optional + If True, the function returns the frequency values as well as the SPL + values of the contour. Default is False. + Returns + ------- + array_like + Either a 1-D or a 2-D numpy array, depending on `return_freq`. + Reference + --------- + .. [1] ISO/IEC (2003). ISO/IEC 226:2003 Acoustics -- Normal equal-loudness- + level contours. + http://www.iso.org/iso/catalogue_detail.htm?csnumber=34222. + Example + ------- + >>> el = equal_loudness(60, [500, 1000, 2000], return_freq=True) + >>> print el + [[ 500. 1000. 2000. ] + [ 62.0511792 60.011588 59.96161453]] + """ + f = interp1d(*iso266(phon, True), kind='cubic') + + if return_freq is True: + return np.array([freqs, f(freqs)]) + + else: + return f(freqs) + + +def get_loudness(spl, freq): + """Returns the approximate loudness level in phons for a side-presented + steady pure tone according to the ISO/IEC 226 standard [1]_. + This function generates a range of equal-loudness contours and interpolates + between them. Therefore it is more efficient to pass many level and + frequency values to one function call than it is to make many function + calls. + Parameters + ---------- + spl : number or array_like + Sound pressure level or levels in dB SPL. + freq : number or array_like + Frequency or frequencies in Hz. + Returns + ------- + number or array_like + Phon value(s). + Reference + --------- + .. [1] ISO/IEC (2003). ISO/IEC 226:2003 Acoustics -- Normal equal-loudness- + level contours. + http://www.iso.org/iso/catalogue_detail.htm?csnumber=34222. + Example + ------- + phons = get_loudness([50, 60, 70] [500, 500, 500]) + print phons + [ 47.3 57.8 68.4] + + """ + phons = np.arange(0, 90.1, .1) + freqs = np.arange(20, 12501) + spls = np.empty((len(phons), len(freqs))) + + for i, phon in enumerate(phons): + spls[i] = equal_loudness(phon, freqs) + + if not hasattr(spl, '__iter__'): + spl = [spl] + + if not hasattr(freq, '__iter__'): + freq = [freq] + + spls = spls.T + results = [] + + for _spl, _freq in zip(spl, freq): + ix = (np.abs(freqs - _freq)).argmin() + iy = (np.abs(spls[ix] - _spl)).argmin() + results.append(phons[iy]) + + if len(results) == 1: + return results[0] + + else: + return np.array(results) + + +def pitch2freq(pitch): + # from https://music.arts.uci.edu/dobrian/maxcookbook/pitch-and-loudness-formulae + relative_pitch = pitch - 69 + factor = 2 ** (relative_pitch / 12) + freq = a_freq * factor + return freq + +def velocity2amplitude(velocity): + # from https://www.cs.cmu.edu/~rbd/papers/velocity-icmc2006.pdf + r = 10 ** (upright_piano_dynamic_range / 20) + b = 127 / (126 * np.sqrt(r)) - (1 / 126) + m = (1 - b) / 127 + a = (m * velocity + b) ** 2 + a *= a_440_amplitude_at_max_velocity # scale amplitudes to get realistic perceived loudness + return a + +def amplitude2db(amplitude): + power_db = 20 * np.log10(amplitude) + return power_db + +def get_db_of_equivalent_loudness_at_440hz(freqs, db): + phons = get_loudness(db, freqs) + equal_dbs = [] + for p in phons: + equal_dbs.append(equal_loudness(p, [440])[0]) + return np.array(equal_dbs) + +def compute_total_loudness(eq_amplitudes_440hz, onsets, offsets): + # Compute the instantaneous amplitude, turn it back to dbs, then to perceived loudness with unique freq 440 Hz + # model amplitude as square function, loudness = peak amplitude from onset to offset, 0 afterwards. + # an exponential model might be better + assert all([len(values) == len(onsets) for values in [eq_amplitudes_440hz, offsets]]) + + timepoints = np.array(sorted(onsets + offsets)) + amplitudes_per_time = np.zeros(len(timepoints)) + # on each segment, compute the total amplitude + # amplitudes are not just summed: p1+p2 = sqrt(p1**2 + p2**2) + # ref: https://erlend-viggen.no/acoustic-quantities-1/ + for i_n in range(len(onsets)): + indexes = np.where(np.logical_and(timepoints >= onsets[i_n], timepoints < offsets[i_n])) + amplitudes_per_time[indexes] += eq_amplitudes_440hz[i_n] ** 2 + for i in range(len(amplitudes_per_time)): + amplitudes_per_time[i] = np.sqrt(amplitudes_per_time[i]) + # compute power + power_per_time = amplitude2db(amplitudes_per_time) + power_per_time[np.where(power_per_time == -np.inf)] = 0 + # compute loudness + loudness_per_time = get_loudness(power_per_time, [440] * len(power_per_time)) # amplitudes at 440hz, they were computed to make same loudness as original amplitudes at original F. + + # now integrate + total_loudness = 0 + for i_t in range(len(timepoints) - 1): + total_loudness += loudness_per_time[i_t] * (timepoints[i_t + 1] - timepoints[i_t]) + + return total_loudness / (timepoints[-1] - timepoints[0]), np.std(loudness_per_time) + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/src/music/utilities/handcoded_rep_utilities/tht/__init__.py b/src/music/utilities/handcoded_rep_utilities/tht/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/handcoded_rep_utilities/tht/confidence.py b/src/music/utilities/handcoded_rep_utilities/tht/confidence.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad635783bbe8b67fdfba6502530406375ca4e7d --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/confidence.py @@ -0,0 +1,250 @@ +"""Module containing functions to evaluate the confidence of a hypothesis +over a ongoing playback.""" + +import scipy.stats as st +import numpy as np +import warnings + + +def gaussian_weight(distances): + return np.exp(-(distances ** 2)) + + +def conf_exp(xs, proj, onsets, delta): + _, r_p, p = zip(*utils.project(xs, proj, onsets)) + errors = abs(np.array(p) - np.array(r_p)) + relative_errors = errors / float(delta) + ret = 0.01 ** relative_errors + return ret + + +def conf(xs, proj, onsets, delta, mult, decay, weight_func=gaussian_weight): + '''Confidence of a set of tactus projections over a playback. + + Complexity: O(|proj|) \in O(|ongoing_play|) + ''' + xs, r_p, p = list(zip(*utils.project(xs, proj, onsets))) + errors = np.array(p) - np.array(r_p) + relative_errors = decay * errors / float(delta) + ret = weight_func(relative_errors) + return mult * ret + + +def all_history_eval_exp(ht, ongoing_play): + ''' + Evaluates a hypothesis on an ongoing_play. It takes into consideration the + whole history of the playback. Uses exponential function for distance. + + Complexity: O(|ongoing_play|) + ''' + xs, proj = zip(*ht.proj_with_x(ongoing_play)) + conf_sum = sum(conf_exp(xs, proj, ongoing_play.discovered_play(), ht.d)) + return ((conf_sum / len(proj)) * + (conf_sum / len(ongoing_play.discovered_play()))) + + +class WindowedExpEval: + '''Confidence is evaluated with exp function over a window of time.''' + + def __init__(self, window): + self.window = window + + def __call__(self, ht, ongoing_play): + discovered_play = ongoing_play.discovered_play() + last = discovered_play[-1] + discovered_play_f = [o for o in discovered_play + if o > last - self.window] + return all_history_eval_exp(ht, play.Playback(discovered_play_f)) + + +def all_history_eval(ht, ongoing_play): + ''' + Evaluates a hypothesis on an ongoing_play. It takes into consideration the + whole history of the playback. + + Complexity: O(|ongoing_play|) + ''' + xs, proj = list(zip(*ht.proj_with_x(ongoing_play))) + conf_sum = sum(conf(xs, proj, ongoing_play.discovered_play(), + ht.d, 1, 0.01, lambda x: abs(x))) + return ((conf_sum / len(proj)) * + (conf_sum / len(ongoing_play.discovered_play()))) + + +class EvalAssembler: + ''' + Assembler for hypothesis evaluation functions. + + Args: + conf_modifiers: list of confs modifiers. + end_modifiers: list of end modifiers. + + Conf Modifiers modify the list of confidence scores per projected beat. + They must be callable with the following signature: + (ht, projected_onsets, discovered_onsets, confidence_scores) -> + (projected_onsets, discovered_onsets, confidence_scores) + + End Modifiers modify the final confidence score, after it was summed + and normalized. + They must be callable with the following signature: + hypothesis_tracker, ongoing_play, confidence_score -> confidence_score + ''' + + def __init__(self, conf_modifiers, end_modifiers, mult=1, decay=5): + self.conf_modifiers = conf_modifiers + self.end_modifiers = end_modifiers + self.mult = mult + self.decay = decay + + def __call__(self, ht, ongoing_play): + try: + xs, proj = list(zip(*ht.proj_with_x(ongoing_play))) + except ValueError as ve: + print(ht, ht.r, ht.d) + raise ve + discovered_onsets = ongoing_play.discovered_play() + confs = conf(xs, proj, discovered_onsets, ht.d, self.mult, self.decay) + for cm in self.conf_modifiers: + proj, discovered_onsets, confs = cm(ht, proj, discovered_onsets, + confs) + + conf_sum = sum(confs) + try: + if len(proj) == 0: + return 0 + + end_conf = ((conf_sum / len(proj)) * + (conf_sum / len(discovered_onsets))) + + for em in self.end_modifiers: + end_conf = em(ht, ongoing_play, end_conf) + + return end_conf + except ZeroDivisionError: + print(('> Zero Division Error with ongoing_play: {} ' + '| ht = {} | proj = {}').format( + ongoing_play.discovered_play(), + ht, proj)) + + +class OnsetRestrictedConfMod: + ''' + Function class for evaluating a hypothesis on a restricted set of the onsets + of the playback. In this class, onsets are restricted to *n* before the + last discovered onset. + ''' + + def __init__(self, prev_onsets_allowed): + ''' + Initializes the function class. + + Args: + prev_onsets_allowed: number of onsets used in the eval before the + last discovered onset. + ''' + self.prev = prev_onsets_allowed + + def __call__(self, ht, proj, discovered_onsets, confs): + starting_idx = max(len(discovered_onsets) - self.prev, 0) + return (proj[starting_idx:], discovered_onsets[starting_idx:], + confs[starting_idx:]) + + +class TimeRestrictedConfMod: + ''' + Function class for evaluating a hypothesis on a restricted time before now + of the playback. + ''' + + def __init__(self, prev_ms_allowed, mult=1, decay=5): + self.prev = prev_ms_allowed + self.mult = mult + self.decay = decay + + def __call__(self, ht, proj, discovered_onsets, confs): + onsets_idx = 0 + while (discovered_onsets[onsets_idx] < + discovered_onsets[-1] - self.prev): + onsets_idx += 1 + + n_discovered_onsets = discovered_onsets[onsets_idx:] + + xs, n_proj = list(zip(*ht.proj_with_x(play.Playback(n_discovered_onsets)))) + n_confs = conf(xs, n_proj, n_discovered_onsets, ht.d, self.mult, + self.decay) + + return (n_proj, n_discovered_onsets, n_confs) + + +class DeltaPriorEndMod: + ''' + Function class for evaluating hypothesis that scores by usign another + eval function for confidence and also multiplies by a prior distribution + over the delta value. + ''' + + MAX_DELTA = 1500 # ms + MIN_DELTA = 187 # ms + DELTA_MU = 600.0 # ms + DELTA_SIGMA = 400 # 250.0 175 # ms + delta_clip_a = (MIN_DELTA - DELTA_MU) / DELTA_SIGMA + delta_clip_b = (MAX_DELTA - DELTA_MU) / DELTA_SIGMA + + def _delta_prior(self, d): + return st.truncnorm.pdf(d, a=self.delta_clip_a, b=self.delta_clip_b, + loc=self.DELTA_MU, scale=self.DELTA_SIGMA) + + def __call__(self, ht, ongoing_play, end_conf): + return self._delta_prior(ht.d) * end_conf + + +class WindowedExpEvalPrior: + + def __init__(self, window): + self.window = window + + def __call__(self, ht, ongoing_play): + win_score = WindowedExpEval(self.window)(ht, ongoing_play) + delta_score = DeltaPriorEndMod()(ht, ongoing_play, win_score) + return delta_score + +conf_all_exp = all_history_eval_exp +conf_all = EvalAssembler([], [], 1, 5) +conf_prev = EvalAssembler([TimeRestrictedConfMod(1000, 1, 5)], []) +conf_all_w_prior = EvalAssembler([], [DeltaPriorEndMod()]) +conf_prev_w_prior = EvalAssembler([TimeRestrictedConfMod(5000)], + [DeltaPriorEndMod()]) + +windowed_conf = WindowedExpEval(6000) + + +try: + from src.music.utilities.handcoded_rep_utilities.tht import povel1985, utils, playback as play + + + class PovelAccentConfMod: + ''' + Function class for evaluating a hypothesis where confidence on each beat + onset is multiplied if the onset is accented according to Povel 1981 rules. + ''' + + def __init__(self, accent_multiplier): + self.multiplier = accent_multiplier + + def __call__(self, ht, proj, discovered_onsets, confs): + accents = set(m2.povel1985.accented_onsets(discovered_onsets)) + accented_confs = [ + c * self.multiplier + for c, o in zip(confs, discovered_onsets) + if o in accents + ] + return proj, discovered_onsets, accented_confs + + conf_accents_prior = EvalAssembler([PovelAccentConfMod(4)], [DeltaPriorEndMod()]) + conf_accents_prev_prior = EvalAssembler( + [PovelAccentConfMod(4), TimeRestrictedConfMod(1000)], [DeltaPriorEndMod()]) +except ImportError as ie: + warnings.warn('Code to measure hypothesis confidence with Povel and Essens' + ' 1985 accent algorithm could not be instanced given that ' + ' module m2.povel1985 is not installed.') + diff --git a/src/music/utilities/handcoded_rep_utilities/tht/correction.py b/src/music/utilities/handcoded_rep_utilities/tht/correction.py new file mode 100644 index 0000000000000000000000000000000000000000..a798f2f1e49e1a3445a25fbc0aee78355e001e76 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/correction.py @@ -0,0 +1,237 @@ +"""Module with correction functions thta given a hypothesis tracker and +a ongoing playback return a HypothesisCorrection class.""" + +import math + +import numpy as np +from src.music.utilities.handcoded_rep_utilities.tht import hypothesis as hs, hypothesis, utils, playback, confidence + +from scipy import stats + + + +def exp_error_conf(error, multiplicator, decay, delta): + return multiplicator * error * (decay ** (np.abs(error) / float(delta))) + + +def gauss_error_conf(error, multiplicator, decay, delta): + return multiplicator * error * confidence.gaussian_weight(decay * error / delta) + + +def error_calc(ht, ongoing_play): + p_w_x = ht.proj_with_x(ongoing_play) + try: + xs, p = list(zip(*p_w_x)) + except ValueError: + print(ht, ongoing_play.onset_times) + + #xs, p, r_p = zip(*utils.centered_real_proj(xs, p, ongoing_play)) + xs, p, r_p = list(zip(*utils.project(xs, p, ongoing_play.discovered_play()))) + + err = np.array(r_p) - np.array(p) + return xs, err, p + + +def proj_error_conf(ht, ongoing_play, mult, decay, err_conf_f): + xs, err, p = error_calc(ht, ongoing_play) + return xs, err_conf_f(np.array(err), mult, decay, ht.d), p + + +class HypothesisCorrection(): + """Structure holding information of each hypothesis correction. + + This class contains information pertaining the correction event. + """ + + def __init__(self, o_rho, o_delta, n_rho, n_delta, + r_value=None, p_value=None, stderr=None, + o_mse=None, n_mse=None, d_rho=None, d_delta=None): + self.o_rho = o_rho + self.o_delta = o_delta + self.n_rho = n_rho + self.n_delta = n_delta + self.r_value = r_value + self.p_value = p_value + self.stderr = stderr + self.o_mse = o_mse + self.n_mse = n_mse + self.d_rho = d_rho if d_rho is not None else n_rho - o_rho + self.d_delta = d_delta if d_delta is not None else n_delta - o_delta + + def new_hypothesis(self): + return hs.Hypothesis(self.n_rho, self.n_delta) + + def __repr__(self): + return '(dr: %.2f, dd: %.2f)' % (self.dr, self.dd) + + @property + def dr(self): + return self.d_rho + + @property + def dd(self): + return self.d_delta + + +class HypothesisCorrectionMethod(object): + """Represents a method to correct a hypothesis. Must be callable + with a hypothesis (HypothesisTracker) and a playback.""" + + def __call__(self, ht, ongoing_play): + raise NotImplementedError() + + +class LinearRegressOverSmoothedErrorCorrection(HypothesisCorrectionMethod): + """Generates a HypothesisCorrection for the hypothesis and the ongoing play. + + Correction is performed using a linear regression with the x values being + those corresponding to the projection of the hypothesis. The y values are + the error of the hypothesis prediction versus the closest onset for each + passed through a smoothing function that tones down outliers. + The linear regression then generates a intercept and slope value that + minimize the mse between x and y. The intercept and slope values are + considered the new rho and delta values of the hypothesis. + + Complexity: O(|ongoing_play|) + """ + + def __init__(self, multiplicator=1.0, decay=0.01): + self.mult = multiplicator + self.decay = decay + + def __call__(self, ht, ongoing_play): + xs, err, p = error_calc(ht, ongoing_play) + conf = exp_error_conf(err, self.mult, self.decay, ht.d) + + (delta_delta, delta_rho, r_value, + p_value, stderr) = stats.linregress(xs, conf) + + return HypothesisCorrection(o_rho=ht.r, o_delta=ht.d, + n_rho=ht.r + delta_rho, + n_delta=ht.d + delta_delta, + r_value=r_value, p_value=p_value, + stderr=stderr) + + +class WindowedCorrection(HypothesisCorrectionMethod): + ''' + Correction function in which only part of the past of the percieved onsets + is taken into account for the correction. + ''' + + def __init__(self, mult, decay, window): + ''' + Args: + mult: double multiplier of the error + decay: double multiplier of the error on the decay part + window: ms before last onsets to check for errors + ''' + self.mult = mult + self.decay = decay + self.window = window + + def __call__(self, ht, ongoing_play): + discovered_onsets = np.array(ongoing_play.discovered_play()) + discovered_onsets = discovered_onsets[discovered_onsets > + discovered_onsets[-1] - + self.window] + sub_pl = playback.Playback(discovered_onsets) + xs, err, p = error_calc(ht, sub_pl) + conf = exp_error_conf(err, self.mult, self.decay, ht.d) + + (delta_delta, delta_rho, r_value, + p_value, stderr) = stats.linregress(xs, conf) + + return HypothesisCorrection(o_rho=ht.r, o_delta=ht.d, + n_rho=ht.r + delta_rho, + n_delta=ht.d + delta_delta, + r_value=r_value, p_value=p_value, + stderr=stderr) + + +# TODO: Mergear los métodos de corrección +class MovingWindowedSmoothCorrection(HypothesisCorrectionMethod): + ''' + Correction function in which the new hypothesis is moved forward to + the last projections. Correction is measured over a limited time before + the current onset. + ''' + + def __init__(self, mult, decay, window): + ''' + Args: + mult: double multiplier of the error + decay: double multiplier of the error on the decay part + window: ms before last onsets to check for errors + ''' + self.mult = mult + self.decay = decay + self.window = window + + def __call__(self, ht, ongoing_play): + discovered_onsets = np.array(ongoing_play.discovered_play()) + discovered_onsets = discovered_onsets[discovered_onsets > + discovered_onsets[-1] - + self.window] + sub_pl = playback.Playback(discovered_onsets) + xs, err, p = error_calc(ht, sub_pl) + conf = gauss_error_conf(err, self.mult, self.decay, ht.d) + + if len(p) > 2: + (delta_delta, delta_rho, r_value, + p_value, stderr) = stats.linregress(xs, conf) + + n_h = hypothesis.Hypothesis(ht.r + delta_rho, ht.d + delta_delta) + + n_p = n_h.proj(sub_pl) + + return HypothesisCorrection(o_rho=ht.r, o_delta=ht.d, + n_rho=n_p[-2], + n_delta=n_p[-1] - n_p[-2], + d_rho=delta_rho, d_delta=delta_delta, + r_value=r_value, p_value=p_value, + stderr=stderr) + + else: + return HypothesisCorrection(o_rho=ht.r, o_delta=ht.d, + n_rho=ht.r, n_delta=ht.d) + + + +class LinRegsOverSmoothedErrorCorrectionWithPeak(HypothesisCorrectionMethod): + + def __init__(self, decay=0.0001): + self.decay = decay + + def __call__(self, ht, ongoing_play): + mult = (-1 * ht.d) / math.log(self.decay) + lin_r_corr = LinearRegressOverSmoothedErrorCorrection(mult, self.decay) + return lin_r_corr(ht, ongoing_play) + + +class MultLinRegsOSEC(LinearRegressOverSmoothedErrorCorrection): + + def __init__(self, mult=1.0, decay=0.02, by=5): + self.by = by + + LinearRegressOverSmoothedErrorCorrection.__init__(self, mult, decay) + + def __call__(self, ht, ongoing_play): + nh = ht + for i in range(self.by): + nc = super(self.__class__, self).__call__(nh, ongoing_play) + nh = nc.new_hypothesis() + return nc + +lin_r_corr = LinearRegressOverSmoothedErrorCorrection() +lin_r_corr_alt = LinearRegressOverSmoothedErrorCorrection(1, 0.001) +lin_r_corr_max = LinRegsOverSmoothedErrorCorrectionWithPeak() +lin_r_corr_max_descent = LinRegsOverSmoothedErrorCorrectionWithPeak(0.001) +lin_r_corr_opt_by_5 = MultLinRegsOSEC(2, 0.0001, 5) +lin_r_corr_opt = LinearRegressOverSmoothedErrorCorrection(2, .0001) +windowed_corr = WindowedCorrection(2, 0.0001, 6000) + + +def no_corr(ht, ongoing_play): + 'Correction function that performs no correction' + return HypothesisCorrection(ht.r, ht.d, ht.r, ht.d) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/defaults.py b/src/music/utilities/handcoded_rep_utilities/tht/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d678337432bbcced7d525d0cfa9be43b712f59 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/defaults.py @@ -0,0 +1,21 @@ +'''Default configuration for THT''' + +from src.music.utilities.handcoded_rep_utilities.tht import similarity, confidence, correction + +eval_f = confidence.windowed_conf +corr_f = correction.windowed_corr +sim_f = similarity.min_dist_sim +similarity_epsilon = 0.005 +max_delta = (60000.0 / 40) # 40 bpm +min_delta = (60000.0 / 320) # 320 bpm +max_hypotheses = 30 + +config = { + 'eval_f': eval_f, + 'corr_f': corr_f, + 'sim_f': sim_f, + 'similarity_epsilon': similarity_epsilon, + 'max_delta': max_delta, + 'min_delta': min_delta, + 'max_hypotheses': max_hypotheses +} diff --git a/src/music/utilities/handcoded_rep_utilities/tht/hypothesis.py b/src/music/utilities/handcoded_rep_utilities/tht/hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea663a2ca44b1495d81b920879ecd09f33b596e --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/hypothesis.py @@ -0,0 +1,93 @@ +"""This module contains classes representing tactus hypothesis.""" + +import math + +import numpy as np + + +class Hypothesis(object): + """Represents a hypothesis (rho, delta) and contains several + convenience methods.""" + + def __init__(self, rho, delta): + self.htuple = (rho, delta) + + @property + def name(self): + return self.__repr__() + + @property + def r(self): + return self.htuple[0] + + @property + def d(self): + return self.htuple[1] + + @property + def cur(self): + return self.htuple + + def bpm(self): + return 60000.0 / self.d + + def proj_with_x_in_range(self, min, max): + min_x, max_x = self.proj_x_range(min, max) + return ((x, self.r + self.d * x) for x in range(min_x, max_x+1)) + + def proj_with_x(self, play): + return self.proj_with_x_in_range(play.min, play.max) + + def proj_in_range(self, min, max): + return np.array([v[1] for v in self.proj_with_x_in_range(min, max)]) + + def proj(self, play): + return np.array([v[1] for v in self.proj_with_x(play)]) + + def proj_x_range(self, min, max): + min_x = int(math.ceil((min - self.d / 2.0 - self.r) / self.d)) + max_x = int(math.floor((max + self.d / 2.0 - self.r) / self.d)) + return min_x, max_x + + def __getitem__(self, key): + if key == 0: + return self.r + if key == 1: + return self.d + else: + return object.__getitem__(self, key) + + def __setitem__(self, key, value): + if key == 0: + self.htuple = (value, self.d) + if key == 1: + self.htuple = (self.r, value) + else: + return object.__setitem__(self, key, value) + + def __repr__(self): + return 'H(%.2f, %.2f)' % self.htuple + + def __lt__(self, other): + return self.r < other.r or (self.r == other.r and self.d < other.d) + + +class HypothesisFromIndex(Hypothesis): + """Represents a hypothesis created from index on onset times. + Name is represented from the onset numbers, rather than the onset times. + + onset_times must be in milliseconds.""" + + def __init__(self, start_idx, end_idx, onset_times): + start_offset = onset_times[start_idx] + end_offset = onset_times[end_idx] + Hypothesis.__init__(self, start_offset, end_offset - start_offset) + self._name = '%d-%d' % (start_idx, end_idx) + self.onset_indexes = (start_idx, end_idx) + + @property + def name(self): + return self._name + + def __repr__(self): + return 'Hi:%s' % (self.name) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/playback.py b/src/music/utilities/handcoded_rep_utilities/tht/playback.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7cb5a93d1bb3f020b64f3e237e584dbff61b26 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/playback.py @@ -0,0 +1,72 @@ +"""Module containing classes that represents playbacks. A playback is an +enhanced container for a set of onset events (onset times).""" + +import numpy as np + + +class Playback(): + """Represents the entire playback of a song. + + Has the same interface as OngoingPlayback except for the discovering + methods. + """ + + def __init__(self, onset_times): + self.onset_times = onset_times + + @property + def min(self): + 'First onset' + return self.onset_times[0] + + @property + def max(self): + 'Last onset' + return self.onset_times[-1] + + def discovered_play(self): + 'Onsets discovered at the moment' + return self.onset_times + + +class OngoingPlayback(Playback): + """Represents a playback that is discovered onset by onset. + + This class is used to manage the discovery process of a song, by exposing + only a chuck of the song, adding one more onset to what's been discovered + at a time. + + Interal Variables + onset_times: numpy array of all milliseconds with events in order + up_to_discovered_index: index up to which all events were discovered + (not inclusive) + """ + + def __init__(self, onset_times): + self.onset_times = np.array(onset_times) + self.up_to_discovered_index = 1 + + def advance(self): + 'Discover a new onset' + if (self.up_to_discovered_index < len(self.onset_times)): + self.up_to_discovered_index += 1 + return True + return False + + @property + def discovered_index(self): + 'Returns the index of the last discovered onset' + return self.up_to_discovered_index - 1 + + @property + def max(self): + 'Last onset discovered. None if no onset has been discovered yet' + return self.onset_times[self.discovered_index] + + @property + def discovered_onset(self): + 'Last onset discovered. Same as max.' + return self.max + + def discovered_play(self): + return self.onset_times[:self.up_to_discovered_index] diff --git a/src/music/utilities/handcoded_rep_utilities/tht/povel1985.py b/src/music/utilities/handcoded_rep_utilities/tht/povel1985.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7f99c539a6d9e77dc29ebf7734d0c08285ad71 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/povel1985.py @@ -0,0 +1,115 @@ +import numpy as np + + +def cluster_onsets(onsets): + '''Returns sorted clusters of onsets. + Does not asume onsets are cyclical. + Args: + onsets: list of ms + Returns: + List of ints where each int represents how many onsets should be in + that cluster. The sum of the return value should be equal to the + length of the 'onsets' list. + ''' + if len(onsets) <= 2: + return [len(onsets)] + + max_clustering_dur = 450 # ms + + iois = np.diff(onsets) + iois_set = list(set(iois)) + median_ioi = sorted(iois_set)[len(iois_set) // 2] + clusters = [1] + for idx in range(1, len(onsets) - 1): + prev_dur = iois[idx - 1] + next_dur = iois[idx] + min_dur = min(prev_dur, next_dur) + if (min_dur == prev_dur and min_dur < median_ioi and + min_dur < max_clustering_dur): + clusters[-1] += 1 + else: + clusters.append(1) + + if next_dur < median_ioi and next_dur < max_clustering_dur: + clusters[-1] += 1 + else: + clusters.append(1) + + return clusters + + +def accented_onsets(onsets): + 'Returns subset of onsets that are rhythmically accented (Povel 1985)' + clusters = cluster_onsets(onsets) + accented = [] + it = iter(onsets) + for cluster in clusters: + if cluster == 1: + accented.append(next(it)) + elif cluster == 2: + next(it) + accented.append(next(it)) + else: + accented.append(next(it)) + cluster -= 2 + while cluster != 0: + next(it) + cluster -= 1 + accented.append(next(it)) + return accented + + +def hypothesis_counterevidence(onsets, hypothesis, phrase_length, W=4): + ''' + Calculates the counter evidence score in Povel 1985. + Args: + onsets: list of time onsets + hypothesis: (phase, period) tuple. Period and onset times should be + multiples of a same base time step + phrase_length: length of a phrase in onsets in base time step units + W: weight of -ev counterevidence (see paper) + Returns: + int with counterevidence score + ''' + accents = accented_onsets(onsets) + counterevidence = 0 + projection = hypothesis[0] + total_length = np.ceil(onsets[-1] / float(phrase_length)) * phrase_length + while projection < total_length: + if projection not in accents: + if projection not in onsets: + counterevidence += W + else: + counterevidence += 1 + projection += hypothesis[1] + return counterevidence + + +def best_clock(onsets, base_time_step, phrase_length): + ''' + Returns the best clock (phase, period) for the onset sequence using an + implementation of Povel 1985's model. + Args: + onsets: list of onset times (in ms or multiple of base time step) + base_time_step: length in milliseconds of base timestep or 1 if + onsets are not milliseconds + phrase_length: length in base steps of a phrase + Returns: + ((phase, period), counterevidence) + ''' + hypothesis_space = [ + (phase, period) + for period in np.arange(1, phrase_length / 2) * base_time_step + for phase in np.arange(0, period) + if phrase_length % period == 0 + ] + hypothesis_space_w_score = [(hypothesis, + hypothesis_counterevidence(onsets, hypothesis, + phrase_length)) + for hypothesis in hypothesis_space] + sorted_hs = sorted(hypothesis_space_w_score, key=lambda x: x[1]) + return sorted_hs[0] + + +def cv_to_category(counterevidence): + return counterevidence + 1 \ No newline at end of file diff --git a/src/music/utilities/handcoded_rep_utilities/tht/similarity.py b/src/music/utilities/handcoded_rep_utilities/tht/similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..c7567acb19398a5995260230f01e104c637a4a77 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/similarity.py @@ -0,0 +1,36 @@ +"""Module containing functions to measure similarity between two hypothesis +trackers with respect to a ongoing playback.""" + +from src.music.utilities.handcoded_rep_utilities.tht import playback, confidence + + +def proj_conf_sim(h, i, ongoing_play): + """Evaluates the similarity between two hypothesis measuring the confidence + of one on another.""" + proj = playback.Playback(i.proj(ongoing_play)) + return confidence.all_history_eval(h, proj) + + +def id_sim(h, i, ongoing_play): + """Two hypothesis are similar if they have the same delta and equivalent + phase. + """ + return int(h.d == i.d and ((h.r - i.r) / float(i.d)) % 1 == 0) + + +def min_dist_sim(h, i, *args): + """ + Similarity index comes from relative similarity at their closest point. + + Asumes i is a newer hypothesis than h. + + For how dR is calculated, see https://goo.gl/photos/pSQ6gkvgPkn2D4rm9 + """ + assert (i.r > h.r or (i.r == h.r and i.d > h.d), + 'i (%s) is not newer than h (%s)') + D = abs(h.d - i.d) + dD = D / max(h.d, i.d) + R = abs(i.r - h.r) % h.d + A = h.d / 2 + dR = (A - abs(R - A)) / A + return 1 - max(dD, dR) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tactus_hypothesis_tracker.py b/src/music/utilities/handcoded_rep_utilities/tht/tactus_hypothesis_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7070a398380174913b71ee17c864b986ff282947 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tactus_hypothesis_tracker.py @@ -0,0 +1,211 @@ +"""This module contains the TactusTrackersGenerator class that can be configured +to generate complete hypothesis trackers for the playback of a case.""" + +from . import hypothesis, playback, defaults, confidence +from .correction import HypothesisCorrection, windowed_corr +import collections +import logging +from typing import * + +Rho = NewType('Rho', float) +Delta = NewType('Delta', float) +OnsetIdx = NewType('OnsetIdx', int) +Conf = NewType('Conf', float) + +class HypothesisTracker(hypothesis.HypothesisFromIndex): + """Class that holds information of the hypothesis evolution. + + A hypothesis is defined as a rho and a delta values, where all tactus + predictions are described as: rho + delta * k, for some integer k. + + The 'name' of the hypothesis is given by the two onset indexes used + to originate the hypothesis. The 'beta' value is the first hypothesis. + 'corr' contains a list of Correction objects with information about + each correction performed over the hypothesis. 'cur' is the current value + of the hypothesis. 'confs' contains the evolution of the confidence for + the hypothesis. + + The tracker also contains some convenience methods to work with the + current hypothesis. The 'r' property gives as the current rho value, + the 'd' property the current 'delta'. The 'proj' generates all + tactus predictions by the hypothesis within range of a playback. + + The 'update' method allows us to correct the current hypothesis with + a correction function and to update the confence status with a + confidence function. + """ + beta: Tuple[Rho, Delta] + oonset_times: List[float] + corr: List[Tuple[OnsetIdx, HypothesisCorrection]] + confs: List[Tuple[OnsetIdx, float]] + + def __init__(self, start_idx, end_idx, onset_times): + super(self.__class__, self).__init__(start_idx, end_idx, onset_times) + self.beta = self.htuple + self.onset_times = onset_times + self.corr = [] # [(onset_idx, hypothesis_correction)] + self.confs = [] # [(onset_idx, conf_value)] + + def update(self, ongoing_play, eval_f, corr_f): + "Updates a hypothesis with new conf and applying corrections." + correction = corr_f(self, ongoing_play) + self.corr.append((ongoing_play.discovered_index, correction)) + self.htuple = correction.new_hypothesis() + n_conf = eval_f(self, ongoing_play) + self.confs.append((ongoing_play.discovered_index, n_conf)) + + @property + def cur(self): + return self.htuple + + @property + def conf(self): + return self.confs[-1][1] + + def origin_onsets(self): + return (self.beta[0], sum(self.beta)) + + +class TactusHypothesisTracker(): + """Configurable class to generate hypothesis trackers for a case. + + Configuration includes: + * an eval function that defines how to evaluate a hypothesis over + certain Playback + * a correction functions that produces a HypothesisCorrection for a + hypothesis over a Playback + * a similarity function that defines how similar are two hypothesis + * a similarity_epsilon that defines the threshold for trimming + * a maximun amount of hypothesis trackers to be kept. Only hypotheses + best confidence are kept. + + When called on a set of onset_times it will return the hypothesis trackers + generated by the model. + """ + + logger = logging.getLogger('TactusHypothesisTracker') + + def __init__(self, eval_f, corr_f, sim_f, similarity_epsilon, + min_delta, max_delta, max_hypotheses, + archive_hypotheses=False): + self.eval_f = eval_f + self.corr_f = corr_f + self.sim_f = sim_f + self.similarity_epsilon = similarity_epsilon + self.min_delta = min_delta + self.max_delta = max_delta + self.max_hypotheses = max_hypotheses + self.archive_hypotheses = archive_hypotheses + + def __call__(self, onset_times): + """ + Performs the tracking of tactus hypothesis as defined by the model from + the song represented by the received onset_times. + + Args: + onset_times: a sorted list of ms where the musical events occur. + + Returns: + A dict :: hypothesis_name -> HypothesisTracker + """ + self.logger.debug('Started tracking for onsets (%d) : %s', + len(onset_times), onset_times) + ongoing_play = playback.OngoingPlayback(onset_times) + hypothesis_trackers = [] + archived_hypotheses = [] + while ongoing_play.advance(): + n_hts = list(self._generate_new_hypothesis(ongoing_play)) + self.logger.debug('New step. %d hypothesis created', len(n_hts)) + + hypothesis_trackers.extend(n_hts) + + for h in hypothesis_trackers: + h.update(ongoing_play, self.eval_f, self.corr_f) + + kept_hs, trimmed_hs = self._trim_similar_hypotheses( + hypothesis_trackers, ongoing_play) + self.logger.debug('Trimmed by similarity (%d): %s', + ongoing_play.discovered_index, + str([str(h) for h in trimmed_hs])) + + k_best_hs, other_hs = self._split_k_best_hypotheses(kept_hs) + self.logger.debug('Trimmed by score (%d): %s', + ongoing_play.discovered_index, + str([str(h) for h in other_hs])) + hypothesis_trackers = k_best_hs + if (self.archive_hypotheses): + archived_hypotheses.extend(other_hs) + self.logger.debug('End of step. %d trackers remaining', + len(hypothesis_trackers)) + + return dict([(ht.name, ht) + for ht in archived_hypotheses + hypothesis_trackers]) + + def _generate_new_hypothesis(self, ongoing_play): + "Generates new hypothesis trackers given discovered onset in playback." + end_index = ongoing_play.discovered_index + for k in range(end_index): + delta = (ongoing_play.onset_times[end_index] - + ongoing_play.onset_times[k]) + if self.min_delta <= delta and delta <= self.max_delta: + yield HypothesisTracker(k, end_index, + ongoing_play.onset_times) + + def _trim_similar_hypotheses(self, hts, ongoing_play): + """Partitions new hypothesis into those that should be trimmed given + a set of comparsion hypothesis. + + Assumes hypothesis trackers are sorted by when they were generated in + hts. + """ + trimmed_hs_data = [] + kept_hs = [] + remaining_hts = collections.deque(hts) + while remaining_hts: + ht = remaining_hts.popleft() + n_remaining_hts = collections.deque() + kept_hs.append(ht) + while remaining_hts: + n_ht = remaining_hts.popleft() + s = self.sim_f(ht, n_ht, ongoing_play) + if s > (1 - self.similarity_epsilon): + trimmed_hs_data.append((n_ht, ht)) + else: + n_remaining_hts.append(n_ht) + + remaining_hts = n_remaining_hts + + return (kept_hs, trimmed_hs_data) + + def _split_k_best_hypotheses(self, hts): + """Splits hypotheses into the self.max_hypotheses best + (according to confidence) and the rest. + + Both result list will be sorted in order of generation.""" + hts_info = [(-1 * ht.conf, idx) for idx, ht in enumerate(hts)] + sorted_hts_info = sorted(hts_info) + best_hts_idx = set([ + i for _, i in sorted_hts_info[:self.max_hypotheses]]) + best_k_hts = [ht for idx, ht in enumerate(hts) + if idx in best_hts_idx] + other_hts = [ht for idx, ht in enumerate(hts) + if idx not in best_hts_idx] + return best_k_hts, other_hts + + +def default_tht(**kwargs): + '''Returns a TactusHypothesisTracker with the default configuration. + + Default config may be overriden witih kwargs. See defaults.config + ''' + config = defaults.config.copy() + config.update(kwargs) + return TactusHypothesisTracker(**config) + + +jnmr_tht = default_tht( + **{ + 'eval_f': confidence.WindowedExpEval(6000), + 'corr_f': windowed_corr + } +) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tests/__init__.py b/src/music/utilities/handcoded_rep_utilities/tht/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tests/tactus_hypothesis_tracker_test.py b/src/music/utilities/handcoded_rep_utilities/tht/tests/tactus_hypothesis_tracker_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2658fde351ead518736d61c33a909ace07844ea2 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tests/tactus_hypothesis_tracker_test.py @@ -0,0 +1,155 @@ +import unittest +import pytest +import collections + +from m2.tht import tactus_hypothesis_tracker +from m2.tht import correction + + +class TrimSimHypothesisTest(unittest.TestCase): + + def setUp(self): + pass + + class sim_f_mock(): + + def __init__(self): + self.calls = [] + + def __call__(self, h, i, _): + self.calls.append((h, i)) + return int(i % h == 0) + + def test_trim_sim_hypothesis(self): + sim_f = self.sim_f_mock() + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + None, None, sim_f, 0.00001, None, None, None) + kept, trimmed = tht._trim_similar_hypotheses(range(2, 10), sim_f) + self.assertEqual(sim_f.calls, + [(2, x) for x in range(3, 10)] + + [(3, 5), (3, 7), (3, 9)] + + [(5, 7)]) + self.assertEqual(kept, [2, 3, 5, 7]) + self.assertEqual(trimmed, + [(4, 2), (6, 2), (8, 2)] + + [(9, 3)]) + + def test_k_best_hypothesis(self): + HT = collections.namedtuple('HT', ['id', 'conf']) + hts = [HT(idx, 7 if idx % 3 == 0 else idx) + for idx in range(11)] + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + None, None, None, None, None, None, 5) + k_best, other = tht._split_k_best_hypotheses(hts) + self.assertEqual(k_best, [hts[0], hts[3], hts[6], hts[8], hts[10]]) + self.assertEqual(other, [hts[1], hts[2], hts[4], + hts[5], hts[7], hts[9]]) + +def id_corr_f(ht, op): + return correction.HypothesisCorrection(ht.r, ht.d, ht.r, ht.d) + +class TestHypothesisGeneration: + + def test_generated_hypothesis_with_no_restrictions(self, mocker): + onset_times = list(range(10)) + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=id_corr_f, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=1000, + min_delta=1, + max_hypotheses=1000) + hts = tht(onset_times) + assert len(hts) == (10 * 9) / 2 + + + def test_generated_hypothesis_with_max_hypothesis_restriction(self, + mocker): + onset_times = list(range(10)) + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=id_corr_f, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=1000, + min_delta=1, + max_hypotheses=10) + hts = tht(onset_times) + assert len(hts) == 10 + + def test_generated_hypothesis_with_max_delta_restriction(self, mocker): + onset_times = list(range(10)) + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=id_corr_f, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=1, + min_delta=1, + max_hypotheses=1000) + hts = tht(onset_times) + assert len(hts) == 9 + + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=id_corr_f, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=2, + min_delta=1, + max_hypotheses=1000) + hts = tht(onset_times) + assert len(hts) == 9 + 8 + + def test_generated_hypothesis_with_min_delta_restriction(self, mocker): + onset_times = list(range(10)) + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=id_corr_f, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=3, + min_delta=3, + max_hypotheses=1000) + hts = tht(onset_times) + assert len(hts) == 7 + + +onset_times = list(range(10)) +proj_1 = lambda xs: [x[0] for x in xs] + +@pytest.fixture +def hts(mocker): + tht = tactus_hypothesis_tracker.TactusHypothesisTracker( + eval_f=lambda ht, op: ht.r, + corr_f=correction.no_corr, + sim_f=lambda h, i, *a: False, + similarity_epsilon=0, + max_delta=4, + min_delta=2, + max_hypotheses=50) + hts = tht(onset_times) + return hts + + +class TestGeneralTrackingResults: + + + def test_deltas_in_range(self, hts): + print([ht.d for ht in hts.values()]) + assert all([ht.d >= 2 and ht.d <= 4 + for ht in hts.values() + ]) + + def test_non_repeated_hypothesis(self, hts): + assert (len(set([ht.origin_onsets() for ht in hts.values()])) == + len(hts)) + + def test_conf_and_corr_onset_index_are_equal(self, hts): + assert all([proj_1(ht.corr) == proj_1(ht.confs) + for ht in hts.values()]) + + def test_conf_onsets_are_complete_and_greater_than_beta_2(self, hts): + assert all([proj_1(ht.confs) == list(range(ht.onset_indexes[1], 10)) + for ht in hts.values()]) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tests/tracking_overtime_test.py b/src/music/utilities/handcoded_rep_utilities/tht/tests/tracking_overtime_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a61b155608a492a125a13695222f3e73cc537e17 --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tests/tracking_overtime_test.py @@ -0,0 +1,85 @@ +import pytest +import addict +from m2.tht import tracking_overtime + + +@pytest.fixture +def basic_hts_mock(mocker): + '''A mock handcoded_rep_utilities result with two HypothesisTracker''' + m = mocker + + onset_times = [0, 100, 200, 300] + + h1 = m.MagicMock() + h1.confs = list(zip(range(1, 4), [1, 1, 4])) + h1.corr = list(zip(range(1, 4), [m.MagicMock() for _ in range(3)])) + h1.onset_times = onset_times + h1.__repr__ = m.Mock(return_value='h1') + + h2 = m.MagicMock() + h2.confs = list(zip(range(2, 4), [2, 3])) + h2.corr = list(zip(range(2, 4), [m.MagicMock() for _ in range(1, 3)])) + h2.onset_times = onset_times + h2.__repr__ = m.Mock(return_value='h2') + + hts = {'h1': h1, 'h2': h2} + return addict.Dict({ + 'hts': hts, + 'h1': h1, + 'h2': h2, + 'onset_times': onset_times + }) + + +def matchesHypothesisAtTime(hts=None, + onset_idx=None, corr=None, + ht_value=None, conf=None): + def match(hat): + if hts: + hat.hts == hts + if onset_idx: + assert hat.onset_idx == onset_idx + if corr: + assert hat.corr == corr + if ht_value: + assert ht_value == ht_value + if conf: + assert conf == conf + return True + return match + + +def equalsToMatchers(hats, matchers): + assert len(hats) == len(matchers) + assert all([m(h) for m, h in zip(matchers, hats)]) + + +def test_overtime_tracking_init(basic_hts_mock): + b = basic_hts_mock + hts_at_time = tracking_overtime.OvertimeTracking(b.hts) + assert hts_at_time.onset_times == b.onset_times + assert sorted(hts_at_time.time.keys()) == b.onset_times[1:] + hts_at_sorted_time = list(hts_at_time.hypothesis_by_time()) + print(hts_at_sorted_time) + equalsToMatchers(hts_at_sorted_time[0][1], + [matchesHypothesisAtTime(hts=b.h1, onset_idx=1, conf=1)]) + equalsToMatchers(hts_at_sorted_time[1][1], + [matchesHypothesisAtTime(hts=b.h1, onset_idx=2, conf=1), + matchesHypothesisAtTime(hts=b.h2, onset_idx=2, conf=2)]) + equalsToMatchers(hts_at_sorted_time[2][1], + [matchesHypothesisAtTime(hts=b.h1, onset_idx=3, conf=4), + matchesHypothesisAtTime(hts=b.h2, onset_idx=3, conf=3)]) + + +def test_conf_sorted_hats(basic_hts_mock): + b = basic_hts_mock + hts_at_time = tracking_overtime.OvertimeTracking(b.hts) + hts_at_sorted_time = list(hts_at_time.hypothesis_sorted_by_conf()) + equalsToMatchers(hts_at_sorted_time[0][1], + [matchesHypothesisAtTime(hts=b.h1, onset_idx=1, conf=1)]) + equalsToMatchers(hts_at_sorted_time[1][1], + [matchesHypothesisAtTime(hts=b.h2, onset_idx=2, conf=2), + matchesHypothesisAtTime(hts=b.h1, onset_idx=2, conf=1)]) + equalsToMatchers(hts_at_sorted_time[2][1], + [matchesHypothesisAtTime(hts=b.h1, onset_idx=3, conf=4), + matchesHypothesisAtTime(hts=b.h2, onset_idx=3, conf=3)]) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tests/utils_test.py b/src/music/utilities/handcoded_rep_utilities/tht/tests/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d7c74c6e655d3279441a3f2726849cfe41978c --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tests/utils_test.py @@ -0,0 +1,18 @@ +import unittest +import mock + +from m2.tht import utils + +class RealProjTest(unittest.TestCase): + + def setUp(self): + pass + + def test_real_proj(self): + matched = mock.MagicMock() + matched.discovered_play = mock.MagicMock(return_value=[1, 2, 3, 4, 5]) + to_match = [-2, 2.2, 2.3, 2.5, 4, 4.5, 6, 7] + xs = range(len(to_match)) + expected = [1, 2, 2, 2, 4, 4, 5] + _, _, result = zip(*utils.real_proj(xs, to_match, matched)) + self.assertEqual(list(result), expected) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tracker_analysis.py b/src/music/utilities/handcoded_rep_utilities/tht/tracker_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..c18e61bad2b49fa7d8d803a5ee30667a128d7bca --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tracker_analysis.py @@ -0,0 +1,360 @@ +'''This module contains a class with methods to perform analysis of the tactus +phase.''' + +from typing import Dict, List, Tuple, Union, Optional +from . import tactus_hypothesis_tracker, hypothesis, playback, defaults as tht_defaults +import numpy as np +from .tactus_hypothesis_tracker import HypothesisTracker +from scipy.stats import norm +import pandas as pd +import pickle + +Delta = float +Rho = float +Conf = float + +class TactusCaseAnalyzer: + + def __init__(self): + pass + + def top_hypothesis(self, case): + 'Given a case, returns a list of top tactus hypothesis' + def _sorting_key_gen(): + for i in range(3, len(case['onset_times'])): + def key(item): # item :: (ht_name, conf_dict) + return item[1][i] if i in item[1] else 0.0 + yield key + + def _top_hypothesis_iter(): + hts = case['hypothesis_trackers'] + hts_iterators = [(ht, dict(ht.confs)) + for ht in list(hts.values())] + for key in _sorting_key_gen(): + yield max(hts_iterators, key=key)[0] + return list(_top_hypothesis_iter()) + + +def sorting_key_gen(onset_idx): + 'Returns the confidence sorting key for hypothesis items at onset index' + def key(item): # item :: (ht_name, conf_dict) + return item[1][onset_idx] if onset_idx in item[1] else None + return key + + +def hypothesis_ranks_overtime(hypothesis_trackers, playback_length): + """Returns a structure to evaluate hypothesis rank over time. + + Args: + hypothesis_trackers: dictionary of hypothesis_name -> HypothesisTracker + playback_length: total amount of onsets considered of the playback + + Returns: + dict(onset_idx -> hypothesis_ranking) as list + with + hypothesis_ranking :: dict(ranking -> + (hypothesis_tracker, abs_confidence_at_onset_idx)) + as list + """ + results = [] + for i in range(playback_length): # Filtering confidence values + sort_key = sorting_key_gen(i) + enhanced_trackers = [(item[0], sort_key(item)) + for item in [(t, dict(t.confs)) for t + in list(hypothesis_trackers.values())] + if sort_key(item) is not None] + sorted_trackers = sorted(enhanced_trackers, key=lambda x: x[1], + reverse=True) + results.append((i, sorted_trackers)) + + return results + + +def create_trackers_segments(hypothesis_ranks_overtime, trackers_to_show): + 'Creates segments: tracker -> [(onset_times: [], conf: [])]' + trackers_segments = {} + for idx, hypothesis_ranking in enumerate(hypothesis_ranks_overtime): + for t, t_conf in hypothesis_ranking[:trackers_to_show]: + if t not in trackers_segments: + trackers_segments[t] = [] + last_segment = None + if (trackers_segments[t] and + trackers_segments[t][-1][0][-1] == idx - 1): + last_segment = trackers_segments[t][-1] + if last_segment: + last_segment[0].append(idx) + last_segment[1].append(t_conf) + else: + last_segment = ([idx], [t_conf]) + trackers_segments[t].append(last_segment) + + return trackers_segments + + +def tracker_dump(tracker, stream): + print('ht name', tracker.name, file=stream) + print('ht beta %f %f' % tracker.beta, file=stream) + for n, corr in tracker.corr: + print('ht corr %d %f %f' % (n, corr.n_rho, corr.n_delta), file=stream) + for n, conf in tracker.confs: + print('ht conf %d %f' % (n, conf), file=stream) + + +def top_hypothesis(hts, onset_times_count): + ''' + Given a case, returns a list of top tactus hypothesis + + Returns: + :: [(onset_idx :: int, ht :: HypothesisTracker)] + ''' + def _sorting_key_gen(): + for i in range(3, onset_times_count): + def key(item): # item :: (ht_name, conf_dict) + return item[1][i] if i in item[1] else None + yield (i, key) + + def _top_hypothesis_iter(): + hts_iterators = [(ht, dict(ht.confs)) + for ht in list(hts.values())] + for idx, key in _sorting_key_gen(): + f_list = [item for item in hts_iterators if key(item) is not None] + if len(f_list) == 0: + continue + yield (idx, max(f_list, key=key)[0]) + return list(_top_hypothesis_iter()) + + +def produce_beats_information(onset_times, top_hts, adapt_period=False, + max_delta_bpm=160, adapt_phase=None, + avoid_quickturns=None): + ''' + Runs through trackers and onset times, generating beats by projecting each + top hypothesis correction at that onset time on the interval between said + onset time and the next one. + + Args: + onset_times :: [ms] + top_hts :: [(onset_idx, HypothesisTracker)] + adapt_period :: Bool - Whether top hypothesis projection period should + be adapted to be slow enough (see 'max_delta_bpm') + max_delta_bpm :: bpm - max projection delta (in bpm) + adapt_phase :: conf_function - function used to evaluate possible phase + values. None is phase should not be + adapted. + avoid_quickturns :: ms - avoid switching hypothesis if its not the new + top hypothesis for longer than + 'avoid_quickturns' + + Returns: + :: [ms] + ''' + top_onset_idxs = [onset_idx for onset_idx, _ in top_hts] + onset_idxs = [0] + top_onset_idxs[1:] + [top_onset_idxs[-1]] + onset_limits_idx = [(onset_idxs[i], onset_idxs[i+1]) + for i in range(0, len(onset_idxs) - 1)] + onset_limits = [(onset_times[l], onset_times[r]) + for l, r in onset_limits_idx] + assert len(onset_limits) == len(top_hts) + + ret = [] + last_ht = None + suggested_change_ht = None + suggested_change_time = None + phase_corr = 0 + for idx in range(len(onset_limits)): + onset_idx, top_ht = top_hts[idx] + left_limit, right_limit = onset_limits[idx] + iht = dict(top_ht.corr)[onset_idx].new_hypothesis() + if avoid_quickturns != None: + if last_ht == None: + last_ht = top_ht + elif (top_ht.origin_onsets() != last_ht.origin_onsets()): + current_time = onset_times[onset_idx] + if (suggested_change_ht == None or top_ht.origin_onsets() != suggested_change_ht.origin_onsets()): + suggested_change_ht = top_ht + suggested_change_time = current_time + iht = dict(last_ht.corr)[onset_idx].new_hypothesis() + elif (suggested_change_ht.origin_onsets() == top_ht.origin_onsets() and + current_time - suggested_change_time < avoid_quickturns): + iht = dict(last_ht.corr)[onset_idx].new_hypothesis() + else: + last_ht = top_ht + + if adapt_period: + d = iht.d + divisions = 0 + while (60000 / d) > max_delta_bpm: + divisions += 1 + d = d * 2 + if (adapt_phase is not None and + (last_ht is None or + last_ht.origin_onsets() != top_ht.origin_onsets())): + possible_k = list(range(2 ** divisions)) + phase_corr = max( + possible_k, + key=lambda k: adapt_phase( + hypothesis.Hypothesis(iht.r + iht.d * k, d), + playback.Playback(onset_times[:onset_idx])) + ) + + r = iht.r + iht.d * phase_corr + iht = hypothesis.Hypothesis(r, d) + beats = np.array(iht.proj_in_range(left_limit, right_limit)) + for beat in beats[1:]: + ret.append(beat) + return ret + + +def track_beats(onset_times, tracker=tactus_hypothesis_tracker.default_tht()): + '''Generates tracked beats from onset_times by projecting to hypothesis + during tracking.''' + hts = tracker(onset_times) + + top_hts = top_hypothesis(hts, len(onset_times)) + + beats = produce_beats_information(onset_times, top_hts) + + return beats + + +def ht_grid(min_delta=tht_defaults.min_delta, + max_delta=tht_defaults.max_delta, + delta_sample_num=60, rho_sample_num=20): + ''' + Delta and rho grid for distribution sampling. + + Return: + delta_values: List + rho_values: List + ''' + delta_values = np.linspace(min_delta, + max_delta, + num=delta_sample_num) + rho_values = np.linspace(0, 1, num=rho_sample_num) + + return delta_values, rho_values + + +def ht_weighted_distribution(points, delta_samples, rho_samples, + delta_sigma=25, rho_sigma=0.1): + ''' + Calculates the probability of tactus hypothesis given 'points'. + + Calculates P(H | D) where H is a rho, delta tactus hypothesis and D is a + set of rho, delta, confidence points. + + P(r, d | t_i) prop= sum_{i} t_i.c * norm.pdf((t_i.r, t_i.d), mu=(r, d), + sigma=(rho_sigma, delta_sigma)) + + Args: + points: List[Tuple[Delta, Rho, Weight]] + delta_samples: delta values on which to calculate the distribution + rho_samples: rho values on which to calculate the distribution + delta_sigma: sigma used to weight the points relative to the sample + rho_sigma: sigma used to weight the points relative to the sample + + Return: + DataFrame with columns rho, delta, weight where rho and delta are the + cross product of 'delta_samples' and 'rho_samples'. + ''' + def weighted_sum(delta, rho, confs: List[Tuple[Delta, Rho, Conf]]): + # Asumes P(delta | D) and P(rho | D) independent + r_weight = norm.pdf([x[1] for x in confs], loc=rho, scale=rho_sigma) + d_weight = norm.pdf([x[0] for x in confs], loc=delta, + scale=delta_sigma) + cs = np.array([x[2] for x in confs]) + return (cs * r_weight * d_weight).sum() + #return sum(( + # norm.pdf(_d, loc=delta, scale=delta_sigma) * + # norm.pdf(_r, loc=rho, scale=rho_sigma) * c + # for _d, _r, c in confs)) + + hist2d = np.array([ + (d, r, weighted_sum(d, r, points)) + for d in delta_samples + for r in rho_samples + ]) + print('#d: {}, #r: {}, #p {}'.format(len(delta_samples), + len(rho_samples), + len(points))) + hist2d[:, 2] = hist2d[:, 2] / hist2d[:, 2].sum() + return pd.DataFrame(hist2d, columns=('delta', 'rho', 'weight')) + + +def tht_ht_points(hts: Dict[str, HypothesisTracker]): + ''' + Extracts delta and rho points from HypothesisTracker set. + + Returns: + List[Tuple[Delta, Rho, Conf]] + ''' + conf_values = [(corr.n_delta, + (corr.n_rho % corr.n_delta) / corr.n_delta, + conf) + for ht in hts.values() + for (idx, corr), (_, conf) in zip(ht.corr, ht.confs)] + return conf_values + + +def tht_grid(hts: Dict[str, HypothesisTracker]): + '''Calculates confidence map over rho and delta hypothesis space. + + The confidence map is calculated by creating a grid over rho and delta + and on each point summing the confidence of hypothesis nearby. + + The resulting map is normalized by the sum of values as a histogram. + + Result: + (n x 3) array with columns as: rho_value, delta_value and conf + ''' + delta_samples, rho_samples = ht_grid() + + conf_values = tht_ht_points(hts) + + df = ht_weighted_distribution(conf_values, delta_samples, + rho_samples) + return df[['rho', 'delta', 'weight']].values + + +def tht_tracking_confs(tht_pkl_fn: Union[str, Dict[str, HypothesisTracker]], + onset_count: Optional[int] = None) -> [float, float]: + ''' + Obtains the tracking confidence from a handcoded_rep_utilities pickle as the avg top conf. + + Args: + tht_pkl_fn: either: + * filename of the handcoded_rep_utilities tracking pickle + * unpickled tracking (dict[str, HypothesisTracker]) + onset_count: total number of onsets or None + + Returns: + list of confidence values at each timepoint :: [ms, confidence score] + ''' + if isinstance(tht_pkl_fn, str): + with open(tht_pkl_fn, 'rb') as f: + hts = pickle.load(f) + else: + hts = tht_pkl_fn + + if onset_count is None: + onset_count = max([o for ht in hts.values() for o, c in ht.confs]) + + top_hts = top_hypothesis(hts, onset_count) + + conf_values = [(ht.onset_times[idx], dict(ht.confs)[idx]) + for idx, ht in top_hts] + + return conf_values + + +def tht_tracking_conf(tht_pkl_fn: Union[str, Dict[str, HypothesisTracker]], + onset_count: Optional[int] = None) -> float: + ''' + Mean handcoded_rep_utilities top tracking confidence. + + See tht_tracking_confs for parameters details. + ''' + time_values, conf_values = zip(*tht_tracking_confs(tht_pkl_fn, + onset_count)) + + return np.mean(conf_values) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/tracking_overtime.py b/src/music/utilities/handcoded_rep_utilities/tht/tracking_overtime.py new file mode 100644 index 0000000000000000000000000000000000000000..ef0ffff67440e3cf997ea7ceb40390336a5aa6af --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/tracking_overtime.py @@ -0,0 +1,63 @@ +import numpy as np + +class OvertimeTracking: + ''' + Class for analysing the results of tracking a set of onsets during the + tracking process. + ''' + + def __init__(self, hts): + ''' + Initialize the instance analyze the tracking overtime. + + self.hts: tracking result + self.time: time(ms) -> [HypothesisAtTime] + self.onset_times: [ms] + + Args: + hts: string -> HypothesisTracker result of a + TactusHypothesisTracker + ''' + self.hts = hts + self.time = {} + self.onset_times = sorted(list(hts.values())[0].onset_times) + + assert all([np.array_equal(self.onset_times, ht.onset_times) + for ht in list(hts.values())]) + + # Update time with one HAT per hypothesis at a given time + for name, ht in list(hts.items()): + assert (([x[0] for x in ht.corr] == [x[0] for x in ht.confs]), + 'hts corrections and conf does not have same onset indexes') + for idx in range(len(ht.corr)): + onset_idx, corr = ht.corr[idx] + conf = ht.confs[idx][1] + onset_time = self.onset_times[onset_idx] + hts_at_time = self.time.get(onset_time, []) + hts_at_time.append(HypothesisAtTime(ht, onset_idx, corr, conf)) + self.time[onset_time] = hts_at_time + + def hypothesis_by_time(self): + 'Returns the list of HTS sorted by time' + return ((time, self.time[time]) for time in self.onset_times[1:] + if time in self.time) # TODO(march): verify that check is necessary + + def hypothesis_sorted_by_conf(self): + 'Returns the list of HTS sorted by time and then by confidence' + for time, hats in self.hypothesis_by_time(): + yield (time, sorted(hats, key=lambda hat: hat.conf, reverse=True)) + + +class HypothesisAtTime: + ''' + Class to represent a hypothesis at a given time. + ''' + def __init__(self, hts_ref, onset_idx, corr, conf): + self.hts = hts_ref + self.onset_idx = onset_idx + self.corr = corr + self.ht_value = corr.new_hypothesis() + self.conf = conf + + def __repr__(self): + return '%s (c:%.2f)' % (self.hts, self.conf) diff --git a/src/music/utilities/handcoded_rep_utilities/tht/utils.py b/src/music/utilities/handcoded_rep_utilities/tht/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..788f935986b54aaf65003342fc98ca9dcfda114f --- /dev/null +++ b/src/music/utilities/handcoded_rep_utilities/tht/utils.py @@ -0,0 +1,68 @@ +"""Utils for tactus processing.""" + +import numpy as np +import more_itertools as mit + + +def real_proj(xs, proj, ongoing_play): + return project(xs, proj, ongoing_play.discovered_play()) + + +def centered_real_proj(xs, proj, ongoing_play): + _xs = np.array(xs) + _proj = np.array(proj) + _r_p_pos = project(_xs[_xs >= 0], _proj[_xs >= 0], + ongoing_play.discovered_play()) + _r_p_neg = reversed(project(reversed(_xs[_xs < 0]), + reversed(_proj[_xs < 0]), + reversed(ongoing_play.discovered_play()))) + return list(_r_p_neg) + list(_r_p_pos) + + +def project(xs, base, reference): + ''' + For each value in base obtains the closest value in reference without + repetition. Reference values are processed in order, so some of them may + not be used because there are neighbors closer to the base value. As a + result, some base values might not be matched. + + Args: + xs: index associated with base + base: iterable + reference: iterable + + Returns: + :: [(index, base value, reference value)] + List of matched pairs of values. Some values from base might not appear + in the result. + ''' + _base = list(base) + if len(_base) == 0: + return [] + + play_it = mit.peekable(reference) + proj_it = mit.peekable(zip(xs, _base)) + last_play_onset = next(play_it) + more_proj = True + ret = [] + while more_proj: + try: + last_proj_idx, last_proj_onset = next(proj_it) + last_dist = abs(last_play_onset - last_proj_onset) + try: + while True: + new_dist = abs(play_it.peek() - last_proj_onset) + if new_dist < last_dist: + last_play_onset = next(play_it) + last_dist = new_dist + else: + break + ret.append((last_proj_idx, last_proj_onset, last_play_onset)) + # TODO: Dividir la funcion en dos + #last_play_onset = play_it.next() + except StopIteration: + more_proj = False + ret.append((last_proj_idx, last_proj_onset, last_play_onset)) + except StopIteration: + more_proj = False + return ret diff --git a/src/music/utilities/midi_processor.py b/src/music/utilities/midi_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..fd25b09347d7508739d54f0086b15f437fb874ab --- /dev/null +++ b/src/music/utilities/midi_processor.py @@ -0,0 +1,680 @@ +import pretty_midi +from copy import deepcopy +import numpy as np +from miditok import CPWord, Structured +from miditoolkit import MidiFile +from src.music.config import MAX_EMBEDDING, CHUNK_SIZE +from src.music.utilities.chord_structured import ChordStructured + +# code from https://github.com/jason9693/midi-neural-processor +RANGE_NOTE_ON = 128 +RANGE_NOTE_OFF = 128 +RANGE_VEL = 32 +RANGE_TIME_SHIFT = 100 +MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON + +START_IDX = { + 'note_on': 0, + 'note_off': RANGE_NOTE_ON, + 'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF, + 'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT +} + +# Our parameters +pitch_range = range(21, 109) +beat_res = {(0, 4): 8, (4, 12): 4} +nb_velocities = 32 +additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True, 'TimeSignature': False, 'Program': False, + 'rest_range': (2, 8), # (half, 8 beats) + 'nb_tempos': 32, # nb of tempo bins + 'tempo_range': (40, 250)} # (min, max) + +# Creates the tokenizer_cp and loads a MIDI +# tokenizer_cp = CPWord(pitch_range, beat_res, nb_velocities, additional_tokens) +tokenizer_structured = ChordStructured(pitch_range, beat_res, nb_velocities) + +class SustainAdapter: + def __init__(self, time, type): + self.start = time + self.type = type + + +class SustainDownManager: + def __init__(self, start, end): + self.start = start + self.end = end + self.managed_notes = [] + self._note_dict = {} # key: pitch, value: note.start + + def add_managed_note(self, note: pretty_midi.Note): + self.managed_notes.append(note) + + def transposition_notes(self): + for note in reversed(self.managed_notes): + try: + note.end = self._note_dict[note.pitch] + except KeyError: + note.end = max(self.end, note.end) + self._note_dict[note.pitch] = note.start + + +# Divided note by note_on, note_off +class SplitNote: + def __init__(self, type, time, value, velocity): + ## type: note_on, note_off + self.type = type + self.time = time + self.velocity = velocity + self.value = value + + def __repr__(self): + return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\ + .format(self.time, self.type, self.value, self.velocity) + + +class Event: + def __init__(self, event_type, value): + self.type = event_type + self.value = value + + def __repr__(self): + return ''.format(self.type, self.value) + + def to_int(self): + return START_IDX[self.type] + self.value + + @staticmethod + def from_int(int_value): + info = Event._type_check(int_value) + return Event(info['type'], info['value']) + + @staticmethod + def _type_check(int_value): + range_note_on = range(0, RANGE_NOTE_ON) + range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF) + range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT) + + valid_value = int_value + + if int_value in range_note_on: + return {'type': 'note_on', 'value': valid_value} + elif int_value in range_note_off: + valid_value -= RANGE_NOTE_ON + return {'type': 'note_off', 'value': valid_value} + elif int_value in range_time_shift: + valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF) + return {'type': 'time_shift', 'value': valid_value} + else: + valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT) + return {'type': 'velocity', 'value': valid_value} + + +def _divide_note(notes): + result_array = [] + notes.sort(key=lambda x: x.start) + + for note in notes: + on = SplitNote('note_on', note.start, note.pitch, note.velocity) + off = SplitNote('note_off', note.end, note.pitch, None) + result_array += [on, off] + return result_array + + +def _merge_note(snote_sequence): + note_on_dict = {} + result_array = [] + + for snote in snote_sequence: + # print(note_on_dict) + if snote.type == 'note_on': + note_on_dict[snote.value] = snote + elif snote.type == 'note_off': + try: + on = note_on_dict[snote.value] + off = snote + if off.time - on.time == 0: + continue + result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time) + result_array.append(result) + except: + print('info removed pitch: {}'.format(snote.value)) + return result_array + + +def _snote2events(snote: SplitNote, prev_vel: int): + result = [] + if snote.velocity is not None: + modified_velocity = snote.velocity // 4 + if prev_vel != modified_velocity: + result.append(Event(event_type='velocity', value=modified_velocity)) + result.append(Event(event_type=snote.type, value=snote.value)) + return result + + +def _event_seq2snote_seq(event_sequence): + timeline = 0 + velocity = 0 + snote_seq = [] + + for event in event_sequence: + if event.type == 'time_shift': + timeline += ((event.value+1) / 100) + if event.type == 'velocity': + velocity = event.value * 4 + else: + snote = SplitNote(event.type, timeline, event.value, velocity) + snote_seq.append(snote) + return snote_seq + + +def _make_time_sift_events(prev_time, post_time): + time_interval = int(round((post_time - prev_time) * 100)) + results = [] + while time_interval >= RANGE_TIME_SHIFT: + results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1)) + time_interval -= RANGE_TIME_SHIFT + if time_interval == 0: + return results + else: + return results + [Event(event_type='time_shift', value=time_interval-1)] + + +def _control_preprocess(ctrl_changes): + sustains = [] + + manager = None + for ctrl in ctrl_changes: + if ctrl.value >= 64 and manager is None: + # sustain down + manager = SustainDownManager(start=ctrl.time, end=None) + elif ctrl.value < 64 and manager is not None: + # sustain up + manager.end = ctrl.time + sustains.append(manager) + manager = None + elif ctrl.value < 64 and len(sustains) > 0: + sustains[-1].end = ctrl.time + return sustains + + +def _note_preprocess(susteins, notes): + note_stream = [] + count_note_processed = 0 + if susteins: # if the midi file has sustain controls + for sustain in susteins: + if len(notes) > 0: + for note_idx, note in enumerate(notes): + if note.start < sustain.start: + note_stream.append(note) + last_counted = True + elif note.start > sustain.end: + # notes = notes[note_idx:] + # sustain.transposition_notes() + last_counted = False + break + else: + sustain.add_managed_note(note) + last_counted = True + count_note_processed += 1 + sustain.transposition_notes() # transpose what in the sustain + note_stream += sustain.managed_notes # add to stream + # remove notes that were already added to the stream + last_idx = note_idx if not last_counted else note_idx + 1 + if last_idx < len(notes): + notes = notes[last_idx:] # save next notes, previous notes were stored in note stream + else: + notes = [] + note_stream += notes + count_note_processed += len(notes) + else: # else, just push everything into note stream + for note_idx, note in enumerate(notes): + note_stream.append(note) + + note_stream.sort(key= lambda x: x.start) + return note_stream + +def midi_valid(midi) -> bool: + # if any(ts.numerator != 4 or ts.denominator != 4 for ts in midi.time_signature_changes): + # return False # time signature different from 4/4 + # if midi.max_tick < 10 * midi.ticks_per_beat: + # return False # this MIDI is too short + return True + + +def encode_midi_structured(file_path, nb_aug, nb_noise): + notes = [] + mid = MidiFile(file_path) + assert midi_valid(mid) + + # Converts MIDI to tokens, and back to a MIDI + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + + assert len(notes) == len(mid.instruments[0].notes) + + # sort notes + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + original_notes = deepcopy(notes) + # convert notes to ints + encoded_main = tokenizer_structured.midi_to_tokens(mid)[0] + + min_pitch = np.min([n.pitch for n in notes]) + + encoded_augmentations = [] + noise_shift = 6 + aug_shift = 3 + embedding_noise = None + for i_aug in range(nb_aug): + a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) + mid.instruments[0].notes = a_notes + assert midi_valid(mid) + embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes + encoded_augmentations.append(embedding_aug) + if nb_noise > 0: + a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) + mid.instruments[0].notes = a_notes + assert midi_valid(mid) + embedding_noise = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes + + return encoded_main, encoded_augmentations, embedding_noise + +def encode_midi_cp(file_path, nb_aug, nb_noise): + notes = [] + mid = MidiFile(file_path) + assert midi_valid(mid) + + # Converts MIDI to tokens, and back to a MIDI + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + + assert len(notes) == len(mid.instruments[0].notes) + + # sort notes + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + original_notes = deepcopy(notes) + # convert notes to ints + encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] + + min_pitch = np.min([n.pitch for n in notes]) + + encoded_augmentations = [] + noise_shift = 6 + aug_shift = 3 + embedding_noise = None + for i_aug in range(nb_aug): + a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) + mid.instruments[0].notes = a_notes + assert midi_valid(mid) + embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes + encoded_augmentations.append(embedding_aug) + if nb_noise > 0: + a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) + mid.instruments[0].notes = a_notes + assert midi_valid(mid) + embedding_noise = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes + + return encoded_main, encoded_augmentations, embedding_noise + +def alter_notes_exact_tick(notes, shift, min_pitch): + # copy original notes + a_notes = deepcopy(notes) + # sample smart augmentation + pitch_shift, time_scaling = 0, 0 + while pitch_shift == 0 and time_scaling == 0: + pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) + time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) + assert pitch_shift <= shift and pitch_shift >= -shift + # modify notes + for e in a_notes: + e.start = int(e.start * (1. + time_scaling / 100)) + e.end = int(e.end * (1. + time_scaling / 100)) + new_pitch = max(e.pitch + pitch_shift, 0) + e.pitch = new_pitch + return a_notes + +def alter_notes(notes, shift, min_pitch): + # copy original notes + a_notes = deepcopy(notes) + # sample smart augmentation + pitch_shift, time_scaling = 0, 0 + while pitch_shift == 0 and time_scaling == 0: + pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) + time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) + assert pitch_shift <= shift and pitch_shift >= -shift + # modify notes + for e in a_notes: + e.start = e.start * (1. + time_scaling / 100) + e.end = e.end * (1. + time_scaling / 100) + new_pitch = max(e.pitch + pitch_shift, 0) + e.pitch = new_pitch + return a_notes + +def encode_midi(file_path, nb_aug, nb_noise): + notes = [] + mid = pretty_midi.PrettyMIDI(midi_file=file_path) + + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + + assert len(notes) == len(mid.instruments[0].notes) + # sort notes + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + # convert notes to ints + encoded_main = convert_notes(notes) + + min_pitch = np.min([n.pitch for n in notes]) + + encoded_augmentations = [] + noise_shift = 6 + aug_shift = 3 + embedding_noise = None + for i_aug in range(nb_aug): + a_notes = alter_notes(notes, aug_shift, min_pitch) + embedding_group = convert_notes(a_notes) # encode notes + encoded_augmentations.append(embedding_group) + if nb_noise > 0: + a_notes = alter_notes(notes, noise_shift, min_pitch) + embedding_noise = convert_notes(a_notes) # encode notes + + return encoded_main, encoded_augmentations, embedding_noise + + +def chunk_notes(n_notes_per_chunk, notes): + index = 0 + chunks = [] + for n in n_notes_per_chunk: + chunks.append(notes[index:index+n]) + index += n + return chunks + +def chunk_first_embedding(chunk_size, embedding): + chunks = [] + index = 0 + if len(embedding) < chunk_size: + return [embedding] + else: + for i in range(chunk_size, len(embedding) + chunk_size, chunk_size): + if (len(embedding) - index) > (chunk_size / 2): + chunks.append(embedding[index:i]) + index = i + return chunks + +def encode_midi_in_chunks(file_path, n_aug, n_noise): + n_noise = 0 + notes = [] + mid = pretty_midi.PrettyMIDI(midi_file=file_path) + # preprocess midi + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + + assert len(notes) == len(mid.instruments[0].notes) + + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + # convert notes to ints + main_embedding = convert_notes(notes) + # split the sequence of events in chunks + if np.max(main_embedding) < MAX_EMBEDDING and np.min(main_embedding) >= 0: + encoded_chunks = chunk_first_embedding(CHUNK_SIZE, main_embedding) + else: + assert False + + n_notes_per_chunk = [np.argwhere(np.array(ec) < 128).flatten().size for ec in encoded_chunks] + + chunked_notes = chunk_notes(n_notes_per_chunk, notes) + + # reencode chunks by shifting notes + encoded_chunks = [] + for note_group in chunked_notes: + note_group = shift_notes(note_group) + embedding_main = convert_notes(note_group)[:CHUNK_SIZE] + encoded_chunks.append(embedding_main) + + min_pitches = [np.min([n.pitch for n in cn]) for cn in chunked_notes] + + encoded_augmentations = [] + aug_shift = 3 + for i_aug in range(n_aug): + chunked_embedding_aug = [] + for note_group, min_pitch in zip(chunked_notes, min_pitches): + a_notes = alter_notes(note_group, aug_shift, min_pitch) + a_notes = shift_notes(a_notes) + assert len(a_notes) == len(note_group) + embedding_group = convert_notes(a_notes)[:CHUNK_SIZE] # encode notes + chunked_embedding_aug.append(embedding_group) + encoded_augmentations += chunked_embedding_aug + + assert len(encoded_augmentations) == n_aug * len(encoded_chunks) + return encoded_chunks, encoded_augmentations, [] + +def encode_miditok_in_chunks(file_path, n_aug, n_noise): + n_noise = 0 + notes = [] + mid = MidiFile(file_path) + assert midi_valid(mid) + + # Converts MIDI to tokens, and back to a MIDI + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + assert len(notes) == len(mid.instruments[0].notes) + + # sort notes + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + # convert notes to ints + encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] + + encoded_chunks = chunk_first_embedding(CHUNK_SIZE, encoded_main) + n_notes_per_chunk = [len([tokenizer_cp.vocab.token_to_event[e[0]] for e in enc_chunk if tokenizer_cp.vocab.token_to_event[e[0]] == 'Family_Note']) + for enc_chunk in encoded_chunks] + chunked_notes = chunk_notes(n_notes_per_chunk, notes) + + # reencode chunks by shifting notes + encoded_chunks = [] + for note_group in chunked_notes: + mid.instruments[0].notes = note_group + mid = shift_mid(mid) # shift midi + assert midi_valid(mid) + embedding_main = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # tokenize midi + encoded_chunks.append(embedding_main) + + + min_pitch = np.min([n.pitch for n in notes]) + + encoded_augmentations = [] + aug_shift = 3 + for i_aug in range(n_aug): + chunked_embedding_aug = [] + for note_group in chunked_notes: + a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) + assert len(a_notes) == len(note_group) + mid.instruments[0].notes = a_notes + # shift midi + mid = shift_mid(mid) + assert midi_valid(mid) + # tokenize midi + embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # encode notes + chunked_embedding_aug.append(embedding_aug) + encoded_augmentations += chunked_embedding_aug + + assert len(encoded_augmentations) == n_aug * len(encoded_chunks) + return encoded_chunks, encoded_augmentations, [] + + +def encode_midi_chunks_structured(file_path, n_aug, n_noise): + n_noise = 0 + notes = [] + mid = MidiFile(file_path) + assert midi_valid(mid) + + # Converts MIDI to tokens, and back to a MIDI + for inst in mid.instruments: + inst_notes = inst.notes + # ctrl.number is the number of sustain control. If you want to know abour the number type of control, + # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 + ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) + notes += _note_preprocess(ctrls, inst_notes) + assert len(notes) == len(mid.instruments[0].notes) + + nb_notes = CHUNK_SIZE // 4 + notes = notes[:50 * nb_notes] # limit to 50 chunks to speed up + # sort notes + arg_rank = np.argsort([n.start for n in notes]) + notes = list(np.array(notes)[arg_rank]) + + assert (len(notes) // nb_notes) > 1 # assert at least 3 chunks + n_notes_per_chunk = [nb_notes for _ in range(len(notes) // nb_notes)] + if len(notes) % nb_notes > nb_notes / 2: + n_notes_per_chunk.append(len(notes) % nb_notes) + chunked_notes = chunk_notes(n_notes_per_chunk, notes) + + # reencode chunks by shifting notes + encoded_chunks = [] + for note_group in chunked_notes: + mid.instruments[0].notes = note_group + mid = shift_mid(mid) # shift midi + assert midi_valid(mid) + embedding_main = tokenizer_structured.midi_to_tokens(mid)[0] # tokenize midi + encoded_chunks.append(embedding_main) + + + min_pitch = np.min([n.pitch for n in notes]) + + encoded_augmentations = [] + aug_shift = 3 + for i_aug in range(n_aug): + chunked_embedding_aug = [] + for note_group in chunked_notes: + a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) + assert len(a_notes) == len(note_group) + mid.instruments[0].notes = a_notes + # shift midi + mid = shift_mid(mid) + assert midi_valid(mid) + # tokenize midi + embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes + chunked_embedding_aug.append(embedding_aug) + encoded_augmentations += chunked_embedding_aug + + assert len(encoded_augmentations) == n_aug * len(encoded_chunks) + return encoded_chunks, encoded_augmentations, [] + +def shift_mid(mid): + # mid = deepcopy(mid) + to_remove = mid.instruments[0].notes[0].start + if to_remove > 0: + for n in mid.instruments[0].notes: + n.start -= to_remove + n.end -= to_remove + + # for e in mid.tempo_changes: + # e.time = max(0, e.time - to_remove) + # + # for e in mid.time_signature_changes: + # e.time = max(0, e.time - to_remove) + # + # for e in mid.key_signature_changes: + # e.time = max(0, e.time - to_remove) + return mid + +def shift_notes(notes): + to_remove = notes[0].start + for n in notes: + n.start -= to_remove + n.end -= to_remove + return notes + +def convert_notes(notes): + events = [] + dnotes = _divide_note(notes) # split events in on / off + + # print(dnotes) + dnotes.sort(key=lambda x: x.time) + # print('sorted:') + # print(dnotes) + cur_time = 0 + cur_vel = 0 + for snote in dnotes: + events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) + events += _snote2events(snote=snote, prev_vel=cur_vel) + # events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) + + cur_time = snote.time + cur_vel = snote.velocity + + event_list = [e.to_int() for e in events] + if not (np.max(event_list) < MAX_EMBEDDING and np.min(event_list) >= 0): + print('weird') + assert False + return event_list + +def decode_midi_structured(encoding, file_path=None): + mid = tokenizer_structured.tokens_to_midi([encoding]) + if file_path: + mid.dump(file_path) + return mid + +def decode_midi_cp(encoding, file_path=None): + mid = tokenizer_cp.tokens_to_midi([encoding]) + if file_path: + mid.dump(file_path) + return mid + +def decode_midi(idx_array, file_path=None): + event_sequence = [Event.from_int(idx) for idx in idx_array] + # print(event_sequence) + snote_seq = _event_seq2snote_seq(event_sequence) + note_seq = _merge_note(snote_seq) + note_seq.sort(key=lambda x:x.start) + + mid = pretty_midi.PrettyMIDI() + # if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set + instument = pretty_midi.Instrument(1, False, "Developed By Yang-Kichang") + instument.notes = note_seq + + mid.instruments.append(instument) + if file_path is not None: + mid.write(file_path) + return mid + + +if __name__ == '__main__': + encoded = encode_midi('bin/ADIG04.mid') + print(encoded) + decided = decode_midi(encoded,file_path='bin/test.mid') + + ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid') + print(ins) + print(ins.instruments[0]) + for i in ins.instruments: + print(i.control_changes) + print(i.notes) + diff --git a/src/music/utilities/processing_models/__init__.py b/src/music/utilities/processing_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/processing_models/piano_detection_model.py b/src/music/utilities/processing_models/piano_detection_model.py new file mode 100644 index 0000000000000000000000000000000000000000..60d309537101b060713d2adf430d333d1ef8b886 --- /dev/null +++ b/src/music/utilities/processing_models/piano_detection_model.py @@ -0,0 +1,298 @@ +"""This piano solo detection module is trained by Bochen Li in Feb. 2020, and +then is cleaned up by Qiuqiang Kong in Jul. 2020. +Code from https://github.com/bytedance/GiantMIDI-Piano +""" +import numpy as np +import librosa +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +# Hyper-parameters +SR = 32000 +FRAME_LEN = 2048 + +if SR == 32000: + FRAME_HOP = 500 + +CH_NUM = 1 +USE_DB = False +OFFSET = 1.0 + +N_FFT = FRAME_LEN +WIN = np.sqrt(np.hanning(N_FFT)) +DIM_F = int(FRAME_LEN / 2) +DIM_F = 256 +DIM_T = 64 +DIM_T_HOP = 64 + + +def read_audio_stereo(filename): + wav, _ = librosa.core.load(filename, sr=SR, mono=None) + if wav.ndim == 1: + wav = np.tile(wav[..., None], (1,2)) + else: + wav = wav.T + return wav + + +def wav2spec_mono(wav): + spec = librosa.core.stft(y=wav, n_fft=N_FFT, + hop_length=FRAME_HOP, win_length=FRAME_LEN, + window=WIN, center='True', pad_mode='constant') + mag, pha = librosa.core.magphase(spec) + if USE_DB: + mag = librosa.core.amplitude_to_db(S=(mag+OFFSET)) + ang = np.angle(pha) + mag = mag[:DIM_F, :] + ang = ang[:DIM_F, :] + mag = mag[None, ...] + ang = ang[None, ...] + return mag, ang + +def spec2wav_mono(mag, ang): + if USE_DB: + mag = librosa.core.db_to_amplitude(S_db=mag) - OFFSET + pha = np.exp(1j * ang) + spec = mag * pha + if DIM_F % 2 == 0: + tmp = np.zeros((1, spec.shape[-1])) + spec = np.concatenate((spec, tmp), axis=0) + + wav = librosa.core.istft(stft_matrix=spec, + hop_length=FRAME_HOP, + win_length=FRAME_LEN, + window=WIN, center='True') + return wav + + +def wav2spec(wav): + """ + input: mono shape=(n,) or stereo shape=(n,2) + output: mag, ang + mono shape=(1,F,T) or stereo shape=(2,F,T) + """ + + if wav.ndim == 1: + + mag, ang = wav2spec_mono(wav) + + else: + + mag1, ang1 = wav2spec_mono(wav[:, 0]) + mag2, ang2 = wav2spec_mono(wav[:, 1]) + mag = np.concatenate((mag1, mag2), axis=0) + ang = np.concatenate((ang1, ang2), axis=0) + + return mag, ang + +def spec2wav(mag, ang): + if mag.shape[0] == 1: + mag = mag[0,...] + ang = ang[0,...] + wav = spec2wav_mono(mag, ang) + else: + wav1 = spec2wav_mono(mag[0,...], ang[0,...]) + wav2 = spec2wav_mono(mag[1,...], ang[1,...]) + wav = np.concatenate( (wav1[...,None], wav2[...,None]), axis=-1 ) + + return wav + + +class ConvBlock(nn.Module): + def __init__(self, in_plane, out_plane, droprate=0.0): + super(ConvBlock, self).__init__() + + self.conv = nn.Conv2d(in_plane, out_plane, kernel_size=3, stride=1, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_plane) + self.relu = nn.ReLU(inplace=True) + + self.droprate = droprate + + def forward(self, x): + out = self.relu(self.bn(self.conv(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, training=self.training) + return out + + +class PianoSoloDetector(object): + def __init__(self, chkpt_path): + """Piano solo detector.""" + self.model = PianoDetection() + + if torch.cuda.is_available(): + print('Using GPU') + self.model = self.model.cuda() + # else: + # print('Using CPU') + + self.model.load(chkpt_path) + + def predict(self, wav): + """Predict the probabilities of piano solo on 1-second segments. + """ + rms = np.sqrt(np.mean(wav ** 2)) + wav = wav / rms / 20 + duration = len(wav) / SR + + n_seg = int(duration / 1.00) + + mag_segs = [] + batch_size = 32 + + all_probs = [] + zero_locts = [] + + for i in np.arange(n_seg): + wav_seg = wav[i * SR : (i + 1) * SR + 1000] + + if np.sqrt(np.mean(wav_seg**2)) < 0.001: + zero_locts.append(i) + + mag, ang = wav2spec(wav_seg) + mag = mag[..., :DIM_T] + + mag_segs.append(mag) + + if len(mag_segs) == batch_size or i == n_seg - 1: + probs = self.predict_seg(np.array(mag_segs)) + all_probs.append(probs) + mag_segs = [] + + all_probs = np.concatenate(all_probs) + zero_locts = np.array(zero_locts) + + if len(zero_locts) > 0: + all_probs[zero_locts] = 0 + + return all_probs + + def predict_seg(self, mag_seg): + """Predict the probability of piano solo on each segment. + + Args: + mag_seg: (batch_size, 1, F, T) + + Returns: + probs: (batch_size,) + """ + x = np.transpose(mag_seg, (0, 1, 3, 2)) + y = self.model.predict_on_batch(x) # (batch_size, classes_num) + probs = y[:, 1] + return probs + + +class PianoDetection(nn.Module): + def __init__(self): + super(PianoDetection, self).__init__() + + self.net = CNN() + + self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) + + if torch.cuda.is_available(): + self.criterion = nn.CrossEntropyLoss().cuda() + else: + self.criterion = nn.CrossEntropyLoss() + + def forward(self, x): + x = self.net(x) + return x + + def _convert(self, x): + x_var = [] + x_var = Variable(torch.FloatTensor(x)) + + if torch.cuda.is_available(): + x_var = x_var.cuda() + + return x_var + + def _convert_int(self, x): + x_var = [] + x_var = Variable(torch.LongTensor(x)) + + if torch.cuda.is_available(): + x_var = x_var.cuda() + + return x_var + + def train_on_batch(self, x, t): + self.train() + x = self._convert(x) + t = self._convert_int(t) + y = self.forward(x=x) + self.optimizer.zero_grad() + loss = self.criterion(y, t) + loss.backward() + self.optimizer.step() + return loss.data.cpu().numpy() + + + def eval_on_batch(self, x, t): + self.eval() + x = self._convert(x) + t = self._convert_int(t) + y = self.forward(x) + loss = self.criterion(y, t) + return loss.data.cpu().numpy() + + def predict_on_batch(self, x): + self.eval() + x = self._convert(x) + y = self.forward(x) + y = F.softmax(y, dim=1) + return y.data.cpu().numpy() + + def adjust_learning_rate(self, epoch): + lr = self.lr * (0.8 ** np.floor(epoch / 5)) + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + return lr + + def save(self, filename): + torch.save(self.state_dict(), filename+".pth") + + def load(self, filename): + if torch.cuda.is_available(): + self.load_state_dict(torch.load(filename)) + else: + self.load_state_dict(torch.load(filename, map_location='cpu')) + + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + + self.cnn1 = ConvBlock(1, 32) + self.cnn2 = ConvBlock(32, 64) + self.cnn3 = ConvBlock(64, 64) + self.cnn4 = ConvBlock(64, 32) + self.fn1 = nn.Linear(2048, 50) + self.fn2 = nn.Linear(50, 2) + + def forward(self, x): + x = self.cnn1(x) + x = F.avg_pool2d(x, 2) + + x = self.cnn2(x) + x = F.avg_pool2d(x, 2) + + x = self.cnn3(x) + x = F.avg_pool2d(x, 2) + + x = self.cnn4(x) + x = F.avg_pool2d(x, 2) + + x_dim = x.shape[1] * x.shape[2] * x.shape[3] + x = x.view(-1, x_dim) + + x = self.fn1(x) + x = F.relu(x) + + x = self.fn2(x) + + return x \ No newline at end of file diff --git a/src/music/utilities/representation_learning_utilities/__init__.py b/src/music/utilities/representation_learning_utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music/utilities/representation_learning_utilities/__pycache__/__init__.cpython-39.pyc b/src/music/utilities/representation_learning_utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1dd537a1bc39ab6e5b207bb4ac62c0dc0507daf Binary files /dev/null and b/src/music/utilities/representation_learning_utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music/utilities/representation_learning_utilities/__pycache__/constants.cpython-39.pyc b/src/music/utilities/representation_learning_utilities/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c3a4ee5bc3eaa4878812298a2b071429d9f4fe5 Binary files /dev/null and b/src/music/utilities/representation_learning_utilities/__pycache__/constants.cpython-39.pyc differ diff --git a/src/music/utilities/representation_learning_utilities/argument_funcs.py b/src/music/utilities/representation_learning_utilities/argument_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..adf6ca9ad92719e6bd8f0fb27c1ea02f8031c5fe --- /dev/null +++ b/src/music/utilities/representation_learning_utilities/argument_funcs.py @@ -0,0 +1,293 @@ +import argparse + +from .constants import SEPERATOR + +# parse_train_args +def parse_train_args(): + """ + ---------- + Author: Damon Gwinn + ---------- + Argparse arguments for training a model + ---------- + """ + + parser = argparse.ArgumentParser() + + parser.add_argument("-input_dir", type=str, default="/home/cedric/Documents/pianocktail/data/music/diverse_piano_datasets/diverse_piano/", help="Folder of preprocessed and " + "pickled midi files") + parser.add_argument("-output_dir", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/", help="Folder to save model " + "weights. Saves one every epoch") + parser.add_argument("-model", type=str, default="music_transformer", help="which model to use (autoencoder or standard)") + parser.add_argument("-weight_modulus", type=int, default=5, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)") + parser.add_argument("-print_modulus", type=int, default=10, help="How often to print train results for a batch (batch loss, learn rate, etc.)") + parser.add_argument("-epoch_size", type=int, default=100, help="Number of batches per epoch") + parser.add_argument("-seed", type=int, default=-1, help="seed to use, use -1 to not set a seed") + parser.add_argument("-trial_id", type=int, default=1, help="trial identifier") + parser.add_argument("-trial_name", type=str, default="", help="trial name") + parser.add_argument("-data_augmentation", type=int, default=0, help="whether to implement data augmentation on training set") + parser.add_argument("-noisy_train", type=int, default=0, help="whether to implement data augmentation on training set") + + parser.add_argument("-n_workers", type=int, default=0, help="Number of threads for the dataloader") + parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") + parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting") + + parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on") + parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at") + + parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.") + parser.add_argument("-lr_warming", type=int, default=4000, help="Number of warming batches for learning rate") + parser.add_argument("-lr_asymptot", type=float, default=0.00015, help="asymptotic lr") + parser.add_argument("-lr_target_up", type=float, default=0.001, help="highest lr") + parser.add_argument("-lr_exp_slope", type=float, default=3e4, help="slope of decreasing exponential for lr") + parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)") + parser.add_argument("-batch_size", type=int, default=4, help="Batch size to use") + parser.add_argument("-max_batch", type=int, default=4, help="maximum batch size the gpu can handle") + parser.add_argument("-epochs", type=int, default=300, help="Number of epochs to use") + + parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations") + parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider") + parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use") + parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention") + parser.add_argument("-d_model", type=int, default=384, help="Dimension of the model (output dim of embedding layers, etc.)") + parser.add_argument("-kdim", type=int, default=256, help="Dimension of keys") + parser.add_argument("-qdim", type=int, default=256, help="Dimension of queries") + parser.add_argument("-multiple_er", type=int, default=1, help="Whether to use relation embeddings for each head (True) or one shared (False)") + parser.add_argument("-nb_evals", type=int, default=250, help="Number of evaluation steps") + + parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer") + + parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate") + + return parser.parse_args() + + +def parse_train_siamese_args(): + """ + ---------- + Author: Damon Gwinn + ---------- + Argparse arguments for training a model + ---------- + """ + + parser = argparse.ArgumentParser() + + parser.add_argument("-input_dir", type=str, default="/home/cedric/Documents/pianocktail/data/music/encoded_chunks_structured/diverse_piano/", help="Folder of " + "preprocessed and " + "pickled midi files") + parser.add_argument("-output_dir", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/", help="Folder to save model " + "weights. Saves one every epoch") + parser.add_argument("-music_transfo_path", type=str, default= + "/home/cedric/Documents/pianocktail/data/checkpoints/music_representation/music_transformer_autoencoder_structured/small_structured_auto_64batch1/") + parser.add_argument("-nb_neg", type=int, default=128, help="") + parser.add_argument("-model", type=str, default="siamese_net", help="which model to use (autoencoder or standard)") + parser.add_argument("-max_sequence", type=int, default=512, help="Maximum midi sequence to consider") + parser.add_argument("-weight_modulus", type=int, default=5, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)") + parser.add_argument("-print_modulus", type=int, default=10, help="How often to print train results for a batch (batch loss, learn rate, etc.)") + parser.add_argument("-epoch_size", type=int, default=100, help="Number of batches per epoch") + parser.add_argument("-seed", type=int, default=-1, help="seed to use, use -1 to not set a seed") + parser.add_argument("-trial_id", type=int, default=1, help="trial identifier") + parser.add_argument("-trial_name", type=str, default="", help="trial name") + parser.add_argument("-rep_size", type=int, default=128, help="") + + parser.add_argument("-n_workers", type=int, default=0, help="Number of threads for the dataloader") + parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") + parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting") + + parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on") + parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at") + + parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.") + parser.add_argument("-lr_warming", type=int, default=4000, help="Number of warming batches for learning rate") + parser.add_argument("-lr_asymptot", type=float, default=0.00015, help="asymptotic lr") + parser.add_argument("-lr_target_up", type=float, default=0.001, help="highest lr") + parser.add_argument("-lr_exp_slope", type=float, default=3e4, help="slope of decreasing exponential for lr") + parser.add_argument("-batch_size", type=int, default=4, help="Batch size to use") + parser.add_argument("-max_batch", type=int, default=2, help="maximum batch size the gpu can handle") + parser.add_argument("-epochs", type=int, default=300, help="Number of epochs to use") + parser.add_argument("-nb_evals", type=int, default=250, help="Number of evaluation steps") + + return parser.parse_args() + +# print_train_args +def print_train_args(args): + """ + ---------- + Author: Damon Gwinn + ---------- + Prints training arguments + ---------- + """ + + print(SEPERATOR) + print("input_dir:", args.input_dir) + print("output_dir:", args.output_dir) + print("weight_modulus:", args.weight_modulus) + print("print_modulus:", args.print_modulus) + print("") + print("n_workers:", args.n_workers) + print("force_cpu:", args.force_cpu) + print("tensorboard:", not args.no_tensorboard) + print("") + print("continue_weights:", args.continue_weights) + print("continue_epoch:", args.continue_epoch) + print("") + print("lr:", args.lr) + print("ce_smoothing:", args.ce_smoothing) + print("batch_size:", args.batch_size) + print("epochs:", args.epochs) + print("") + print("rpr:", args.rpr) + print("max_sequence:", args.max_sequence) + print("n_layers:", args.n_layers) + print("num_heads:", args.num_heads) + print("d_model:", args.d_model) + print("kdim:", args.kdim) + print("qdim:", args.qdim) + print("") + print("dim_feedforward:", args.dim_feedforward) + print("dropout:", args.dropout) + print(SEPERATOR) + print("") + +# parse_eval_args +def parse_eval_args(): + """ + ---------- + Author: Damon Gwinn + ---------- + Argparse arguments for evaluating a model + ---------- + """ + + parser = argparse.ArgumentParser() + + parser.add_argument("-dataset_dir", type=str, default="/home/cedric/Documents/pianocktail/data/midi/diverse_midi_piano_pytorch_processed/", help="Folder of preprocessed and pickled midi files") + parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()") + # parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader") + # parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") + # + # parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use") + # + # parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations") + # parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model") + # parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use") + # parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention") + # parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)") + # + # parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer") + + return parser.parse_args() + +# print_eval_args +def print_eval_args(args): + """ + ---------- + Author: Damon Gwinn + ---------- + Prints evaluation arguments + ---------- + """ + + print(SEPERATOR) + print("dataset_dir:", args.dataset_dir) + print("model_weights:", args.model_weights) + print("n_workers:", args.n_workers) + print("force_cpu:", args.force_cpu) + print("") + print("batch_size:", args.batch_size) + print("") + print("rpr:", args.rpr) + print("max_sequence:", args.max_sequence) + print("n_layers:", args.n_layers) + print("num_heads:", args.num_heads) + print("d_model:", args.d_model) + print("kdim:", args.kdim) + print("qdim:", args.qdim) + print("") + print("dim_feedforward:", args.dim_feedforward) + print(SEPERATOR) + print("") + +# parse_generate_args +def parse_generate_args(): + """ + ---------- + Author: Damon Gwinn + ---------- + Argparse arguments for generation + ---------- + """ + + parser = argparse.ArgumentParser() + parser.add_argument("-trial_path", type=str, default=None, help="path of the trial generator with") + parser.add_argument("-midi_root", type=str, default=None, help="Midi file to prime the generator with") + parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to") + parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.") + parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available") + + parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be") + parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with") + parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()") + parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy") + + return parser.parse_args() + +# print_generate_args +def print_generate_args(args): + """ + ---------- + Author: Damon Gwinn + ---------- + Prints generation arguments + ---------- + """ + + print(SEPERATOR) + print("midi_root:", args.midi_root) + print("output_dir:", args.output_dir) + print("primer_file:", args.primer_file) + print("force_cpu:", args.force_cpu) + print("") + print("target_seq_length:", args.target_seq_length) + print("num_prime:", args.num_prime) + print("model_weights:", args.model_weights) + print("beam:", args.beam) + print("") + print("rpr:", args.rpr) + print("max_sequence:", args.max_sequence) + print("n_layers:", args.n_layers) + print("num_heads:", args.num_heads) + print("d_model:", args.d_model) + print("kdim:", args.kdim) + print("qdim:", args.qdim) + print("") + print("dim_feedforward:", args.dim_feedforward) + print(SEPERATOR) + print("") + +# write_model_params +def write_model_params(args, output_file): + """ + ---------- + Author: Damon Gwinn + ---------- + Writes given training parameters to text file + ---------- + """ + + o_stream = open(output_file, "w") + + o_stream.write("rpr: " + str(args.rpr) + "\n") + o_stream.write("lr: " + str(args.lr) + "\n") + o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n") + o_stream.write("batch_size: " + str(args.batch_size) + "\n") + o_stream.write("max_sequence: " + str(args.max_sequence) + "\n") + o_stream.write("n_layers: " + str(args.n_layers) + "\n") + o_stream.write("num_heads: " + str(args.num_heads) + "\n") + o_stream.write("d_model: " + str(args.d_model) + "\n") + o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n") + o_stream.write("dropout: " + str(args.dropout) + "\n") + + o_stream.close() diff --git a/src/music/utilities/representation_learning_utilities/constants.py b/src/music/utilities/representation_learning_utilities/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..57c197ebe569a1dae44516e69b9a2c462c4fbdd6 --- /dev/null +++ b/src/music/utilities/representation_learning_utilities/constants.py @@ -0,0 +1,28 @@ +import torch + +from src.music.utilities.midi_processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT + +SEPERATOR = "=========================" + +# Taken from the paper +ADAM_BETA_1 = 0.9 +ADAM_BETA_2 = 0.98 +ADAM_EPSILON = 10e-9 + +LR_DEFAULT_START = 1.0 +# SCHEDULER_WARMUP_STEPS = 4000 +# LABEL_SMOOTHING_E = 0.1 + +# DROPOUT_P = 0.1 + +TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT +TOKEN_PAD = TOKEN_END + 1 + +VOCAB_SIZE = TOKEN_PAD + 1 + +TORCH_FLOAT = torch.float32 +TORCH_INT = torch.int32 + +TORCH_LABEL_TYPE = torch.long + +PREPEND_ZEROS_WIDTH = 4 diff --git a/src/music/utilities/representation_learning_utilities/device.py b/src/music/utilities/representation_learning_utilities/device.py new file mode 100644 index 0000000000000000000000000000000000000000..00fbe16bd6ec82cf6019a787e4c8b4612f6ccc10 --- /dev/null +++ b/src/music/utilities/representation_learning_utilities/device.py @@ -0,0 +1,67 @@ +# For all things related to devices +#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS #### + +import torch + +TORCH_CPU_DEVICE = torch.device("cpu") + +if(torch.cuda.device_count() > 0): + TORCH_CUDA_DEVICE = torch.device("cuda") +else: + print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----") + print("") + TORCH_CUDA_DEVICE = None + +USE_CUDA = True + +# use_cuda +def use_cuda(cuda_bool): + """ + ---------- + Author: Damon Gwinn + ---------- + Sets whether to use CUDA (if available), or use the CPU (not recommended) + ---------- + """ + + global USE_CUDA + USE_CUDA = cuda_bool + +# get_device +def get_device(): + """ + ---------- + Author: Damon Gwinn + ---------- + Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise. + ---------- + """ + + if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)): + return TORCH_CPU_DEVICE + else: + return TORCH_CUDA_DEVICE + +# cuda_device +def cuda_device(): + """ + ---------- + Author: Damon Gwinn + ---------- + Grabs the cuda device (may be None if CUDA is not available) + ---------- + """ + + return TORCH_CUDA_DEVICE + +# cpu_device +def cpu_device(): + """ + ---------- + Author: Damon Gwinn + ---------- + Grabs the cpu device + ---------- + """ + + return TORCH_CPU_DEVICE diff --git a/src/music/utilities/representation_learning_utilities/lr_scheduling.py b/src/music/utilities/representation_learning_utilities/lr_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..16f01ebe55904da3cdcb28bce792c1cd29711bf3 --- /dev/null +++ b/src/music/utilities/representation_learning_utilities/lr_scheduling.py @@ -0,0 +1,127 @@ +#Library Imports +import math +import numpy as np +import matplotlib.pyplot as plt +#Using Adam optimizer with +#Beta_1=0.9, Beta_2=0.98, and Epsilon=10^-9 + +#Learning rate varies over course of training +#lrate = sqrt(d_model)*min((1/sqrt(step_num)), step_num*(1/warmup_steps*sqrt(warmup_steps))) + +# def lr_plot(steps, target_up=7e-4, param=1, d_model=512, asymptot=1e-4, warmup_steps=10000): +# scaled_target = target_up * np.sqrt(d_model) +# asymptot_scale = asymptot * np.sqrt(d_model) +# slope = scaled_target / warmup_steps +# p1 = - param * np.log(scaled_target - asymptot_scale) - warmup_steps +# out = np.zeros(steps.size) +# out[:warmup_steps] = slope * steps[:warmup_steps] * 1/np.sqrt(d_model) +# out[warmup_steps:] = (np.exp(-(steps[warmup_steps:] + p1) / param) + asymptot_scale) * 1/np.sqrt(d_model) +# +# # out[warmup_steps:] = ((steps[warmup_steps:] - (warmup_steps - (scaled_target-asymptot_scale)**-2)) **-0.5 + asymptot_scale) * 1/np.sqrt(d_model) +# plt.figure() +# plt.plot(out) +# plt.ylim([np.min(out), 1e-3]) + + +# LrStepTracker +class MyLrStepTracker: + """ + """ + + def __init__(self, model_dim=512, warmup_steps=4000, asymptot=1e-4, target_up=8e-4, exp_slope=1e4, init_steps=0): + # Store Values + self.warmup_steps = warmup_steps + self.model_dim = model_dim + self.asymptot = asymptot + self.exp_slope = exp_slope + self.init_steps = init_steps + + # Begin Calculations + self.invsqrt_dim = 1 / math.sqrt(model_dim) + self.scaled_target = target_up * math.sqrt(model_dim) + self.asymptot_scale = asymptot * math.sqrt(model_dim) + self.constant = - exp_slope * math.log(self.scaled_target - self.asymptot_scale) - warmup_steps + self.invsqrt_warmup = warmup_steps**(-1.5) + self.slope = self.scaled_target / warmup_steps + + + # step + def step(self, step): + """ + ---------- + Author: Ryan Marshall + Modified: Damon Gwinn + ---------- + Method to pass to LambdaLR. Increments the step and computes the new learn rate. + ---------- + """ + + step += self.init_steps + if(step <= self.warmup_steps): + return self.invsqrt_dim * self.slope * step # linear warmup + else: + return self.invsqrt_dim * (math.exp(-(step + self.constant) / self.exp_slope) + self.asymptot_scale) + +# steps = np.arange(1, 60*2000) +# tracker = MyLrStepTracker(warmup_steps=4000, asymptot=1e-4, target_up=8e-4, exp_slope=2e4) +# out = [tracker.step(s) for s in steps] +# plt.figure() +# plt.plot(out) +# plt.show() +# LrStepTracker +class LrStepTracker: + """ + ---------- + Author: Ryan Marshall + Modified: Damon Gwinn + ---------- + Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR). + + Learn rate for each step (batch) given the warmup steps is: + lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ] + + This is from Attention is All you Need (https://arxiv.org/abs/1706.03762) + ---------- + """ + + def __init__(self, model_dim=512, warmup_steps=4000, baseline=0, init_steps=0): + # Store Values + self.warmup_steps = warmup_steps + self.model_dim = model_dim + self.baseline = baseline + self.init_steps = init_steps + + # Begin Calculations + self.invsqrt_dim = (1 / math.sqrt(model_dim)) + self.invsqrt_warmup = warmup_steps**(-1.5) + + # step + def step(self, step): + """ + ---------- + Author: Ryan Marshall + Modified: Damon Gwinn + ---------- + Method to pass to LambdaLR. Increments the step and computes the new learn rate. + ---------- + """ + + step += self.init_steps + if(step <= self.warmup_steps): + return self.invsqrt_dim * self.invsqrt_warmup * step + self.baseline + else: + invsqrt_step = (1 / math.sqrt(step)) + return self.invsqrt_dim * invsqrt_step + self.baseline + +# get_lr +def get_lr(optimizer): + """ + ---------- + Author: Damon Gwinn + ---------- + Hack to get the current learn rate of the model + ---------- + """ + + for param_group in optimizer.param_groups: + return param_group['lr'] diff --git a/src/music/utilities/representation_learning_utilities/sampler.py b/src/music/utilities/representation_learning_utilities/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c749f4cb9c88dcaaaccb3030ea901ff1340fb372 --- /dev/null +++ b/src/music/utilities/representation_learning_utilities/sampler.py @@ -0,0 +1,45 @@ +from torch.utils.data.sampler import RandomSampler, Sampler +import numpy as np + +class FixedLenRandomSampler(RandomSampler): + """ + Code from mnpinto - Miguel + https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 + """ + def __init__(self, data_source, bs, epoch_size, *args, **kwargs): + super().__init__(data_source) + self.epoch_size = epoch_size + self.bs = bs + self.not_sampled = np.array([True]*len(data_source)) + self.size_to_sample = self.epoch_size * self.bs + + @property + def _reset_state(self): + self.not_sampled[:] = True + + def __iter__(self): + ns = sum(self.not_sampled) + idx_last = [] + if ns >= self.size_to_sample: + idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample, replace=False).tolist() + if ns == self.size_to_sample: + self._reset_state + else: + idx_last = np.where(self.not_sampled)[0].tolist() + self._reset_state + idx = np.random.choice(np.where(self.not_sampled)[0], size=self.size_to_sample-len(idx_last), replace=False).tolist() + self.not_sampled[idx] = False + idx = [*idx_last, *idx] + # print(ns, len(idx), len(idx_last)) # debug + out = [] + i_idx = 0 + for i in range(self.epoch_size): + batch = [] + for j in range(self.bs): + batch.append(idx[i_idx]) + i_idx += 1 + out.append(batch) + return iter(out) + + def __len__(self): + return self.epoch_size diff --git a/src/music/utils.py b/src/music/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b479f8c2a5f315ac81e066e4cb541e37f106cf5 --- /dev/null +++ b/src/music/utils.py @@ -0,0 +1,308 @@ +import os + +import pandas as pd +from pydub import AudioSegment +import numpy as np +from moviepy.editor import * +import time +import pickle +import audioread +import librosa # install numba==0.49.1 +# setup A: numba 0.51.2, librosa 0.6.3, llvmlite: 0.34.0 +# setupB: numba==0.49.1, llvmlite-0.32.1 +from src.music.config import RATE_AUDIO_SAVE +import hashlib +import unicodedata +import re + +# from src.music.piano_detection_model.piano_detection_model import SR + +def clean_removed_mp3_from_csv(path): + print(f"Cleaning meta_data.csv using files from the folder, in {path}") + files = os.listdir(path) + indexes_to_remove = [] + meta_data = pd.read_csv(path + 'meta_data.csv') + for i, fn in enumerate(meta_data['filename']): + if fn not in files: + indexes_to_remove.append(i) + meta_data = meta_data.drop(indexes_to_remove) + meta_data.to_csv(path + 'meta_data.csv', index=False) + print('\tDone.') + +def clean_removed_csv_from_folder(path): + print(f"Cleaning files from folder using meta_data.csv listed file, in {path}") + files = os.listdir(path) + meta_data = pd.read_csv(path + 'meta_data.csv') + hashes = set(meta_data['hash']) + count = 0 + for f in files: + if f not in ['meta_data.csv', 'url.txt']: + if f[:-4] not in hashes: + count += 1 + print(count) + # os.remove(path + f) + stop = 1 + print('\tDone.') + +# def convert_mp3_to_mono_16k(path): +# print(f"\n\n\t\tConverting mp3 to mono and 16k sample rate, in {path}\n") +# if '.mp3' == path[-4:]: +# audio = AudioFileClip(path) +# audio.write_audiofile(path[:-4] + '.mp3', +# verbose=False, +# logger=None, +# fps=FPS, +# ffmpeg_params=["-ac", "1"]) +# else: +# list_files = os.listdir(path) +# for i, f in enumerate(list_files): +# print(compute_progress(i, len(list_files))) +# if ".mp3" in f: +# audio = AudioFileClip(path + f) +# audio.write_audiofile(path + f[:-4] + '.mp3', +# verbose=False, +# logger=None, +# fps=FPS, # 16000 sr +# ffmpeg_params=["-ac", "1"] # make it mono +# ) +# print('\tDone.') + + + +def load_audio(path, sr=22050, mono=True, offset=0.0, duration=None, + dtype=np.float32, res_type='kaiser_best', + backends=[audioread.ffdec.FFmpegAudioFile]): + """Load audio. Copied from librosa.core.load() except that ffmpeg backend is + always used in this function. Code from piano_transcription_inference""" + + y = [] + with audioread.audio_open(os.path.realpath(path), backends=backends) as input_file: + sr_native = input_file.samplerate + n_channels = input_file.channels + + s_start = int(np.round(sr_native * offset)) * n_channels + + if duration is None: + s_end = np.inf + else: + s_end = s_start + (int(np.round(sr_native * duration)) + * n_channels) + + n = 0 + + for frame in input_file: + frame = librosa.core.audio.util.buf_to_float(frame, dtype=dtype) + n_prev = n + n = n + len(frame) + + if n < s_start: + # offset is after the current frame + # keep reading + continue + + if s_end < n_prev: + # we're off the end. stop reading + break + + if s_end < n: + # the end is in this frame. crop. + frame = frame[:s_end - n_prev] + + if n_prev <= s_start <= n: + # beginning is in this frame + frame = frame[(s_start - n_prev):] + + # tack on the current frame + y.append(frame) + + if y: + y = np.concatenate(y) + + if n_channels > 1: + y = y.reshape((-1, n_channels)).T + if mono: + y = librosa.core.audio.to_mono(y) + + if sr is not None: + y = librosa.core.audio.resample(y, sr_native, sr, res_type=res_type) + + else: + sr = sr_native + + # Final cleanup for dtype and contiguity + y = np.ascontiguousarray(y, dtype=dtype) + + return (y, sr) + +def compute_progress(iter, total): + return f"{int((iter+ 1) / total * 100)}%" + +def compute_progress_and_eta(times, iter, total, n_av=3000): + av_time = np.mean(times[-n_av:]) + progress = int(((iter + 1) / total) * 100) + eta_h = int(av_time * (total - iter) // 3600) + eta_m = int((av_time * (total - iter) - (eta_h * 3600)) // 60) + eta_s = int((av_time * (total - iter) - (eta_h * 3600) - eta_m * 60)) + eta = f"Progress: {progress}%, ETA: {eta_h}H{eta_m}M{eta_s}S." + return eta + +def crop_mp3_from_meta_data_constraints(path, clean_constraints=True): + print(f"Cropping mp3 using constraints from meta_data.csv, in {path}") + meta_data = pd.read_csv(path + 'meta_data.csv') + constraint_start = meta_data['constraint_start'].copy() + length = meta_data['length'].copy() + constraint_end = meta_data['constraint_end'].copy() + filenames = meta_data['filename'].copy() + times = [5] + for i, c_start, c_end, fn, l in zip(range(len(constraint_start)), constraint_start, constraint_end, filenames, length): + if c_start != 0 or c_end != l: + i_time = time.time() + print(compute_progress_and_eta(times, i, len(constraint_start), n_av=100)) + song = AudioSegment.from_mp3(path + fn) + extract = song[c_start*1000:c_end*1000] + extract.export(path + fn, format="mp3") + if clean_constraints: + constraint_start[i] = 0 + constraint_end[i] = length[i] + meta_data['constraint_start'] = constraint_start + meta_data['constraint_end'] = constraint_end + meta_data.to_csv(path + 'meta_data.csv', index=False) + times.append(time.time() - i_time) + print('\tDone.') + +def get_all_subfiles_with_extension(path, max_depth=3, extension='.*', current_depth=0): + folders = [f for f in os.listdir(path) if os.path.isdir(path + f)] + # get all files in current folder with a given extension + if isinstance(extension, list): + assert all([isinstance(e, str) for e in extension]), 'extension can be a str or a list' + files = [path + f for f in os.listdir(path) if os.path.isfile(path + f) and any([ext == f[-len(ext):] for ext in extension])] + elif isinstance(extension, str): + assert extension[0] == '.', 'extension should be an extension or a list of extensions' + if extension == '.*': + files = [path + f for f in os.listdir(path) if os.path.isfile(path + f)] + else: + files = [path + f for f in os.listdir(path) if os.path.isfile(path + f) and f[-len(extension):]==extension] + else: + print('Error: extension should be either a str or a list') + raise ValueError + + if current_depth < max_depth: + for fold in folders: + files += get_all_subfiles_with_extension(path + fold + '/', max_depth=max_depth, extension=extension, current_depth=current_depth+1) + return files + +def get_out_path(in_path, in_word, out_word, out_extension, exclude_paths=()): + splitted_in_path = in_path.split('/') + for i in range(len(splitted_in_path)): + if splitted_in_path[i] == in_word: + splitted_in_path[i] = out_word + playlist_index = i + 1 + file_index = len(splitted_in_path) - 1 + if splitted_in_path[playlist_index] in exclude_paths: + to_exclude = True + return None, to_exclude, None + else: + to_exclude = False + if out_word != 'midi': + splitted_in_path[playlist_index] = '_'.join(splitted_in_path[playlist_index].split('_')[:-len(in_word.split('_'))]) + '_' + out_word + else: + splitted_in_path[playlist_index] += '_' + out_word + if 'fake' not in splitted_in_path: + os.makedirs('/'.join(splitted_in_path[:playlist_index + 1]), exist_ok=True) + if out_word != 'midi': + new_filename = '_'.join(splitted_in_path[file_index].split('_')[:-len(in_word.split('_'))]) + '_' + out_word + out_extension + else: + new_filename = '.'.join(splitted_in_path[file_index].split('.')[:-len(in_word.split('_'))]) + '_' + out_word + out_extension + splitted_in_path[file_index] = new_filename + splitted_in_path = splitted_in_path[:playlist_index + 1] + [splitted_in_path[file_index]] + out_path = '/'.join(splitted_in_path) + return out_path, to_exclude, splitted_in_path[playlist_index] + +def set_all_seeds(seed): + import random + import numpy as np + import torch + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + +def get_paths_in_and_out(in_path, in_word, in_extension, out_word, out_extension, max_depth, exclude_paths=()): + # find all files with the in_extension in subfolders of in_path up to max_depth. + # for each, replace the in_word keyword in folders with the out_word, and append out_word to filenames. + all_in_paths = get_all_subfiles_with_extension(in_path, max_depth=max_depth, extension=in_extension) + indexes_not_transcribed = [] + all_out_paths = [] + all_playlists = [] + for i_path, in_path in enumerate(all_in_paths): + out_path, to_exclude, playlist = get_out_path(in_path=in_path, in_word=in_word, out_word=out_word, out_extension=out_extension, exclude_paths=exclude_paths) + if not to_exclude: + indexes_not_transcribed.append(i_path) + all_out_paths.append(out_path) + all_playlists.append(playlist) + all_in_paths = [in_path for i, in_path in enumerate(all_in_paths) if i in indexes_not_transcribed] + assert len(all_out_paths) == len(all_in_paths) + return all_in_paths, all_out_paths, all_playlists + +def get_path_and_filter_existing(in_path, in_word, in_extension, out_word, out_extension, max_depth, exclude_paths=()): + # find all files with the in_extension in subfolders of in_path up to max_depth. + # for each, replace the in_word keyword in folders with the out_word, and append out_word to filenames. + all_in_paths = get_all_subfiles_with_extension(in_path, max_depth=max_depth, extension=in_extension) + indexes_to_process = [] + all_out_paths = [] + all_playlists = [] + for i_path, in_path in enumerate(all_in_paths): + out_path, to_exclude, playlist = get_out_path(in_path=in_path, in_word=in_word, out_word=out_word, out_extension=out_extension, exclude_paths=exclude_paths) + if not to_exclude: + if not os.path.exists(out_path): + indexes_to_process.append(i_path) + all_out_paths.append(out_path) + all_playlists.append(playlist) + all_in_paths = list(np.array(all_in_paths)[indexes_to_process])#[in_path for i, in_path in enumerate(all_in_paths) if i in indexes_to_process] + assert len(all_out_paths) == len(all_in_paths) + return all_in_paths, all_out_paths, all_playlists + +def md5sum(filename, blocksize=65536): + hash = hashlib.md5() + with open(filename, "rb") as f: + for block in iter(lambda: f.read(blocksize), b""): + hash.update(block) + return hash.hexdigest() + + +emoji_pattern = re.compile("[" + u"\U0001F600-\U0001F64F" # emoticons + u"\U0001F300-\U0001F5FF" # symbols & pictographs + u"\U0001F680-\U0001F6FF" # transport & map symbols + u"\U0001F1E0-\U0001F1FF" # flags (iOS) + "]+", flags=re.UNICODE) +def slugify(value, allow_unicode=False): + """ + Taken from https://github.com/django/django/blob/master/django/utils/text.py + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + """ + value = str(value).lower() + if allow_unicode: + value = unicodedata.normalize('NFKC', value) + else: + value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') + value = re.sub(r'[^\w\s-]', '', value.lower()) + value = emoji_pattern.sub(r'', value) + value = re.sub(r'[-\s]+', '_', value).strip('-_') + # if value == '': + # for i in range(10): + # value += str(np.random.choice(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])) + return value + +if __name__ == '__main__': + path = "/home/cedric/Documents/pianocktail/data/midi/street_piano/" + # for folder in ['my_sheet_music_transcriptions']:#os.listdir(path): + # print('\n\n\t\t', folder) + # convert_mp4_to_mp3(path + folder + '/') + + clean_removed_csv_from_folder(path) + # folder = 'street_piano/' + # for folder in ['street_piano/']: + # clean_removed_mp3_from_csv(path + folder) diff --git a/src/music2cocktailrep/__init__.py b/src/music2cocktailrep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music2cocktailrep/__pycache__/__init__.cpython-39.pyc b/src/music2cocktailrep/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce2d299f4b7e3868f0c4b1e956db34ad4728d43 Binary files /dev/null and b/src/music2cocktailrep/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music2cocktailrep/pipeline/__init__.py b/src/music2cocktailrep/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music2cocktailrep/pipeline/__pycache__/__init__.cpython-39.pyc b/src/music2cocktailrep/pipeline/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c9f14ef634d995eef166117bc5f72d5fbe1db2d Binary files /dev/null and b/src/music2cocktailrep/pipeline/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music2cocktailrep/pipeline/__pycache__/music2affect.cpython-39.pyc b/src/music2cocktailrep/pipeline/__pycache__/music2affect.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a00ca693cf0849d3e3110e0cc51cc4e4d8560d01 Binary files /dev/null and b/src/music2cocktailrep/pipeline/__pycache__/music2affect.cpython-39.pyc differ diff --git a/src/music2cocktailrep/pipeline/__pycache__/music2cocktailrep.cpython-39.pyc b/src/music2cocktailrep/pipeline/__pycache__/music2cocktailrep.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dbcbf59a68d3b68d507dd61d18ed0f385a078ce Binary files /dev/null and b/src/music2cocktailrep/pipeline/__pycache__/music2cocktailrep.cpython-39.pyc differ diff --git a/src/music2cocktailrep/pipeline/music2affect.py b/src/music2cocktailrep/pipeline/music2affect.py new file mode 100644 index 0000000000000000000000000000000000000000..1272753dcce4d6a32265bc5e76b5198ebce62bcf --- /dev/null +++ b/src/music2cocktailrep/pipeline/music2affect.py @@ -0,0 +1,49 @@ +import pickle +import numpy as np +from src.music.config import CHECKPOINTS_PATH + +# these can be generated by the fit_glm script. +min_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + '/music2cocktails/music2affects/min_handcoded_reps.txt') +max_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + 'music2cocktails/music2affects/max_handcoded_reps.txt') +affective_models_path = CHECKPOINTS_PATH + '/music2cocktails/music2affects/music2affect_models.pickle' +final_keys_path = CHECKPOINTS_PATH + "/music2cocktails/music2affects/final_best_keys.pickle" + +def sigmoid(x, shift, beta): + return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 + +def normalize_handcoded_reps(handcoded_rep): + return (handcoded_rep - min_handcoded_reps) / (max_handcoded_reps - min_handcoded_reps) + +def setup_pretrained_affective_models(): + with open(final_keys_path, 'rb') as f: + best_keys = pickle.load(f) + keys = sorted(set(best_keys[0] + best_keys[1] + best_keys[2])) + bestkeys_indexes = [np.array([keys.index(k) for k in bk]) for bk in best_keys] + + with open(affective_models_path, 'rb') as f: + music2affect_models = pickle.load(f) + + + def music2affect(handcoded_rep): + if handcoded_rep.ndim == 1: + handcoded_rep = handcoded_rep.reshape(1, -1) + assert handcoded_rep.shape[1] == len(keys) + handcoded_rep = normalize_handcoded_reps(handcoded_rep) + affects = [] + for i_dim, dim in enumerate(['valence', 'arousal', 'dominance']): + model = music2affect_models[dim] + my_preds = [] + probas = model.predict_proba(handcoded_rep[:, bestkeys_indexes[i_dim]]) + for r in probas: + my_preds.append(np.mean(np.random.choice(range(1, 11), p=r, size=1000))) + my_preds = np.array(my_preds) + affects.append(my_preds) + # affects.append(model.predict(handcoded_rep)) + affects = np.array(affects).transpose() + affects = ((affects - 1) / 9 - 0.5) * 2 # map to -1, 1 + affects[:, 0] = sigmoid(affects[:, 0], shift=0, beta=7) # stretch for wider distribution + affects[:, 1] = sigmoid(affects[:, 1], shift=-0.05, beta=5) # stretch for wider distribution + affects[:, 2] = sigmoid(affects[:, 2], shift=0.05, beta=8) # stretch for wider distribution + return affects + + return music2affect, keys \ No newline at end of file diff --git a/src/music2cocktailrep/pipeline/music2cocktailrep.py b/src/music2cocktailrep/pipeline/music2cocktailrep.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf64d3d81edc98943a971afd186595705c41194 --- /dev/null +++ b/src/music2cocktailrep/pipeline/music2cocktailrep.py @@ -0,0 +1,80 @@ +import os + +import numpy as np +import torch +import time + +from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster +from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model +from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models + +global music2affect, find_affective_cluster, translation_vae +import streamlit as st + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def setup_translation_models(): + global music2affect, find_affective_cluster, translation_vae + music2affect, keys = setup_pretrained_affective_models() + find_affective_cluster = get_affect2affective_cluster() + translation_vae = setup_trained_model() + return translation_vae + +def music2affect_cluster(handcoded_rep): + global music2affect, find_affective_cluster + affects = np.clip(music2affect(handcoded_rep), -1, 1) + cluster_id = find_affective_cluster(affects) + return cluster_id, affects + +def music2flavor(music_ai_rep, affective_cluster_id): + global translation_vae + cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail') + return cocktail_rep + +def debug_translation(music_ai_rep): + global translation_vae + music_reconstruction = translation_vae(music_ai_rep, modality_out='music') + return music_reconstruction + +def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0): + init_time = time.time() + if verbose: print(' ' * level + 'Synesthetic mapping..') + if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.') + # affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep) + affective_cluster_id, affect = None, None + if verbose: print(' ' * (level*2) + 'Mapping to flavors.') + cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id) + if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.') + return cocktail_rep, affective_cluster_id, affect + +# def sigmoid(x, shift, beta): +# return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 +# +# cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)] + +# def plot_cluster_ids_dataset(handcoded_rep_path): +# import matplotlib.pyplot as plt +# reps, _, _ = get_data(handcoded_rep_path, keys) +# cluster_ids, affects = music2affect_cluster(reps) +# # plt.figure() +# # affects2 = affects.copy() +# # affects2 = sigmoid(affects2, 0.05, 8) +# # plt.hist(affects2[:, 2], bins=30) +# # plt.xlim([-1, 1]) +# fig = plt.figure() +# ax = fig.add_subplot(projection='3d') +# ax.set_xlim([-1, 1]) +# ax.set_ylim([-1, 1]) +# ax.set_zlim([-1, 1]) +# for cluster_id in sorted(set(cluster_ids)): +# indexes = np.argwhere(cluster_ids == cluster_id).flatten() +# if len(indexes) > 0: +# ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150) +# ax.set_xlabel('Valence') +# ax.set_ylabel('Arousal') +# ax.set_zlabel('Dominance') +# plt.figure() +# plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)]) +# plt.show() +# +# plot_cluster_ids_dataset(handcoded_rep_path) \ No newline at end of file diff --git a/src/music2cocktailrep/training/__init__.py b/src/music2cocktailrep/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music2cocktailrep/training/__pycache__/__init__.cpython-39.pyc b/src/music2cocktailrep/training/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd8f7f60faa0fdcd3842adf385fb080ea8acb36 Binary files /dev/null and b/src/music2cocktailrep/training/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music2cocktailrep/training/latent_translation/__init__.py b/src/music2cocktailrep/training/latent_translation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/music2cocktailrep/training/latent_translation/__pycache__/__init__.cpython-39.pyc b/src/music2cocktailrep/training/latent_translation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c389a3dd719eebb7042929fed2d62d7305e292b Binary files /dev/null and b/src/music2cocktailrep/training/latent_translation/__pycache__/__init__.cpython-39.pyc differ diff --git a/src/music2cocktailrep/training/latent_translation/__pycache__/setup_trained_model.cpython-39.pyc b/src/music2cocktailrep/training/latent_translation/__pycache__/setup_trained_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07adc3de2b5144ba7b1edfc9a34859ecd7233114 Binary files /dev/null and b/src/music2cocktailrep/training/latent_translation/__pycache__/setup_trained_model.cpython-39.pyc differ diff --git a/src/music2cocktailrep/training/latent_translation/__pycache__/vae_model.cpython-39.pyc b/src/music2cocktailrep/training/latent_translation/__pycache__/vae_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c74148f32b008711eb4629442e115dc9b2d08c9a Binary files /dev/null and b/src/music2cocktailrep/training/latent_translation/__pycache__/vae_model.cpython-39.pyc differ diff --git a/src/music2cocktailrep/training/latent_translation/dataset.py b/src/music2cocktailrep/training/latent_translation/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f290ab38b212484176108f0e4eaf4818e9115bc --- /dev/null +++ b/src/music2cocktailrep/training/latent_translation/dataset.py @@ -0,0 +1,306 @@ +from torch.utils.data import Dataset +import numpy as np +import torch +device = 'cuda' if torch.cuda.is_available() else 'cpu' +from src.music2cocktailrep.analysis.explore import get_alignment_dataset + +# Add your custom dataset class here +class CocktailDataset(Dataset): + def __init__(self, split, cocktail_reps): + + self.n_cocktails, self.dim_cocktail = cocktail_reps.shape + labels = np.zeros([self.n_cocktails]) + if split == 'train': + self.cocktail_reps = cocktail_reps[:int(0.9 * self.n_cocktails), :].copy() + self.labels = labels[:int(0.9 * self.n_cocktails)].copy() + elif split == 'test': + self.cocktail_reps = cocktail_reps[int(0.9 * self.n_cocktails):, :].copy() + self.labels = labels[int(0.9 * self.n_cocktails):].copy() + elif split == 'all': + self.cocktail_reps = cocktail_reps.copy() + self.labels = labels.copy() + else: + raise ValueError + + # self.n_cocktails = self.cocktail_reps.shape[0] + # indexes = np.arange(self.n_cocktails) + # np.random.shuffle(indexes) + self.cocktail_reps = torch.FloatTensor(self.cocktail_reps).to(device) + # oversample cocktails with eggs and bubbles + ind_egg = np.argwhere(self.cocktail_reps[:, -1] > 0).flatten() + ind_bubbles = np.argwhere(self.cocktail_reps[:, -3] > 0).flatten() + n_copies = 4 + egg_copies = torch.tile(self.cocktail_reps[ind_egg, :], dims=(n_copies * 3, 1)) + bubbles_copies = torch.tile(self.cocktail_reps[ind_bubbles, :], dims=(n_copies, 1)) + self.cocktail_reps = torch.cat([self.cocktail_reps, egg_copies, bubbles_copies], dim=0) + self.n_cocktails = self.cocktail_reps.shape[0] + indexes = np.arange(self.n_cocktails) + np.random.shuffle(indexes) + self.cocktail_reps = self.cocktail_reps[indexes] + self.labels = torch.LongTensor(np.zeros([self.n_cocktails])).to(device) + self.contains_egg = self.cocktail_reps[:, -1] > 0 + self.contains_bubbles = self.cocktail_reps[:, -3] > 0 + + def __len__(self): + return self.cocktail_reps.shape[0] + + def __getitem__(self, idx): + return self.cocktail_reps[idx], self.labels[idx], self.contains_egg[idx], self.contains_bubbles[idx] + +class CocktailLabeledDataset(Dataset): + def __init__(self, split, cocktail_reps): + + + dataset = get_alignment_dataset() + labels = sorted(dataset['cocktail'].keys()) + self.n_labels = len(labels) + + # n_cocktails = np.sum([len(dataset['cocktail'][k]) for k in labels]) + all_cocktails = [] + for k in labels: + all_cocktails += dataset['cocktail'][k] + # assert n_cocktails == len(set(all_cocktails)) + all_cocktails = np.array(all_cocktails) + + cocktail_reps = cocktail_reps[all_cocktails] + cocktail_labels = [] + for i in all_cocktails: + for i_k, k in enumerate(labels): + if i in dataset['cocktail'][k]: + cocktail_labels.append(i_k) + break + cocktail_labels = np.array(cocktail_labels) + assert len(cocktail_labels) == len(cocktail_reps) + self.n_cocktails, self.dim_cocktail = cocktail_reps.shape + + indexes_train = [] + indexes_test = [] + for k in labels: + indexes_k = np.argwhere(cocktail_labels == labels.index(k)).flatten() + indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) + indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) + indexes_train = np.array(indexes_train) + indexes_test = np.array(indexes_test) + assert len(set(indexes_train) & set(indexes_test)) == 0 + if split == 'train': + self.cocktail_reps = cocktail_reps[indexes_train].copy() + self.labels = cocktail_labels[indexes_train].copy() + elif split == 'test': + self.cocktail_reps = cocktail_reps[indexes_test].copy() + self.labels = cocktail_labels[indexes_test].copy() + elif split == 'all': + self.cocktail_reps = cocktail_reps.copy() + self.labels = cocktail_labels.copy() + else: + raise ValueError + + self.n_cocktails = self.cocktail_reps.shape[0] + indexes = np.arange(self.n_cocktails) + np.random.shuffle(indexes) + self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) + self.labels = torch.LongTensor(self.labels[indexes]).to(device) + + def __len__(self): + return self.cocktail_reps.shape[0] + + def __getitem__(self, idx): + return self.cocktail_reps[idx], self.labels[idx] + +class MusicDataset(Dataset): + def __init__(self, split, music_reps, music_rep_paths): + + + self.n_music, self.dim_music = music_reps.shape + labels = np.zeros([self.n_music]) + + if split == 'train': + self.music_reps = music_reps[:int(0.9 * self.n_music), :].copy() + self.labels = labels[:int(0.9 * self.n_music)].copy() + elif split == 'test': + self.music_reps = music_reps[int(0.9 * self.n_music):, :].copy() + self.labels = labels[int(0.9 * self.n_music):].copy() + elif split == 'all': + self.music_reps = music_reps.copy() + self.labels = labels.copy() + else: + raise ValueError + self.n_music = self.music_reps.shape[0] + indexes = np.arange(self.n_music) + np.random.shuffle(indexes) + self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) + self.labels = torch.LongTensor(self.labels[indexes]).to(device) + + def __len__(self): + return self.music_reps.shape[0] + + def __getitem__(self, idx): + return self.music_reps[idx], self.labels[idx] + +class RegressedGroundingDataset(Dataset): + def __init__(self, split, music_reps, music_rep_paths, cocktail_reps): + + dataset = get_alignment_dataset() + labels = sorted(dataset['cocktail'].keys()) + self.n_labels = len(labels) + + n_music = np.sum([len(dataset['music'][k]) for k in labels]) + all_music_filenames = [] + for k in labels: + all_music_filenames += dataset['music'][k] + assert n_music == len(set(all_music_filenames)) + all_music_filenames = np.array(all_music_filenames) + + all_cocktails = [] + for k in labels: + all_cocktails += dataset['cocktail'][k] + # assert n_cocktails == len(set(all_cocktails)) + all_cocktails = np.array(all_cocktails) + + indexes = [] + for music_filename in all_music_filenames: + rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') + found = False + for i, rep_path in enumerate(music_rep_paths): + if rep_name == rep_path[-len(rep_name):]: + indexes.append(i) + found = True + break + assert found + # assert len(indexes) == len(all_music_filenames) + music_reps = music_reps[np.array(indexes)] + music_labels = [] + for music_filename in all_music_filenames: + for i_k, k in enumerate(labels): + if music_filename in dataset['music'][k]: + music_labels.append(i_k) + break + assert len(music_labels) == len(music_reps) + music_labels = np.array(music_labels) + self.n_music, self.dim_music = music_reps.shape + self.classes = labels + + + cocktail_reps = cocktail_reps[all_cocktails] + cocktail_labels = [] + for i in all_cocktails: + for i_k, k in enumerate(labels): + if i in dataset['cocktail'][k]: + cocktail_labels.append(i_k) + break + cocktail_labels = np.array(cocktail_labels) + assert len(cocktail_labels) == len(cocktail_reps) + self.n_cocktails, self.dim_cocktail = cocktail_reps.shape + + cocktail_reps_matching_music_reps = [] + for l in music_labels: + ind_cocktails = np.where(cocktail_labels==l)[0] + cocktail_reps_matching_music_reps.append(cocktail_reps[np.random.choice(ind_cocktails)]) + cocktail_reps_matching_music_reps = np.array(cocktail_reps_matching_music_reps) + + + indexes_train = [] + indexes_test = [] + for k in labels: + indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() + indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) + indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) + indexes_train = np.array(indexes_train) + indexes_test = np.array(indexes_test) + assert len(set(indexes_train) & set(indexes_test)) == 0 + + if split == 'train': + self.music_reps = music_reps[indexes_train].copy() + self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_train].copy() + # self.labels = music_labels[indexes_train].copy() + elif split == 'test': + self.music_reps = music_reps[indexes_test].copy() + self.cocktail_reps = cocktail_reps_matching_music_reps[indexes_test].copy() + # self.labels = music_labels[indexes_test].copy() + elif split == 'all': + self.music_reps = music_reps.copy() + self.cocktail_reps = cocktail_reps_matching_music_reps.copy() + # self.labels = music_labels.copy() + else: + raise ValueError + self.n_music = self.music_reps.shape[0] + indexes = np.arange(self.n_music) + np.random.shuffle(indexes) + self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) + self.cocktail_reps = torch.FloatTensor(self.cocktail_reps[indexes]).to(device) + # self.labels = torch.LongTensor(self.labels[indexes]).to(device) + + def __len__(self): + return self.music_reps.shape[0] + + def __getitem__(self, idx): + return self.music_reps[idx], self.cocktail_reps[idx] + +class MusicLabeledDataset(Dataset): + def __init__(self, split, music_reps, music_rep_paths): + + dataset = get_alignment_dataset() + labels = sorted(dataset['cocktail'].keys()) + self.n_labels = len(labels) + + n_music = np.sum([len(dataset['music'][k]) for k in labels]) + all_music_filenames = [] + for k in labels: + all_music_filenames += dataset['music'][k] + assert n_music == len(set(all_music_filenames)) + all_music_filenames = np.array(all_music_filenames) + + indexes = [] + for music_filename in all_music_filenames: + rep_name = music_filename.replace('_processed.mid', '_b256_r128_represented.txt') + found = False + for i, rep_path in enumerate(music_rep_paths): + if rep_name == rep_path[-len(rep_name):]: + indexes.append(i) + found = True + break + assert found + # assert len(indexes) == len(all_music_filenames) + music_reps = music_reps[np.array(indexes)] + music_labels = [] + for music_filename in all_music_filenames: + for i_k, k in enumerate(labels): + if music_filename in dataset['music'][k]: + music_labels.append(i_k) + break + assert len(music_labels) == len(music_reps) + music_labels = np.array(music_labels) + self.n_music, self.dim_music = music_reps.shape + self.classes = labels + + indexes_train = [] + indexes_test = [] + for k in labels: + indexes_k = np.argwhere(music_labels == labels.index(k)).flatten() + indexes_train += list(indexes_k[:int(0.9 * len(indexes_k))]) + indexes_test += list(indexes_k[int(0.9 * len(indexes_k)):]) + indexes_train = np.array(indexes_train) + indexes_test = np.array(indexes_test) + assert len(set(indexes_train) & set(indexes_test)) == 0 + + if split == 'train': + self.music_reps = music_reps[indexes_train].copy() + self.labels = music_labels[indexes_train].copy() + elif split == 'test': + self.music_reps = music_reps[indexes_test].copy() + self.labels = music_labels[indexes_test].copy() + elif split == 'all': + self.music_reps = music_reps.copy() + self.labels = music_labels.copy() + else: + raise ValueError + self.n_music = self.music_reps.shape[0] + indexes = np.arange(self.n_music) + np.random.shuffle(indexes) + self.music_reps = torch.FloatTensor(self.music_reps[indexes]).to(device) + self.labels = torch.LongTensor(self.labels[indexes]).to(device) + + def __len__(self): + return self.music_reps.shape[0] + + def __getitem__(self, idx): + return self.music_reps[idx], self.labels[idx] \ No newline at end of file diff --git a/src/music2cocktailrep/training/latent_translation/run.py b/src/music2cocktailrep/training/latent_translation/run.py new file mode 100644 index 0000000000000000000000000000000000000000..12779b5cdb790fa2523ea5d8f14b9d49a589606e --- /dev/null +++ b/src/music2cocktailrep/training/latent_translation/run.py @@ -0,0 +1,506 @@ +import os + +import torch; torch.manual_seed(0) +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.distributions +import numpy as np +import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 +from vae_model import get_gml_vae_models +from utils import get_dataloaders, compute_swd_loss +import matplotlib.pyplot as plt +from src.music.config import MUSIC_REP_PATH +from src.cocktails.config import FULL_COCKTAIL_REP_PATH +import json +import argparse +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +if torch.cuda.is_available(): + print('Using GPUs') +else: + print('Using CPUs') + +music_rep_path = "/home/cedric/Documents/pianocktail/data/music/represented_small/" +music_rep_path = MUSIC_REP_PATH + "music_reps_normalized_meanstd.pickle" +# music_rep_path = "/home/cedric/Documents/pianocktail/data/music/32_represented/reps.pickle" +LOSS = nn.CrossEntropyLoss() +def run_epoch(epoch, model, data, params, opt, train): + if epoch == params['n_epochs_music_pretrain']: + print(f'Switching to bs: {params["batch_size"]}') + for k in data.keys(): + prefix = 'train' if train else 'test' + data[k].batch_sampler.update_epoch_size_and_batch(params[prefix + '_epoch_size'], params['batch_size']) + if train: + model.train() + else: + model.eval() + keys_to_track = params['keys_to_track'] + losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) + step = 0 + cf_matrices_music = [] + cf_matrices_cocktail = [] + for i_batch, data_music, data_cocktail, data_music_lab, data_cocktail_lab, data_reg_grounding \ + in zip(range(len(data['music'])), data['music'], data['cocktail'], data['music_labeled'], data['cocktail_labeled'], data['reg_grounding']): + x_music, _ = data_music + x_cocktail, _, contains_egg, contains_bubbles = data_cocktail + x_music_lab, labels_music = data_music_lab + x_cocktail_lab, labels_cocktail = data_cocktail_lab + x_reg_music, x_reg_cocktail = data_reg_grounding + step += x_music.shape[0] + if train: opt.zero_grad() + + # weight more examples that have bubbles or egg in the mse computation + bubbles_egg_weights = torch.ones([contains_bubbles.shape[0]]) + bubbles_egg_weights[contains_bubbles] += 1 + bubbles_egg_weights[contains_egg] += 3 + + # vae + x_hat_cocktail, z_cocktail, mu_cocktail, log_var_cocktail = model(x_cocktail, modality_in='cocktail', modality_out='cocktail') + mse_loss_cocktail = torch.sum(((x_cocktail - x_hat_cocktail)**2).mean(axis=1) * bubbles_egg_weights) / bubbles_egg_weights.sum() + if contains_bubbles.sum() > 0: + bubble_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_bubbles, -3].mean()) + else: + bubble_mse = np.nan + if contains_egg.sum() > 0: + egg_mse = float(((x_cocktail - x_hat_cocktail)**2)[contains_egg, -1].mean()) + else: + egg_mse = np.nan + + kld_loss_cocktail = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) + + x_hat_music, z_music, mu_music, log_var_music = model(x_music, modality_in='music', modality_out='music') + mse_loss_music = ((x_music - x_hat_music)**2).mean() + kld_loss_music = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) + + music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music + cocktail_vae_loss = mse_loss_cocktail + params['beta_vae'] * kld_loss_cocktail + vae_loss = cocktail_vae_loss + params['beta_music'] * music_vae_loss + # music_vae_loss = mse_loss_music + params['beta_vae'] * kld_loss_music + brb_kld_loss_cocktail, brb_kld_loss_music, brb_mse_loss_music, brb_mse_loss_cocktail, brb_mse_latent_loss, brb_music_vae_loss, brb_vae_loss = [0] * 7 + + if params['use_brb_vae']: + # vae back to back + out = model.forward_b2b(x_cocktail, modality_in_out='cocktail', modality_intermediate='music') + x_hat_cocktail, x_intermediate_music, mu_cocktail, log_var_cocktail, z_cocktail, mu_music, log_var_music, z_music = out + brb_mse_loss_cocktail = ((x_cocktail - x_hat_cocktail) ** 2).mean() + brb_mse_latent_loss_1 = ((z_music - z_cocktail) ** 2).mean() + brb_kld_loss_cocktail_1 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) + brb_kld_loss_music_1 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) + # brb_cocktail_in_loss = mse_loss_cocktail + mse_latents_1 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) + + out = model.forward_b2b(x_music, modality_in_out='music', modality_intermediate='cocktail') + x_hat_music, x_intermediate_cocktail, mu_music, log_var_music, z_music, mu_cocktail, log_var_cocktail, z_cocktail = out + brb_mse_loss_music = ((x_music - x_hat_music) ** 2).mean() + brb_mse_latent_loss_2 = ((z_music - z_cocktail) ** 2).mean() + brb_kld_loss_cocktail_2 = torch.mean(-0.5 * torch.sum(1 + log_var_cocktail - mu_cocktail ** 2 - log_var_cocktail.exp(), dim=1)) + brb_kld_loss_music_2 = torch.mean(-0.5 * torch.sum(1 + log_var_music - mu_music ** 2 - log_var_music.exp(), dim=1)) + # brb_music_in_loss = mse_loss_music + mse_latents_2 + params['beta_vae'] * (kld_loss_cocktail + kld_loss_music) + brb_mse_latent_loss = (brb_mse_latent_loss_1 + brb_mse_latent_loss_2) / 2 + brb_kld_loss_music = (brb_kld_loss_music_1 + brb_kld_loss_music_2) / 2 + brb_kld_loss_cocktail = (brb_kld_loss_cocktail_1 + brb_kld_loss_cocktail_2) / 2 + brb_vae_loss = brb_mse_latent_loss + brb_mse_loss_cocktail + brb_mse_loss_music + params['beta_vae'] * (brb_kld_loss_music + brb_kld_loss_cocktail) + brb_music_vae_loss = brb_mse_loss_music + params['beta_vae'] * brb_kld_loss_music + brb_mse_latent_loss + + # swd + if params['beta_swd'] > 0: + swd_loss = compute_swd_loss(z_music, z_cocktail, params['latent_dim']) + else: + swd_loss = 0 + + # classif losses + if params['beta_classif'] > 0: + pred_music = model.classify(x_music_lab, modality_in='music') + classif_loss_music = LOSS(pred_music, labels_music) + accuracy_music = torch.mean((torch.argmax(pred_music, dim=1) == labels_music).float()) + cf_matrices_music.append(get_cf_matrix(pred_music, labels_music)) + pred_cocktail = model.classify(x_cocktail_lab, modality_in='cocktail') + classif_loss_cocktail = LOSS(pred_cocktail, labels_cocktail) + accuracy_cocktail = torch.mean((torch.argmax(pred_cocktail, dim=1) == labels_cocktail).float()) + cf_matrices_cocktail.append(get_cf_matrix(pred_cocktail, labels_cocktail)) + + else: + classif_loss_cocktail, classif_loss_music = 0, 0 + accuracy_music, accuracy_cocktail = 0, 0 + cf_matrices_cocktail.append(np.zeros((2, 2))) + cf_matrices_music.append(np.zeros((2, 2))) + + if params['beta_reg_grounding'] > 0: + x_hat_cocktail, _, _, _ = model(x_reg_music, modality_in='music', modality_out='cocktail', freeze_decoder=True) + mse_reg_grounding = ((x_reg_cocktail - x_hat_cocktail) ** 2).mean() + else: + mse_reg_grounding = 0 + + if params['use_brb_vae']: + global_minus_classif = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss + global_loss = params['beta_vae_loss'] * (vae_loss + brb_music_vae_loss) + params['beta_swd'] * swd_loss + \ + params['beta_classif'] * (classif_loss_cocktail + params['beta_music_classif'] * classif_loss_music) + else: + global_minus_classif = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss + global_loss = params['beta_vae_loss'] * vae_loss + params['beta_swd'] * swd_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ + params['beta_reg_grounding'] * mse_reg_grounding + # global_loss = params['beta_vae_loss'] * cocktail_vae_loss + params['beta_classif'] * (classif_loss_cocktail + classif_loss_music) + \ + # params['beta_reg_grounding'] * mse_reg_grounding + + losses['brb_vae_loss'].append(float(brb_vae_loss)) + losses['brb_mse_latent_loss'].append(float(brb_mse_latent_loss)) + losses['brb_kld_loss_cocktail'].append(float(brb_kld_loss_cocktail)) + losses['brb_kld_loss_music'].append(float(brb_kld_loss_music)) + losses['brb_mse_loss_music'].append(float(brb_mse_loss_music)) + losses['brb_mse_loss_cocktail'].append(float(brb_mse_loss_cocktail)) + losses['swd_losses'].append(float(swd_loss)) + losses['vae_losses'].append(float(vae_loss)) + losses['kld_losses_music'].append(float(kld_loss_music)) + losses['kld_losses_cocktail'].append(float(kld_loss_cocktail)) + losses['mse_losses_music'].append(float(mse_loss_music)) + losses['mse_losses_cocktail'].append(float(mse_loss_cocktail)) + losses['global_losses'].append(float(global_loss)) + losses['classif_losses_music'].append(float(classif_loss_music)) + losses['classif_losses_cocktail'].append(float(classif_loss_cocktail)) + losses['classif_acc_cocktail'].append(float(accuracy_cocktail)) + losses['classif_acc_music'].append(float(accuracy_music)) + losses['beta_reg_grounding'].append(float(mse_reg_grounding)) + losses['bubble_mse'].append(bubble_mse) + losses['egg_mse'].append(egg_mse) + + if train: + # if epoch < params['n_epochs_music_pretrain']: + # music_vae_loss.backward() + # elif epoch >= params['n_epochs_music_pretrain'] and epoch < (params['n_epochs_music_pretrain'] + params['n_epochs_train']): + # global_minus_classif.backward() + # elif epoch >= (params['n_epochs_music_pretrain'] + params['n_epochs_train']): + global_loss.backward() + opt.step() + + if params['log_every'] != 0: + if step != 0 and step % params['log_every'] == 0: + print(f'\tBatch #{i_batch}') + for k in params['keys_to_print']: + if k != 'steps': + print(f'\t {k}: Train: {np.nanmean(losses[k][-params["log_every"]:]):.3f}') + # print(f'\t {k}: Train: {torch.mean(torch.cat(losses[k][-params["log_every"]:])):.3f}') + return losses, [np.mean(cf_matrices_music, axis=0), np.mean(cf_matrices_cocktail, axis=0)] + +def get_cf_matrix(pred, labels): + bs, dim = pred.shape + labels = labels.detach().numpy() + pred_labels = np.argmax(pred.detach().numpy(), axis=1) + confusion_matrix = np.zeros((dim, dim)) + for i in range(bs): + confusion_matrix[labels[i], pred_labels[i]] += 1 + for i in range(dim): + if np.sum(confusion_matrix[i]) != 0: + confusion_matrix[i] /= np.sum(confusion_matrix[i]) + return confusion_matrix + +def train(model, dataloaders, params): + keys_to_track = params['keys_to_track'] + opt = torch.optim.AdamW(list(model.parameters()), lr=params['lr']) + if params['decay_step'] > 0: scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=params['decay_step'], gamma=0.5) + all_train_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) + all_eval_losses = dict(zip(keys_to_track, [[] for _ in range(len(keys_to_track))])) + best_eval_loss = np.inf + + data_train = dict() + data_test = dict() + for k in dataloaders.keys(): + if '_train' in k: + data_train[k[:-6]] = dataloaders[k] + elif '_test' in k: + data_test[k[:-5]] = dataloaders[k] + else: + raise ValueError + # run first eval + eval_losses, _ = run_epoch(0, model, data_test, params, opt, train=False) + for k in params['keys_to_track']: + if k == 'steps': + all_train_losses[k].append(0) + all_eval_losses[k].append(0) + else: + all_train_losses[k].append(np.nan) + all_eval_losses[k].append(np.mean(eval_losses[k])) + # all_train_losses[k].append(torch.Tensor([np.nan])) + # all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) + print(f'Initial evaluation') + for k in params['keys_to_print']: + to_print = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] + # to_print = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] + print(f' {k}: Eval: {to_print:.3f}') + step = 0 + for epoch in range(params['epochs']): + print(f'\n------------\nEpoch #{epoch}') + # run training epoch + train_losses, train_cf_matrices = run_epoch(epoch, model, data_train, params, opt, train=True) + # run eval epoch + eval_losses, eval_cf_matrices = run_epoch(epoch, model, data_test, params, opt, train=False) + + if epoch < params['n_epochs_music_pretrain']: + epoch_size = params['pretrain_train_epoch_size'] + else: + epoch_size = params['train_epoch_size'] + step += epoch_size + for k in params['keys_to_track']: + if k == 'steps': + all_train_losses[k].append(epoch) + all_eval_losses[k].append(epoch) + else: + all_train_losses[k].append(np.nanmean(train_losses[k])) + all_eval_losses[k].append(np.nanmean(eval_losses[k])) + # all_train_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(train_losses[k])))) + # all_eval_losses[k].append(torch.atleast_1d(torch.mean(torch.cat(eval_losses[k])))) + if params['decay_step']: scheduler.step() + # logging + print(f'----\n\tEval epoch #{epoch}') + for k in params['keys_to_print']: + to_print_eval = all_eval_losses[k][-1] if k != 'steps' else all_eval_losses[k][-1] + to_print_train = all_train_losses[k][-1] if k != 'steps' else all_train_losses[k][-1] + # to_print_eval = all_eval_losses[k][-1][0] if k != 'steps' else all_eval_losses[k][-1] + # to_print_train = all_train_losses[k][-1][0] if k != 'steps' else all_train_losses[k][-1] + print(f'\t {k}: Eval: {to_print_eval:.3f} / Train: {to_print_train:.3f}') + + if epoch % params['plot_every'] == 0: + plot_all_losses(all_train_losses.copy(), all_eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) + # saving models + save_losses(all_train_losses, all_eval_losses, params['save_path'] + 'results.txt') + if params['save_every'] != 0: + if epoch % params['save_every'] == 0: + print('Saving model.') + save_model(model, path=params['save_path'], name=f'epoch_{epoch}') + if all_eval_losses['global_losses'][-1] < best_eval_loss: + best_eval_loss = all_eval_losses['global_losses'][-1] + print(f'New best eval loss: {best_eval_loss:.3f}, saving model.') + # print(f'New best eval loss: {best_eval_loss[0]:.3f}, saving model.') + save_model(model, path=params['save_path'], name='best_eval') + print('Saving last model.') + save_model(model, path=params['save_path'], name=f'last') + return model, all_train_losses, all_eval_losses, train_cf_matrices, eval_cf_matrices + +def save_losses(train_losses, eval_losses, path): + results = [] + keys = sorted(train_losses.keys()) + for k in keys: + if k != 'steps': + results.append(train_losses[k])#list(torch.cat(train_losses[k]).detach().cpu().numpy())) + else: + results.append(train_losses[k]) + for k in keys: + if k != 'steps': + results.append(eval_losses[k])#list(torch.cat(eval_losses[k]).detach().cpu().numpy())) + else: + results.append(eval_losses[k]) + np.savetxt(path, np.array(results)) + +def save_model(model, path, name): + torch.save(model.state_dict(), path + f'checkpoints_{name}.save') + +def run_training(params): + params = compute_expe_name_and_save_path(params) + dataloaders, n_labels, stats = get_dataloaders(cocktail_rep_path=params['cocktail_rep_path'], + music_rep_path=params['music_rep_path'], + batch_size=params['pretrain_batch_size'], + train_epoch_size=params['pretrain_train_epoch_size'], + test_epoch_size=params['pretrain_test_epoch_size']) + params['nb_classes'] = n_labels + params['stats'] = stats + params['classif_classes'] = dataloaders['music_labeled_train'].dataset.classes + vae_gml_model = get_gml_vae_models(layer_type=params['layer_type'], + input_dim_music=dataloaders['music_train'].dataset.dim_music, + input_dim_cocktail=dataloaders['cocktail_train'].dataset.dim_cocktail, + hidden_dim=params['hidden_dim'], + n_hidden=params['n_hidden'], + latent_dim=params['latent_dim'], + nb_classes=params['nb_classes'], + dropout=params['dropout']) + params['dim_music'] = dataloaders['music_train'].dataset.dim_music + params['dim_cocktail'] = dataloaders['cocktail_train'].dataset.dim_cocktail + with open(params['save_path'] + 'params.json', 'w') as f: + json.dump(params, f) + models, train_losses, eval_losses, train_cf_matrices, eval_cf_matrices = train(vae_gml_model, dataloaders, params) + plot_all_losses(train_losses.copy(), eval_losses.copy(), train_cf_matrices, eval_cf_matrices, params) + return models, train_losses, eval_losses + +def plot_all_losses(train_losses, eval_losses, train_cf_matrices, eval_cf_matrices, params): + plot_losses(train_losses, train_cf_matrices, 'train', params) + plot_losses(eval_losses, eval_cf_matrices, 'eval', params) + +def plot_losses(losses, cf_matrices, split, params): + save_path = params['save_path'] + 'plots/' + os.makedirs(save_path, exist_ok=True) + steps = losses['steps'] + for k in losses.keys(): + # if k != 'steps': + # losses[k] = losses[k]#torch.cat(losses[k]).detach().cpu().numpy() + # else: + losses[k] = np.array(losses[k]) + losses['sum_loss_classif'] = losses['classif_losses_music'] + losses['classif_losses_cocktail'] + losses['av_acc_classif'] = (losses['classif_acc_cocktail'] + losses['classif_acc_music'])/2 + losses['sum_mse_vae'] = losses['mse_losses_cocktail'] + losses['mse_losses_music'] + losses['sum_kld_vae'] = losses['kld_losses_cocktail'] + losses['kld_losses_music'] + + + plt.figure() + for k in ['global_losses', 'vae_losses', 'swd_losses', 'sum_mse_vae', 'sum_kld_vae']: + factor = 10 if k == 'swd_losses' else 1 + plt.plot(steps, losses[k] * factor, label=k) + plt.title(split) + plt.legend() + plt.ylim([0, 2.5]) + plt.savefig(save_path + f'plot_high_level_losses_{split}.png') + plt.close(plt.gcf()) + + plt.figure() + for k in ['classif_acc_cocktail', 'classif_acc_music']: + plt.plot(steps, losses[k], label=k) + plt.title(split) + plt.ylim([0, 1]) + plt.legend() + plt.savefig(save_path + f'plot_classif_accuracies_{split}.png') + plt.close(plt.gcf()) + + plt.figure() + for k in ['mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', + 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'beta_reg_grounding', + 'bubble_mse', 'egg_mse']: + factor = 10 if k == 'swd_losses' else 1 + plt.plot(steps, losses[k] * factor, label=k) + plt.title(split) + plt.ylim([0, 2.5]) + plt.legend() + plt.savefig(save_path + f'plot_detailed_losses_{split}.png') + plt.close(plt.gcf()) + + for i_k, k in enumerate(['music', 'cocktail']): + plt.figure() + plt.imshow(cf_matrices[i_k], vmin=0, vmax=1) + labx = plt.xticks(range(len(params['classif_classes'])), params['classif_classes'], rotation=45) + laby = plt.yticks(range(len(params['classif_classes'])), params['classif_classes']) + labxx = plt.xlabel('predicted') + labyy = plt.ylabel('true') + plt.title(split + ' ' + k) + plt.colorbar() + plt.savefig(save_path + f'cf_matrix_{split}_{k}.png', artists=(labx, laby, labxx, labyy)) + plt.close(plt.gcf()) + + if params['use_brb_vae']: + plt.figure() + for k in ['brb_vae_loss', 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'mse_losses_music', 'brb_mse_latent_loss']: + factor = 10 if k == 'swd_losses' else 1 + plt.plot(steps, losses[k] * factor, label=k) + plt.title(split) + plt.ylim([0, 2.5]) + plt.legend() + plt.savefig(save_path + f'plot_detailed_brb_losses_{split}.png') + plt.close(plt.gcf()) + +def parse_args(): + parser = argparse.ArgumentParser(description="") + parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") + parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") + parser.add_argument("--hidden_dim", type=int, default=256) #128 + parser.add_argument("--n_hidden", type=int, default=1) + parser.add_argument("--latent_dim", type=int, default=40) #40 + parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) + parser.add_argument("--n_epochs_train", type=int, default=200) + parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) + parser.add_argument("--beta_vae_loss", type=float, default=1.) + parser.add_argument("--beta_vae", type=float, default=1.2) # keep this low~1 to allow music classification... + parser.add_argument("--beta_swd", type=float, default=1) + parser.add_argument("--beta_reg_grounding", type=float, default=2.5) + parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 + parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge + parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge + parser.add_argument("--pretrain_batch_size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--decay_step", type=int, default=0) + parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) + parser.add_argument("--music_rep_path", type=str, default=music_rep_path) + parser.add_argument("--use_brb_vae", type=bool, default=False) + parser.add_argument("--layer_type", type=str, default='gml') + parser.add_argument("--dropout", type=float, default=0.2) + + # best parameters + # parser = argparse.ArgumentParser(description="") + # parser.add_argument("--save_path", type=str, default="/home/cedric/Documents/pianocktail/experiments/music/representation_learning/saved_models/latent_translation/") + # parser.add_argument("--trial_id", type=str, default="b256_r128_classif001_ld40_meanstd") + # parser.add_argument("--hidden_dim", type=int, default=256) #128 + # parser.add_argument("--n_hidden", type=int, default=1) + # parser.add_argument("--latent_dim", type=int, default=40) #40 + # parser.add_argument("--n_epochs_music_pretrain", type=int, default=0) + # parser.add_argument("--n_epochs_train", type=int, default=200) + # parser.add_argument("--n_epochs_classif_finetune", type=int, default=0) + # parser.add_argument("--beta_vae_loss", type=float, default=1.) + # parser.add_argument("--beta_vae", type=float, default=1) # keep this low~1 to allow music classification... + # parser.add_argument("--beta_swd", type=float, default=1) + # parser.add_argument("--beta_reg_grounding", type=float, default=2.5) + # parser.add_argument("--beta_classif", type=float, default=0.01)#0.01) #TODO: try 0.1, default 0.01 + # parser.add_argument("--beta_music", type=float, default=100) # higher loss on the music that needs more to converge + # parser.add_argument("--beta_music_classif", type=float, default=300) # try300# higher loss on the music that needs more to converge + # parser.add_argument("--pretrain_batch_size", type=int, default=128) + # parser.add_argument("--batch_size", type=int, default=32) + # parser.add_argument("--lr", type=float, default=0.001) + # parser.add_argument("--decay_step", type=int, default=0) + # parser.add_argument("--cocktail_rep_path", type=str, default=FULL_COCKTAIL_REP_PATH) + # parser.add_argument("--music_rep_path", type=str, default=music_rep_path) + # parser.add_argument("--use_brb_vae", type=bool, default=False) + # parser.add_argument("--layer_type", type=str, default='gml') + # parser.add_argument("--dropout", type=float, default=0.2) + args = parser.parse_args() + return args + +def compute_expe_name_and_save_path(params): + save_path = params['save_path'] + params["trial_id"] + if params["use_brb_vae"]: + save_path += '_usebrb' + save_path += f'_lr{params["lr"]}' + save_path += f'_bs{params["batch_size"]}' + save_path += f'_bmusic{params["beta_music"]}' + save_path += f'_bswd{params["beta_swd"]}' + save_path += f'_bclassif{params["beta_classif"]}' + save_path += f'_bvae{params["beta_vae_loss"]}' + save_path += f'_bvaekld{params["beta_vae"]}' + save_path += f'_lat{params["latent_dim"]}' + save_path += f'_hd{params["n_hidden"]}x{params["hidden_dim"]}' + save_path += f'_drop{params["dropout"]}' + save_path += f'_decay{params["decay_step"]}' + save_path += f'_layertype{params["layer_type"]}' + number_added = False + counter = 1 + while os.path.exists(save_path): + if number_added: + save_path = '_'.join(save_path.split('_')[:-1]) + f'_{counter}' + counter += 1 + else: + save_path += f'_{counter}' + params["save_path"] = save_path + '/' + os.makedirs(save_path) + print(f'logging to {save_path}') + return params + +if __name__ == '__main__': + keys_to_track = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', + 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', + 'brb_kld_loss_cocktail', 'brb_kld_loss_music', 'brb_mse_loss_music', 'brb_mse_loss_cocktail', 'brb_mse_latent_loss', 'brb_vae_loss', 'beta_reg_grounding', + 'bubble_mse', 'egg_mse'] + + keys_to_print = ['steps', 'global_losses', 'vae_losses', 'mse_losses_cocktail', 'mse_losses_music', 'kld_losses_cocktail', + 'kld_losses_music', 'swd_losses', 'classif_losses_cocktail', 'classif_losses_music', 'classif_acc_cocktail', 'classif_acc_music', 'beta_reg_grounding'] + #TODO: first phase vae pretraining for music + # then in second phase: vae cocktail and music, brb vaes + args = parse_args() + params = dict(nb_classes=None, + save_every=0, #epochs + log_every=0, #32*500, + plot_every=10, # in epochs + keys_to_track=keys_to_track, + keys_to_print=keys_to_print,) + params.update(vars(args)) + + params['train_epoch_size'] = params['batch_size'] * 100 + params['test_epoch_size'] = params['batch_size'] * 10 + params['pretrain_train_epoch_size'] = params['pretrain_batch_size'] * 100 + params['pretrain_test_epoch_size'] = params['pretrain_batch_size'] * 10 + params['epochs'] = params['n_epochs_music_pretrain'] + params['n_epochs_train'] + params['n_epochs_classif_finetune'] + run_training(params) + + diff --git a/src/music2cocktailrep/training/latent_translation/setup_trained_model.py b/src/music2cocktailrep/training/latent_translation/setup_trained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c5e2c433ba9fa92d842ae4a80211980702b746 --- /dev/null +++ b/src/music2cocktailrep/training/latent_translation/setup_trained_model.py @@ -0,0 +1,83 @@ +import json +import torch +import numpy as np +from src.music2cocktailrep.training.latent_translation.vae_model import get_gml_vae_models +from src.music.config import TRANSLATION_VAE_CHKP_PATH +from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys +import os +from huggingface_hub import hf_hub_download +from shutil import copy +import hashlib + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +# TOKEN = os.environ['token'] +rep_keys = get_bunch_of_rep_keys()['custom'] + +def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH): + # download translation model + # repo_id = "ccolas/translation_vae" + # filename = "checkpoints_best_eval_old.save" + # downloaded_path = hf_hub_download(repo_id=repo_id, + # filename=filename, + # repo_type='model', + # use_auth_token=TOKEN) + model_path = checkpoint_path + 'checkpoints_best_eval.save' + # copy(downloaded_path, model_path) + with open(checkpoint_path + 'params.json', 'r') as f: + params = json.load(f) + + model = get_gml_vae_models(layer_type=params['layer_type'], + input_dim_music=params['dim_music'], + input_dim_cocktail=params['dim_cocktail'], + hidden_dim=params['hidden_dim'], + n_hidden=params['n_hidden'], + latent_dim=params['latent_dim'], + nb_classes=params['nb_classes'], + dropout=params['dropout']) + model = model.to(device) + stats = params['stats'] + stats_music = np.array(stats['mean_std_music_rep']) + stats_cocktail = np.array(stats['mean_std_cocktail_rep_norm11']) + + def normalize_music_input(input): + return (input - stats_music[0]) / stats_music[1] + + model.load_state_dict(torch.load(model_path)) + model.eval() + print('HEREEE: ', torch.sum(torch.Tensor([param.sum() for param in list(model.parameters())]))) + print('model hash: ', hashlib.md5(open(model_path, 'rb').read()).hexdigest()) + def denormalize_cocktail_output(output): + return output * stats_cocktail[1] + stats_cocktail[0] + + def complete_model(music_input, modality_out): + input = torch.Tensor(music_input).float() + + if input.ndim == 1: input = input.reshape(1, -1) + normalized_input = normalize_music_input(input).float() + if torch.cuda.is_available(): + normalized_input = normalized_input.to(device) + if modality_out == 'music': + music_reconstruction = model(normalized_input, modality_in='music', modality_out='music')[0] + if device == 'cuda': + return music_reconstruction.cpu().detach().numpy().flatten() + else: + return music_reconstruction.detach().numpy().flatten() + elif modality_out == 'cocktail': + cocktail_rep = model(normalized_input, modality_in='music', modality_out='cocktail')[0] + cocktail_rep = cocktail_rep.detach().cpu().numpy() + cocktail_rep = denormalize_cocktail_output(cocktail_rep) + # post processing of the output cocktail reps (clipped and reshaped for eggy and fizzy + cocktail_rep = np.clip(cocktail_rep, -1, 1) # clip to -1, 1 + ind_eggy = rep_keys.index('end eggy') + ind_fizzy = rep_keys.index('end fizzy') + cocktail_rep[:, ind_eggy][np.where(cocktail_rep[:, ind_eggy] < -0.5)] = -1 + cocktail_rep[:, ind_eggy][np.where(cocktail_rep[:, ind_eggy] >= -0.5)] = 0.5 * cocktail_rep[:, ind_eggy][np.where(cocktail_rep[:, ind_eggy] >= -0.5)] + 0.15 + cocktail_rep[:, ind_fizzy][np.where(cocktail_rep[:, ind_fizzy] < -0.65)] = -1 + return cocktail_rep.copy() + else: + raise ValueError + + return complete_model + +if __name__ == '__main__': + model = setup_trained_model() diff --git a/src/music2cocktailrep/training/latent_translation/utils.py b/src/music2cocktailrep/training/latent_translation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc18dba0e7ad7dcb04088ef3e44898f6b69d0731 --- /dev/null +++ b/src/music2cocktailrep/training/latent_translation/utils.py @@ -0,0 +1,178 @@ +import os.path + +from dataset import CocktailDataset, MusicDataset, CocktailLabeledDataset, MusicLabeledDataset, RegressedGroundingDataset +import torch +import numpy as np +from src.music.utilities.representation_learning_utilities.sampler import FixedLenRandomSampler +from torch.utils.data.sampler import RandomSampler +from torch.utils.data import DataLoader +from src.music.utils import get_all_subfiles_with_extension +import pickle + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def wasserstein1d(x, y): + x1, _ = torch.sort(x, dim=0) + y1, _ = torch.sort(y, dim=0) + z = (x1-y1).view(-1) + n = z.size(0) + return torch.dot(z, z) / n + +def compute_swd_loss(minibatch1, minibatch2, latent_dim, n_projections=10000): + # sample random projections + theta = torch.randn((latent_dim, n_projections), + requires_grad=False, + device=device) + theta = theta/torch.norm(theta, dim=0)[None, :] + + proj1 = minibatch1@theta + proj2 = minibatch2@theta + + # compute sliced wasserstein distance on projected features + gloss = wasserstein1d(proj1, proj2) + return gloss + + + +def get_dataloaders(cocktail_rep_path, music_rep_path, batch_size, train_epoch_size, test_epoch_size): + assert train_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size' + assert test_epoch_size % batch_size == 0, 'epoch size is expressed in steps, must be a multiple of batch size' + assert '.pickle' in music_rep_path + + if not os.path.exists(music_rep_path): + music_rep_paths = get_all_subfiles_with_extension(music_rep_path.replace('music_reps_normalized_meanstd.pickle', ''), max_depth=3, extension='.txt', current_depth=0) + music_reps = [] + for p in music_rep_paths: + music_reps.append(np.loadtxt(p)) + music_reps = np.array(music_reps) + mean_std = np.array([music_reps.mean(axis=0), music_reps.std(axis=0)]) + music_reps = (music_reps - mean_std[0]) / mean_std[1] + assert len(music_rep_paths) == len(music_reps), 'check bug with mean_std' + data = dict(zip(music_rep_paths, music_reps)) + to_save = dict(musicpath2musicrep=data, + mean_std=mean_std) + with open(music_rep_path, 'wb') as f: + pickle.dump(to_save, f) + + with open(music_rep_path, 'rb') as f: + data = pickle.load(f) + mean_std_music_rep = data['mean_std'] + music_rep_paths = sorted(data['musicpath2musicrep'].keys()) + music_reps = np.array([data['musicpath2musicrep'][k] for k in music_rep_paths]) + + + cocktail_reps = np.loadtxt(cocktail_rep_path) + mean_std_cocktail_rep_norm11 = np.array([cocktail_reps.mean(axis=0), cocktail_reps.std(axis=0)]) + cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0) + + train_data_cocktail = CocktailDataset(split='train', cocktail_reps=cocktail_reps) + test_data_cocktail = CocktailDataset(split='test', cocktail_reps=cocktail_reps) + train_data_music = MusicDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths) + test_data_music = MusicDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths) + + train_sampler_cocktail = FixedLenRandomSampler(train_data_cocktail, bs=batch_size, epoch_size=train_epoch_size) + test_sampler_cocktail = FixedLenRandomSampler(test_data_cocktail, bs=batch_size, epoch_size=test_epoch_size) + train_sampler_music = FixedLenRandomSampler(train_data_music, bs=batch_size, epoch_size=train_epoch_size) + test_sampler_music = FixedLenRandomSampler(test_data_music, bs=batch_size, epoch_size=test_epoch_size) + + train_data_cocktail = DataLoader(train_data_cocktail, batch_sampler=train_sampler_cocktail) + test_data_cocktail = DataLoader(test_data_cocktail, batch_sampler=test_sampler_cocktail) + train_data_music = DataLoader(train_data_music, batch_sampler=train_sampler_music) + test_data_music = DataLoader(test_data_music, batch_sampler=test_sampler_music) + + train_data_cocktail_labeled = CocktailLabeledDataset(split='train', cocktail_reps=cocktail_reps) + test_data_cocktail_labeled = CocktailLabeledDataset(split='test', cocktail_reps=cocktail_reps) + train_data_music_labeled = MusicLabeledDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths) + test_data_music_labeled = MusicLabeledDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths) + + train_sampler_cocktail_labeled = FixedLenRandomSampler(train_data_cocktail_labeled, bs=batch_size, epoch_size=train_epoch_size) + test_sampler_cocktail_labeled = FixedLenRandomSampler(test_data_cocktail_labeled, bs=batch_size, epoch_size=test_epoch_size) + train_sampler_music_labeled = FixedLenRandomSampler(train_data_music_labeled, bs=batch_size, epoch_size=train_epoch_size) + test_sampler_music_labeled = FixedLenRandomSampler(test_data_music_labeled, bs=batch_size, epoch_size=test_epoch_size) + + train_data_cocktail_labeled = DataLoader(train_data_cocktail_labeled, batch_sampler=train_sampler_cocktail_labeled) + test_data_cocktail_labeled = DataLoader(test_data_cocktail_labeled, batch_sampler=test_sampler_cocktail_labeled) + train_data_music_labeled = DataLoader(train_data_music_labeled, batch_sampler=train_sampler_music_labeled) + test_data_music_labeled = DataLoader(test_data_music_labeled, batch_sampler=test_sampler_music_labeled) + + train_data_grounding = RegressedGroundingDataset(split='train', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps) + test_data_grounding = RegressedGroundingDataset(split='test', music_reps=music_reps, music_rep_paths=music_rep_paths, cocktail_reps=cocktail_reps) + + train_sampler_grounding = FixedLenRandomSampler(train_data_grounding, bs=batch_size, epoch_size=train_epoch_size) + test_sampler_grounding = FixedLenRandomSampler(test_data_grounding, bs=batch_size, epoch_size=test_epoch_size) + + train_data_grounding = DataLoader(train_data_grounding, batch_sampler=train_sampler_grounding) + test_data_grounding = DataLoader(test_data_grounding, batch_sampler=test_sampler_grounding) + + data_loaders = dict(music_train=train_data_music, + music_test=test_data_music, + cocktail_train=train_data_cocktail, + cocktail_test=test_data_cocktail, + music_labeled_train=train_data_music_labeled, + music_labeled_test=test_data_music_labeled, + cocktail_labeled_train=train_data_cocktail_labeled, + cocktail_labeled_test=test_data_cocktail_labeled, + reg_grounding_train=train_data_grounding, + reg_grounding_test=test_data_grounding + ) + for k in data_loaders.keys(): + print(f'Dataset {k}, size: {len(data_loaders[k].dataset)}') + assert data_loaders['cocktail_labeled_train'].dataset.n_labels == data_loaders['music_labeled_train'].dataset.n_labels + stats = dict(mean_std_music_rep=mean_std_music_rep.tolist(), mean_std_cocktail_rep_norm11=mean_std_cocktail_rep_norm11.tolist()) + return data_loaders, data_loaders['music_labeled_train'].dataset.n_labels, stats + + +class FixedLenRandomSampler(RandomSampler): + """ + Code from mnpinto - Miguel + https://forums.fast.ai/t/epochs-of-arbitrary-length/27777/10 + """ + def __init__(self, data_source, bs, epoch_size, *args, **kwargs): + super().__init__(data_source) + self.not_sampled = np.array([True]*len(data_source)) + self.update_epoch_size_and_batch(epoch_size, bs) + + def update_epoch_size_and_batch(self, epoch_size, bs): + self.epoch_size = epoch_size + self.bs = bs + self.size_to_sample = self.epoch_size + self.nb_batches_per_epoch = self.epoch_size // self.bs + + def _reset_state(self): + self.not_sampled[:] = True + + def reset_and_sample(self, idx, total_to_sample): + n_to_sample = total_to_sample - len(idx) + ns = sum(self.not_sampled) + if ns == 0: + self._reset_state() + return self.reset_and_sample(idx, total_to_sample) + elif ns >= n_to_sample: + new_idx = np.random.choice(np.where(self.not_sampled)[0], size=n_to_sample, replace=False).tolist() + new_idx = [*idx, *new_idx] + assert len(new_idx) == total_to_sample + return new_idx + else: + idx_last = np.where(self.not_sampled)[0].tolist() + new_idx = [*idx, *idx_last] + self._reset_state() + return self.reset_and_sample(new_idx, total_to_sample) + + def __iter__(self): + idx = self.reset_and_sample(idx=[], total_to_sample=self.size_to_sample) + assert len(idx) == self.size_to_sample + self.not_sampled[idx] = False + # print(ns, len(idx), len(idx_last)) # debug + out = [] + i_idx = 0 + for i in range(self.nb_batches_per_epoch): + batch = [] + for j in range(self.bs): + batch.append(idx[i_idx]) + i_idx += 1 + out.append(batch) + return iter(out) + + def __len__(self): + return self.nb_batches_per_epoch diff --git a/src/music2cocktailrep/training/latent_translation/vae_model.py b/src/music2cocktailrep/training/latent_translation/vae_model.py new file mode 100644 index 0000000000000000000000000000000000000000..86731ea2c9c8ed6487b8adb208b79e707aa95c43 --- /dev/null +++ b/src/music2cocktailrep/training/latent_translation/vae_model.py @@ -0,0 +1,307 @@ +import torch +from torch import nn +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +class Encoder(nn.Module): + + def __init__(self, input_dim_music, input_dim_cocktail, hidden_dim, latent_dim, n_hidden, dropout): + super(Encoder, self).__init__() + self.projection_music = nn.Linear(input_dim_music, hidden_dim - 2) # such that the concatenation with domain encoding is of size hidden_dim + self.projection_cocktail = nn.Linear(input_dim_cocktail, hidden_dim - 2) + self.latent_dim = latent_dim + self.n_hidden = n_hidden + assert self.n_hidden in [1, 2] + self.FC_input = nn.Linear(hidden_dim, hidden_dim) + if self.n_hidden > 1: self.FC_input2 = nn.Linear(hidden_dim, hidden_dim) + self.FC_mean = nn.Linear(hidden_dim, latent_dim) + self.FC_var = nn.Linear (hidden_dim, latent_dim) + self.softplus = nn.Softplus() + self.LeakyReLU = nn.LeakyReLU(0.2) + if dropout != 0: + self.use_dropout = True + self.dropout1 = nn.Dropout(dropout) + if self.n_hidden > 1: self.dropout2 = nn.Dropout(dropout) + else: + self.use_dropout = False + + def forward(self, x, modality): + modality_code = torch.FloatTensor(torch.zeros(size=(x.shape[0], 2))).to(device) + if modality == 'music': + modality_code[:, 0] = 1 + input = self.projection_music(x) + elif modality == 'cocktail': + modality_code[:, 1] = 1 + input = self.projection_cocktail(x) + else: + raise NotImplementedError + input = torch.cat([input, modality_code], dim=1) # batch_size x hidden_dim + + h = self.LeakyReLU(self.FC_input(input)) + if self.use_dropout: h = self.dropout1(h) + if self.n_hidden > 1: + h = self.LeakyReLU(self.FC_input2(h)) + if self.use_dropout: h = self.dropout2(h) + mean = self.FC_mean(h) + std = self.softplus(self.FC_var(h)) + return mean, std + + + + +class Decoder(nn.Module): + def __init__(self, latent_dim, hidden_dim, output_dim_music, output_dim_cocktail, n_hidden, dropout): + super(Decoder, self).__init__() + self.projection_latent = nn.Linear(latent_dim, hidden_dim - 2) + self.n_hidden = n_hidden + assert self.n_hidden in [1, 2] + self.FC_hidden = nn.Linear(hidden_dim, hidden_dim) + if self.n_hidden>1: self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) + self.projection_out_music = nn.Linear(hidden_dim, output_dim_music) + self.projection_out_cocktail = nn.Linear(hidden_dim, output_dim_cocktail) + self.LeakyReLU = nn.LeakyReLU(0.2) + + if dropout != 0: + self.use_dropout = True + self.dropout1 = nn.Dropout(dropout) + if self.n_hidden > 1: self.dropout2 = nn.Dropout(dropout) + else: + self.use_dropout = False + + def forward(self, x, modality): + modality_code = torch.FloatTensor(torch.zeros(size=(x.shape[0], 2))).to(device) + if modality == 'music': + modality_code[:, 0] = 1 + elif modality == 'cocktail': + modality_code[:, 1] = 1 + else: + raise NotImplementedError + input = torch.cat([self.projection_latent(x), modality_code], dim=1) + + + h = self.LeakyReLU(self.FC_hidden(input)) + if self.use_dropout: h = self.dropout1(h) + if self.n_hidden > 1: + h = self.LeakyReLU(self.FC_hidden2(h)) + if self.use_dropout: h = self.dropout2(h) + + if modality == 'music': + z_out = self.projection_out_music(h) + elif modality == 'cocktail': + z_out = self.projection_out_cocktail(h) + else: + raise NotImplementedError + return z_out + +class GML(nn.Module): + def __init__(self, input_dim, latent_dim): + super(GML, self).__init__() + self.input_dim = input_dim + self.latent_dim = latent_dim + self.FC_hidden_gated = nn.Linear(input_dim, latent_dim * 2) + self.FC_hidden_direct = nn.Linear(input_dim, latent_dim) + self.sigmoid = nn.Sigmoid() + + def forward(self, input1, input2): + z = self.FC_hidden_gated(input1) + z_prime = self.FC_hidden_direct(input2) + dz = z[:, self.latent_dim:] + gates = self.sigmoid(z[:, :self.latent_dim]) + return (1 - gates) * z_prime + gates * dz + +class GMLEncoder(nn.Module): + def __init__(self, input_dim_music, input_dim_cocktail, hidden_dim, latent_dim, n_hidden, dropout): + super(GMLEncoder, self).__init__() + self.input_dim_music = input_dim_music + self.input_dim_cocktail = input_dim_cocktail + self.n_hidden = n_hidden + self.projection_music = nn.Linear(input_dim_music, hidden_dim - 2) # such that the concatenation with domain encoding is of size hidden_dim + self.projection_cocktail = nn.Linear(input_dim_cocktail, hidden_dim - 2) + assert self.n_hidden in [1, 2] + self.FC_input = nn.Linear(hidden_dim, hidden_dim) + if self.n_hidden>1: self.FC_input2 = nn.Linear(hidden_dim, hidden_dim) + self.GML_layer = GML(hidden_dim, latent_dim) + self.GML_layer2 = GML(hidden_dim, latent_dim) + self.latent_dim = latent_dim + self.LeakyReLU = nn.LeakyReLU(0.2) + # self.softplus = nn.Softplus() + self.training = True + + if dropout != 0: + self.use_dropout = True + self.dropout1 = nn.Dropout(dropout) + if self.n_hidden > 1: self.dropout2 = nn.Dropout(dropout) + else: + self.use_dropout = False + + def forward(self, x, modality): + modality_code = torch.FloatTensor(torch.zeros(size=(x.shape[0], 2))).to(device) + if modality == 'music': + modality_code[:, 0] = 1 + input = self.projection_music(x) + elif modality == 'cocktail': + modality_code[:, 1] = 1 + input = self.projection_cocktail(x) + else: + raise NotImplementedError + input = torch.cat([input, modality_code], dim=1) # batch_size x hidden_dim + + h = self.LeakyReLU(self.FC_input(input)) + if self.use_dropout: h = self.dropout1(h) + if self.n_hidden > 1: + h = self.LeakyReLU(self.FC_input2(h)) + if self.use_dropout: h = self.dropout2(h) + log_var = self.GML_layer(h, input)#self.softplus(self.GML_layer(h, input)) + mean = self.GML_layer2(h, input) + + return mean, log_var + +class GMLDecoder(nn.Module): + def __init__(self, latent_dim, hidden_dim, output_dim_cocktail, output_dim_music, n_hidden, dropout): + super(GMLDecoder, self).__init__() + self.projection_latent = nn.Linear(latent_dim, hidden_dim - 2) + self.FC_hidden = nn.Linear(hidden_dim, hidden_dim) + self.n_hidden = n_hidden + assert self.n_hidden in [1, 2] + if self.n_hidden>1: self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim) + self.GML_layer = GML(hidden_dim, hidden_dim) + self.projection_out_music = nn.Linear(hidden_dim, output_dim_music) + self.projection_out_cocktail = nn.Linear(hidden_dim, output_dim_cocktail) + self.LeakyReLU = nn.LeakyReLU(0.2) + + if dropout != 0: + self.use_dropout = True + self.dropout1 = nn.Dropout(dropout) + if self.n_hidden > 1: self.dropout2 = nn.Dropout(dropout) + else: + self.use_dropout = False + + def forward(self, x, modality): + modality_code = torch.FloatTensor(torch.zeros(size=(x.shape[0], 2))).to(device) + if modality == 'music': + modality_code[:, 0] = 1 + elif modality == 'cocktail': + modality_code[:, 1] = 1 + else: + raise NotImplementedError + input = torch.cat([self.projection_latent(x), modality_code], dim=1) + + h = self.LeakyReLU(self.FC_hidden(input)) + if self.use_dropout: h = self.dropout1(h) + if self.n_hidden > 1: + h = self.LeakyReLU(self.FC_hidden2(h)) + if self.use_dropout: h = self.dropout2(h) + + z_out = self.GML_layer(h, input) + if modality == 'music': + z_out = self.projection_out_music(z_out) + elif modality == 'cocktail': + # z_out = (torch.sigmoid(self.projection_out_cocktail(z_out)) - 0.5) * 2.2 # normalize in -1, 1 the output + z_out = self.projection_out_cocktail(z_out) + else: + raise NotImplementedError + return z_out + +class GMLVAEModel(nn.Module): + def __init__(self, encoder, decoder, classif_head, dropout): + super(GMLVAEModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.classif_head = classif_head + self.latent_dim = self.encoder.latent_dim + if dropout != 0: + self.use_dropout = True + self.dropout = nn.Dropout(dropout) + else: + self.use_dropout = False + + def reparameterization(self, mean, std): + epsilon = torch.randn_like(std).to(device) # sampling epsilon + z = mean + std * epsilon # reparameterization trick + return z + + def encode(self, x, modality_in): + mean, std = self.encoder(x, modality_in) + z = self.reparameterization(mean, std) # takes exponential function (log var -> std) + return z + + def sample(self, modality, n=1): + assert modality in ['music', 'cocktail'] + z = torch.randn(size=(n, self.latent_dim)) + return self.decoder(z, modality) + + def classify(self, x, modality_in): + h = self.classif_head(self.encode(x, modality_in)) + # if self.use_dropout: h = self.dropout(h) + return h + + def forward(self, x, modality_in, modality_out, freeze_decoder=False): + mean, std = self.encoder(x, modality_in) + z = self.reparameterization(mean, std) # takes exponential function (log var -> std) + # z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> std) + if freeze_decoder: + for child in self.decoder.parameters(): + child.require_grad = False + else: + for child in self.decoder.parameters(): + child.require_grad = True + x_hat = self.decoder(z, modality_out) + return x_hat, z, mean, std + + def forward_b2b(self, x, modality_in_out, modality_intermediate): + mean1, std1 = self.encoder(x, modality_in_out) + z1 = self.reparameterization(mean1, std1) + x_intermediate = self.decoder(z1, modality_intermediate) + mean2, std2 = self.encoder(x_intermediate, modality_intermediate) + z2 = self.reparameterization(mean2, std2) + x_hat = self.decoder(z2, modality_in_out) + return x_hat, x_intermediate, mean1, std1, z1, mean2, std2, z2 + + +# class VAEModel(nn.Module): +# def __init__(self, encoder, decoder): +# super(VAEModel, self).__init__() +# self.encoder = encoder +# self.decoder = decoder +# +# def reparameterization(self, mean, var): +# epsilon = torch.randn_like(var).to(device) # sampling epsilon +# z = mean + var * epsilon # reparameterization trick +# return z +# +# def forward(self, x): +# mean, log_var = self.encoder(x) +# z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var) +# x_hat = self.decoder(z) +# return x_hat, z, mean, log_var + +def get_gml_vae_models(layer_type, input_dim_cocktail, input_dim_music, hidden_dim, n_hidden, latent_dim, nb_classes, dropout): + if layer_type == 'dense': + encoder = Encoder(input_dim_cocktail=input_dim_cocktail, input_dim_music=input_dim_music, + hidden_dim=hidden_dim, latent_dim=latent_dim, n_hidden=n_hidden, dropout=dropout) + decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim_cocktail=input_dim_cocktail, + output_dim_music=input_dim_music, n_hidden=n_hidden, dropout=dropout) + elif layer_type == 'gml': + encoder = GMLEncoder(input_dim_cocktail=input_dim_cocktail, input_dim_music=input_dim_music, + hidden_dim=hidden_dim, latent_dim=latent_dim, n_hidden=n_hidden, dropout=dropout) + decoder = GMLDecoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim_cocktail=input_dim_cocktail, + output_dim_music=input_dim_music, n_hidden=n_hidden, dropout=dropout) + else: + raise ValueError + classifier = nn.Linear(in_features=latent_dim, out_features=nb_classes) + vae_gml_model = GMLVAEModel(encoder=encoder, decoder=decoder, classif_head=classifier, dropout=dropout).to(device) + return vae_gml_model + +# def get_vae_models(input_dim, hidden_dim, latent_dim, nb_classes): +# encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim) +# decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = input_dim) +# model = VAEModel(encoder=encoder, decoder=decoder).to(device) +# classifier = ClassifierHead(latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=nb_classes) +# return model, classifier + +# class ClassifierHead(nn.Module): +# def __init__(self, input_dim, output_dim): +# super(ClassifierHead, self).__init__() +# self.FC_output = nn.Linear(input_dim, output_dim) +# +# def forward(self, x): +# return self.FC_output(x) \ No newline at end of file diff --git a/src/pianocktail.py b/src/pianocktail.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3754e0f2a712d8dba35660f2bae2cad6b6e570 --- /dev/null +++ b/src/pianocktail.py @@ -0,0 +1,79 @@ +import time +import os +import pickle +from src.music.pipeline.music_pipeline import encode_music +from src.music2cocktailrep.pipeline.music2cocktailrep import music2cocktailrep, setup_translation_models, debug_translation +from src.cocktails.pipeline.cocktailrep2recipe import cocktailrep2recipe +from src.debugger import Debugger +from datetime import datetime +from shutil import copy + +synestesia_path = '../data/synesthesia/' +debugger = Debugger() + +def pianocktail(record=False, url=None, midi=None, audio=None, processed=None, crop=40, verbose=False, debug=False, level=0): + assert url is not None or midi is not None or audio is not None or processed is not None + if verbose: print('------\nNew synesthetic exploration!') + init_time = time.time() + music_ai_rep, music_handcoded_rep, all_paths, error = encode_music(record=record, url=url, audio_path=audio, midi_path=midi, nb_aug=0, noise_injection=False, + augmentation=False, processed_path=processed, crop=crop, apply_filtering=False, verbose=verbose, + level=level+2) + if music_ai_rep is not None: + cocktail_rep, affective_cluster_id, affect = music2cocktailrep(music_ai_rep, music_handcoded_rep, verbose=verbose, level=level+2) + cocktail_recipes, scores = cocktailrep2recipe(cocktail_rep, target_affective_cluster=affective_cluster_id, verbose=verbose, full_verbose=verbose, level=level+2) + cocktail_recipe = cocktail_recipes[0] + recipe_score = scores[0] + if debug: + music_reconstruction = debug_translation(music_ai_rep) + debugger.extract_info(all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=verbose, level=level+2) + debug_info = debugger.debug_dict + else: + debug_info = None + if verbose: + print(cocktail_recipe.replace('Recipe', ' ' * (level + 2) + 'Generated recipe:').replace('None ()', '')) + debugger.print_debug(level=level+2) + print(f'\nEnd of synesthetic exploration ({int(time.time() - init_time)} secs).\n------') + + else: + cocktail_recipe = None + debug_info = None + return cocktail_recipe, debug_info + +def setup_and_run(url=None, midi=None, audio=None, verbose=False, debug=False, extra_code=None): + assert url is not None or midi is not None or audio is not None + now = datetime.now() + folder_name = f'{now.year}-{now.month}-{now.day}_{now.hour}:{now.minute}:{now.second}' + folder_path = synestesia_path + folder_name + if extra_code is not None: + folder_path += '_' + extra_code + if os.path.exists(folder_path): + folder_path += '_2' + folder_path += '/' + os.makedirs(folder_path, exist_ok=True) + recipe, debug = pianocktail(url=url, verbose=verbose, debug=debug) + with open(folder_path + 'debug.pk', 'wb') as f: + pickle.dump(debug, f) + with open(folder_path + 'recipe.txt', 'w') as f: + f.write(recipe) + paths = debug['all_paths'] + if paths['url'] is not None: + with open(folder_path + 'url.txt', 'w') as f: + f.write(paths['url']) + for k in ['audio_path', 'midi_path']: + origin = paths[k] + copy(origin, folder_path + origin.split('/')[-1]) + + +if __name__ == '__main__': + urls = ["https://www.youtube.com/watch?v=PLFVGwGQcB0", + "https://www.youtube.com/watch?v=VQmuAr93OlI", + "https://www.youtube.com/watch?v=Nv2GgV34qIg&list=PLO9E3V4rGLD8_iWrCioJRWZXJJE3Fzu_J&index=4", + "https://www.youtube.com/watch?v=qAEIjWYdoYc&list=PLO9E3V4rGLD8_iWrCioJRWZXJJE3Fzu_J&index=1", + "https://www.youtube.com/watch?v=M73x3O7dhmg&list=PLO9E3V4rGLD8_iWrCioJRWZXJJE3Fzu_J&index=5"] + setup_translation_models() + setup_and_run(url=urls[0], verbose=True, debug=True) + recipes = [] + for url in urls: + recipe = pianocktail(url=url, verbose=True, debug=True)[0] + recipes.append(recipe) + stop = 1 diff --git a/src/readme.md b/src/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..eccb23015bddca6d094f58732cf44ca727cc1d1f --- /dev/null +++ b/src/readme.md @@ -0,0 +1,28 @@ +install pytorch first, then requirements + + +conda create -n pianocktailv2 python=3.9 +conda install pytorch torchvision torchaudio cpuonly -c pytorch + +[comment]: <> (conda install -c conda-forge pysoundfile (for librosa)) + +pip install: +librosa +pytube +pandas +pydub +miditok +accelerate +pynput +pyaudio +music21 +mord +moviepy +pretty_midi +piano_transcription_inference +tensorboard +setuptools==59.5.0 + + + +install ffmpeg with apt.