Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.model_selection import train_test_split | |
from sklearn.datasets import load_breast_cancer | |
from sklearn.tree import DecisionTreeClassifier | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
) | |
model_card = f""" | |
## Description | |
The **DecisionTreeClassifier** employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as **ccp_alpha**. | |
By increasing the value of **ccp_alpha**, a greater number of nodes can be pruned. | |
In this demo, a DecisionTreeClassifier will be trained on the Breast Cancer dataset. | |
Then, the effect of **ccp_alpha** in many terms of the tree-based model like the impurity of leaves, depth, number of nodes, and accuracy on train and test data are shown in many figures. | |
Based on this information, the best number of **ccp_alpha** is chosen. This demo also shows the results of the best **ccp_alpha** with accuracy on train and test datasets. | |
You can play around with different ``test size`` and ``random state`` | |
## Dataset | |
Breast Cancer | |
""" | |
X, y = load_breast_cancer(return_X_y=True) | |
def get_ccp(test_size, random_state): | |
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size) | |
clf = DecisionTreeClassifier(random_state=random_state) | |
path = clf.cost_complexity_pruning_path(X_train, y_train) | |
ccp_alphas, impurities = path.ccp_alphas, path.impurities | |
fig1, ax1 = plt.subplots() | |
ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post") | |
ax1.set_xlabel("effective alpha") | |
ax1.set_ylabel("total impurity of leaves") | |
ax1.set_title("Total Impurity vs effective alpha for training set") | |
clfs = [] | |
for ccp_alpha in ccp_alphas: | |
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha) | |
clf.fit(X_train, y_train) | |
clfs.append(clf) | |
clfs = clfs[:-1] | |
ccp_alphas = ccp_alphas[:-1] | |
node_counts = [clf.tree_.node_count for clf in clfs] | |
depth = [clf.tree_.max_depth for clf in clfs] | |
fig2, ax2 = plt.subplots() | |
ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post") | |
ax2.set_xlabel("alpha") | |
ax2.set_ylabel("number of nodes") | |
ax2.set_title("Number of nodes vs alpha") | |
fig3, ax3 = plt.subplots() | |
ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post") | |
ax3.set_xlabel("alpha") | |
ax3.set_ylabel("depth of tree") | |
ax3.set_title("Depth vs alpha") | |
fig3.tight_layout() | |
train_scores = [clf.score(X_train, y_train) for clf in clfs] | |
test_scores = [clf.score(X_test, y_test) for clf in clfs] | |
fig4, ax4 = plt.subplots() | |
ax4.set_xlabel("alpha") | |
ax4.set_ylabel("accuracy") | |
ax4.set_title("Accuracy vs alpha for training and testing sets") | |
ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post") | |
ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post") | |
ax4.legend() | |
score_gap = [] | |
for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas): | |
score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha)) | |
score_gap.sort(key=lambda a: a[2]) | |
top3_score = score_gap[:3] | |
top3_score.sort(key=lambda a: a[1], reverse=True) | |
text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}" | |
return fig1, fig2, fig3, fig4, text | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown(''' | |
<div> | |
<h1 style='text-align: center'>β Post pruning decision trees with cost complexity pruning π </h1> | |
</div> | |
''') | |
gr.Markdown(model_card) | |
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py\">scikit-learn</a>") | |
test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size") | |
random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state") | |
with gr.Row(): | |
with gr.Column(): | |
plot_impurity = gr.Plot() | |
with gr.Column(): | |
plot_node = gr.Plot() | |
with gr.Row(): | |
with gr.Column(): | |
plot_depth = gr.Plot() | |
with gr.Column(): | |
plot_compare = gr.Plot() | |
with gr.Row(): | |
result = gr.Textbox(label="Resusts") | |
test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result]) | |
random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result]) | |
demo.launch() |