|
import gradio as gr |
|
from sklearn.datasets import make_blobs |
|
from sklearn.cluster import KMeans |
|
from sklearn.metrics import silhouette_samples, silhouette_score |
|
|
|
import matplotlib.pyplot as plt |
|
import matplotlib.cm as cm |
|
import numpy as np |
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", |
|
) |
|
|
|
|
|
def main( |
|
n_clusters: int = 2, |
|
n_samples: int = 500, |
|
n_features: int = 2, |
|
n_centers: int = 4, |
|
cluster_std: int = 1, |
|
): |
|
|
|
|
|
|
|
X, y = make_blobs( |
|
n_samples=n_samples, |
|
n_features=n_features, |
|
centers=n_centers, |
|
cluster_std=cluster_std, |
|
center_box=(-10.0, 10.0), |
|
shuffle=True, |
|
random_state=1, |
|
) |
|
|
|
n_clusters = int(n_clusters) |
|
fig1, ax1 = plt.subplots() |
|
fig1.set_size_inches(9, 4) |
|
fig2, ax2 = plt.subplots() |
|
fig2.set_size_inches(9, 4) |
|
|
|
|
|
|
|
ax1.set_xlim([-0.1, 1]) |
|
|
|
|
|
ax1.set_ylim([0, len(X) + (n_clusters + 1) * 10]) |
|
|
|
|
|
|
|
clusterer = KMeans(n_clusters=n_clusters, n_init="auto", random_state=10) |
|
cluster_labels = clusterer.fit_predict(X) |
|
|
|
|
|
|
|
|
|
silhouette_avg = silhouette_score(X, cluster_labels) |
|
print( |
|
"For n_clusters =", |
|
n_clusters, |
|
"The average silhouette_score is :", |
|
silhouette_avg, |
|
) |
|
|
|
|
|
sample_silhouette_values = silhouette_samples(X, cluster_labels) |
|
|
|
y_lower = 10 |
|
for i in range(n_clusters): |
|
|
|
|
|
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i] |
|
|
|
ith_cluster_silhouette_values.sort() |
|
|
|
size_cluster_i = ith_cluster_silhouette_values.shape[0] |
|
y_upper = y_lower + size_cluster_i |
|
|
|
color = cm.nipy_spectral(float(i) / n_clusters) |
|
ax1.fill_betweenx( |
|
np.arange(y_lower, y_upper), |
|
0, |
|
ith_cluster_silhouette_values, |
|
facecolor=color, |
|
edgecolor=color, |
|
alpha=0.7, |
|
) |
|
|
|
|
|
ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i)) |
|
|
|
|
|
y_lower = y_upper + 10 |
|
|
|
ax1.set_title("The silhouette plot for the various clusters.") |
|
ax1.set_xlabel("The silhouette coefficient values") |
|
ax1.set_ylabel("Cluster label") |
|
|
|
|
|
ax1.axvline(x=silhouette_avg, color="red", linestyle="--") |
|
|
|
ax1.set_yticks([]) |
|
ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1]) |
|
|
|
|
|
colors = cm.nipy_spectral(cluster_labels.astype(float) / n_clusters) |
|
ax2.scatter( |
|
X[:, 0], X[:, 1], marker=".", s=30, lw=0, alpha=0.7, c=colors, edgecolor="k" |
|
) |
|
|
|
|
|
centers = clusterer.cluster_centers_ |
|
|
|
ax2.scatter( |
|
centers[:, 0], |
|
centers[:, 1], |
|
marker="o", |
|
c="white", |
|
alpha=1, |
|
s=200, |
|
edgecolor="k", |
|
) |
|
|
|
for i, c in enumerate(centers): |
|
ax2.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50, edgecolor="k") |
|
|
|
ax2.set_title("The visualization of the clustered data.") |
|
ax2.set_xlabel("Feature space for the 1st feature") |
|
ax2.set_ylabel("Feature space for the 2nd feature") |
|
|
|
return fig1, fig2 |
|
|
|
|
|
title = """# Selecting the number of clusters with silhouette analysis on KMeans clustering π""" |
|
description = """ |
|
This app demonstrates a silhouette analysis for KMeans clustering on sample data. |
|
|
|
The purpose of a clustering algorithm is to find groups of similar data points. The purpose of a silhouette analysis is to determine the optimal number of clusters for a given clustering algorithm. The silhouette analysis can be used on any clustering algorithm, but it is most commonly used with KMeans clustering. |
|
""" |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
gr.Markdown("""### Dataset Generation Parameters""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
n_samples = gr.inputs.Slider( |
|
minimum=100, |
|
maximum=1000, |
|
default=500, |
|
step=50, |
|
label="Number of Samples", |
|
) |
|
n_features = gr.inputs.Slider( |
|
minimum=2, maximum=5, default=2, step=1, label="Number of Features" |
|
) |
|
n_centers = gr.inputs.Slider( |
|
minimum=2, maximum=5, default=4, step=1, label="Number of Centers" |
|
) |
|
cluster_std = gr.inputs.Slider( |
|
minimum=0.0, maximum=1.0, default=1, step=0.1, label="Cluster deviation" |
|
) |
|
n_clusters = gr.inputs.Slider( |
|
minimum=2, maximum=6, default=2, step=1, label="Number of Clusters" |
|
) |
|
run_button = gr.Button("Analyse Silhouette") |
|
with gr.Row(): |
|
plot_silhouette = gr.Plot() |
|
plot_clusters = gr.Plot() |
|
outputs = [plot_silhouette, plot_clusters] |
|
inputs = [n_clusters, n_samples, n_features, n_centers, cluster_std] |
|
run_button.click(fn=main, inputs=inputs, outputs=outputs) |
|
|
|
|
|
demo.launch() |
|
|