Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from sklearn.cluster import KMeans | |
def get_cluster_model(ckpt_path): | |
checkpoint = torch.load(ckpt_path) | |
kmeans_dict = {} | |
for spk, ckpt in checkpoint.items(): | |
km = KMeans(ckpt["n_features_in_"]) | |
km.__dict__["n_features_in_"] = ckpt["n_features_in_"] | |
km.__dict__["_n_threads"] = ckpt["_n_threads"] | |
km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"] | |
kmeans_dict[spk] = km | |
return kmeans_dict | |
def get_cluster_result(model, x, speaker): | |
""" | |
x: np.array [t, 256] | |
return cluster class result | |
""" | |
return model[speaker].predict(x) | |
def get_cluster_center_result(model, x,speaker): | |
"""x: np.array [t, 256]""" | |
predict = model[speaker].predict(x) | |
return model[speaker].cluster_centers_[predict] | |
def get_center(model, x,speaker): | |
return model[speaker].cluster_centers_[x] | |