|
import gradio as gr |
|
import numpy as np |
|
from functools import partial |
|
|
|
from matplotlib import pyplot as plt |
|
from scipy.cluster.hierarchy import dendrogram |
|
from sklearn.datasets import load_iris |
|
from sklearn.cluster import AgglomerativeClustering |
|
|
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
) |
|
model_card = f""" |
|
## Description |
|
|
|
This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset. |
|
There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan` |
|
You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations. |
|
Note: If `linkage criterion` is **ward**, only **euclidean** can use |
|
|
|
|
|
## Dataset |
|
|
|
Iris dataset |
|
""" |
|
iris = load_iris() |
|
X = iris.data |
|
|
|
def iter_grid(n_rows, n_cols): |
|
|
|
for _ in range(n_rows): |
|
with gr.Row(): |
|
for _ in range(n_cols): |
|
with gr.Column(): |
|
yield |
|
|
|
def plot_dendrogram(linkage_name, metric_name): |
|
|
|
if linkage_name == "ward" and metric_name != "euclidean": |
|
return None |
|
|
|
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name) |
|
|
|
model = model.fit(X) |
|
|
|
|
|
|
|
counts = np.zeros(model.children_.shape[0]) |
|
n_samples = len(model.labels_) |
|
for i, merge in enumerate(model.children_): |
|
current_count = 0 |
|
for child_idx in merge: |
|
if child_idx < n_samples: |
|
current_count += 1 |
|
else: |
|
current_count += counts[child_idx - n_samples] |
|
counts[i] = current_count |
|
|
|
linkage_matrix = np.column_stack( |
|
[model.children_, model.distances_, counts] |
|
).astype(float) |
|
fig, axes = plt.subplots() |
|
|
|
dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3) |
|
|
|
axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}") |
|
axes.set_xlabel("Number of points in node (or index of point if no parenthesis).") |
|
return fig |
|
|
|
|
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown(''' |
|
<div> |
|
<h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1> |
|
</div> |
|
''') |
|
gr.Markdown(model_card) |
|
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>") |
|
input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use") |
|
metrics = ["euclidean", "l1", "l2", "manhattan"] |
|
counter = 0 |
|
for _ in iter_grid(2, 2): |
|
if counter >= len(metrics): |
|
break |
|
|
|
input_metric = metrics[counter] |
|
plot = gr.Plot(label=input_metric) |
|
fn = partial(plot_dendrogram, metric_name=input_metric) |
|
input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot) |
|
counter += 1 |
|
|
|
demo.launch() |