rashmi commited on
Commit
c6738f5
1 Parent(s): 715c53e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html
2
+
3
+ from itertools import cycle
4
+ from time import time
5
+
6
+ import gradio as gr
7
+ import matplotlib.colors as colors
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from joblib import cpu_count
11
+ from sklearn.cluster import Birch, MiniBatchKMeans
12
+ from sklearn.datasets import make_blobs
13
+
14
+ plt.switch_backend("agg")
15
+
16
+
17
+ def do_submit(n_samples, birch_threshold, birch_n_clusters):
18
+ n_samples = int(n_samples)
19
+ birch_threshold = float(birch_threshold)
20
+ birch_n_clusters = int(birch_n_clusters)
21
+ result = ""
22
+
23
+ # Generate centers for the blobs so that it forms a 10 X 10 grid.
24
+ xx = np.linspace(-22, 22, 10)
25
+ yy = np.linspace(-22, 22, 10)
26
+ xx, yy = np.meshgrid(xx, yy)
27
+ n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
28
+
29
+ # Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
30
+ X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0)
31
+
32
+ # Use all colors that matplotlib provides by default.
33
+ colors_ = cycle(colors.cnames.keys())
34
+
35
+ fig = plt.figure(figsize=(12, 4))
36
+ fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)
37
+
38
+ # Compute clustering with BIRCH with and without the final clustering step
39
+ # and plot.
40
+ birch_models = [
41
+ Birch(threshold=1.7, n_clusters=None),
42
+ Birch(threshold=1.7, n_clusters=100),
43
+ ]
44
+ final_step = ["without global clustering", "with global clustering"]
45
+
46
+ for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
47
+ t = time()
48
+ birch_model.fit(X)
49
+ result += (
50
+ "BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t))
51
+ + "\n"
52
+ )
53
+
54
+ # Plot result
55
+ labels = birch_model.labels_
56
+ centroids = birch_model.subcluster_centers_
57
+ n_clusters = np.unique(labels).size
58
+ result = result + "n_clusters : %d" % n_clusters + "\n"
59
+
60
+ ax = fig.add_subplot(1, 3, ind + 1)
61
+ for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
62
+ mask = labels == k
63
+ ax.scatter(
64
+ X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5
65
+ )
66
+ if birch_model.n_clusters is None:
67
+ ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
68
+ ax.set_ylim([-25, 25])
69
+ ax.set_xlim([-25, 25])
70
+ ax.set_autoscaley_on(False)
71
+ ax.set_title("BIRCH %s" % info)
72
+
73
+ # Compute clustering with MiniBatchKMeans.
74
+ mbk = MiniBatchKMeans(
75
+ init="k-means++",
76
+ n_clusters=100,
77
+ batch_size=256 * cpu_count(),
78
+ n_init=10,
79
+ max_no_improvement=10,
80
+ verbose=0,
81
+ random_state=0,
82
+ )
83
+ t0 = time()
84
+ mbk.fit(X)
85
+ t_mini_batch = time() - t0
86
+ result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n"
87
+ mbk_means_labels_unique = np.unique(mbk.labels_)
88
+
89
+ ax = fig.add_subplot(1, 3, 3)
90
+ for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):
91
+ mask = mbk.labels_ == k
92
+ ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5)
93
+ ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
94
+ ax.set_xlim([-25, 25])
95
+ ax.set_ylim([-25, 25])
96
+ ax.set_title("MiniBatchKMeans")
97
+ ax.set_autoscaley_on(False)
98
+
99
+ return fig, result
100
+
101
+
102
+ # Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
103
+ theme = gr.themes.Monochrome(
104
+ primary_hue="indigo",
105
+ secondary_hue="blue",
106
+ neutral_hue="slate",
107
+ radius_size=gr.themes.sizes.radius_sm,
108
+ font=[
109
+ gr.themes.GoogleFont("Open Sans"),
110
+ "ui-sans-serif",
111
+ "system-ui",
112
+ "sans-serif",
113
+ ],
114
+ )
115
+
116
+ title = "Compare BIRCH and MiniBatchKMeans"
117
+ with gr.Blocks(title=title, theme=theme) as demo:
118
+ gr.Markdown(f"## {title}")
119
+ gr.Markdown(
120
+ "[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)"
121
+ )
122
+
123
+ gr.Markdown(
124
+ "This example compares the timing of BIRCH (with and without the global clustering step) and \
125
+ MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\
126
+ \n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \
127
+ even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \
128
+ Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \
129
+ rerun it with a larger value for n_samples.\
130
+ \n\n\
131
+ If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters."
132
+ )
133
+
134
+ n_samples = gr.Slider(
135
+ minimum=20000,
136
+ maximum=80000,
137
+ label="Number of samples",
138
+ step=500,
139
+ value=25000,
140
+ )
141
+ birch_threshold = gr.Slider(
142
+ minimum=0.5,
143
+ maximum=2.0,
144
+ label="Birch Threshold",
145
+ step=0.1,
146
+ value=1.7,
147
+ )
148
+ birch_n_clusters = gr.Slider(
149
+ minimum=0,
150
+ maximum=100,
151
+ label="Birch number of clusters",
152
+ step=1,
153
+ value=100,
154
+ )
155
+
156
+ plt_out = gr.Plot()
157
+ output = gr.Textbox(label="Output", multiline=True)
158
+
159
+ sub_btn = gr.Button("Submit")
160
+ sub_btn.click(
161
+ fn=do_submit,
162
+ inputs=[n_samples, birch_threshold, birch_n_clusters],
163
+ outputs=[plt_out, output],
164
+ )
165
+
166
+
167
+ if __name__ == "__main__":
168
+ demo.launch()
169
+