import os import numpy as np import torch import time from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models global music2affect, find_affective_cluster, translation_vae import streamlit as st os.environ["TOKENIZERS_PARALLELISM"] = "false" def setup_translation_models(): global music2affect, find_affective_cluster, translation_vae music2affect, keys = setup_pretrained_affective_models() find_affective_cluster = get_affect2affective_cluster() translation_vae = setup_trained_model() return translation_vae def music2affect_cluster(handcoded_rep): global music2affect, find_affective_cluster affects = np.clip(music2affect(handcoded_rep), -1, 1) cluster_id = find_affective_cluster(affects) return cluster_id, affects def music2flavor(music_ai_rep, affective_cluster_id): global translation_vae cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail') return cocktail_rep def debug_translation(music_ai_rep): global translation_vae music_reconstruction = translation_vae(music_ai_rep, modality_out='music') return music_reconstruction def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0): init_time = time.time() if verbose: print(' ' * level + 'Synesthetic mapping..') if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.') # affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep) affective_cluster_id, affect = None, None if verbose: print(' ' * (level*2) + 'Mapping to flavors.') cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id) if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.') return cocktail_rep, affective_cluster_id, affect # def sigmoid(x, shift, beta): # return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 # # cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)] # def plot_cluster_ids_dataset(handcoded_rep_path): # import matplotlib.pyplot as plt # reps, _, _ = get_data(handcoded_rep_path, keys) # cluster_ids, affects = music2affect_cluster(reps) # # plt.figure() # # affects2 = affects.copy() # # affects2 = sigmoid(affects2, 0.05, 8) # # plt.hist(affects2[:, 2], bins=30) # # plt.xlim([-1, 1]) # fig = plt.figure() # ax = fig.add_subplot(projection='3d') # ax.set_xlim([-1, 1]) # ax.set_ylim([-1, 1]) # ax.set_zlim([-1, 1]) # for cluster_id in sorted(set(cluster_ids)): # indexes = np.argwhere(cluster_ids == cluster_id).flatten() # if len(indexes) > 0: # ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150) # ax.set_xlabel('Valence') # ax.set_ylabel('Arousal') # ax.set_zlabel('Dominance') # plt.figure() # plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)]) # plt.show() # # plot_cluster_ids_dataset(handcoded_rep_path)