Spaces:
Runtime error
Runtime error
import pickle | |
import numpy as np | |
from src.music.config import CHECKPOINTS_PATH | |
# these can be generated by the fit_glm script. | |
min_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + '/music2cocktails/music2affects/min_handcoded_reps.txt') | |
max_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + 'music2cocktails/music2affects/max_handcoded_reps.txt') | |
affective_models_path = CHECKPOINTS_PATH + '/music2cocktails/music2affects/music2affect_models.pickle' | |
final_keys_path = CHECKPOINTS_PATH + "/music2cocktails/music2affects/final_best_keys.pickle" | |
def sigmoid(x, shift, beta): | |
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2 | |
def normalize_handcoded_reps(handcoded_rep): | |
return (handcoded_rep - min_handcoded_reps) / (max_handcoded_reps - min_handcoded_reps) | |
def setup_pretrained_affective_models(): | |
with open(final_keys_path, 'rb') as f: | |
best_keys = pickle.load(f) | |
keys = sorted(set(best_keys[0] + best_keys[1] + best_keys[2])) | |
bestkeys_indexes = [np.array([keys.index(k) for k in bk]) for bk in best_keys] | |
with open(affective_models_path, 'rb') as f: | |
music2affect_models = pickle.load(f) | |
def music2affect(handcoded_rep): | |
if handcoded_rep.ndim == 1: | |
handcoded_rep = handcoded_rep.reshape(1, -1) | |
assert handcoded_rep.shape[1] == len(keys) | |
handcoded_rep = normalize_handcoded_reps(handcoded_rep) | |
affects = [] | |
for i_dim, dim in enumerate(['valence', 'arousal', 'dominance']): | |
model = music2affect_models[dim] | |
my_preds = [] | |
probas = model.predict_proba(handcoded_rep[:, bestkeys_indexes[i_dim]]) | |
for r in probas: | |
my_preds.append(np.mean(np.random.choice(range(1, 11), p=r, size=1000))) | |
my_preds = np.array(my_preds) | |
affects.append(my_preds) | |
# affects.append(model.predict(handcoded_rep)) | |
affects = np.array(affects).transpose() | |
affects = ((affects - 1) / 9 - 0.5) * 2 # map to -1, 1 | |
affects[:, 0] = sigmoid(affects[:, 0], shift=0, beta=7) # stretch for wider distribution | |
affects[:, 1] = sigmoid(affects[:, 1], shift=-0.05, beta=5) # stretch for wider distribution | |
affects[:, 2] = sigmoid(affects[:, 2], shift=0.05, beta=8) # stretch for wider distribution | |
return affects | |
return music2affect, keys |