import time import warnings from functools import partial import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn import cluster, datasets from sklearn.preprocessing import StandardScaler from itertools import cycle, islice def train_models(selected_data, n_samples, n_clusters, n_neighbors, cls_name): np.random.seed(0) default_base = {"n_neighbors": 10, "n_clusters": 3} noisy_circles = datasets.make_circles(n_samples=n_samples, factor=0.5, noise=0.05) noisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05) blobs = datasets.make_blobs(n_samples=n_samples, random_state=8) no_structure = np.random.rand(n_samples, 2), None # Anisotropicly distributed data random_state = 170 X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state) transformation = [[0.6, -0.6], [-0.4, 0.8]] X_aniso = np.dot(X, transformation) aniso = (X_aniso, y) # blobs with varied variances varied = datasets.make_blobs( n_samples=n_samples, cluster_std=[1.0, 2.5, 0.5], random_state=random_state ) dataset_list = { "Noisy Circles": [noisy_circles, {"n_clusters": n_clusters}], "Noisy Moons": [noisy_moons, {"n_clusters": n_clusters}], "Varied": [varied, {"n_neighbors": n_neighbors}], "Aniso": [aniso, {"n_neighbors": n_neighbors}], "Blobs": [blobs, {}], "No Structure": [no_structure, {}], } params = default_base.copy() params.update(dataset_list[selected_data][1]) X, y = dataset_list[selected_data][0] X = StandardScaler().fit_transform(X) ward = cluster.AgglomerativeClustering( n_clusters=params["n_clusters"], linkage="ward" ) complete = cluster.AgglomerativeClustering( n_clusters=params["n_clusters"], linkage="complete" ) average = cluster.AgglomerativeClustering( n_clusters=params["n_clusters"], linkage="average" ) single = cluster.AgglomerativeClustering( n_clusters=params["n_clusters"], linkage="single" ) clustering_algorithms = { "Single Linkage": single, "Average Linkage": average, "Complete Linkage": complete, "Ward Linkage": ward, } t0 = time.time() algorithm = clustering_algorithms[cls_name] # catch warnings related to kneighbors_graph with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="the number of connected components of the " + "connectivity matrix is [0-9]{1,2}" + " > 1. Completing it to avoid stopping the tree early.", category=UserWarning, ) algorithm.fit(X) t1 = time.time() if hasattr(algorithm, "labels_"): y_pred = algorithm.labels_.astype(int) else: y_pred = algorithm.predict(X) fig, ax = plt.subplots() colors = np.array( list( islice( cycle( [ "#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3", "#999999", "#e41a1c", "#dede00", ] ), int(max(y_pred) + 1), ) ) ) ax.scatter(X[:, 0], X[:, 1], color=colors[y_pred]) ax.set_xlim(-2.5, 2.5) ax.set_ylim(-2.5, 2.5) ax.set_xticks(()) ax.set_yticks(()) return fig def iter_grid(n_rows, n_cols): # create a grid using gradio Block for _ in range(n_rows): with gr.Row(): for _ in range(n_cols): with gr.Column(): yield title = "🧑🏻‍🔬 Compare linkages in hierarchical clustering 🧑🏻‍🔬" with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown("This app demonstrates different linkage methods in" " hierarchical clustering 🔗") input_models = ["Single Linkage", "Average Linkage", "Complete Linkage", "Ward Linkage"] input_data = gr.Radio( choices=["Noisy Circles", "Noisy Moons", "Varied", "Aniso", "Blobs", "No Structure"], value="Noisy Moons" ) n_samples = gr.Slider(minimum=500, maximum=2000, step=50, label = "Number of Samples") n_neighbors = gr.Slider(minimum=2, maximum=5, step=1, label = "Number of neighbors") n_clusters = gr.Slider(minimum=2, maximum=5, step=1, label = "Number of Clusters") counter = 0 for _ in iter_grid(2, 2): if counter >= len(input_models): break input_model = input_models[counter] plot = gr.Plot(label=input_model) fn = partial(train_models, cls_name=input_model) input_data.change(fn=fn, inputs=[input_data, n_samples, n_clusters, n_neighbors], outputs=plot) n_samples.change(fn=fn, inputs=[input_data, n_samples, n_clusters, n_neighbors], outputs=plot) n_neighbors.change(fn=fn, inputs=[input_data, n_samples, n_clusters, n_neighbors], outputs=plot) n_clusters.change(fn=fn, inputs=[input_data, n_samples, n_clusters, n_neighbors], outputs=plot) counter += 1 demo.launch()