Benjamin Bossan commited on
Commit
0415b11
1 Parent(s): a88bd97

Users can change the number of clusters

Browse files
Files changed (1) hide show
  1. app.py +65 -31
app.py CHANGED
@@ -20,7 +20,7 @@ plt.style.use('seaborn')
20
 
21
 
22
  SEED = 0
23
- N_CLUSTERS = 4
24
  N_SAMPLES = 1000
25
  np.random.seed(SEED)
26
 
@@ -29,38 +29,52 @@ def normalize(X):
29
  return StandardScaler().fit_transform(X)
30
 
31
 
32
- def get_regular():
33
- centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]]
34
- assert len(centers) == N_CLUSTERS
35
- X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED)
 
 
 
 
 
 
 
 
 
 
 
 
36
  return normalize(X), labels
37
 
38
 
39
- def get_circles():
40
  X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
41
  return normalize(X), labels
42
 
43
 
44
- def get_moons():
45
  X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
46
  return normalize(X), labels
47
 
48
 
49
- def get_noise():
50
  X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES)
51
  return normalize(X), labels
52
 
53
 
54
- def get_anisotropic():
55
- X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170)
56
  transformation = [[0.6, -0.6], [-0.4, 0.8]]
57
  X = np.dot(X, transformation)
58
  return X, labels
59
 
60
 
61
- def get_varied():
 
 
62
  X, labels = make_blobs(
63
- n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED
64
  )
65
  return normalize(X), labels
66
 
@@ -74,41 +88,41 @@ DATA_MAPPING = {
74
  'varied': get_varied,
75
  }
76
 
77
- def get_kmeans(X, **kwargs):
78
- model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED)
79
  model.set_params(**kwargs)
80
  return model.fit(X)
81
 
82
 
83
- def get_dbscan(X, **kwargs):
84
  model = DBSCAN(eps=0.3)
85
  model.set_params(**kwargs)
86
  return model.fit(X)
87
 
88
 
89
- def get_agglomerative(X, **kwargs):
90
  connectivity = kneighbors_graph(
91
- X, n_neighbors=N_CLUSTERS, include_self=False
92
  )
93
  # make connectivity symmetric
94
  connectivity = 0.5 * (connectivity + connectivity.T)
95
  model = AgglomerativeClustering(
96
- n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity
97
  )
98
  model.set_params(**kwargs)
99
  return model.fit(X)
100
 
101
 
102
- def get_meanshift(X, **kwargs):
103
  bandwidth = estimate_bandwidth(X, quantile=0.3)
104
  model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
105
  model.set_params(**kwargs)
106
  return model.fit(X)
107
 
108
 
109
- def get_spectral(X, **kwargs):
110
  model = SpectralClustering(
111
- n_clusters=N_CLUSTERS,
112
  eigen_solver="arpack",
113
  affinity="nearest_neighbors",
114
  )
@@ -116,7 +130,7 @@ def get_spectral(X, **kwargs):
116
  return model.fit(X)
117
 
118
 
