merve's picture
merve HF staff
Update app.py
e686cf2
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import time
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics.pairwise import pairwise_distances_argmin
model_card = f"""
## Description
This demo compares the performance of the **MiniBatchKMeans** and **KMeans**. The MiniBatchKMeans is faster, but gives slightly different results.
The points that are labelled differently between the two algorithms are also plotted.
You can play around with different ``number of samples`` and ``number of mini batch size`` to see the effect
## Dataset
Simulation dataset
"""
def do_train(n_samples, batch_size):
np.random.seed(0)
centers = np.random.rand(3, 2)
n_clusters = len(centers)
X, labels_true = make_blobs(n_samples=n_samples, centers=centers, cluster_std=0.7)
k_means = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10)
t0 = time.time()
k_means.fit(X)
t_batch = time.time() - t0
mbk = MiniBatchKMeans(
init="k-means++",
n_clusters=n_clusters,
batch_size=batch_size,
n_init=10,
max_no_improvement=10,
verbose=0,
)
t0 = time.time()
mbk.fit(X)
t_mini_batch = time.time() - t0
k_means_cluster_centers = k_means.cluster_centers_
order = pairwise_distances_argmin(k_means.cluster_centers_, mbk.cluster_centers_)
mbk_means_cluster_centers = mbk.cluster_centers_[order]
k_means_labels = pairwise_distances_argmin(X, k_means_cluster_centers)
mbk_means_labels = pairwise_distances_argmin(X, mbk_means_cluster_centers)
colors = ["#4EACC5", "#FF9C34", "#4E9A06"]
# KMeans
fig1, axes1 = plt.subplots()
for k, col in zip(range(n_clusters), colors):
my_members = k_means_labels == k
cluster_center = k_means_cluster_centers[k]
axes1.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".", markersize=15)
axes1.plot(
cluster_center[0],
cluster_center[1],
"o",
markerfacecolor=col,
markeredgecolor="k",
markersize=12,
)
axes1.set_title("KMeans")
axes1.set_xticks(())
axes1.set_yticks(())
# MiniBatchKMeans
fig2, axes2 = plt.subplots()
for k, col in zip(range(n_clusters), colors):
my_members = mbk_means_labels == k
cluster_center = mbk_means_cluster_centers[k]
axes2.plot(X[my_members, 0], X[my_members, 1], "w", markerfacecolor=col, marker=".", markersize=15)
axes2.plot(
cluster_center[0],
cluster_center[1],
"o",
markerfacecolor=col,
markeredgecolor="k",
markersize=12,
)
axes2.set_title("MiniBatchKMeans")
axes2.set_xticks(())
axes2.set_yticks(())
# Initialize the different array to all False
different = mbk_means_labels == 4
fig3, axes3 = plt.subplots()
for k in range(n_clusters):
different += (k_means_labels == k) != (mbk_means_labels == k)
identic = np.logical_not(different)
axes3.plot(X[identic, 0], X[identic, 1], "w", markerfacecolor="#bbbbbb", marker=".", markersize=15)
axes3.plot(X[different, 0], X[different, 1], "w", markerfacecolor="m", marker=".", markersize=15)
axes3.set_title("Difference")
axes3.set_xticks(())
axes3.set_yticks(())
text = f"KMeans Train time: {t_batch:.2f}s Inertia: {k_means.inertia_:.4f}. MiniBatchKMeans Train time: {t_mini_batch:.2f}s Inertia: {mbk.inertia_:.4f}"
plt.close()
return fig1, fig2, fig3, text
with gr.Blocks() as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>Comparison of the K-Means and MiniBatchKMeans clustering algorithms</h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_mini_batch_kmeans.html#sphx-glr-auto-examples-cluster-plot-mini-batch-kmeans-py\">scikit-learn</a>")
n_samples = gr.Slider(minimum=500, maximum=5000, step=500, value=500, label="Number of samples")
batch_size = gr.Slider(minimum=100, maximum=2000, step=100, value=100, label="Size of the mini batches")
with gr.Row():
with gr.Column():
plot1 = gr.Plot(label="KMeans")
with gr.Column():
plot2 = gr.Plot(label="MiniBatchKMeans")
with gr.Column():
plot3 = gr.Plot(label="Difference")
with gr.Row():
results = gr.Textbox(label="Results")
n_samples.change(fn=do_train, inputs=[n_samples, batch_size], outputs=[plot1, plot2, plot3, results])
batch_size.change(fn=do_train, inputs=[n_samples, batch_size], outputs=[plot1, plot2, plot3, results])
demo.launch()