Spaces:
Runtime error
Runtime error
File size: 1,393 Bytes
d5d7329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from sklearn.cluster import KMeans
def get_cluster_model(ckpt_path: Path | str):
with Path(ckpt_path).open("rb") as f:
checkpoint = torch.load(
f, map_location="cpu"
) # Danger of arbitrary code execution
kmeans_dict = {}
for spk, ckpt in checkpoint.items():
km = KMeans(ckpt["n_features_in_"])
km.__dict__["n_features_in_"] = ckpt["n_features_in_"]
km.__dict__["_n_threads"] = ckpt["_n_threads"]
km.__dict__["cluster_centers_"] = ckpt["cluster_centers_"]
kmeans_dict[spk] = km
return kmeans_dict
def check_speaker(model: Any, speaker: Any):
if speaker not in model:
raise ValueError(f"Speaker {speaker} not in {list(model.keys())}")
def get_cluster_result(model: Any, x: Any, speaker: Any):
"""
x: np.array [t, 256]
return cluster class result
"""
check_speaker(model, speaker)
return model[speaker].predict(x)
def get_cluster_center_result(model: Any, x: Any, speaker: Any):
"""x: np.array [t, 256]"""
check_speaker(model, speaker)
predict = model[speaker].predict(x)
return model[speaker].cluster_centers_[predict]
def get_center(model: Any, x: Any, speaker: Any):
check_speaker(model, speaker)
return model[speaker].cluster_centers_[x]
|