ccolas's picture
Upload 174 files
93c029f
raw
history blame
3.23 kB
import os
import numpy as np
import torch
import time
from src.cocktails.pipeline.get_affect2affective_cluster import get_affect2affective_cluster
from src.music2cocktailrep.training.latent_translation.setup_trained_model import setup_trained_model
from src.music2cocktailrep.pipeline.music2affect import setup_pretrained_affective_models
global music2affect, find_affective_cluster, translation_vae
import streamlit as st
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_translation_models():
global music2affect, find_affective_cluster, translation_vae
music2affect, keys = setup_pretrained_affective_models()
find_affective_cluster = get_affect2affective_cluster()
translation_vae = setup_trained_model()
return translation_vae
def music2affect_cluster(handcoded_rep):
global music2affect, find_affective_cluster
affects = np.clip(music2affect(handcoded_rep), -1, 1)
cluster_id = find_affective_cluster(affects)
return cluster_id, affects
def music2flavor(music_ai_rep, affective_cluster_id):
global translation_vae
cocktail_rep = translation_vae(music_ai_rep, modality_out='cocktail')
return cocktail_rep
def debug_translation(music_ai_rep):
global translation_vae
music_reconstruction = translation_vae(music_ai_rep, modality_out='music')
return music_reconstruction
def music2cocktailrep(music_ai_rep, handcoded_music_rep, verbose=False, level=0):
init_time = time.time()
if verbose: print(' ' * level + 'Synesthetic mapping..')
if verbose: print(' ' * (level*2) + 'Mapping to affective cluster.')
# affective_cluster_id, affect = music2affect_cluster(handcoded_music_rep)
affective_cluster_id, affect = None, None
if verbose: print(' ' * (level*2) + 'Mapping to flavors.')
cocktail_rep = music2flavor(music_ai_rep, affective_cluster_id)
if verbose: print(' ' * (level + 2) + f'Mapped in {int(time.time() - init_time)} seconds.')
return cocktail_rep, affective_cluster_id, affect
# def sigmoid(x, shift, beta):
# return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
#
# cluster_colors = ['#%06X' % random.randint(0, 0xFFFFFF) for _ in range(10)]
# def plot_cluster_ids_dataset(handcoded_rep_path):
# import matplotlib.pyplot as plt
# reps, _, _ = get_data(handcoded_rep_path, keys)
# cluster_ids, affects = music2affect_cluster(reps)
# # plt.figure()
# # affects2 = affects.copy()
# # affects2 = sigmoid(affects2, 0.05, 8)
# # plt.hist(affects2[:, 2], bins=30)
# # plt.xlim([-1, 1])
# fig = plt.figure()
# ax = fig.add_subplot(projection='3d')
# ax.set_xlim([-1, 1])
# ax.set_ylim([-1, 1])
# ax.set_zlim([-1, 1])
# for cluster_id in sorted(set(cluster_ids)):
# indexes = np.argwhere(cluster_ids == cluster_id).flatten()
# if len(indexes) > 0:
# ax.scatter(affects[indexes, 0], affects[indexes, 1], affects[indexes, 2], c=cluster_colors[cluster_id], s=150)
# ax.set_xlabel('Valence')
# ax.set_ylabel('Arousal')
# ax.set_zlabel('Dominance')
# plt.figure()
# plt.bar(range(10), [np.argwhere(cluster_ids == i).size for i in range(10)])
# plt.show()
#
# plot_cluster_ids_dataset(handcoded_rep_path)