Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ import gradio as gr
|
|
20 |
|
21 |
C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
|
22 |
CMAP = ListedColormap([C1, C2, C3])
|
23 |
-
GRANULARITY = 0.
|
24 |
SEED = 1
|
25 |
N_ESTIMATORS = 30
|
26 |
|
@@ -68,14 +68,14 @@ def create_plot(feature_string, n_estimators, max_depth, model_idx):
|
|
68 |
|
69 |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
70 |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
71 |
-
xrange = np.arange(x_min, x_max,
|
72 |
-
yrange = np.arange(y_min, y_max,
|
73 |
xx, yy = np.meshgrid(xrange, yrange)
|
74 |
|
75 |
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
76 |
Z = Z.reshape(xx.shape)
|
77 |
|
78 |
-
fig = plt.figure()
|
79 |
ax = fig.add_subplot(111)
|
80 |
|
81 |
ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
|
@@ -88,6 +88,8 @@ def create_plot(feature_string, n_estimators, max_depth, model_idx):
|
|
88 |
ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
|
89 |
ax.legend()
|
90 |
ax.set_title(f'{model_name} | Score: {score}')
|
|
|
|
|
91 |
|
92 |
return fig
|
93 |
|
@@ -106,29 +108,34 @@ This plot compares the **decision surfaces** learned by a decision tree classifi
|
|
106 |
There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.
|
107 |
|
108 |
Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.
|
|
|
|
|
109 |
'''
|
110 |
|
111 |
with gr.Blocks() as demo:
|
112 |
-
gr.
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
|
134 |
demo.launch()
|
|
|
20 |
|
21 |
C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
|
22 |
CMAP = ListedColormap([C1, C2, C3])
|
23 |
+
GRANULARITY = 0.05
|
24 |
SEED = 1
|
25 |
N_ESTIMATORS = 30
|
26 |
|
|
|
68 |
|
69 |
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
70 |
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
71 |
+
xrange = np.arange(x_min, x_max, GRANULARITY)
|
72 |
+
yrange = np.arange(y_min, y_max, GRANULARITY)
|
73 |
xx, yy = np.meshgrid(xrange, yrange)
|
74 |
|
75 |
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
76 |
Z = Z.reshape(xx.shape)
|
77 |
|
78 |
+
fig = plt.figure(figsize=(4, 3.5))
|
79 |
ax = fig.add_subplot(111)
|
80 |
|
81 |
ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
|
|
|
88 |
ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
|
89 |
ax.legend()
|
90 |
ax.set_title(f'{model_name} | Score: {score}')
|
91 |
+
fig.set_tight_layout(True)
|
92 |
+
fig.set_constrained_layout(True)
|
93 |
|
94 |
return fig
|
95 |
|
|
|
108 |
There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.
|
109 |
|
110 |
Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.
|
111 |
+
|
112 |
+
Created by [@hubadul](https://huggingface.co/huabdul) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_iris.html).
|
113 |
'''
|
114 |
|
115 |
with gr.Blocks() as demo:
|
116 |
+
with gr.Row():
|
117 |
+
with gr.Column(scale=1):
|
118 |
+
gr.Markdown(info)
|
119 |
+
selections = combinations(FEATURES, 2)
|
120 |
+
selections = [f'{s[0]}, {s[1]}' for s in selections]
|
121 |
+
dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
|
122 |
+
slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
|
123 |
+
slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
|
124 |
+
|
125 |
+
with gr.Column(scale=2):
|
126 |
+
counter = 0
|
127 |
+
for _ in iter_grid(2, 2):
|
128 |
+
if counter >= len(MODELS):
|
129 |
+
break
|
130 |
+
|
131 |
+
plot = gr.Plot(show_label=False)
|
132 |
+
fn = partial(create_plot, model_idx=counter)
|
133 |
+
|
134 |
+
dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
|
135 |
+
slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
|
136 |
+
slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
|
137 |
+
demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
|
138 |
+
|
139 |
+
counter += 1
|
140 |
|
141 |
demo.launch()
|