huabdul commited on
Commit
a74c801
1 Parent(s): a96b0d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.colors import ListedColormap
4
+ from itertools import combinations
5
+ from functools import partial
6
+
7
+ plt.rcParams['figure.dpi'] = 100
8
+
9
+ from sklearn.datasets import load_iris
10
+ from sklearn.ensemble import (
11
+ RandomForestClassifier,
12
+ ExtraTreesClassifier,
13
+ AdaBoostClassifier,
14
+ )
15
+ from sklearn.tree import DecisionTreeClassifier
16
+
17
+ import gradio as gr
18
+
19
+ # ========================================
20
+
21
+ C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
22
+ CMAP = ListedColormap([C1, C2, C3])
23
+ GRANULARITY = 0.01
24
+ SEED = 1
25
+ N_ESTIMATORS = 30
26
+
27
+ FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
28
+ LABELS = ["Setosa", "Versicolour", "Virginica"]
29
+ MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']
30
+
31
+ iris = load_iris()
32
+
33
+ MODELS = [
34
+ DecisionTreeClassifier(max_depth=None),
35
+ RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
36
+ ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
37
+ AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
38
+ ]
39
+
40
+ # ========================================
41
+
42
+ def create_plot(feature_string, n_estimators, model_idx):
43
+ np.random.seed(SEED)
44
+
45
+ feature_list = feature_string.split(',')
46
+ feature_list = [s.strip() for s in feature_list]
47
+ idx_x = FEATURES.index(feature_list[0])
48
+ idx_y = FEATURES.index(feature_list[1])
49
+
50
+ X = iris.data[:, [idx_x, idx_y]]
51
+ y = iris.target
52
+
53
+ rnd_idx = np.random.permutation(X.shape[0])
54
+ X = X[rnd_idx]
55
+ y = y[rnd_idx]
56
+
57
+ X = (X - X.mean(0)) / X.std(0)
58
+
59
+ model_name = MODEL_NAMES[model_idx]
60
+ model = MODELS[model_idx]
61
+
62
+ model.fit(X, y)
63
+ score = round(model.score(X, y), 3)
64
+
65
+ x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
66
+ y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
67
+ xrange = np.arange(x_min, x_max, 0.1)
68
+ yrange = np.arange(y_min, y_max, 0.1)
69
+ xx, yy = np.meshgrid(xrange, yrange)
70
+
71
+ Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
72
+ Z = Z.reshape(xx.shape)
73
+
74
+ fig = plt.figure()
75
+ ax = fig.add_subplot(111)
76
+
77
+ ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
78
+
79
+ for i, label in enumerate(LABELS):
80
+ X_label = X[y==i,:]
81
+ y_label = y[y==i]
82
+ ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)
83
+
84
+ ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
85
+ ax.legend()
86
+ ax.set_title(f'{model_name} | Score: {score}')
87
+
88
+ return fig
89
+
90
+ def iter_grid(n_rows, n_cols):
91
+ for _ in range(n_rows):
92
+ with gr.Row():
93
+ for _ in range(n_cols):
94
+ with gr.Column():
95
+ yield
96
+
97
+ with gr.Blocks() as demo:
98
+ selections = combinations(FEATURES, 2)
99
+ selections = [f'{s[0]}, {s[1]}' for s in selections]
100
+ dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
101
+ slider = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
102
+
103
+ counter = 0
104
+ for _ in iter_grid(2, 2):
105
+ if counter >= len(MODELS):
106
+ break
107
+
108
+ plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
109
+ fn = partial(create_plot, model_idx=counter)
110
+
111
+ dd.change(fn, inputs=[dd, slider], outputs=[plot])
112
+ slider.change(fn, inputs=[dd, slider], outputs=[plot])
113
+ demo.load(fn, inputs=[dd, slider], outputs=[plot])
114
+
115
+ counter += 1
116
+
117
+ demo.launch(share=True, debug=True)