vumichien commited on
Commit
75e0a65
1 Parent(s): e7d0eec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from functools import partial
4
+
5
+ from matplotlib import pyplot as plt
6
+ from scipy.cluster.hierarchy import dendrogram
7
+ from sklearn.datasets import load_iris
8
+ from sklearn.cluster import AgglomerativeClustering
9
+
10
+
11
+ theme = gr.themes.Monochrome(
12
+ primary_hue="indigo",
13
+ secondary_hue="blue",
14
+ neutral_hue="slate",
15
+ )
16
+ model_card = f"""
17
+ ## Description
18
+
19
+ This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset.
20
+ There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan`
21
+ You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations.
22
+ Note: If `linkage criterion` is **ward**, only **euclidean** can use
23
+
24
+
25
+ ## Dataset
26
+
27
+ Iris dataset
28
+ """
29
+ iris = load_iris()
30
+ X = iris.data
31
+
32
+ def iter_grid(n_rows, n_cols):
33
+ # create a grid using gradio Block
34
+ for _ in range(n_rows):
35
+ with gr.Row():
36
+ for _ in range(n_cols):
37
+ with gr.Column():
38
+ yield
39
+
40
+ def plot_dendrogram(linkage_name, metric_name):
41
+ # Create linkage matrix and then plot the dendrogram
42
+ if linkage_name == "ward" and metric_name != "euclidean":
43
+ return None
44
+ # setting distance_threshold=0 ensures we compute the full tree.
45
+ model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name)
46
+
47
+ model = model.fit(X)
48
+
49
+ # plot the top three levels of the dendrogram
50
+
51
+ counts = np.zeros(model.children_.shape[0])
52
+ n_samples = len(model.labels_)
53
+ for i, merge in enumerate(model.children_):
54
+ current_count = 0
55
+ for child_idx in merge:
56
+ if child_idx < n_samples:
57
+ current_count += 1 # leaf node
58
+ else:
59
+ current_count += counts[child_idx - n_samples]
60
+ counts[i] = current_count
61
+
62
+ linkage_matrix = np.column_stack(
63
+ [model.children_, model.distances_, counts]
64
+ ).astype(float)
65
+ fig, axes = plt.subplots()
66
+
67
+ dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3)
68
+ # Plot the corresponding dendrogram
69
+ axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}")
70
+ axes.set_xlabel("Number of points in node (or index of point if no parenthesis).")
71
+ return fig
72
+
73
+
74
+
75
+ with gr.Blocks(theme=theme) as demo:
76
+ gr.Markdown('''
77
+ <div>
78
+ <h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1>
79
+ </div>
80
+ ''')
81
+ gr.Markdown(model_card)
82
+ gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>")
83
+ input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use")
84
+ metrics = ["euclidean", "l1", "l2", "manhattan"]
85
+ counter = 0
86
+ for _ in iter_grid(2, 2):
87
+ if counter >= len(metrics):
88
+ break
89
+
90
+ input_metric = metrics[counter]
91
+ plot = gr.Plot(label=input_metric)
92
+ fn = partial(plot_dendrogram, metric_name=input_metric)
93
+ input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot)
94
+ counter += 1
95
+
96
+ demo.launch()