huabdul's picture
Update app.py
4bd4fb5
raw history blame
No virus
5.06 kB
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from itertools import combinations
from functools import partial
plt.rcParams['figure.dpi'] = 100
from sklearn.datasets import load_iris
from sklearn.ensemble import (
RandomForestClassifier,
ExtraTreesClassifier,
AdaBoostClassifier,
)
from sklearn.tree import DecisionTreeClassifier
import gradio as gr
# ========================================
C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
CMAP = ListedColormap([C1, C2, C3])
GRANULARITY = 0.05
SEED = 1
N_ESTIMATORS = 30
FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
LABELS = ["Setosa", "Versicolour", "Virginica"]
MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']
iris = load_iris()
MODELS = [
DecisionTreeClassifier(max_depth=None),
RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
]
# ========================================
def create_plot(feature_string, n_estimators, max_depth, model_idx):
np.random.seed(SEED)
feature_list = feature_string.split(',')
feature_list = [s.strip() for s in feature_list]
idx_x = FEATURES.index(feature_list[0])
idx_y = FEATURES.index(feature_list[1])
X = iris.data[:, [idx_x, idx_y]]
y = iris.target
rnd_idx = np.random.permutation(X.shape[0])
X = X[rnd_idx]
y = y[rnd_idx]
X = (X - X.mean(0)) / X.std(0)
model_name = MODEL_NAMES[model_idx]
model = MODELS[model_idx]
if model_idx != 0: model.n_estimators = n_estimators
if model_idx != 3: model.max_depth = max_depth
if model_idx == 3: model.estimator.max_depth = max_depth
model.fit(X, y)
score = round(model.score(X, y), 3)
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xrange = np.arange(x_min, x_max, GRANULARITY)
yrange = np.arange(y_min, y_max, GRANULARITY)
xx, yy = np.meshgrid(xrange, yrange)
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
fig = plt.figure(figsize=(4, 3.5))
ax = fig.add_subplot(111)
ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)
for i, label in enumerate(LABELS):
X_label = X[y==i,:]
y_label = y[y==i]
ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)
ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
ax.legend()
ax.set_title(f'{model_name} | Score: {score}')
fig.set_tight_layout(True)
fig.set_constrained_layout(True)
return fig
def iter_grid(n_rows, n_cols):
for _ in range(n_rows):
with gr.Row():
for _ in range(n_cols):
with gr.Column():
yield
info = '''
# Plot the decision surfaces of ensembles of trees on the Iris dataset
This plot compares the **decision surfaces** learned by a decision tree classifier, a random forest classifier, an extra-trees classifier, and by an AdaBoost classifier.
There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.
Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.
Created by [@hubadul](https://huggingface.co/huabdul) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_iris.html).
'''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(info)
selections = combinations(FEATURES, 2)
selections = [f'{s[0]}, {s[1]}' for s in selections]
dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')
with gr.Column(scale=2):
counter = 0
for _ in iter_grid(2, 2):
if counter >= len(MODELS):
break
plot = gr.Plot(show_label=False)
fn = partial(create_plot, model_idx=counter)
dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
counter += 1
demo.launch()