Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import time | |
from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster | |
from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model | |
from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models | |
global music2affect, find_affective_cluster, translation_vae | |
import streamlit as st | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def setup_translation_models(): | |
global music2affect, find_affective_cluster, translation_vae | |
music2affect, keys = setup_pretrained_affective_models() | |
find_affective_cluster = get_affect2affective_cluster() | |
translation_vae = setup_trained_model() | |
return translation_vae | |
def music2affect_cluster(handcoded_rep): | |
global music2affect, find_affective_cluster | |
affects = np.clip(music2affect(handcoded_rep), -1, 1) | |
cluster_id = find_affective_cluster(affects) | |
return cluster_id, affects | |
def music2flavor(music_ai_rep, affective_cluster_id): | |
global translation_vae | |
cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail') | |
return cocktail_rep | |
def debug_translation(music_ai_rep): | |
global translation_vae | |
music_reconstruction = translation_vae(music_ai_rep, modality_out='music') | |
return music_reconstruction | |
def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0): | |
init_time = time.time() | |
if verbose: print(' ' * level + 'Synesthetic mapping..') | |
if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.') | |
# affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep) | |
affective_cluster_id, affect = None, None | |
if verbose: print(' ' * (level*2) + 'Mapping to flavors.') | |
cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id) | |
if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.') | |
return cocktail_rep, affective_cluster_id, affect | |
# def sigmoid(x, shift, beta): | |
# return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 | |
# | |
# cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)] | |
# def plot_cluster_ids_dataset(handcoded_rep_path): | |
# import matplotlib.pyplot as plt | |
# reps, _, _ = get_data(handcoded_rep_path, keys) | |
# cluster_ids, affects = music2affect_cluster(reps) | |
# # plt.figure() | |
# # affects2 = affects.copy() | |
# # affects2 = sigmoid(affects2, 0.05, 8) | |
# # plt.hist(affects2[:, 2], bins=30) | |
# # plt.xlim([-1, 1]) | |
# fig = plt.figure() | |
# ax = fig.add_subplot(projection='3d') | |
# ax.set_xlim([-1, 1]) | |
# ax.set_ylim([-1, 1]) | |
# ax.set_zlim([-1, 1]) | |
# for cluster_id in sorted(set(cluster_ids)): | |
# indexes = np.argwhere(cluster_ids == cluster_id).flatten() | |
# if len(indexes) > 0: | |
# ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150) | |
# ax.set_xlabel('Valence') | |
# ax.set_ylabel('Arousal') | |
# ax.set_zlabel('Dominance') | |
# plt.figure() | |
# plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)]) | |
# plt.show() | |
# | |
# plot_cluster_ids_dataset(handcoded_rep_path) |