"""Gradio demo for different clustering techiniques Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html """ import math from functools import partial import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import ( AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth ) from sklearn.datasets import make_blobs, make_circles, make_moons from sklearn.mixture import GaussianMixture from sklearn.neighbors import kneighbors_graph from sklearn.preprocessing import StandardScaler plt.style.use('seaborn') SEED = 0 MAX_CLUSTERS = 10 N_SAMPLES = 1000 N_COLS = 3 FIGSIZE = 7, 7 # does not affect size in webpage COLORS = [ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan' ] assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters" np.random.seed(SEED) def normalize(X): return StandardScaler().fit_transform(X) def get_regular(n_clusters): # spiral pattern centers = [ [0, 0], [1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1], [2, -1], ][:n_clusters] assert len(centers) == n_clusters X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED) return normalize(X), labels def get_circles(n_clusters): X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED) return normalize(X), labels def get_moons(n_clusters): X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED) return normalize(X), labels def get_noise(n_clusters): np.random.seed(SEED) X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,)) return normalize(X), labels def get_anisotropic(n_clusters): X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170) transformation = [[0.6, -0.6], [-0.4, 0.8]] X = np.dot(X, transformation) return X, labels def get_varied(n_clusters): cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters] assert len(cluster_std) == n_clusters X, labels = make_blobs( n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED ) return normalize(X), labels def get_spiral(n_clusters): # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html np.random.seed(SEED) t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES)) x = t * np.cos(t) y = t * np.sin(t) X = np.concatenate((x, y)) X += 0.7 * np.random.randn(2, N_SAMPLES) X = np.ascontiguousarray(X.T) labels = np.zeros(N_SAMPLES, dtype=int) return normalize(X), labels DATA_MAPPING = { 'regular': get_regular, 'circles': get_circles, 'moons': get_moons, 'spiral': get_spiral, 'noise': get_noise, 'anisotropic': get_anisotropic, 'varied': get_varied, } def get_groundtruth_model(X, labels, n_clusters, **kwargs): # dummy model to show true label distribution class Dummy: def __init__(self, y): self.labels_ = labels return Dummy(labels) def get_kmeans(X, labels, n_clusters, **kwargs): model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED) model.set_params(**kwargs) return model.fit(X) def get_dbscan(X, labels, n_clusters, **kwargs): model = DBSCAN(eps=0.3) model.set_params(**kwargs) return model.fit(X) def get_agglomerative(X, labels, n_clusters, **kwargs): connectivity = kneighbors_graph( X, n_neighbors=n_clusters, include_self=False ) # make connectivity symmetric connectivity = 0.5 * (connectivity + connectivity.T) model = AgglomerativeClustering( n_clusters=n_clusters, linkage="ward", connectivity=connectivity ) model.set_params(**kwargs) return model.fit(X) def get_meanshift(X, labels, n_clusters, **kwargs): bandwidth = estimate_bandwidth(X, quantile=0.25) model = MeanShift(bandwidth=bandwidth, bin_seeding=True) model.set_params(**kwargs) return model.fit(X) def get_spectral(X, labels, n_clusters, **kwargs): model = SpectralClustering( n_clusters=n_clusters, eigen_solver="arpack", affinity="nearest_neighbors", ) model.set_params(**kwargs) return model.fit(X) def get_optics(X, labels, n_clusters, **kwargs): model = OPTICS( min_samples=7, xi=0.05, min_cluster_size=0.1, ) model.set_params(**kwargs) return model.fit(X) def get_birch(X, labels, n_clusters, **kwargs): model = Birch(n_clusters=n_clusters) model.set_params(**kwargs) return model.fit(X) def get_gaussianmixture(X, labels, n_clusters, **kwargs): model = GaussianMixture( n_components=n_clusters, covariance_type="full", random_state=SEED, ) model.set_params(**kwargs) return model.fit(X) MODEL_MAPPING = { 'True labels': get_groundtruth_model, 'KMeans': get_kmeans, 'DBSCAN': get_dbscan, 'MeanShift': get_meanshift, 'SpectralClustering': get_spectral, 'OPTICS': get_optics, 'Birch': get_birch, 'GaussianMixture': get_gaussianmixture, 'AgglomerativeClustering': get_agglomerative, } def plot_clusters(ax, X, labels): set_clusters = set(labels) set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately for label, color in zip(sorted(set_clusters), COLORS): idx = labels == label if not sum(idx): continue ax.scatter(X[idx, 0], X[idx, 1], color=color) # show outliers (if any) idx = labels == -1 if sum(idx): ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x') ax.grid(None) ax.set_xticks([]) ax.set_yticks([]) return ax def cluster(dataset: str, n_clusters: int, clustering_algorithm: str): if isinstance(n_clusters, dict): n_clusters = n_clusters['value'] else: n_clusters = int(n_clusters) X, labels = DATA_MAPPING[dataset](n_clusters) model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters) if hasattr(model, "labels_"): y_pred = model.labels_.astype(int) else: y_pred = model.predict(X) fig, ax = plt.subplots(figsize=FIGSIZE) plot_clusters(ax, X, y_pred) ax.set_title(clustering_algorithm, fontsize=16) return fig title = "Clustering with Scikit-learn" description = ( "This example shows how different clustering algorithms work. Simply pick " "the dataset and the number of clusters to see how the clustering algorithms work. " "Colored cirles are (predicted) labels and black x are outliers." ) def iter_grid(n_rows, n_cols): # create a grid using gradio Block for _ in range(n_rows): with gr.Row(): for _ in range(n_cols): with gr.Column(): yield with gr.Blocks(title=title) as demo: gr.HTML(f"{title}") gr.Markdown(description) input_models = list(MODEL_MAPPING) input_data = gr.Radio( list(DATA_MAPPING), value="regular", label="dataset" ) input_n_clusters = gr.Slider( minimum=1, maximum=MAX_CLUSTERS, value=4, step=1, label='Number of clusters' ) n_rows = int(math.ceil(len(input_models) / N_COLS)) counter = 0 for _ in iter_grid(n_rows, N_COLS): if counter >= len(input_models): break input_model = input_models[counter] plot = gr.Plot(label=input_model) fn = partial(cluster, clustering_algorithm=input_model) input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot) input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot) counter += 1 demo.launch()