File size: 901 Bytes
da6d421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import torch
from sklearn.cluster import KMeans

def get_cluster_model(ckpt_path):
    checkpoint = torch.load(ckpt_path)
    kmeans_dict = {}
    for spk, ckpt in checkpoint.items():
        km = KMeans(ckpt["n_features_in_"])
        km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
        km.__dict__["_n_threads"] = ckpt["_n_threads"]
        km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
        kmeans_dict[spk] = km
    return kmeans_dict

def get_cluster_result(model, x, speaker):
    """
        x: np.array [t, 256]
        return cluster class result
    """
    return model[speaker].predict(x)

def get_cluster_center_result(model, x,speaker):
    """x: np.array [t, 256]"""
    predict = model[speaker].predict(x)
    return model[speaker].cluster_centers_[predict]

def get_center(model, x,speaker):
    return model[speaker].cluster_centers_[x]