File size: 5,803 Bytes
c6738f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd426fe
c6738f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html

from itertools import cycle
from time import time

import gradio as gr
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
from joblib import cpu_count
from sklearn.cluster import Birch, MiniBatchKMeans
from sklearn.datasets import make_blobs

plt.switch_backend("agg")


def do_submit(n_samples, birch_threshold, birch_n_clusters):
    n_samples = int(n_samples)
    birch_threshold = float(birch_threshold)
    birch_n_clusters = int(birch_n_clusters)
    result = ""

    # Generate centers for the blobs so that it forms a 10 X 10 grid.
    xx = np.linspace(-22, 22, 10)
    yy = np.linspace(-22, 22, 10)
    xx, yy = np.meshgrid(xx, yy)
    n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))

    # Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
    X, y = make_blobs(n_samples=n_samples, centers=n_centers, random_state=0)

    # Use all colors that matplotlib provides by default.
    colors_ = cycle(colors.cnames.keys())

    fig = plt.figure(figsize=(12, 4))
    fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)

    # Compute clustering with BIRCH with and without the final clustering step
    # and plot.
    birch_models = [
        Birch(threshold=1.7, n_clusters=None),
        Birch(threshold=1.7, n_clusters=100),
    ]
    final_step = ["without global clustering", "with global clustering"]

    for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
        t = time()
        birch_model.fit(X)
        result += (
            "BIRCH %s as the final step took %0.2f seconds" % (info, (time() - t))
            + "\n"
        )

        # Plot result
        labels = birch_model.labels_
        centroids = birch_model.subcluster_centers_
        n_clusters = np.unique(labels).size
        result = result + "n_clusters : %d" % n_clusters + "\n"

        ax = fig.add_subplot(1, 3, ind + 1)
        for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
            mask = labels == k
            ax.scatter(
                X[mask, 0], X[mask, 1], c="w", edgecolor=col, marker=".", alpha=0.5
            )
            if birch_model.n_clusters is None:
                ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
        ax.set_ylim([-25, 25])
        ax.set_xlim([-25, 25])
        ax.set_autoscaley_on(False)
        ax.set_title("BIRCH %s" % info)

    # Compute clustering with MiniBatchKMeans.
    mbk = MiniBatchKMeans(
        init="k-means++",
        n_clusters=100,
        batch_size=256 * cpu_count(),
        n_init=10,
        max_no_improvement=10,
        verbose=0,
        random_state=0,
    )
    t0 = time()
    mbk.fit(X)
    t_mini_batch = time() - t0
    result += "Time taken to run MiniBatchKMeans %0.2f seconds" % t_mini_batch + "\n"
    mbk_means_labels_unique = np.unique(mbk.labels_)

    ax = fig.add_subplot(1, 3, 3)
    for this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):
        mask = mbk.labels_ == k
        ax.scatter(X[mask, 0], X[mask, 1], marker=".", c="w", edgecolor=col, alpha=0.5)
        ax.scatter(this_centroid[0], this_centroid[1], marker="+", c="k", s=25)
    ax.set_xlim([-25, 25])
    ax.set_ylim([-25, 25])
    ax.set_title("MiniBatchKMeans")
    ax.set_autoscaley_on(False)

    return fig, result


# Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)

title = "Compare BIRCH and MiniBatchKMeans"
with gr.Blocks(title=title, theme=theme) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(
        "This is an interactive demo for this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_birch_vs_minibatchkmeans.html)."
    )

    gr.Markdown(
        "This example compares the timing of BIRCH (with and without the global clustering step) and \
        MiniBatchKMeans on a synthetic dataset having 25,000 samples and 2 features generated using make_blobs.\
 \n Both MiniBatchKMeans and BIRCH are very scalable algorithms and could run efficiently on hundreds of thousands or \
    even millions of datapoints. We chose to limit the dataset size of this example in the interest of keeping our \
    Continuous Integration resource usage reasonable but the interested reader might enjoy editing this script to \
    rerun it with a larger value for n_samples.\
\n\n\
If n_clusters is set to None, the data is reduced from 25,000 samples to a set of 158 clusters. This can be viewed as a preprocessing step before the final (global) clustering step that further reduces these 158 clusters to 100 clusters."
    )
    
    n_samples = gr.Slider(
        minimum=20000,
        maximum=80000,
        label="Number of samples",
        step=500,
        value=25000,
    )
    birch_threshold = gr.Slider(
        minimum=0.5,
        maximum=2.0,
        label="Birch Threshold",
        step=0.1,
        value=1.7,
    )
    birch_n_clusters = gr.Slider(
        minimum=0,
        maximum=100,
        label="Birch number of clusters",
        step=1,
        value=100,
    )

    plt_out = gr.Plot()
    output = gr.Textbox(label="Output", multiline=True)

    sub_btn = gr.Button("Submit")
    sub_btn.click(
        fn=do_submit,
        inputs=[n_samples, birch_threshold, birch_n_clusters],
        outputs=[plt_out, output],
    )


if __name__ == "__main__":
    demo.launch()