vumichien's picture
Update app.py
32f27ed
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
)
model_card = f"""
## Description
The **DecisionTreeClassifier** employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as **ccp_alpha**.
By increasing the value of **ccp_alpha**, a greater number of nodes can be pruned.
In this demo, a DecisionTreeClassifier will be trained on the Breast Cancer dataset.
Then, the effect of **ccp_alpha** in many terms of the tree-based model like the impurity of leaves, depth, number of nodes, and accuracy on train and test data are shown in many figures.
Based on this information, the best number of **ccp_alpha** is chosen. This demo also shows the results of the best **ccp_alpha** with accuracy on train and test datasets.
You can play around with different ``test size`` and ``random state``
## Dataset
Breast Cancer
"""
X, y = load_breast_cancer(return_X_y=True)
def get_ccp(test_size, random_state):
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size)
clf = DecisionTreeClassifier(random_state=random_state)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
fig1, ax1 = plt.subplots()
ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
ax1.set_xlabel("effective alpha")
ax1.set_ylabel("total impurity of leaves")
ax1.set_title("Total Impurity vs effective alpha for training set")
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]
node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig2, ax2 = plt.subplots()
ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
ax2.set_xlabel("alpha")
ax2.set_ylabel("number of nodes")
ax2.set_title("Number of nodes vs alpha")
fig3, ax3 = plt.subplots()
ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
ax3.set_xlabel("alpha")
ax3.set_ylabel("depth of tree")
ax3.set_title("Depth vs alpha")
fig3.tight_layout()
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]
fig4, ax4 = plt.subplots()
ax4.set_xlabel("alpha")
ax4.set_ylabel("accuracy")
ax4.set_title("Accuracy vs alpha for training and testing sets")
ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax4.legend()
score_gap = []
for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas):
score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha))
score_gap.sort(key=lambda a: a[2])
top3_score = score_gap[:3]
top3_score.sort(key=lambda a: a[1], reverse=True)
text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}"
return fig1, fig2, fig3, fig4, text
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>βš’ Post pruning decision trees with cost complexity pruning πŸ› </h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py\">scikit-learn</a>")
test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size")
random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state")
with gr.Row():
with gr.Column():
plot_impurity = gr.Plot()
with gr.Column():
plot_node = gr.Plot()
with gr.Row():
with gr.Column():
plot_depth = gr.Plot()
with gr.Column():
plot_compare = gr.Plot()
with gr.Row():
result = gr.Textbox(label="Resusts")
test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result])
random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result])
demo.launch()