File size: 7,121 Bytes
3531ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f27632
3531ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import plotly.express as px
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
import gradio as gr


# Load data from https://www.openml.org/d/554
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

print("Data loaded")
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))


scaler = StandardScaler()


def dataset_display(digit, count_per_digit, binary_image):
    if digit not in range(10):
        # return a figure displaying an error message
        return px.imshow(
            np.zeros((28, 28)),
            labels=dict(x="Pixel columns", y="Pixel rows"),
            title=f"Digit {digit} is not in the data",
        )

    binary_value = True if binary_image == 1 else False
    digit_idxs = np.where(y == str(digit))[0]
    random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)

    fig = px.imshow(
        np.array([X[i].reshape(28, 28) for i in random_idxs]),
        labels=dict(x="Pixel columns", y="Pixel rows"),
        title=f"Examples of Digit {digit} in Data",
        facet_col=0,
        facet_col_wrap=5,
        binary_string=binary_value,
    )

    return fig


def predict(img):
    try:
        img = img.reshape(1, -1)
    except:
        return "Show Your Drawing Skills"

    try:
        img = scaler.transform(img)
        prediction = clf.predict(img)
        return prediction[0]
    except:
        return "Train the model first"


def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=train_sample, test_size=10000
    )

    penalty_dict = {
        "l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
        "l1": ["liblinear", "saga"],
        "elasticnet": ["saga"],
    }

    if solver not in penalty_dict[penalty]:
        return (
            "Solver not supported for the selected penalty",
            "Change the Combination",
            None,
        )

    global clf
    global scaler
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
    clf.fit(X_train, y_train)
    sparsity = np.mean(clf.coef_ == 0) * 100
    score = clf.score(X_test, y_test)

    coef = clf.coef_.copy()
    scale = np.abs(coef).max()

    fig = px.imshow(
        np.array([coef[i].reshape(28, 28) for i in range(10)]),
        labels=dict(x="Pixel columns", y="Pixel rows"),
        title=f"Classification vector for each digit",
        range_color=[-scale, scale],
        facet_col=0,
        facet_col_wrap=5,
        facet_col_spacing=0.01,
        color_continuous_scale="RdBu",
        zmin=-scale,
        zmax=scale,
    )

    return score, sparsity, fig


with gr.Blocks() as demo:
    gr.Markdown("# MNIST classification using multinomial logistic + L1 ")
    gr.Markdown(
        """This interactive demo is based on the [MNIST classification using multinomial logistic + L1](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sparse_logistic_regression_mnist.html#sphx-glr-auto-examples-linear-model-plot-sparse-logistic-regression-mnist-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/)  library, which is a widely-used library for machine learning in Python. The primary goal of this demo is to showcase the use of logistic regression in classifying handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset, which is a well-known benchmark dataset in computer vision. The dataset is loaded from [OpenML](https://www.openml.org/d/554), which is an open platform for machine learning research that provides easy access to a large number of datasets.
The model is trained using the scikit-learn library, which provides a range of tools for machine learning, including classification, regression, and clustering algorithms, as well as tools for data preprocessing and model evaluation. The demo calculates the score and sparsity metrics using test data, which provides insight into the model's performance and sparsity, respectively. The score metric indicates how well the model is performing, while the sparsity metric provides information about the number of non-zero coefficients in the model, which can be useful for interpreting the model and reducing its complexity.
    """
    )

    with gr.Tab("Explore the Data"):
        gr.Markdown("## ")
        with gr.Row():
            digit = gr.Slider(0, 9, label="Select the Digit", value=5, step=1)
            count_per_digit = gr.Slider(
                1, 10, label="Number of Images", value=10, step=1
            )
            binary_image = gr.Slider(0, 1, label="Binary Image", value=0, step=1)

        gen_btn = gr.Button("Show Me ")
        gen_btn.click(
            dataset_display,
            inputs=[digit, count_per_digit, binary_image],
            outputs=gr.Plot(),
        )

    with gr.Tab("Train Your Model"):
        gr.Markdown("# Play with the parameters to see how the model changes")

        gr.Markdown("## Solver and penalty")

        gr.Markdown(
            """
        Penalty | Solver
        -------|---------------
        l1 | liblinear, saga 
        l2  | lbfgs, newton-cg, newton-cholesky, sag, saga             
        elasticnet | saga
        """
        )

        with gr.Row():
            train_sample = gr.Slider(
                1000, 60000, label="Train Sample", value=5000, step=1
            )

            c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
            tol = gr.Slider(
                0.1, 1, label="Tolerance for stopping criteria.", value=0.1, step=0.1
            )
            max_iter = gr.Slider(100, 1000, label="Max Iter", value=100, step=1)

            penalty = gr.Dropdown(
                ["l1", "l2", "elasticnet"], label="Penalty", value="l1"
            )
            solver = gr.Dropdown(
                ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
                label="Solver",
                value="saga",
            )

        train_btn = gr.Button("Train")
        train_btn.click(
            train_model,
            inputs=[train_sample, c, tol, solver, penalty],
            outputs=[
                gr.Textbox(label="Score"),
                gr.Textbox(label="Sparsity"),
                gr.Plot(),
            ],
        )

    with gr.Tab("Predict the Digit"):
        gr.Markdown("## Draw a digit and see the model's prediction")
        inputs = gr.Sketchpad(brush_radius=1.0)
        outputs = gr.Textbox(label="Predicted Label", lines=1)
        skecth_btn = gr.Button("Classify the Sketch")
        skecth_btn.click(predict, inputs, outputs)


demo.launch()