monitoringInterface / utils_clustering.py
HugoHE's picture
initial commit
1215771
raw
history blame
2.1 kB
from util import *
from sklearn.cluster import KMeans, SpectralClustering, DBSCAN
from sklearn import metrics
import numpy as np
import warnings
import os
from util.Monitor_construction import Box
from runtime_monitors import *
def k_means_cluster(data, n_clusters):
kmeans = KMeans(n_clusters=n_clusters, init='k-means++', random_state=0, n_init="auto")
kmeans.fit_predict(data)
lbs = kmeans.labels_ # cluster labels
clusters = dict()
for lb in set(lbs):
idx = np.where(lbs == lb)[0]
clusters[lb] = list(zip(idx, data[idx]))
return clusters
def spectral_cluster(data, n_clusters):
n_neighbors = min(n_clusters, 10)
spectral = SpectralClustering(n_clusters=n_clusters, affinity='nearest_neighbors', n_neighbors=n_neighbors,
gamma=1.0, eigen_solver="arpack", random_state=0)
# catch warnings related to kneighbors_graph
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="the number of connected components of the "
+ "connectivity matrix is [0-9]{1,2}"
+ " > 1. Completing it to avoid stopping the tree early.",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="Graph is not fully connected, spectral embedding"
+ " may not work as expected.",
category=UserWarning,
)
spectral = spectral.fit(data)
lbs = spectral.labels_ # cluster labels
clusters = dict()
for lb in set(lbs):
idx = np.where(lbs == lb)[0]
clusters[lb] = list(zip(idx, data[idx]))
return clusters
def dbscan_cluster(data, eps, min_samples):
db = DBSCAN(eps=eps, min_samples=min_samples).fit(data)
lbs = db.labels_ # cluster labels
n_cls = len(set(lbs)) - (1 if -1 in lbs else 0) # number of clusters
n_noise = list(lbs).count(-1)
clusters = dict()
for lb in set(lbs):
idx = np.where(lbs == lb)[0]
clusters[lb] = list(zip(idx, data[idx]))
return clusters