marik0's picture
Add details to description and plot
07451ac
raw history blame
No virus
3.78 kB
import numpy as np
from sklearn.cluster import AffinityPropagation
from sklearn import metrics
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import gradio as gr
def generate_data(num_centers, num_samples):
all_centers = [[1, 1], [-1, -1], [1, -1], [-1, 1]]
centers = all_centers[:num_centers]
X, labels_true = make_blobs(n_samples=num_samples, centers=centers, cluster_std=0.5, random_state=0)
return X, labels_true
def create_plot(num_clusters, num_samples):
X, labels_true = generate_data(num_clusters, num_samples)
af = AffinityPropagation(preference=-50, random_state=0).fit(X)
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_
n_clusters_ = len(cluster_centers_indices)
metrics_str = f"Estimated number of clusters: {n_clusters_}\n"
metrics_str += f"Homogeneity: {metrics.homogeneity_score(labels_true, labels):0.3f}\n"
metrics_str += f"Completeness: {metrics.completeness_score(labels_true, labels):0.3f}\n"
metrics_str += f"V-measure: {metrics.v_measure_score(labels_true, labels):0.3f}\n"
metrics_str += f"Adjusted Rand Index: {metrics.adjusted_rand_score(labels_true, labels):0.3f}\n"
metrics_str += f"Adjusted Mutual Information: {metrics.adjusted_mutual_info_score(labels_true, labels):0.3f}\n"
metrics_str += f"Silhouette Coefficient: {metrics.silhouette_score(X, labels, metric='sqeuclidean'):0.3f}\n"
fig = plt.figure(1)
plt.clf()
colors = plt.cycler("color", plt.cm.viridis(np.linspace(0, 1, n_clusters_)))
for k, col in zip(range(n_clusters_), colors):
class_members = labels == k
cluster_center = X[cluster_centers_indices[k]]
plt.scatter(
X[class_members, 0], X[class_members, 1], color=col["color"], marker="."
)
plt.scatter(
cluster_center[0], cluster_center[1], s=14, color=col["color"], marker="o"
)
for x in X[class_members]:
plt.plot(
[cluster_center[0], x[0]], [cluster_center[1], x[1]], color=col["color"]
)
plt.title("Estimated number of clusters: %d" % n_clusters_)
plt.xlabel("x")
plt.ylabel("y")
return fig, metrics_str
title = "Affinity propagation clustering algorithm"
description = """
This demo plots clusters of a synthetic 2D dataset that contains up to 4 clusters using the affinity propagation algorithm.\
The 2-dimensional dataset is generated around 2 to 4 predetermined cluster centers, by sampling a Gaussian distribution \
with 0.5 standard deviation around each center. The demo uses the affinity propagation clustering algorithm to assign the data into \
clusters. It also calculates a cluster center. \
The figure shows a scatter plot of the data points and their connection to the respective cluster center. The demo also \
presents several metrics based on the true and assigned labels.
"""
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
num_clusters = gr.Slider(minimum=2, maximum=4, step=1, value=2, label="Number of clusters")
num_samples = gr.Slider(minimum=100, maximum=300, step=100, value=200, label="Number of samples")
with gr.Row():
plot = gr.Plot()
text_box = gr.Textbox(label="Results")
num_clusters.change(fn=create_plot, inputs=[num_clusters, num_samples], outputs=[plot, text_box])
num_samples.change(fn=create_plot, inputs=[num_clusters, num_samples], outputs=[plot, text_box])
demo.launch(enable_queue=True)