import gradio as gr import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets import make_blobs def get_clusters_plot(n_blobs, cluster_std): X, _, centers = make_blobs( n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True ) bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) ms.fit(X) labels = ms.labels_ cluster_centers = ms.cluster_centers_ labels_unique = np.unique(labels) n_clusters_ = len(labels_unique) colors = ["#dede00", "#377eb8", "#f781bf"] markers = ["x", "o", "^"] fig = plt.figure() for k, col in zip(range(n_clusters_), colors): my_members = labels == k cluster_center = cluster_centers[k] plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col) plt.plot( cluster_center[0], cluster_center[1], markers[k], markerfacecolor=col, markeredgecolor="k", markersize=14, ) return fig demo = gr.Interface( get_clusters_plot, [ gr.Slider( minimum=2, maximum=10, label="Number of clusters in data", step=1, value=3 ), gr.Slider( minimum=0.1, maximum=1, label="Cluster standard deviation", step=0.1, value=0.6, ), ], gr.Plot(), allow_flagging="never", ) if __name__ == "__main__": demo.launch()