import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from sklearn.datasets import load_breast_cancer from sklearn.tree import DecisionTreeClassifier theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", ) model_card = f""" ## Description The **DecisionTreeClassifier** employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as **ccp_alpha**. By increasing the value of **ccp_alpha**, a greater number of nodes can be pruned. In this demo, a DecisionTreeClassifier will be trained on the Breast Cancer dataset. Then, the effect of **ccp_alpha** in many terms of the tree-based model like the impurity of leaves, depth, number of nodes, and accuracy on train and test data are shown in many figures. Based on this information, the best number of **ccp_alpha** is chosen. This demo also shows the results of the best **ccp_alpha** with accuracy on train and test datasets. You can play around with different ``test size`` and ``random state`` ## Dataset Breast Cancer """ X, y = load_breast_cancer(return_X_y=True) def get_ccp(test_size, random_state): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size) clf = DecisionTreeClassifier(random_state=random_state) path = clf.cost_complexity_pruning_path(X_train, y_train) ccp_alphas, impurities = path.ccp_alphas, path.impurities fig1, ax1 = plt.subplots() ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post") ax1.set_xlabel("effective alpha") ax1.set_ylabel("total impurity of leaves") ax1.set_title("Total Impurity vs effective alpha for training set") clfs = [] for ccp_alpha in ccp_alphas: clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha) clf.fit(X_train, y_train) clfs.append(clf) clfs = clfs[:-1] ccp_alphas = ccp_alphas[:-1] node_counts = [clf.tree_.node_count for clf in clfs] depth = [clf.tree_.max_depth for clf in clfs] fig2, ax2 = plt.subplots() ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post") ax2.set_xlabel("alpha") ax2.set_ylabel("number of nodes") ax2.set_title("Number of nodes vs alpha") fig3, ax3 = plt.subplots() ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post") ax3.set_xlabel("alpha") ax3.set_ylabel("depth of tree") ax3.set_title("Depth vs alpha") fig3.tight_layout() train_scores = [clf.score(X_train, y_train) for clf in clfs] test_scores = [clf.score(X_test, y_test) for clf in clfs] fig4, ax4 = plt.subplots() ax4.set_xlabel("alpha") ax4.set_ylabel("accuracy") ax4.set_title("Accuracy vs alpha for training and testing sets") ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post") ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post") ax4.legend() score_gap = [] for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas): score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha)) score_gap.sort(key=lambda a: a[2]) top3_score = score_gap[:3] top3_score.sort(key=lambda a: a[1], reverse=True) text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}" return fig1, fig2, fig3, fig4, text with gr.Blocks(theme=theme) as demo: gr.Markdown('''

⚒ Post pruning decision trees with cost complexity pruning 🛠

''') gr.Markdown(model_card) gr.Markdown("Author: Vu Minh Chien. Based on the example from scikit-learn") test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size") random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state") with gr.Row(): with gr.Column(): plot_impurity = gr.Plot() with gr.Column(): plot_node = gr.Plot() with gr.Row(): with gr.Column(): plot_depth = gr.Plot() with gr.Column(): plot_compare = gr.Plot() with gr.Row(): result = gr.Textbox(label="Resusts") test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result]) random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result]) demo.launch()