huabdul commited on
Commit
4bd4fb5
1 Parent(s): 5836e70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -20,7 +20,7 @@ import gradio as gr
20
 
21
  C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
22
  CMAP = ListedColormap([C1, C2, C3])
23
- GRANULARITY = 0.01
24
  SEED = 1
25
  N_ESTIMATORS = 30
26
 
@@ -68,14 +68,14 @@ def create_plot(feature_string, n_estimators, max_depth, model_idx):
68
 
69
  x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
70
  y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
71
- xrange = np.arange(x_min, x_max, 0.1)
72
- yrange = np.arange(y_min, y_max, 0.1)
73
  xx, yy = np.meshgrid(xrange, yrange)
74
 
75
  Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
76
  Z = Z.reshape(xx.shape)
77
 
78
- fig = plt.figure()
79
  ax = fig.add_subplot(111)
80
 
81
  ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
@@ -88,6 +88,8 @@ def create_plot(feature_string, n_estimators, max_depth, model_idx):
88
  ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
89
  ax.legend()
90
  ax.set_title(f'{model_name} | Score: {score}')
 
 
91
 
92
  return fig
93
 
@@ -106,29 +108,34 @@ This plot compares the **decision surfaces** learned by a decision tree classifi
106
  There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.
107
 
108
  Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.
 
 
109
  '''
110
 
111
  with gr.Blocks() as demo:
112
- gr.Markdown(info)
113
- selections = combinations(FEATURES, 2)
114
- selections = [f'{s[0]}, {s[1]}' for s in selections]
115
- dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
116
- slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
117
- slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
118
-
119
- counter = 0
120
- for _ in iter_grid(2, 2):
121
- if counter >= len(MODELS):
122
- break
123
-
124
- plot = gr.Plot(label=f'{MODEL_NAMES[counter]}')
125
- fn = partial(create_plot, model_idx=counter)
126
-
127
- dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
128
- slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
129
- slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
130
- demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
131
-
132
- counter += 1
 
 
 
133
 
134
  demo.launch()
 
20
 
21
  C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
22
  CMAP = ListedColormap([C1, C2, C3])
23
+ GRANULARITY = 0.05
24
  SEED = 1
25
  N_ESTIMATORS = 30
26
 
 
68
 
69
  x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
70
  y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
71
+ xrange = np.arange(x_min, x_max, GRANULARITY)
72
+ yrange = np.arange(y_min, y_max, GRANULARITY)
73
  xx, yy = np.meshgrid(xrange, yrange)
74
 
75
  Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
76
  Z = Z.reshape(xx.shape)
77
 
78
+ fig = plt.figure(figsize=(4, 3.5))
79
  ax = fig.add_subplot(111)
80
 
81
  ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
 
88
  ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
89
  ax.legend()
90
  ax.set_title(f'{model_name} | Score: {score}')
91
+ fig.set_tight_layout(True)
92
+ fig.set_constrained_layout(True)
93
 
94
  return fig
95
 
 
108
  There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.
109
 
110
  Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.
111
+
112
+ Created by [@hubadul](https://huggingface.co/huabdul) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_iris.html).
113
  '''
114
 
115
  with gr.Blocks() as demo:
116
+ with gr.Row():
117
+ with gr.Column(scale=1):
118
+ gr.Markdown(info)
119
+ selections = combinations(FEATURES, 2)
120
+ selections = [f'{s[0]}, {s[1]}' for s in selections]
121
+ dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
122
+ slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
123
+ slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
124
+
125
+ with gr.Column(scale=2):
126
+ counter = 0
127
+ for _ in iter_grid(2, 2):
128
+ if counter >= len(MODELS):
129
+ break
130
+
131
+ plot = gr.Plot(show_label=False)
132
+ fn = partial(create_plot, model_idx=counter)
133
+
134
+ dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
135
+ slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
136
+ slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
137
+ demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
138
+
139
+ counter += 1
140
 
141
  demo.launch()