eswardivi commited on
Commit
3531ac7
1 Parent(s): eda9e53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+ import numpy as np
3
+
4
+ from sklearn.datasets import fetch_openml
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.utils import check_random_state
9
+ import gradio as gr
10
+
11
+
12
+ # Load data from https://www.openml.org/d/554
13
+ X, y = fetch_openml(
14
+ "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
15
+ )
16
+
17
+ print("Data loaded")
18
+ random_state = check_random_state(0)
19
+ permutation = random_state.permutation(X.shape[0])
20
+ X = X[permutation]
21
+ y = y[permutation]
22
+ X = X.reshape((X.shape[0], -1))
23
+
24
+
25
+ scaler = StandardScaler()
26
+
27
+
28
+ def dataset_display(digit, count_per_digit, binary_image):
29
+ if digit not in range(10):
30
+ # return a figure displaying an error message
31
+ return px.imshow(
32
+ np.zeros((28, 28)),
33
+ labels=dict(x="Pixel columns", y="Pixel rows"),
34
+ title=f"Digit {digit} is not in the data",
35
+ )
36
+
37
+ binary_value = True if binary_image == 1 else False
38
+ digit_idxs = np.where(y == str(digit))[0]
39
+ random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)
40
+
41
+ fig = px.imshow(
42
+ np.array([X[i].reshape(28, 28) for i in random_idxs]),
43
+ labels=dict(x="Pixel columns", y="Pixel rows"),
44
+ title=f"Examples of Digit {digit} in Data",
45
+ facet_col=0,
46
+ facet_col_wrap=5,
47
+ binary_string=binary_value,
48
+ )
49
+
50
+ return fig
51
+
52
+
53
+ def predict(img):
54
+ try:
55
+ img = img.reshape(1, -1)
56
+ except:
57
+ return "Show Your Drawing Skills"
58
+
59
+ try:
60
+ img = scaler.transform(img)
61
+ prediction = clf.predict(img)
62
+ return prediction[0]
63
+ except:
64
+ return "Train the model first"
65
+
66
+
67
+ def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
68
+ X_train, X_test, y_train, y_test = train_test_split(
69
+ X, y, train_size=train_sample, test_size=10000
70
+ )
71
+
72
+ penalty_dict = {
73
+ "l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
74
+ "l1": ["liblinear", "saga"],
75
+ "elasticnet": ["saga"],
76
+ }
77
+
78
+ if solver not in penalty_dict[penalty]:
79
+ return (
80
+ "Solver not supported for the selected penalty",
81
+ "Change the Combination",
82
+ None,
83
+ )
84
+
85
+ global clf
86
+ global scaler
87
+ scaler = StandardScaler()
88
+ X_train = scaler.fit_transform(X_train)
89
+ X_test = scaler.transform(X_test)
90
+
91
+ clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
92
+ clf.fit(X_train, y_train)
93
+ sparsity = np.mean(clf.coef_ == 0) * 100
94
+ score = clf.score(X_test, y_test)
95
+
96
+ coef = clf.coef_.copy()
97
+ scale = np.abs(coef).max()
98
+
99
+ fig = px.imshow(
100
+ np.array([coef[i].reshape(28, 28) for i in range(10)]),
101
+ labels=dict(x="Pixel columns", y="Pixel rows"),
102
+ title=f"Classification vector for each digit",
103
+ range_color=[-scale, scale],
104
+ facet_col=0,
105
+ facet_col_wrap=5,
106
+ facet_col_spacing=0.01,
107
+ color_continuous_scale="RdBu",
108
+ zmin=-scale,
109
+ zmax=scale,
110
+ )
111
+
112
+ return score, sparsity, fig
113
+
114
+
115
+ with gr.Blocks() as demo:
116
+ gr.Markdown("# MNIST classification using multinomial logistic + L1 ")
117
+ gr.Markdown(
118
+ """This interactive demo is based on the [MNIST classification using multinomial logistic + L1](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sparse_logistic_regression_mnist.html#sphx-glr-auto-examples-linear-model-plot-sparse-logistic-regression-mnist-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/) library, which is a widely-used library for machine learning in Python. The primary goal of this demo is to showcase the use of logistic regression in classifying handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset, which is a well-known benchmark dataset in computer vision. The dataset is loaded from [OpenML](https://www.openml.org/d/554), which is an open platform for machine learning research that provides easy access to a large number of datasets.
119
+ The model is trained using the scikit-learn library, which provides a range of tools for machine learning, including classification, regression, and clustering algorithms, as well as tools for data preprocessing and model evaluation. The demo calculates the score and sparsity metrics using test data, which provides insight into the model's performance and sparsity, respectively. The score metric indicates how well the model is performing, while the sparsity metric provides information about the number of non-zero coefficients in the model, which can be useful for interpreting the model and reducing its complexity.
120
+ """
121
+ )
122
+
123
+ with gr.Tab("Explore the Data"):
124
+ gr.Markdown("## ")
125
+ with gr.Row():
126
+ digit = gr.Slider(0, 9, label="Select the Digit", value=5, step=1)
127
+ count_per_digit = gr.Slider(
128
+ 1, 10, label="Number of Images", value=10, step=1
129
+ )
130
+ binary_image = gr.Slider(0, 1, label="Binary Image", value=0, step=1)
131
+
132
+ gen_btn = gr.Button("Show Me ")
133
+ gen_btn.click(
134
+ dataset_display,
135
+ inputs=[digit, count_per_digit, binary_image],
136
+ outputs=gr.Plot(),
137
+ )
138
+
139
+ with gr.Tab("Trian Your Model"):
140
+ gr.Markdown("# Play with the parameters to see how the model changes")
141
+
142
+ gr.Markdown("## Solver and penalty")
143
+
144
+ gr.Markdown(
145
+ """
146
+ Penalty | Solver
147
+ -------|---------------
148
+ l1 | liblinear, saga
149
+ l2 | lbfgs, newton-cg, newton-cholesky, sag, saga
150
+ elasticnet | saga
151
+ """
152
+ )
153
+
154
+ with gr.Row():
155
+ train_sample = gr.Slider(
156
+ 1000, 60000, label="Train Sample", value=5000, step=1
157
+ )
158
+
159
+ c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
160
+ tol = gr.Slider(
161
+ 0.1, 1, label="Tolerance for stopping criteria.", value=0.1, step=0.1
162
+ )
163
+ max_iter = gr.Slider(100, 1000, label="Max Iter", value=100, step=1)
164
+
165
+ penalty = gr.Dropdown(
166
+ ["l1", "l2", "elasticnet"], label="Penalty", value="l1"
167
+ )
168
+ solver = gr.Dropdown(
169
+ ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
170
+ label="Solver",
171
+ value="saga",
172
+ )
173
+
174
+ train_btn = gr.Button("Train")
175
+ train_btn.click(
176
+ train_model,
177
+ inputs=[train_sample, c, tol, solver, penalty],
178
+ outputs=[
179
+ gr.Textbox(label="Score"),
180
+ gr.Textbox(label="Sparsity"),
181
+ gr.Plot(),
182
+ ],
183
+ )
184
+
185
+ with gr.Tab("Predict the Digit"):
186
+ gr.Markdown("## Draw a digit and see the model's prediction")
187
+ inputs = gr.Sketchpad(brush_radius=1.0)
188
+ outputs = gr.Textbox(label="Predicted Label", lines=1)
189
+ skecth_btn = gr.Button("Classify the Sketch")
190
+ skecth_btn.click(predict, inputs, outputs)
191
+
192
+
193
+ demo.launch()