huabdul commited on
Commit
4432d63
1 Parent(s): 7292509

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -6
app.py CHANGED
@@ -39,7 +39,7 @@ MODELS = [
39
 
40
  # ========================================
41
 
42
- def create_plot(feature_string, n_estimators, model_idx):
43
  np.random.seed(SEED)
44
 
45
  feature_list = feature_string.split(',')
@@ -58,6 +58,10 @@ def create_plot(feature_string, n_estimators, model_idx):
58
 
59
  model_name = MODEL_NAMES[model_idx]
60
  model = MODELS[model_idx]
 
 
 
 
61
 
62
  model.fit(X, y)
63
  score = round(model.score(X, y), 3)
@@ -94,11 +98,23 @@ def iter_grid(n_rows, n_cols):
94
  with gr.Column():
95
  yield
96
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Blocks() as demo:
 
98
  selections = combinations(FEATURES, 2)
99
  selections = [f'{s[0]}, {s[1]}' for s in selections]
100
  dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
101
- slider = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
 
102
 
103
  counter = 0
104
  for _ in iter_grid(2, 2):
@@ -108,10 +124,11 @@ with gr.Blocks() as demo:
108
  plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
109
  fn = partial(create_plot, model_idx=counter)
110
 
111
- dd.change(fn, inputs=[dd, slider], outputs=[plot])
112
- slider.change(fn, inputs=[dd, slider], outputs=[plot])
113
- demo.load(fn, inputs=[dd, slider], outputs=[plot])
 
114
 
115
  counter += 1
116
 
117
- demo.launch()
 
39
 
40
  # ========================================
41
 
42
+ def create_plot(feature_string, n_estimators, max_depth, model_idx):
43
  np.random.seed(SEED)
44
 
45
  feature_list = feature_string.split(',')
 
58
 
59
  model_name = MODEL_NAMES[model_idx]
60
  model = MODELS[model_idx]
61
+
62
+ if model_idx != 0: model.n_estimators = n_estimators
63
+ if model_idx != 3: model.max_depth = max_depth
64
+ if model_idx == 3: model.estimator.max_depth = max_depth
65
 
66
  model.fit(X, y)
67
  score = round(model.score(X, y), 3)
 
98
  with gr.Column():
99
  yield
100
 
101
+ info = '''
102
+ ## Plot the decision surfaces of ensembles of trees on the Iris dataset
103
+
104
+ This plot compares the decision surfaces learned by a decision tree classifier, a random forest classifier, an extra-trees classifier, and by an AdaBoost classifier.
105
+
106
+ There are in total four features in the Iris dataset. In this example you can select two features at a time for visualization purposes using the dropdown box below.
107
+
108
+ You can also vary the number of estimators in the ensembles and the max depth of the trees using the sliders.
109
+ '''
110
+
111
  with gr.Blocks() as demo:
112
+ gr.Markdown(info)
113
  selections = combinations(FEATURES, 2)
114
  selections = [f'{s[0]}, {s[1]}' for s in selections]
115
  dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
116
+ slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
117
+ slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
118
 
119
  counter = 0
120
  for _ in iter_grid(2, 2):
 
124
  plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
125
  fn = partial(create_plot, model_idx=counter)
126
 
127
+ dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
128
+ slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
129
+ slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
130
+ demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
131
 
132
  counter += 1
133
 
134
+ demo.launch(share=True, debug=True)