# https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html from itertools import cycle from time import time import gradio as gr import matplotlib.colors as colors import matplotlib.pyplot as plt import numpy as np from joblib import cpu_count from sklearn.cluster import Birch, MiniBatchKMeans from sklearn.datasets import make_blobs plt.switch_backend("agg") def do_submit(n_samples, birch_threshold, birch_n_clusters): n_samples = int(n_samples) birch_threshold = float(birch_threshold) birch_n_clusters = int(birch_n_clusters) result = "" # Generate centers for the blobs so that it forms a 10 X 10 grid. xx = np.linspace(-22, 22, 10) yy = np.linspace(-22, 22, 10) xx, yy = np.meshgrid(xx, yy) n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis])) # Generate blobs to do a comparison between MiniBatchKMeans and BIRCH. X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0) # Use all colors that matplotlib provides by default. colors_ = cycle(colors.cnames.keys()) fig = plt.figure(figsize=(12, 4)) fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9) # Compute clustering with BIRCH with and without the final clustering step # and plot. birch_models = [ Birch(threshold=1.7, n_clusters=None), Birch(threshold=1.7, n_clusters=100), ] final_step = ["without global clustering", "with global clustering"] for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)): t = time() birch_model.fit(X) result += ( "BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t)) + "\n" ) # Plot result labels = birch_model.labels_ centroids = birch_model.subcluster_centers_ n_clusters = np.unique(labels).size result = result + "n_clusters : %d" % n_clusters + "\n" ax = fig.add_subplot(1, 3, ind + 1) for this_centroid, k, col in zip(centroids, range(n_clusters), colors_): mask = labels == k ax.scatter( X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5 ) if birch_model.n_clusters is None: ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) ax.set_ylim([-25, 25]) ax.set_xlim([-25, 25]) ax.set_autoscaley_on(False) ax.set_title("BIRCH %s" % info) # Compute clustering with MiniBatchKMeans. mbk = MiniBatchKMeans( init="k-means++", n_clusters=100, batch_size=256 * cpu_count(), n_init=10, max_no_improvement=10, verbose=0, random_state=0, ) t0 = time() mbk.fit(X) t_mini_batch = time() - t0 result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n" mbk_means_labels_unique = np.unique(mbk.labels_) ax = fig.add_subplot(1, 3, 3) for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_): mask = mbk.labels_ == k ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5) ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25) ax.set_xlim([-25, 25]) ax.set_ylim([-25, 25]) ax.set_title("MiniBatchKMeans") ax.set_autoscaley_on(False) return fig, result # Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[ gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif", ], ) title = "Compare BIRCH and MiniBatchKMeans" with gr.Blocks(title=title, theme=theme) as demo: gr.Markdown(f"## {title}") gr.Markdown( "This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)." ) gr.Markdown( "This example compares the timing of BIRCH (with and without the global clustering step) and \ MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\ \n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \ even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \ Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \ rerun it with a larger value for n_samples.\ \n\n\ If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters." ) n_samples = gr.Slider( minimum=20000, maximum=80000, label="Number of samples", step=500, value=25000, ) birch_threshold = gr.Slider( minimum=0.5, maximum=2.0, label="Birch Threshold", step=0.1, value=1.7, ) birch_n_clusters = gr.Slider( minimum=0, maximum=100, label="Birch number of clusters", step=1, value=100, ) plt_out = gr.Plot() output = gr.Textbox(label="Output", multiline=True) sub_btn = gr.Button("Submit") sub_btn.click( fn=do_submit, inputs=[n_samples, birch_threshold, birch_n_clusters], outputs=[plt_out, output], ) if __name__ == "__main__": demo.launch()