|
import gradio as gr |
|
from sklearn.datasets import make_blobs |
|
from sklearn.cluster import BisectingKMeans, KMeans |
|
from functools import partial |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def train_models(n_samples, n_clusters, cls_name): |
|
default_base = {"n_samples": 500, "n_clusters": 1} |
|
|
|
|
|
params = default_base.copy() |
|
params.update({"n_samples":n_samples}) |
|
params.update({"n_clusters":n_clusters}) |
|
|
|
clustering_algorithms = { |
|
"Bisecting K-Means": BisectingKMeans, |
|
"K-Means": KMeans, |
|
} |
|
|
|
X, _ = make_blobs(n_samples=params["n_samples"], centers=2, random_state=0) |
|
fig, ax = plt.subplots() |
|
|
|
model = clustering_algorithms[cls_name] |
|
algo = model(n_clusters=params["n_clusters"], random_state=0, n_init=3) |
|
algo.fit(X) |
|
centers = algo.cluster_centers_ |
|
|
|
ax.scatter(X[:, 0], X[:, 1], s=10, c=algo.labels_) |
|
ax.scatter(centers[:, 0], centers[:, 1], c="r", s=20) |
|
|
|
ax.set_title(f"{cls_name} : {params['n_clusters']} clusters") |
|
|
|
|
|
ax.label_outer() |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
return fig |
|
|
|
|
|
def iter_grid(n_rows, n_cols): |
|
|
|
for _ in range(n_rows): |
|
with gr.Row(): |
|
for _ in range(n_cols): |
|
with gr.Column(): |
|
yield |
|
|
|
|
|
title = "π Performance Comparison: Bisecting vs Regular K-Means" |
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown("This example shows differences between " |
|
+ "Regular K-Means algorithm and Bisecting K-Means. ") |
|
|
|
|
|
input_models = ["Bisecting K-Means", "K-Means"] |
|
|
|
n_samples = gr.Slider(minimum=500, maximum=2000, step=50, |
|
label = "Number of Samples") |
|
|
|
n_clusters = gr.Slider(minimum=1, maximum=20, step=1, |
|
label = "Number of Clusters") |
|
counter = 0 |
|
|
|
for _ in iter_grid(1,2): |
|
if counter >= len(input_models): |
|
break |
|
|
|
input_model = input_models[counter] |
|
plot = gr.Plot(label=input_model) |
|
|
|
fn = partial(train_models, cls_name=input_model) |
|
n_samples.change(fn=fn, inputs=[n_samples, n_clusters], outputs=plot) |
|
|
|
n_clusters.change(fn=fn, inputs=[n_samples, n_clusters], outputs=plot) |
|
counter += 1 |
|
|
|
demo.launch() |