freddyaboulton HF staff commited on
Commit
58047f4
1 Parent(s): d53bfda

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo for different clustering techiniques
2
+
3
+ Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html
4
+
5
+ """
6
+
7
+ import math
8
+ from functools import partial
9
+
10
+ import gradio as gr
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from sklearn.cluster import (
14
+ AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
15
+ )
16
+ from sklearn.datasets import make_blobs, make_circles, make_moons
17
+ from sklearn.mixture import GaussianMixture
18
+ from sklearn.neighbors import kneighbors_graph
19
+ from sklearn.preprocessing import StandardScaler
20
+
21
+
22
+ plt.style.use('seaborn')
23
+
24
+
25
+ SEED = 0
26
+ MAX_CLUSTERS = 10
27
+ N_SAMPLES = 1000
28
+ N_COLS = 3
29
+ FIGSIZE = 7, 7 # does not affect size in webpage
30
+ COLORS = [
31
+ 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
32
+ ]
33
+ assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
34
+ np.random.seed(SEED)
35
+
36
+
37
+ def normalize(X):
38
+ return StandardScaler().fit_transform(X)
39
+
40
+
41
+ def get_regular(n_clusters):
42
+ # spiral pattern
43
+ centers = [
44
+ [0, 0],
45
+ [1, 0],
46
+ [1, 1],
47
+ [0, 1],
48
+ [-1, 1],
49
+ [-1, 0],
50
+ [-1, -1],
51
+ [0, -1],
52
+ [1, -1],
53
+ [2, -1],
54
+ ][:n_clusters]
55
+ assert len(centers) == n_clusters
56
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
57
+ return normalize(X), labels
58
+
59
+
60
+ def get_circles(n_clusters):
61
+ X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
62
+ return normalize(X), labels
63
+
64
+
65
+ def get_moons(n_clusters):
66
+ X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
67
+ return normalize(X), labels
68
+
69
+
70
+ def get_noise(n_clusters):
71
+ np.random.seed(SEED)
72
+ X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))
73
+ return normalize(X), labels
74
+
75
+
76
+ def get_anisotropic(n_clusters):
77
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
78
+ transformation = [[0.6, -0.6], [-0.4, 0.8]]
79
+ X = np.dot(X, transformation)
80
+ return X, labels
81
+
82
+
83
+ def get_varied(n_clusters):
84
+ cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
85
+ assert len(cluster_std) == n_clusters
86
+ X, labels = make_blobs(
87
+ n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
88
+ )
89
+ return normalize(X), labels
90
+
91
+
92
+ def get_spiral(n_clusters):
93
+ # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html
94
+ np.random.seed(SEED)
95
+ t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))
96
+ x = t * np.cos(t)
97
+ y = t * np.sin(t)
98
+ X = np.concatenate((x, y))
99
+ X += 0.7 * np.random.randn(2, N_SAMPLES)
100
+ X = np.ascontiguousarray(X.T)
101
+
102
+ labels = np.zeros(N_SAMPLES, dtype=int)
103
+ return normalize(X), labels
104
+
105
+
106
+ DATA_MAPPING = {
107
+ 'regular': get_regular,
108
+ 'circles': get_circles,
109
+ 'moons': get_moons,
110
+ 'spiral': get_spiral,
111
+ 'noise': get_noise,
112
+ 'anisotropic': get_anisotropic,
113
+ 'varied': get_varied,
114
+ }
115
+
116
+
117
+ def get_groundtruth_model(X, labels, n_clusters, **kwargs):
118
+ # dummy model to show true label distribution
119
+ class Dummy:
120
+ def __init__(self, y):
121
+ self.labels_ = labels
122
+
123
+ return Dummy(labels)
124
+
125
+
126
+ def get_kmeans(X, labels, n_clusters, **kwargs):
127
+ model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
128
+ model.set_params(**kwargs)
129
+ return model.fit(X)
130
+
131
+
132
+ def get_dbscan(X, labels, n_clusters, **kwargs):
133
+ model = DBSCAN(eps=0.3)
134
+ model.set_params(**kwargs)
135
+ return model.fit(X)
136
+
137
+
138
+ def get_agglomerative(X, labels, n_clusters, **kwargs):
139
+ connectivity = kneighbors_graph(
140
+ X, n_neighbors=n_clusters, include_self=False
141
+ )
142
+ # make connectivity symmetric
143
+ connectivity = 0.5 * (connectivity + connectivity.T)
144
+ model = AgglomerativeClustering(
145
+ n_clusters=n_clusters, linkage="ward", connectivity=connectivity
146
+ )
147
+ model.set_params(**kwargs)
148
+ return model.fit(X)
149
+
150
+
151
+ def get_meanshift(X, labels, n_clusters, **kwargs):
152
+ bandwidth = estimate_bandwidth(X, quantile=0.25)
153
+ model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
154
+ model.set_params(**kwargs)
155
+ return model.fit(X)
156
+
157
+
158
+ def get_spectral(X, labels, n_clusters, **kwargs):
159
+ model = SpectralClustering(
160
+ n_clusters=n_clusters,
161
+ eigen_solver="arpack",
162
+ affinity="nearest_neighbors",
163
+ )
164
+ model.set_params(**kwargs)
165
+ return model.fit(X)
166
+
167
+
168
+ def get_optics(X, labels, n_clusters, **kwargs):
169
+ model = OPTICS(
170
+ min_samples=7,
171
+ xi=0.05,
172
+ min_cluster_size=0.1,
173
+ )
174
+ model.set_params(**kwargs)
175
+ return model.fit(X)
176
+
177
+
178
+ def get_birch(X, labels, n_clusters, **kwargs):
179
+ model = Birch(n_clusters=n_clusters)
180
+ model.set_params(**kwargs)
181
+ return model.fit(X)
182
+
183
+
184
+ def get_gaussianmixture(X, labels, n_clusters, **kwargs):
185
+ model = GaussianMixture(
186
+ n_components=n_clusters, covariance_type="full", random_state=SEED,
187
+ )
188
+ model.set_params(**kwargs)
189
+ return model.fit(X)
190
+
191
+
192
+ MODEL_MAPPING = {
193
+ 'True labels': get_groundtruth_model,
194
+ 'KMeans': get_kmeans,
195
+ 'DBSCAN': get_dbscan,
196
+ 'MeanShift': get_meanshift,
197
+ 'SpectralClustering': get_spectral,
198
+ 'OPTICS': get_optics,
199
+ 'Birch': get_birch,
200
+ 'GaussianMixture': get_gaussianmixture,
201
+ 'AgglomerativeClustering': get_agglomerative,
202
+ }
203
+
204
+
205
+ def plot_clusters(ax, X, labels):
206
+ set_clusters = set(labels)
207
+ set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
208
+ for label, color in zip(sorted(set_clusters), COLORS):
209
+ idx = labels == label
210
+ if not sum(idx):
211
+ continue
212
+ ax.scatter(X[idx, 0], X[idx, 1], color=color)
213
+
214
+ # show outliers (if any)
215
+ idx = labels == -1
216
+ if sum(idx):
217
+ ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')
218
+
219
+ ax.grid(None)
220
+ ax.set_xticks([])
221
+ ax.set_yticks([])
222
+ return ax
223
+
224
+
225
+ def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
226
+ if isinstance(n_clusters, dict):
227
+ n_clusters = n_clusters['value']
228
+ else:
229
+ n_clusters = int(n_clusters)
230
+
231
+ X, labels = DATA_MAPPING[dataset](n_clusters)
232
+ model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
233
+ if hasattr(model, "labels_"):
234
+ y_pred = model.labels_.astype(int)
235
+ else:
236
+ y_pred = model.predict(X)
237
+
238
+ fig, ax = plt.subplots(figsize=FIGSIZE)
239
+
240
+ plot_clusters(ax, X, y_pred)
241
+ ax.set_title(clustering_algorithm, fontsize=16)
242
+
243
+ return fig
244
+
245
+
246
+ title = "Clustering with Scikit-learn"
247
+ description = (
248
+ "This example shows how different clustering algorithms work. Simply pick "
249
+ "the dataset and the number of clusters to see how the clustering algorithms work. "
250
+ "Colored cirles are (predicted) labels and black x are outliers."
251
+ )
252
+
253
+
254
+ def iter_grid(n_rows, n_cols):
255
+ # create a grid using gradio Block
256
+ for _ in range(n_rows):
257
+ with gr.Row():
258
+ for _ in range(n_cols):
259
+ with gr.Column():
260
+ yield
261
+
262
+
263
+ with gr.Blocks(title=title) as demo:
264
+ gr.HTML(f"<b>{title}</b>")
265
+ gr.Markdown(description)
266
+
267
+ input_models = list(MODEL_MAPPING)
268
+ input_data = gr.Radio(
269
+ list(DATA_MAPPING),
270
+ value="regular",
271
+ label="dataset"
272
+ )
273
+ input_n_clusters = gr.Slider(
274
+ minimum=1,
275
+ maximum=MAX_CLUSTERS,
276
+ value=4,
277
+ step=1,
278
+ label='Number of clusters'
279
+ )
280
+ n_rows = int(math.ceil(len(input_models) / N_COLS))
281
+ counter = 0
282
+ for _ in iter_grid(n_rows, N_COLS):
283
+ if counter >= len(input_models):
284
+ break
285
+
286
+ input_model = input_models[counter]
287
+ plot = gr.Plot(label=input_model)
288
+ fn = partial(cluster, clustering_algorithm=input_model)
289
+ input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
290
+ input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
291
+ counter += 1
292
+
293
+
294
+ demo.launch()