File size: 8,069 Bytes
6674a4f
 
 
 
 
 
7c9b8f5
 
 
6674a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0415b11
6674a4f
7c9b8f5
 
 
 
 
 
6674a4f
 
 
 
 
 
 
0415b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6674a4f
 
 
0415b11
6674a4f
 
 
 
0415b11
6674a4f
 
 
 
0415b11
e04e7d0
7a8ea79
6674a4f
 
 
0415b11
 
6674a4f
 
 
 
 
0415b11
 
 
6674a4f
0415b11
6674a4f
 
 
 
9be4332
 
e04e7d0
9be4332
 
 
 
 
dffd3b5
9be4332
 
 
 
 
6674a4f
 
 
 
9be4332
6674a4f
 
 
 
 
7c9b8f5
 
 
 
 
 
 
 
 
 
 
0415b11
6674a4f
 
 
 
7c9b8f5
6674a4f
 
 
 
 
7c9b8f5
6674a4f
0415b11
6674a4f
 
 
 
0415b11
6674a4f
 
 
 
 
7c9b8f5
 
6674a4f
 
 
 
 
7c9b8f5
6674a4f
0415b11
6674a4f
 
 
 
 
 
 
7c9b8f5
6674a4f
 
 
 
 
 
 
 
 
7c9b8f5
0415b11
6674a4f
 
 
 
7c9b8f5
6674a4f
0415b11
6674a4f
 
 
 
 
 
7c9b8f5
6674a4f
 
 
 
 
 
 
7c9b8f5
6674a4f
 
 
 
0415b11
 
7c9b8f5
6674a4f
 
 
7c9b8f5
6674a4f
0415b11
 
 
 
 
6674a4f
 
 
 
 
 
7c9b8f5
 
 
 
 
 
0415b11
7c9b8f5
6674a4f
 
 
 
 
7c9b8f5
6674a4f
 
7c9b8f5
6674a4f
 
 
0415b11
a88bd97
0415b11
 
305364f
 
6674a4f
 
5444690
 
 
 
 
 
 
 
 
 
7c9b8f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5444690
 
 
 
 
 
 
 
 
 
7c9b8f5
 
6674a4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""Gradio demo for different clustering techiniques

Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html

