"""Gradio demo for different clustering techiniques Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html """ import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import ( AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth ) from sklearn.datasets import make_blobs, make_circles, make_moons from sklearn.mixture import GaussianMixture from sklearn.neighbors import kneighbors_graph from sklearn.preprocessing import StandardScaler plt.style.use('seaborn') SEED = 0 N_CLUSTERS = 4 N_SAMPLES = 1000 np.random.seed(SEED) def normalize(X): return StandardScaler().fit_transform(X) def get_regular(): centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]] assert len(centers) == N_CLUSTERS X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED) return normalize(X), labels def get_circles(): X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED) return normalize(X), labels def get_moons(): X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED) return normalize(X), labels def get_noise(): X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES) return normalize(X), labels def get_anisotropic(): X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170) transformation = [[0.6, -0.6], [-0.4, 0.8]] X = np.dot(X, transformation) return X, labels def get_varied(): X, labels = make_blobs( n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED ) return normalize(X), labels DATA_MAPPING = { 'regular': get_regular, 'circles': get_circles, 'moons': get_moons, 'noise': get_noise, 'anisotropic': get_anisotropic, 'varied': get_varied, } def get_kmeans(X, **kwargs): model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED) model.set_params(**kwargs) return model.fit(X) def get_dbscan(X, **kwargs): model = DBSCAN(eps=0.3) model.set_params(**kwargs) return model.fit(X) def get_agglomerative(X, **kwargs): connectivity = kneighbors_graph( X, n_neighbors=N_CLUSTERS, include_self=False ) # make connectivity symmetric connectivity = 0.5 * (connectivity + connectivity.T) model = AgglomerativeClustering( n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity ) model.set_params(**kwargs) return model.fit(X) def get_meanshift(X, **kwargs): bandwidth = estimate_bandwidth(X, quantile=0.3) model = MeanShift(bandwidth=bandwidth, bin_seeding=True) model.set_params(**kwargs) return model.fit(X) def get_spectral(X, **kwargs): model = SpectralClustering( n_clusters=N_CLUSTERS, eigen_solver="arpack", affinity="nearest_neighbors", ) model.set_params(**kwargs) return model.fit(X) def get_optics(X, **kwargs): model = OPTICS( min_samples=7, xi=0.05, min_cluster_size=0.1, ) model.set_params(**kwargs) return model.fit(X) def get_birch(X, **kwargs): model = Birch(n_clusters=3) model.set_params(**kwargs) return model.fit(X) def get_gaussianmixture(X, **kwargs): model = GaussianMixture( n_components=N_CLUSTERS, covariance_type="full", random_state=SEED, ) model.set_params(**kwargs) return model.fit(X) MODEL_MAPPING = { 'KMeans': get_kmeans, 'DBSCAN': get_dbscan, 'AgglomerativeClustering': get_agglomerative, 'MeanShift': get_meanshift, 'SpectralClustering': get_spectral, 'OPTICS': get_optics, 'Birch': get_birch, 'GaussianMixture': get_gaussianmixture, } def plot_clusters(ax, X, labels): for label in range(N_CLUSTERS): idx = labels == label if not sum(idx): continue ax.scatter(X[idx, 0], X[idx, 1]) ax.grid(None) ax.set_xticks([]) ax.set_yticks([]) return ax def cluster(clustering_algorithm: str, dataset: str): X, labels = DATA_MAPPING[dataset]() model = MODEL_MAPPING[clustering_algorithm](X) if hasattr(model, "labels_"): y_pred = model.labels_.astype(int) else: y_pred = model.predict(X) fig, axes = plt.subplots(1, 2, figsize=(16, 8)) ax = axes[0] plot_clusters(ax, X, labels) ax.set_title("True clusters") ax = axes[1] plot_clusters(ax, X, y_pred) ax.set_title(clustering_algorithm) return fig title = "Clustering with Scikit-learn" description = "This example shows how different clustering algorithms work. Simply pick the algorithm and the dataset to see the clusters algorithms make." demo = gr.Interface( fn=cluster, inputs=[ gr.Radio( list(MODEL_MAPPING), value="KMeans", label="clustering algorithm" ), gr.Radio( list(DATA_MAPPING), value="regular", label="dataset" ), ], title=title, description=description, outputs=gr.Plot(), ) demo.launch()