cmpatino's picture
Add header explaining what the space does
48ad3a5
raw
history blame
3.71 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
def get_clusters_plot(n_blobs, quantile, cluster_std):
X, _, centers = make_blobs(
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
)
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
fig = plt.figure()
for k in range(n_clusters_):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.scatter(X[my_members, 0], X[my_members, 1])
plt.plot(
cluster_center[0],
cluster_center[1],
"x",
markeredgecolor="k",
markersize=14,
)
plt.title(f"Estimated number of clusters: {n_clusters_}")
if len(centers) != n_clusters_:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" differs from the true number of clusters ({n_blobs})."
+ " Try changing the `Quantile` parameter.</p>"
)
else:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" matches the true number of clusters ({n_blobs})!</p>"
)
return fig, message
with gr.Blocks() as demo:
gr.Markdown(
"""
# Mean Shift Clustering
This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs.
This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py)
"""
)
with gr.Row():
with gr.Column(scale=1):
n_blobs = gr.Slider(
minimum=2,
maximum=10,
label="Number of clusters in the data",
step=1,
value=3,
)
quantile = gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.2,
label="Quantile",
info="Used to determine clustering's bandwidth.",
)
cluster_std = gr.Slider(
minimum=0.1,
maximum=1,
label="Clusters' standard deviation",
step=0.1,
value=0.6,
)
with gr.Column(scale=4):
clusters_plots = gr.Plot(label="Clusters' Plot")
message = gr.HTML()
n_blobs.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
quantile.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
cluster_std.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
demo.load(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
if __name__ == "__main__":
demo.launch()