"""

import math
from functools import partial

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import (
    AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
)
from sklearn.datasets import make_blobs, make_circles, make_moons
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler


plt.style.use('seaborn')


SEED = 0
MAX_CLUSTERS = 10
N_SAMPLES = 1000
N_COLS = 3
FIGSIZE = 7, 7  # does not affect size in webpage
COLORS = [
    'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
]
assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
np.random.seed(SEED)


def normalize(X):
    return StandardScaler().fit_transform(X)


def get_regular(n_clusters):
    # spiral pattern
    centers = [
        [0, 0],
        [1, 0],
        [1, 1],
        [0, 1],
        [-1, 1],
        [-1, 0],
        [-1, -1],
        [0, -1],
        [1, -1],
        [2, -1],
    ][:n_clusters]
    assert len(centers) == n_clusters
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
    return normalize(X), labels


def get_circles(n_clusters):
    X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_moons(n_clusters):
    X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_noise(n_clusters):
    np.random.seed(SEED)
    X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))
    return normalize(X), labels


def get_anisotropic(n_clusters):
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
    transformation = [[0.6, -0.6], [-0.4, 0.8]]
    X = np.dot(X, transformation)
    return X, labels


def get_varied(n_clusters):
    cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
    assert len(cluster_std) == n_clusters
    X, labels = make_blobs(
        n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
    )
    return normalize(X), labels


def get_spiral(n_clusters):
    # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html
    np.random.seed(SEED)
    t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))
    x = t * np.cos(t)
    y = t * np.sin(t)
    X = np.concatenate((x, y))
    X += 0.7 * np.random.randn(2, N_SAMPLES)
    X = np.ascontiguousarray(X.T)

    labels = np.zeros(N_SAMPLES, dtype=int)
    return normalize(X), labels


DATA_MAPPING = {
    'regular': get_regular,
    'circles': get_circles,
    'moons': get_moons,
    'spiral': get_spiral,
    'noise': get_noise,
    'anisotropic': get_anisotropic,
    'varied': get_varied,
}


def get_groundtruth_model(X, labels, n_clusters, **kwargs):
    # dummy model to show true label distribution
    class Dummy:
        def __init__(self, y):
            self.labels_ = labels

    return Dummy(labels)


def get_kmeans(X, labels, n_clusters, **kwargs):
    model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
    model.set_params(**kwargs)
    return model.fit(X)


def get_dbscan(X, labels, n_clusters, **kwargs):
    model = DBSCAN(eps=0.3)
    model.set_params(**kwargs)
    return model.fit(X)


def get_agglomerative(X, labels, n_clusters, **kwargs):
    connectivity = kneighbors_graph(
        X, n_neighbors=n_clusters, include_self=False
    )
    # make connectivity symmetric
    connectivity = 0.5 * (connectivity + connectivity.T)
    model = AgglomerativeClustering(
        n_clusters=n_clusters, linkage="ward", connectivity=connectivity
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_meanshift(X, labels, n_clusters, **kwargs):
    bandwidth = estimate_bandwidth(X, quantile=0.25)
    model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    model.set_params(**kwargs)
    return model.fit(X)


def get_spectral(X, labels, n_clusters, **kwargs):
    model = SpectralClustering(
        n_clusters=n_clusters,
        eigen_solver="arpack",
        affinity="nearest_neighbors",
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_optics(X, labels, n_clusters, **kwargs):
    model = OPTICS(
        min_samples=7,
        xi=0.05,
        min_cluster_size=0.1,
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_birch(X, labels, n_clusters, **kwargs):
    model = Birch(n_clusters=n_clusters)
    model.set_params(**kwargs)
    return model.fit(X)


def get_gaussianmixture(X, labels, n_clusters, **kwargs):
    model = GaussianMixture(
        n_components=n_clusters, covariance_type="full", random_state=SEED,
    )
    model.set_params(**kwargs)
    return model.fit(X)


MODEL_MAPPING = {
    'True labels': get_groundtruth_model,
    'KMeans': get_kmeans,
    'DBSCAN': get_dbscan,
    'MeanShift': get_meanshift,
    'SpectralClustering': get_spectral,
    'OPTICS': get_optics,
    'Birch': get_birch,
    'GaussianMixture': get_gaussianmixture,
    'AgglomerativeClustering': get_agglomerative,
}


def plot_clusters(ax, X, labels):
    set_clusters = set(labels)
    set_clusters.discard(-1)  # -1 signifiies outliers, which we plot separately
    for label, color in zip(sorted(set_clusters), COLORS):
        idx = labels == label
        if not sum(idx):
            continue
        ax.scatter(X[idx, 0], X[idx, 1], color=color)

    # show outliers (if any)
    idx = labels == -1
    if sum(idx):
        ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')

    ax.grid(None)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax


def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
    if isinstance(n_clusters, dict):
        n_clusters = n_clusters['value']
    else:
        n_clusters = int(n_clusters)

    X, labels = DATA_MAPPING[dataset](n_clusters)
    model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
    if hasattr(model, "labels_"):
        y_pred = model.labels_.astype(int)
    else:
        y_pred = model.predict(X)

    fig, ax = plt.subplots(figsize=FIGSIZE)

    plot_clusters(ax, X, y_pred)
    ax.set_title(clustering_algorithm, fontsize=16)

    return fig


title = "Clustering with Scikit-learn"
description = (
    "This example shows how different clustering algorithms work. Simply pick "
    "the dataset and the number of clusters to see how the clustering algorithms work. "
    "Colored cirles are (predicted) labels and black x are outliers."
)


def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield


with gr.Blocks(title=title) as demo:
    gr.HTML(f"<b>{title}</b>")
    gr.Markdown(description)

    input_models = list(MODEL_MAPPING)
    input_data = gr.Radio(
        list(DATA_MAPPING),
        value="regular",
        label="dataset"
    )
    input_n_clusters = gr.Slider(
        minimum=1,
        maximum=MAX_CLUSTERS,
        value=4,
        step=1,
        label='Number of clusters'
    )
    n_rows = int(math.ceil(len(input_models) / N_COLS))
    counter = 0
    for _ in iter_grid(n_rows, N_COLS):
        if counter >= len(input_models):
            break

        input_model = input_models[counter]
        plot = gr.Plot(label=input_model)
        fn = partial(cluster, clustering_algorithm=input_model)
        input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
        input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
        counter += 1


demo.launch()