cmpatino's picture
Add missing period in description
e4abb69
raw
history blame contribute delete
No virus
3.77 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
def get_clusters_plot(n_blobs, quantile, cluster_std):
X, _, centers = make_blobs(
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
)
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
fig = plt.figure()
for k in range(n_clusters_):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.scatter(X[my_members, 0], X[my_members, 1])
plt.plot(
cluster_center[0],
cluster_center[1],
"x",
markeredgecolor="k",
markersize=14,
)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title(f"Estimated number of clusters: {n_clusters_}")
if len(centers) != n_clusters_:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" differs from the true number of clusters ({n_blobs})."
+ " Try changing the `Quantile` parameter.</p>"
)
else:
message = (
'<p style="text-align: center;">'
+ f"The number of estimated clusters ({n_clusters_})"
+ f" matches the true number of clusters ({n_blobs})!</p>"
)
return fig, message
with gr.Blocks() as demo:
gr.Markdown(
"""
# Mean Shift Clustering
This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs.
This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py).
"""
)
with gr.Row():
with gr.Column(scale=1):
n_blobs = gr.Slider(
minimum=2,
maximum=10,
label="Number of clusters in the data",
step=1,
value=3,
)
quantile = gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.2,
label="Quantile",
info="Used to determine clustering's bandwidth.",
)
cluster_std = gr.Slider(
minimum=0.1,
maximum=1,
label="Clusters' standard deviation",
step=0.1,
value=0.6,
)
with gr.Column(scale=4):
clusters_plots = gr.Plot(label="Clusters' Plot")
message = gr.HTML()
n_blobs.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
quantile.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
cluster_std.change(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
demo.load(
get_clusters_plot,
[n_blobs, quantile, cluster_std],
[clusters_plots, message],
queue=False,
)
if __name__ == "__main__":
demo.launch()