import plotly.express as px import numpy as np from sklearn.datasets import fetch_openml from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.utils import check_random_state import gradio as gr # Load data from https://www.openml.org/d/554 X, y = fetch_openml( "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas" ) print("Data loaded") random_state = check_random_state(0) permutation = random_state.permutation(X.shape[0]) X = X[permutation] y = y[permutation] X = X.reshape((X.shape[0], -1)) scaler = StandardScaler() def dataset_display(digit, count_per_digit, binary_image): if digit not in range(10): # return a figure displaying an error message return px.imshow( np.zeros((28, 28)), labels=dict(x="Pixel columns", y="Pixel rows"), title=f"Digit {digit} is not in the data", ) binary_value = True if binary_image == 1 else False digit_idxs = np.where(y == str(digit))[0] random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False) fig = px.imshow( np.array([X[i].reshape(28, 28) for i in random_idxs]), labels=dict(x="Pixel columns", y="Pixel rows"), title=f"Examples of Digit {digit} in Data", facet_col=0, facet_col_wrap=5, binary_string=binary_value, ) return fig def predict(img): try: img = img.reshape(1, -1) except: return "Show Your Drawing Skills" try: img = scaler.transform(img) prediction = clf.predict(img) return prediction[0] except: return "Train the model first" def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"): X_train, X_test, y_train, y_test = train_test_split( X, y, train_size=train_sample, test_size=10000 ) penalty_dict = { "l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"], "l1": ["liblinear", "saga"], "elasticnet": ["saga"], } if solver not in penalty_dict[penalty]: return ( "Solver not supported for the selected penalty", "Change the Combination", None, ) global clf global scaler scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol) clf.fit(X_train, y_train) sparsity = np.mean(clf.coef_ == 0) * 100 score = clf.score(X_test, y_test) coef = clf.coef_.copy() scale = np.abs(coef).max() fig = px.imshow( np.array([coef[i].reshape(28, 28) for i in range(10)]), labels=dict(x="Pixel columns", y="Pixel rows"), title=f"Classification vector for each digit", range_color=[-scale, scale], facet_col=0, facet_col_wrap=5, facet_col_spacing=0.01, color_continuous_scale="RdBu", zmin=-scale, zmax=scale, ) return score, sparsity, fig with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Phân loại dữ liệu MNIST bằng mô hình logistic đa thức và chính quy hóa L1") gr.Markdown( """Mục tiêu chính của bản demo này là giới thiệu cách sử dụng hồi quy logistic trong việc phân loại các chữ số viết tay từ tập dữ liệu [MNIST](https://en.wikipedia.org/wiki/MNIST_database), một tập dữ liệu điểm chuẩn nổi tiếng trong máy tính tầm nhìn. Tập dữ liệu được tải từ [OpenML](https://www.openml.org/d/554), đây là một nền tảng mở dành cho nghiên cứu máy học giúp dễ dàng truy cập vào số lượng lớn tập dữ liệu. Mô hình này được đào tạo bằng thư viện scikit-learn, thư viện này cung cấp nhiều công cụ cho máy học, bao gồm các thuật toán phân loại, hồi quy và phân cụm, cũng như các công cụ tiền xử lý dữ liệu và đánh giá mô hình. Bản demo tính toán điểm số và số liệu thưa thớt bằng cách sử dụng dữ liệu thử nghiệm, cung cấp thông tin chi tiết tương ứng về hiệu suất và độ thưa thớt của mô hình. Số liệu điểm cho biết mô hình đang hoạt động tốt như thế nào, trong khi số liệu thưa thớt cung cấp thông tin về số hệ số khác 0 trong mô hình, có thể hữu ích cho việc diễn giải mô hình và giảm độ phức tạp của nó. """ ) with gr.Tab("Khám phá dữ liệu"): gr.Markdown("## ") with gr.Row(): digit = gr.Slider(0, 9, label="Lựa chọn số", value=5, step=1) count_per_digit = gr.Slider( 1, 10, label="Số lượng ảnh", value=10, step=1 ) binary_image = gr.Slider(0, 1, label="Phân loại ảnh nhị phân", value=0, step=1) gen_btn = gr.Button("Hiển thị") gen_btn.click( dataset_display, inputs=[digit, count_per_digit, binary_image], outputs=gr.Plot(), ) with gr.Tab("Huấn luyện mô hình"): gr.Markdown("# Thay đổi các tham số để xem mô hình thay đổi như thế nào") gr.Markdown("## Solver and penalty") gr.Markdown( """ Penalty | Solver -------|--------------- l1 | saga l2 | saga """ ) with gr.Row(): train_sample = gr.Slider( 1000, 60000, label="Số lượng dữ liệu huấn luyện", value=5000, step=1 ) c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1) tol = gr.Slider( 0.1, 1, label="Dung sai cho tiêu chí dừng.", value=0.1, step=0.1 ) max_iter = gr.Slider(100, 1000, label="Số vòng huấn luyện", value=100, step=1) penalty = gr.Dropdown( ["l1", "l2",], label="Chính quy hóa", value="l1" ) solver = gr.Dropdown( ["saga"], label="Thuật toán", value="saga", ) train_btn = gr.Button("Huấn luyện") train_btn.click( train_model, inputs=[train_sample, c, tol, solver, penalty], outputs=[ gr.Textbox(label="Độ chính xác"), gr.Textbox(label="Độ thưa thớt"), gr.Plot(), ], ) with gr.Tab("Dự đoán số mới"): gr.Markdown("## Draw a digit and see the model's prediction") inputs = gr.Sketchpad(brush_radius=1.0) outputs = gr.Textbox(label="Kết quả dự đoán", lines=1) skecth_btn = gr.Button("Dự đoán ảnh vẽ tay") skecth_btn.click(predict, inputs, outputs) demo.launch()