vumichien commited on
Commit
4e1e636
·
1 Parent(s): a56876a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+ import numpy as np
3
+
4
+ from sklearn.datasets import fetch_openml
5
+ from sklearn.linear_model import LogisticRegression
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.preprocessing import StandardScaler
8
+ from sklearn.utils import check_random_state
9
+ import gradio as gr
10
+
11
+
12
+ # Load data from https://www.openml.org/d/554
13
+ X, y = fetch_openml(
14
+ "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
15
+ )
16
+
17
+ print("Data loaded")
18
+ random_state = check_random_state(0)
19
+ permutation = random_state.permutation(X.shape[0])
20
+ X = X[permutation]
21
+ y = y[permutation]
22
+ X = X.reshape((X.shape[0], -1))
23
+
24
+
25
+ scaler = StandardScaler()
26
+
27
+
28
+ def dataset_display(digit, count_per_digit, binary_image):
29
+ if digit not in range(10):
30
+ # return a figure displaying an error message
31
+ return px.imshow(
32
+ np.zeros((28, 28)),
33
+ labels=dict(x="Pixel columns", y="Pixel rows"),
34
+ title=f"Digit {digit} is not in the data",
35
+ )
36
+
37
+ binary_value = True if binary_image == 1 else False
38
+ digit_idxs = np.where(y == str(digit))[0]
39
+ random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)
40
+
41
+ fig = px.imshow(
42
+ np.array([X[i].reshape(28, 28) for i in random_idxs]),
43
+ labels=dict(x="Pixel columns", y="Pixel rows"),
44
+ title=f"Examples of Digit {digit} in Data",
45
+ facet_col=0,
46
+ facet_col_wrap=5,
47
+ binary_string=binary_value,
48
+ )
49
+
50
+ return fig
51
+
52
+
53
+ def predict(img):
54
+ try:
55
+ img = img.reshape(1, -1)
56
+ except:
57
+ return "Show Your Drawing Skills"
58
+ try:
59
+ img = scaler.transform(img)
60
+ prediction = clf.predict(img)
61
+ return prediction[0]
62
+ except:
63
+ return "Train the model first"
64
+
65
+
66
+ def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
67
+ X_train, X_test, y_train, y_test = train_test_split(
68
+ X, y, train_size=train_sample, test_size=10000
69
+ )
70
+
71
+ penalty_dict = {
72
+ "l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
73
+ "l1": ["liblinear", "saga"],
74
+ "elasticnet": ["saga"],
75
+ }
76
+
77
+ if solver not in penalty_dict[penalty]:
78
+ return (
79
+ "Solver not supported for the selected penalty",
80
+ "Change the Combination",
81
+ None,
82
+ )
83
+
84
+ global clf
85
+ global scaler
86
+ scaler = StandardScaler()
87
+ X_train = scaler.fit_transform(X_train)
88
+ X_test = scaler.transform(X_test)
89
+
90
+ clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
91
+ clf.fit(X_train, y_train)
92
+ sparsity = np.mean(clf.coef_ == 0) * 100
93
+ score = clf.score(X_test, y_test)
94
+
95
+ coef = clf.coef_.copy()
96
+ scale = np.abs(coef).max()
97
+
98
+ fig = px.imshow(
99
+ np.array([coef[i].reshape(28, 28) for i in range(10)]),
100
+ labels=dict(x="Pixel columns", y="Pixel rows"),
101
+ title=f"Classification vector for each digit",
102
+ range_color=[-scale, scale],
103
+ facet_col=0,
104
+ facet_col_wrap=5,
105
+ facet_col_spacing=0.01,
106
+ color_continuous_scale="RdBu",
107
+ zmin=-scale,
108
+ zmax=scale,
109
+ )
110
+
111
+ return score, sparsity, fig
112
+
113
+
114
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
115
+ gr.Markdown("# Phân loại dữ liệu MNIST bằng mô hình logistic logistic đa thức và chính quy hóa L1")
116
+ gr.Markdown(
117
+ """Mục tiêu chính của bản demo này là giới thiệu cách sử dụng hồi quy logistic trong việc phân loại các chữ số viết tay từ tập dữ liệu [MNIST](https://en.wikipedia.org/wiki/MNIST_database), một tập dữ liệu điểm chuẩn nổi tiếng trong máy tính tầm nhìn. Tập dữ liệu được tải từ [OpenML](https://www.openml.org/d/554), đây là một nền tảng mở dành cho nghiên cứu máy học giúp dễ dàng truy cập vào số lượng lớn tập dữ liệu.
118
+ Mô hình này được đào tạo bằng thư viện scikit-learn, thư viện này cung cấp nhiều công cụ cho máy học, bao gồm các thuật toán phân loại, hồi quy và phân cụm, cũng như các công cụ tiền xử lý dữ liệu và đánh giá mô hình. Bản demo tính toán điểm số và số liệu thưa thớt bằng cách sử dụng dữ liệu thử nghiệm, cung cấp thông tin chi tiết tương ứng về hiệu suất và độ thưa thớt của mô hình. Số liệu điểm cho biết mô hình đang hoạt động tốt như thế nào, trong khi số liệu thưa thớt cung cấp thông tin về số hệ số khác 0 trong mô hình, có thể hữu ích cho việc diễn giải mô hình và giảm độ phức tạp của nó.
119
+ """
120
+ )
121
+
122
+ with gr.Tab("Khám phá dữ liệu"):
123
+ gr.Markdown("## ")
124
+ with gr.Row():
125
+ digit = gr.Slider(0, 9, label="Lựa chọn số", value=5, step=1)
126
+ count_per_digit = gr.Slider(
127
+ 1, 10, label="Số lượng ảnh", value=10, step=1
128
+ )
129
+ binary_image = gr.Slider(0, 1, label="Phân loại ảnh nhị phân", value=0, step=1)
130
+
131
+ gen_btn = gr.Button("Hiển thị")
132
+ gen_btn.click(
133
+ dataset_display,
134
+ inputs=[digit, count_per_digit, binary_image],
135
+ outputs=gr.Plot(),
136
+ )
137
+
138
+ with gr.Tab("Huấn luyện mô hình"):
139
+ gr.Markdown("# Thay đổi các tham số để xem mô hình thay đổi như thế nào")
140
+
141
+ gr.Markdown("## Solver and penalty")
142
+ gr.Markdown(
143
+ """
144
+ Penalty | Solver
145
+ -------|---------------
146
+ l1 | saga
147
+ l2 | saga
148
+ """
149
+ )
150
+
151
+ with gr.Row():
152
+ train_sample = gr.Slider(
153
+ 1000, 60000, label="Số lượng dữ liệu huấn luyện", value=5000, step=1
154
+ )
155
+
156
+ c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
157
+ tol = gr.Slider(
158
+ 0.1, 1, label="Dung sai cho tiêu chí dừng.", value=0.1, step=0.1
159
+ )
160
+ max_iter = gr.Slider(100, 1000, label="Số vòng huấn luyện", value=100, step=1)
161
+
162
+ penalty = gr.Dropdown(
163
+ ["l1", "l2",], label="Chính quy hóa", value="l1"
164
+ )
165
+ solver = gr.Dropdown(
166
+ ["saga"],
167
+ label="Thuật toán",
168
+ value="saga",
169
+ )
170
+
171
+ train_btn = gr.Button("Huấn luyện")
172
+ train_btn.click(
173
+ train_model,
174
+ inputs=[train_sample, c, tol, solver, penalty],
175
+ outputs=[
176
+ gr.Textbox(label="Độ chính xác"),
177
+ gr.Textbox(label="Độ thưa thớt"),
178
+ gr.Plot(),
179
+ ],
180
+ )
181
+
182
+ with gr.Tab("Dự đoán số mới"):
183
+ gr.Markdown("## Draw a digit and see the model's prediction")
184
+ inputs = gr.Sketchpad(brush_radius=1.0)
185
+ outputs = gr.Textbox(label="Kết quả dự đoán", lines=1)
186
+ skecth_btn = gr.Button("Dự đoán ảnh vẽ tay")
187
+ skecth_btn.click(predict, inputs, outputs)
188
+
189
+
190
+ demo.launch()