vumichien commited on
Commit
eceb46f
·
1 Parent(s): 1f2b512

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.datasets import load_breast_cancer
6
+ from sklearn.tree import DecisionTreeClassifier
7
+
8
+ theme = gr.themes.Monochrome(
9
+ primary_hue="indigo",
10
+ secondary_hue="blue",
11
+ neutral_hue="slate",
12
+ )
13
+
14
+ X, y = load_breast_cancer(return_X_y=True)
15
+
16
+ def get_ccp(test_size, random_state):
17
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state, test_size=test_size)
18
+ clf = DecisionTreeClassifier(random_state=random_state)
19
+ path = clf.cost_complexity_pruning_path(X_train, y_train)
20
+ ccp_alphas, impurities = path.ccp_alphas, path.impurities
21
+
22
+ fig1, ax1 = plt.subplots()
23
+ ax1.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
24
+ ax1.set_xlabel("effective alpha")
25
+ ax1.set_ylabel("total impurity of leaves")
26
+ ax1.set_title("Total Impurity vs effective alpha for training set")
27
+
28
+ clfs = []
29
+ for ccp_alpha in ccp_alphas:
30
+ clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
31
+ clf.fit(X_train, y_train)
32
+ clfs.append(clf)
33
+ clfs = clfs[:-1]
34
+ ccp_alphas = ccp_alphas[:-1]
35
+
36
+ node_counts = [clf.tree_.node_count for clf in clfs]
37
+ depth = [clf.tree_.max_depth for clf in clfs]
38
+
39
+ fig2, ax2 = plt.subplots()
40
+ ax2.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
41
+ ax2.set_xlabel("alpha")
42
+ ax2.set_ylabel("number of nodes")
43
+ ax2.set_title("Number of nodes vs alpha")
44
+
45
+ fig3, ax3 = plt.subplots()
46
+
47
+ ax3.plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
48
+ ax3.set_xlabel("alpha")
49
+ ax3.set_ylabel("depth of tree")
50
+ ax3.set_title("Depth vs alpha")
51
+ fig3.tight_layout()
52
+
53
+ train_scores = [clf.score(X_train, y_train) for clf in clfs]
54
+ test_scores = [clf.score(X_test, y_test) for clf in clfs]
55
+
56
+ fig4, ax4 = plt.subplots()
57
+ ax4.set_xlabel("alpha")
58
+ ax4.set_ylabel("accuracy")
59
+ ax4.set_title("Accuracy vs alpha for training and testing sets")
60
+ ax4.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
61
+ ax4.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
62
+ ax4.legend()
63
+
64
+ score_gap = []
65
+ for train_score, test_score, ccp_alpha in zip(test_scores, train_scores, ccp_alphas):
66
+ score_gap.append((train_score, test_score, abs(train_score - test_score), ccp_alpha))
67
+ score_gap.sort(key=lambda a: a[2])
68
+ top3_score = score_gap[:3]
69
+ top3_score.sort(key=lambda a: a[1], reverse=True)
70
+ text = f"Train accuracy: {round(top3_score[0][0], 2)}, Test accuracy: {round(top3_score[0][1], 2)}, The best value of cost complexity parameter alpha (ccp_alpha): {round(top3_score[0][2], 2)}"
71
+ return fig1, fig2, fig3, fig4, text
72
+
73
+
74
+ with gr.Blocks(theme=theme) as demo:
75
+ gr.Markdown('''
76
+ <div>
77
+ <h1 style='text-align: center'>⚒ Post pruning decision trees with cost complexity pruning 🛠</h1>
78
+ </div>
79
+ ''')
80
+ gr.Markdown("The DecisionTreeClassifier employs a pruning technique that can be configured using the cost complexity parameter, commonly referred to as ccp_alpha."
81
+ " By increasing the value of ccp_alpha, a greater number of nodes can be pruned. This demo demonstrate the impact of ccp_alpha on tree regularization")
82
+ gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py\">scikit-learn</a>")
83
+ test_size = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Test size")
84
+ random_state = gr.Slider(minimum=0, maximum=2000, step=1, value=0, label="Random state")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ plot_impurity = gr.Plot()
89
+ with gr.Column():
90
+ plot_node = gr.Plot()
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ plot_depth = gr.Plot()
95
+ with gr.Column():
96
+ plot_compare = gr.Plot()
97
+ with gr.Row():
98
+ result = gr.Textbox(label="Resusts")
99
+ test_size.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare, result])
100
+ random_state.change(fn=get_ccp, inputs=[test_size, random_state], outputs=[plot_impurity, plot_node, plot_depth, plot_compare,result])
101
+
102
+ demo.launch()