|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.datasets import load_breast_cancer |
|
from sklearn.tree import DecisionTreeClassifier |
|
|
|
theme = gr.themes.Monochrome( |
|
primary_hue="indigo", |
|
secondary_hue="blue", |
|
neutral_hue="slate", |
|
) |
|
model_card = f""" |
|
## Description |
|
|
|
The **DecisionTreeClassifier** employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as **ccp_alpha**. |
|
By increasing the value of **ccp_alpha**, a greater number of nodes can be pruned. |
|
In this demo, a DecisionTreeClassifier will be trained on the Breast Cancer dataset. Then, the effect of **ccp_alpha** in many terms of the tree-based model like the impurity of leaves, depth, number of nodes, and accuracy on train and test data are shown in many figures. |
|
Based on this information, the best number of **ccp_alpha** is chosen. This demo also shows the results of the best **ccp_alpha** with accuracy on train and test datasets. |
|
You can play around with different ``test size`` and ``random state`` |
|
|
|
## Dataset |
|
|
|
Breast Cancer |
|
|
|
""" |
|
|
|
X, y = load_breast_cancer(return_X_y=True) |
|
|
|
def get_ccp(test_size, random_state): |
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size) |
|
clf = DecisionTreeClassifier(random_state=random_state) |
|
path = clf.cost_complexity_pruning_path(X_train, y_train) |
|
ccp_alphas, impurities = path.ccp_alphas, path.impurities |
|
|
|
fig1, ax1 = plt.subplots() |
|
ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post") |
|
ax1.set_xlabel("effective alpha") |
|
ax1.set_ylabel("total impurity of leaves") |
|
ax1.set_title("Total Impurity vs effective alpha for training set") |
|
|
|
clfs = [] |
|
for ccp_alpha in ccp_alphas: |
|
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha) |
|
clf.fit(X_train, y_train) |
|
clfs.append(clf) |
|
clfs = clfs[:-1] |
|
ccp_alphas = ccp_alphas[:-1] |
|
|
|
node_counts = [clf.tree_.node_count for clf in clfs] |
|
depth = [clf.tree_.max_depth for clf in clfs] |
|
|
|
fig2, ax2 = plt.subplots() |
|
ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post") |
|
ax2.set_xlabel("alpha") |
|
ax2.set_ylabel("number of nodes") |
|
ax2.set_title("Number of nodes vs alpha") |
|
|
|
fig3, ax3 = plt.subplots() |
|
ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post") |
|
ax3.set_xlabel("alpha") |
|
ax3.set_ylabel("depth of tree") |
|
ax3.set_title("Depth vs alpha") |
|
fig3.tight_layout() |
|
|
|
train_scores = [clf.score(X_train, y_train) for clf in clfs] |
|
test_scores = [clf.score(X_test, y_test) for clf in clfs] |
|
|
|
fig4, ax4 = plt.subplots() |
|
ax4.set_xlabel("alpha") |
|
ax4.set_ylabel("accuracy") |
|
ax4.set_title("Accuracy vs alpha for training and testing sets") |
|
ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post") |
|
ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post") |
|
ax4.legend() |
|
|
|
score_gap = [] |
|
for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas): |
|
score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha)) |
|
score_gap.sort(key=lambda a: a[2]) |
|
top3_score = score_gap[:3] |
|
top3_score.sort(key=lambda a: a[1], reverse=True) |
|
text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}" |
|
return fig1, fig2, fig3, fig4, text |
|
|
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown(''' |
|
<div> |
|
<h1 style='text-align: center'>β Post pruning decision trees with cost complexity pruning π </h1> |
|
</div> |
|
''') |
|
gr.Markdown(model_card) |
|
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py\">scikit-learn</a>") |
|
test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size") |
|
random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
plot_impurity = gr.Plot() |
|
with gr.Column(): |
|
plot_node = gr.Plot() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
plot_depth = gr.Plot() |
|
with gr.Column(): |
|
plot_compare = gr.Plot() |
|
with gr.Row(): |
|
result = gr.Textbox(label="Resusts") |
|
test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result]) |
|
random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result]) |
|
|
|
demo.launch() |