import gradio as gr import numpy as np from functools import partial from matplotlib import pyplot as plt from scipy.cluster.hierarchy import dendrogram from sklearn.datasets import load_iris from sklearn.cluster import AgglomerativeClustering theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", ) model_card = f""" ## Description This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset. There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan` You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations. Note: If `linkage criterion` is **ward**, only **euclidean** can use ## Dataset Iris dataset """ iris = load_iris() X = iris.data def iter_grid(n_rows, n_cols): # create a grid using gradio Block for _ in range(n_rows): with gr.Row(): for _ in range(n_cols): with gr.Column(): yield def plot_dendrogram(linkage_name, metric_name): # Create linkage matrix and then plot the dendrogram if linkage_name == "ward" and metric_name != "euclidean": return None # setting distance_threshold=0 ensures we compute the full tree. model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name) model = model.fit(X) # plot the top three levels of the dendrogram counts = np.zeros(model.children_.shape[0]) n_samples = len(model.labels_) for i, merge in enumerate(model.children_): current_count = 0 for child_idx in merge: if child_idx < n_samples: current_count += 1 # leaf node else: current_count += counts[child_idx - n_samples] counts[i] = current_count linkage_matrix = np.column_stack( [model.children_, model.distances_, counts] ).astype(float) fig, axes = plt.subplots() dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3) # Plot the corresponding dendrogram axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}") axes.set_xlabel("Number of points in node (or index of point if no parenthesis).") return fig with gr.Blocks(theme=theme) as demo: gr.Markdown('''

Hierarchical Clustering Dendrogram

''') gr.Markdown(model_card) gr.Markdown("Author: Vu Minh Chien. Based on the example from scikit-learn") input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use") metrics = ["euclidean", "l1", "l2", "manhattan"] counter = 0 for _ in iter_grid(2, 2): if counter >= len(metrics): break input_metric = metrics[counter] plot = gr.Plot(label=input_metric) fn = partial(plot_dendrogram, metric_name=input_metric) input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot) counter += 1 demo.launch()