ccolas's picture
Upload 174 files
93c029f
raw
history blame
2.35 kB
import pickle
import numpy as np
from src.music.config import CHECKPOINTS_PATH
# these can be generated by the fit_glm script.
min_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + '/music2cocktails/music2affects/min_handcoded_reps.txt')
max_handcoded_reps = np.loadtxt(CHECKPOINTS_PATH + 'music2cocktails/music2affects/max_handcoded_reps.txt')
affective_models_path = CHECKPOINTS_PATH + '/music2cocktails/music2affects/music2affect_models.pickle'
final_keys_path = CHECKPOINTS_PATH + "/music2cocktails/music2affects/final_best_keys.pickle"
def sigmoid(x, shift, beta):
return (1 / (1 + np.exp(-(x + shift) * beta)) - 0.5) * 2
def normalize_handcoded_reps(handcoded_rep):
return (handcoded_rep - min_handcoded_reps) / (max_handcoded_reps - min_handcoded_reps)
def setup_pretrained_affective_models():
with open(final_keys_path, 'rb') as f:
best_keys = pickle.load(f)
keys = sorted(set(best_keys[0] + best_keys[1] + best_keys[2]))
bestkeys_indexes = [np.array([keys.index(k) for k in bk]) for bk in best_keys]
with open(affective_models_path, 'rb') as f:
music2affect_models = pickle.load(f)
def music2affect(handcoded_rep):
if handcoded_rep.ndim == 1:
handcoded_rep = handcoded_rep.reshape(1, -1)
assert handcoded_rep.shape[1] == len(keys)
handcoded_rep = normalize_handcoded_reps(handcoded_rep)
affects = []
for i_dim, dim in enumerate(['valence', 'arousal', 'dominance']):
model = music2affect_models[dim]
my_preds = []
probas = model.predict_proba(handcoded_rep[:, bestkeys_indexes[i_dim]])
for r in probas:
my_preds.append(np.mean(np.random.choice(range(1, 11), p=r, size=1000)))
my_preds = np.array(my_preds)
affects.append(my_preds)
# affects.append(model.predict(handcoded_rep))
affects = np.array(affects).transpose()
affects = ((affects - 1) / 9 - 0.5) * 2 # map to -1, 1
affects[:, 0] = sigmoid(affects[:, 0], shift=0, beta=7) # stretch for wider distribution
affects[:, 1] = sigmoid(affects[:, 1], shift=-0.05, beta=5) # stretch for wider distribution
affects[:, 2] = sigmoid(affects[:, 2], shift=0.05, beta=8) # stretch for wider distribution
return affects
return music2affect, keys