File size: 781 Bytes
93c029f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from src.music.config import CHECKPOINTS_PATH
import pickle
import numpy as np

# can be computed from cocktail2affect
cluster_model_path = CHECKPOINTS_PATH + "/music2cocktails/affects2affect_cluster/cluster_model.pickle"

def get_affect2affective_cluster():
    with open(cluster_model_path, 'rb') as f:
        data = pickle.load(f)
    model = data['cluster_model']
    dimensions_weights = data['dimensions_weights']
    def find_cluster(aff_coord):
        if aff_coord.ndim == 1:
            aff_coord = aff_coord.reshape(1, -1)
        return model.predict(aff_coord * np.array(dimensions_weights))
    return find_cluster

def get_affective_cluster_centers():
    with open(cluster_model_path, 'rb') as f:
        data = pickle.load(f)
    return  data['cluster_centers']