vumichien's picture
Create app.py
75e0a65
import gradio as gr
import numpy as np
from functools import partial
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
)
model_card = f"""
## Description
This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset.
There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan`
You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations.
Note: If `linkage criterion` is **ward**, only **euclidean** can use
## Dataset
Iris dataset
"""
iris = load_iris()
X = iris.data
def iter_grid(n_rows, n_cols):
# create a grid using gradio Block
for _ in range(n_rows):
with gr.Row():
for _ in range(n_cols):
with gr.Column():
yield
def plot_dendrogram(linkage_name, metric_name):
# Create linkage matrix and then plot the dendrogram
if linkage_name == "ward" and metric_name != "euclidean":
return None
# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name)
model = model.fit(X)
# plot the top three levels of the dendrogram
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
fig, axes = plt.subplots()
dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3)
# Plot the corresponding dendrogram
axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}")
axes.set_xlabel("Number of points in node (or index of point if no parenthesis).")
return fig
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>")
input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use")
metrics = ["euclidean", "l1", "l2", "manhattan"]
counter = 0
for _ in iter_grid(2, 2):
if counter >= len(metrics):
break
input_metric = metrics[counter]
plot = gr.Plot(label=input_metric)
fn = partial(plot_dendrogram, metric_name=input_metric)
input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot)
counter += 1
demo.launch()