merve's picture
merve HF staff
Update app.py
dd426fe
# https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html
from itertools import cycle
from time import time
import gradio as gr
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
from joblib import cpu_count
from sklearn.cluster import Birch, MiniBatchKMeans
from sklearn.datasets import make_blobs
plt.switch_backend("agg")
def do_submit(n_samples, birch_threshold, birch_n_clusters):
n_samples = int(n_samples)
birch_threshold = float(birch_threshold)
birch_n_clusters = int(birch_n_clusters)
result = ""
# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0)
# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
fig = plt.figure(figsize=(12, 4))
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)
# Compute clustering with BIRCH with and without the final clustering step
# and plot.
birch_models = [
Birch(threshold=1.7, n_clusters=None),
Birch(threshold=1.7, n_clusters=100),
]
final_step = ["without global clustering", "with global clustering"]
for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
t = time()
birch_model.fit(X)
result += (
"BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t))
+ "\n"
)
# Plot result
labels = birch_model.labels_
centroids = birch_model.subcluster_centers_
n_clusters = np.unique(labels).size
result = result + "n_clusters : %d" % n_clusters + "\n"
ax = fig.add_subplot(1, 3, ind + 1)
for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
mask = labels == k
ax.scatter(
X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5
)
if birch_model.n_clusters is None:
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
ax.set_ylim([-25, 25])
ax.set_xlim([-25, 25])
ax.set_autoscaley_on(False)
ax.set_title("BIRCH %s" % info)
# Compute clustering with MiniBatchKMeans.
mbk = MiniBatchKMeans(
init="k-means++",
n_clusters=100,
batch_size=256 * cpu_count(),
n_init=10,
max_no_improvement=10,
verbose=0,
random_state=0,
)
t0 = time()
mbk.fit(X)
t_mini_batch = time() - t0
result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n"
mbk_means_labels_unique = np.unique(mbk.labels_)
ax = fig.add_subplot(1, 3, 3)
for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):
mask = mbk.labels_ == k
ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5)
ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
ax.set_xlim([-25, 25])
ax.set_ylim([-25, 25])
ax.set_title("MiniBatchKMeans")
ax.set_autoscaley_on(False)
return fig, result
# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[
gr.themes.GoogleFont("Open Sans"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
)
title = "Compare BIRCH and MiniBatchKMeans"
with gr.Blocks(title=title, theme=theme) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(
"This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)."
)
gr.Markdown(
"This example compares the timing of BIRCH (with and without the global clustering step) and \
MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\
\n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \
even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \
Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \
rerun it with a larger value for n_samples.\
\n\n\
If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters."
)
n_samples = gr.Slider(
minimum=20000,
maximum=80000,
label="Number of samples",
step=500,
value=25000,
)
birch_threshold = gr.Slider(
minimum=0.5,
maximum=2.0,
label="Birch Threshold",
step=0.1,
value=1.7,
)
birch_n_clusters = gr.Slider(
minimum=0,
maximum=100,
label="Birch number of clusters",
step=1,
value=100,
)
plt_out = gr.Plot()
output = gr.Textbox(label="Output", multiline=True)
sub_btn = gr.Button("Submit")
sub_btn.click(
fn=do_submit,
inputs=[n_samples, birch_threshold, birch_n_clusters],
outputs=[plt_out, output],
)
if __name__ == "__main__":
demo.launch()