Nogizaka46-so / cluster /__init__.py
darksakura's picture
Duplicate from darksakura/l1
f82071f
raw
history blame
901 Bytes
import numpy as np
import torch
from sklearn.cluster import KMeans
def get_cluster_model(ckpt_path):
checkpoint = torch.load(ckpt_path)
kmeans_dict = {}
for spk, ckpt in checkpoint.items():
km = KMeans(ckpt["n_features_in_"])
km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
km.__dict__["_n_threads"] = ckpt["_n_threads"]
km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
kmeans_dict[spk] = km
return kmeans_dict
def get_cluster_result(model, x, speaker):
"""
x: np.array [t, 256]
return cluster class result
"""
return model[speaker].predict(x)
def get_cluster_center_result(model, x,speaker):
"""x: np.array [t, 256]"""
predict = model[speaker].predict(x)
return model[speaker].cluster_centers_[predict]
def get_center(model, x,speaker):
return model[speaker].cluster_centers_[x]