119
- def get_optics(X, **kwargs):
120
  model = OPTICS(
121
  min_samples=7,
122
  xi=0.05,
@@ -126,15 +140,15 @@ def get_optics(X, **kwargs):
126
  return model.fit(X)
127
 
128
 
129
- def get_birch(X, **kwargs):
130
- model = Birch(n_clusters=3)
131
  model.set_params(**kwargs)
132
  return model.fit(X)
133
 
134
 
135
- def get_gaussianmixture(X, **kwargs):
136
  model = GaussianMixture(
137
- n_components=N_CLUSTERS, covariance_type="full", random_state=SEED,
138
  )
139
  model.set_params(**kwargs)
140
  return model.fit(X)
@@ -153,21 +167,29 @@ MODEL_MAPPING = {
153
 
154
 
155
  def plot_clusters(ax, X, labels):
156
- for label in range(N_CLUSTERS):
 
 
157
  idx = labels == label
158
  if not sum(idx):
159
  continue
160
  ax.scatter(X[idx, 0], X[idx, 1])
161
 
 
 
 
 
 
162
  ax.grid(None)
163
  ax.set_xticks([])
164
  ax.set_yticks([])
165
  return ax
166
 
167
 
168
- def cluster(clustering_algorithm: str, dataset: str):
169
- X, labels = DATA_MAPPING[dataset]()
170
- model = MODEL_MAPPING[clustering_algorithm](X)
 
171
  if hasattr(model, "labels_"):
172
  y_pred = model.labels_.astype(int)
173
  else:
@@ -175,18 +197,24 @@ def cluster(clustering_algorithm: str, dataset: str):
175
 
176
  fig, axes = plt.subplots(1, 2, figsize=(16, 8))
177
 
 
178
  ax = axes[0]
179
  plot_clusters(ax, X, labels)
180
  ax.set_title("True clusters")
181
 
 
182
  ax = axes[1]
183
  plot_clusters(ax, X, y_pred)
184
  ax.set_title(clustering_algorithm)
185
 
186
  return fig
187
 
 
188
  title = "Clustering with Scikit-learn"
189
- description = "This example shows how different clustering algorithms work. Simply pick the algorithm and the dataset to see the clusters algorithms make."
 
 
 
190
  demo = gr.Interface(
191
  fn=cluster,
192
  inputs=[
@@ -200,6 +228,12 @@ demo = gr.Interface(
200
  value="regular",
201
  label="dataset"
202
  ),
 
 
 
 
 
 
203
  ],
204
  title=title,
205
  description=description,
 
20
 
21
 
22
  SEED = 0
23
+ MAX_CLUSTERS = 10
24
  N_SAMPLES = 1000
25
  np.random.seed(SEED)
26
 
 
29
  return StandardScaler().fit_transform(X)
30
 
31
 
32
+ def get_regular(n_clusters):
33
+ # spiral pattern
34
+ centers = [
35
+ [0, 0],
36
+ [1, 0],
37
+ [1, 1],
38
+ [0, 1],
39
+ [-1, 1],
40
+ [-1, 0],
41
+ [-1, -1],
42
+ [0, -1],
43
+ [1, -1],
44
+ [2, -1],
45
+ ][:n_clusters]
46
+ assert len(centers) == n_clusters
47
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
48
  return normalize(X), labels
49
 
50
 
51
+ def get_circles(n_clusters):
52
  X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
53
  return normalize(X), labels
54
 
55
 
56
+ def get_moons(n_clusters):
57
  X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
58
  return normalize(X), labels
59
 
60
 
61
+ def get_noise(n_clusters):
62
  X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES)
63
  return normalize(X), labels
64
 
65
 
66
+ def get_anisotropic(n_clusters):
67
+ X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
68
  transformation = [[0.6, -0.6], [-0.4, 0.8]]
69
  X = np.dot(X, transformation)
70
  return X, labels
71
 
72
 
73
+ def get_varied(n_clusters):
74
+ cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
75
+ assert len(cluster_std) == n_clusters
76
  X, labels = make_blobs(
77
+ n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
78
  )
79
  return normalize(X), labels
80
 
 
88
  'varied': get_varied,
89
  }
90
 
91
+ def get_kmeans(X, n_clusters, **kwargs):
92
+ model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
93
  model.set_params(**kwargs)
94
  return model.fit(X)
95
 
96
 
97
+ def get_dbscan(X, n_clusters, **kwargs):
98
  model = DBSCAN(eps=0.3)
99
  model.set_params(**kwargs)
100
  return model.fit(X)
101
 
102
 
103
+ def get_agglomerative(X, n_clusters, **kwargs):
104
  connectivity = kneighbors_graph(
105
+ X, n_neighbors=n_clusters, include_self=False
106
  )
107
  # make connectivity symmetric
108
  connectivity = 0.5 * (connectivity + connectivity.T)
109
  model = AgglomerativeClustering(
110
+ n_clusters=n_clusters, linkage="ward", connectivity=connectivity
111
  )
112
  model.set_params(**kwargs)
113
  return model.fit(X)
114
 
115
 
116
+ def get_meanshift(X, n_clusters, **kwargs):
117
  bandwidth = estimate_bandwidth(X, quantile=0.3)
118
  model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
119
  model.set_params(**kwargs)
120
  return model.fit(X)
121
 
122
 
123
+ def get_spectral(X, n_clusters, **kwargs):
124
  model = SpectralClustering(
125
+ n_clusters=n_clusters,
126
  eigen_solver="arpack",
127
  affinity="nearest_neighbors",
128
  )
 
130
  return model.fit(X)
131
 
132
 
133
+ def get_optics(X, n_clusters, **kwargs):
134
  model = OPTICS(
135
  min_samples=7,
136
  xi=0.05,
 
140
  return model.fit(X)
141
 
142
 
143
+ def get_birch(X, n_clusters, **kwargs):
144
+ model = Birch(n_clusters=n_clusters)
145
  model.set_params(**kwargs)
146
  return model.fit(X)
147
 
148
 
149
+ def get_gaussianmixture(X, n_clusters, **kwargs):
150
  model = GaussianMixture(
151
+ n_components=n_clusters, covariance_type="full", random_state=SEED,
152
  )
153
  model.set_params(**kwargs)
154
  return model.fit(X)
 
167
 
168
 
169
  def plot_clusters(ax, X, labels):
170
+ set_clusters = set(labels)
171
+ set_clusters.discard(-1) # -1 signifiies outliers, which we plot separately
172
+ for label in sorted(set_clusters):
173
  idx = labels == label
174
  if not sum(idx):
175
  continue
176
  ax.scatter(X[idx, 0], X[idx, 1])
177
 
178
+ # show outliers (if any)
179
+ idx = labels == -1
180
+ if sum(idx):
181
+ ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')
182
+
183
  ax.grid(None)
184
  ax.set_xticks([])
185
  ax.set_yticks([])
186
  return ax
187
 
188
 
189
+ def cluster(clustering_algorithm: str, dataset: str, n_clusters: int):
190
+ n_clusters = int(n_clusters)
191
+ X, labels = DATA_MAPPING[dataset](n_clusters)
192
+ model = MODEL_MAPPING[clustering_algorithm](X, n_clusters=n_clusters)
193
  if hasattr(model, "labels_"):
194
  y_pred = model.labels_.astype(int)
195
  else:
 
197
 
198
  fig, axes = plt.subplots(1, 2, figsize=(16, 8))
199
 
200
+ # show true labels in first panel
201
  ax = axes[0]
202
  plot_clusters(ax, X, labels)
203
  ax.set_title("True clusters")
204
 
205
+ # show learned clusters in second panel
206
  ax = axes[1]
207
  plot_clusters(ax, X, y_pred)
208
  ax.set_title(clustering_algorithm)
209
 
210
  return fig
211
 
212
+
213
  title = "Clustering with Scikit-learn"
214
+ description = (
215
+ "This example shows how different clustering algorithms work. Simply pick "
216
+ "the algorithm and the dataset to see how the clustering algorithms work."
217
+ )
218
  demo = gr.Interface(
219
  fn=cluster,
220
  inputs=[
 
228
  value="regular",
229
  label="dataset"
230
  ),
231
+ gr.Slider(
232
+ minimum=1,
233
+ maximum=MAX_CLUSTERS,
234
+ value=4,
235
+ step=1,
236
+ )
237
  ],
238
  title=title,
239
  description=description,