classify-digit / app.py
vumichien's picture
Update app.py
1f22a6d
import plotly.express as px
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
import gradio as gr
# Load data from https://www.openml.org/d/554
X, y = fetch_openml(
"mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)
print("Data loaded")
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))
scaler = StandardScaler()
def dataset_display(digit, count_per_digit, binary_image):
if digit not in range(10):
# return a figure displaying an error message
return px.imshow(
np.zeros((28, 28)),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Digit {digit} is not in the data",
)
binary_value = True if binary_image == 1 else False
digit_idxs = np.where(y == str(digit))[0]
random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)
fig = px.imshow(
np.array([X[i].reshape(28, 28) for i in random_idxs]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Examples of Digit {digit} in Data",
facet_col=0,
facet_col_wrap=5,
binary_string=binary_value,
)
return fig
def predict(img):
try:
img = img.reshape(1, -1)
except:
return "Show Your Drawing Skills"
try:
img = scaler.transform(img)
prediction = clf.predict(img)
return prediction[0]
except:
return "Train the model first"
def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=train_sample, test_size=10000
)
penalty_dict = {
"l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
"l1": ["liblinear", "saga"],
"elasticnet": ["saga"],
}
if solver not in penalty_dict[penalty]:
return (
"Solver not supported for the selected penalty",
"Change the Combination",
None,
)
global clf
global scaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
clf.fit(X_train, y_train)
sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)
coef = clf.coef_.copy()
scale = np.abs(coef).max()
fig = px.imshow(
np.array([coef[i].reshape(28, 28) for i in range(10)]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Classification vector for each digit",
range_color=[-scale, scale],
facet_col=0,
facet_col_wrap=5,
facet_col_spacing=0.01,
color_continuous_scale="RdBu",
zmin=-scale,
zmax=scale,
)
return score, sparsity, fig
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Phân loại dữ liệu MNIST bằng mô hình logistic đa thức và chính quy hóa L1")
gr.Markdown(
"""Mục tiêu chính của bản demo này là giới thiệu cách sử dụng hồi quy logistic trong việc phân loại các chữ số viết tay từ tập dữ liệu [MNIST](https://en.wikipedia.org/wiki/MNIST_database), một tập dữ liệu điểm chuẩn nổi tiếng trong máy tính tầm nhìn. Tập dữ liệu được tải từ [OpenML](https://www.openml.org/d/554), đây là một nền tảng mở dành cho nghiên cứu máy học giúp dễ dàng truy cập vào số lượng lớn tập dữ liệu.
Mô hình này được đào tạo bằng thư viện scikit-learn, thư viện này cung cấp nhiều công cụ cho máy học, bao gồm các thuật toán phân loại, hồi quy và phân cụm, cũng như các công cụ tiền xử lý dữ liệu và đánh giá mô hình. Bản demo tính toán điểm số và số liệu thưa thớt bằng cách sử dụng dữ liệu thử nghiệm, cung cấp thông tin chi tiết tương ứng về hiệu suất và độ thưa thớt của mô hình. Số liệu điểm cho biết mô hình đang hoạt động tốt như thế nào, trong khi số liệu thưa thớt cung cấp thông tin về số hệ số khác 0 trong mô hình, có thể hữu ích cho việc diễn giải mô hình và giảm độ phức tạp của nó.
"""
)
with gr.Tab("Khám phá dữ liệu"):
gr.Markdown("## ")
with gr.Row():
digit = gr.Slider(0, 9, label="Lựa chọn số", value=5, step=1)
count_per_digit = gr.Slider(
1, 10, label="Số lượng ảnh", value=10, step=1
)
binary_image = gr.Slider(0, 1, label="Phân loại ảnh nhị phân", value=0, step=1)
gen_btn = gr.Button("Hiển thị")
gen_btn.click(
dataset_display,
inputs=[digit, count_per_digit, binary_image],
outputs=gr.Plot(),
)
with gr.Tab("Huấn luyện mô hình"):
gr.Markdown("# Thay đổi các tham số để xem mô hình thay đổi như thế nào")
gr.Markdown("## Solver and penalty")
gr.Markdown(
"""
Penalty | Solver
-------|---------------
l1 | saga
l2 | saga
"""
)
with gr.Row():
train_sample = gr.Slider(
1000, 60000, label="Số lượng dữ liệu huấn luyện", value=5000, step=1
)
c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
tol = gr.Slider(
0.1, 1, label="Dung sai cho tiêu chí dừng.", value=0.1, step=0.1
)
max_iter = gr.Slider(100, 1000, label="Số vòng huấn luyện", value=100, step=1)
penalty = gr.Dropdown(
["l1", "l2",], label="Chính quy hóa", value="l1"
)
solver = gr.Dropdown(
["saga"],
label="Thuật toán",
value="saga",
)
train_btn = gr.Button("Huấn luyện")
train_btn.click(
train_model,
inputs=[train_sample, c, tol, solver, penalty],
outputs=[
gr.Textbox(label="Độ chính xác"),
gr.Textbox(label="Độ thưa thớt"),
gr.Plot(),
],
)
with gr.Tab("Dự đoán số mới"):
gr.Markdown("## Draw a digit and see the model's prediction")
inputs = gr.Sketchpad(brush_radius=1.0)
outputs = gr.Textbox(label="Kết quả dự đoán", lines=1)
skecth_btn = gr.Button("Dự đoán ảnh vẽ tay")
skecth_btn.click(predict, inputs, outputs)
demo.launch()