vumichien's picture
Update app.py
0b7082e
raw
history blame
4.56 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
)
model_card = f"""
## Description
The **DecisionTreeClassifier** employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as **ccp_alpha**.
By increasing the value of **ccp_alpha**, a greater number of nodes can be pruned. This demo demonstrates the impact of **ccp_alpha** on tree regularization
## Dataset
Breast Cancer
"""
X, y = load_breast_cancer(return_X_y=True)
def get_ccp(test_size, random_state):
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size)
clf = DecisionTreeClassifier(random_state=random_state)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
fig1, ax1 = plt.subplots()
ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
ax1.set_xlabel("effective alpha")
ax1.set_ylabel("total impurity of leaves")
ax1.set_title("Total Impurity vs effective alpha for training set")
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]
node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig2, ax2 = plt.subplots()
ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
ax2.set_xlabel("alpha")
ax2.set_ylabel("number of nodes")
ax2.set_title("Number of nodes vs alpha")
fig3, ax3 = plt.subplots()
ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
ax3.set_xlabel("alpha")
ax3.set_ylabel("depth of tree")
ax3.set_title("Depth vs alpha")
fig3.tight_layout()
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]
fig4, ax4 = plt.subplots()
ax4.set_xlabel("alpha")
ax4.set_ylabel("accuracy")
ax4.set_title("Accuracy vs alpha for training and testing sets")
ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax4.legend()
score_gap = []
for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas):
score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha))
score_gap.sort(key=lambda a: a[2])
top3_score = score_gap[:3]
top3_score.sort(key=lambda a: a[1], reverse=True)
text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}"
return fig1, fig2, fig3, fig4, text
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>βš’ Post pruning decision trees with cost complexity pruning πŸ› </h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py\">scikit-learn</a>")
test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size")
random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state")
with gr.Row():
with gr.Column():
plot_impurity = gr.Plot()
with gr.Column():
plot_node = gr.Plot()
with gr.Row():
with gr.Column():
plot_depth = gr.Plot()
with gr.Column():
plot_compare = gr.Plot()
with gr.Row():
result = gr.Textbox(label="Resusts")
test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result])
random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result])
demo.launch()