Commit
·
f811cf6
1
Parent(s):
99addf7
App itself
Browse files
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
import pandas as pd
|
4 |
+
from sklearn.pipeline import Pipeline
|
5 |
+
from sklearn.impute import SimpleImputer
|
6 |
+
from sklearn.datasets import fetch_openml
|
7 |
+
from sklearn.compose import ColumnTransformer
|
8 |
+
from sklearn.preprocessing import OrdinalEncoder
|
9 |
+
from sklearn.ensemble import RandomForestClassifier
|
10 |
+
from sklearn.model_selection import train_test_split
|
11 |
+
|
12 |
+
import utils
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def app_fn(seed: int, n_cat: int, n_estimators: int, min_samples_leaf: int):
|
18 |
+
X, y = fetch_openml(
|
19 |
+
"titanic", version=1, as_frame=True, return_X_y=True, parser="pandas"
|
20 |
+
)
|
21 |
+
|
22 |
+
rng = np.random.RandomState(seed=seed)
|
23 |
+
|
24 |
+
X["random_cat"] = rng.randint(n_cat, size=X.shape[0])
|
25 |
+
X["random_num"] = rng.randn(X.shape[0])
|
26 |
+
|
27 |
+
categorical_columns = ["pclass", "sex", "embarked", "random_cat"]
|
28 |
+
numerical_columns = ["age", "sibsp", "parch", "fare", "random_num"]
|
29 |
+
|
30 |
+
X = X[categorical_columns + numerical_columns]
|
31 |
+
|
32 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=seed)
|
33 |
+
|
34 |
+
categorical_encoder = OrdinalEncoder(
|
35 |
+
handle_unknown="use_encoded_value", unknown_value=-1, encoded_missing_value=-1
|
36 |
+
)
|
37 |
+
numerical_pipe = SimpleImputer(strategy="mean")
|
38 |
+
|
39 |
+
preprocessing = ColumnTransformer(
|
40 |
+
[
|
41 |
+
("cat", categorical_encoder, categorical_columns),
|
42 |
+
("num", numerical_pipe, numerical_columns),
|
43 |
+
],
|
44 |
+
verbose_feature_names_out=False,
|
45 |
+
)
|
46 |
+
|
47 |
+
clf = Pipeline(
|
48 |
+
[
|
49 |
+
("preprocess", preprocessing),
|
50 |
+
("classifier", RandomForestClassifier(
|
51 |
+
random_state=seed,
|
52 |
+
n_estimators=n_estimators,
|
53 |
+
min_samples_leaf=min_samples_leaf
|
54 |
+
)
|
55 |
+
),
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
clf.fit(X_train, y_train)
|
60 |
+
|
61 |
+
fig_mdi = utils.plot_rf_importance(clf)
|
62 |
+
fig_perm_train = utils.plot_permutation_boxplot(clf, X_train, y_train, set_="train set")
|
63 |
+
fig_perm_test = utils.plot_permutation_boxplot(clf, X_test, y_test, set_="test set")
|
64 |
+
|
65 |
+
return fig_mdi, fig_perm_train, fig_perm_test
|
66 |
+
|
67 |
+
|
68 |
+
title = "Permutation Importance vs Random Forest Feature Importance (MDI)"
|
69 |
+
with gr.Blocks(title=title) as demo:
|
70 |
+
gr.Markdown(f"# {title}")
|
71 |
+
gr.Markdown(
|
72 |
+
"""
|
73 |
+
### This demo compares the feature importances of a Random Forest classifier using the Mean Decrease Impurity (MDI) method and the Permutation Importance method. \
|
74 |
+
To showcase the difference between the two methods, we add two random features to the Titanic dataset. \
|
75 |
+
The first random feature is categorical and the second one is numerical. \
|
76 |
+
The categorical feature can have its number of categories changed \
|
77 |
+
and the numerical feature is sampled from a Standard Normal Distribution. \
|
78 |
+
Random Forest hyperparameters can also be changed to verify the impact of model complexity on the feature importances.
|
79 |
+
|
80 |
+
[Original Example](https://scikit-learn.org/stable/auto_examples/inspection/plot_permutation_importance.html#sphx-glr-auto-examples-inspection-plot-permutation-importance-py)
|
81 |
+
"""
|
82 |
+
)
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
seed = gr.inputs.Slider(0, 42, 1, default=42, label="Seed")
|
86 |
+
n_cat = gr.inputs.Slider(2, 30, 1, default=3, label="# Cats in random_cat")
|
87 |
+
n_estimators = gr.inputs.Slider(5, 150, 5, default=100, label="# Trees in the forest")
|
88 |
+
min_samples_leaf = gr.inputs.Slider(1, 30, 5, default=1, label="Minimum # samples required to be at a leaf node")
|
89 |
+
|
90 |
+
btn = gr.Button(label="Run")
|
91 |
+
|
92 |
+
fig_mdi = gr.Plot(label="Mean Decrease Impurity (MDI)")
|
93 |
+
|
94 |
+
with gr.Row():
|
95 |
+
fig_perm_train = gr.Plot(label="Permutation Importance (Train)")
|
96 |
+
fig_perm_test = gr.Plot(label="Permutation Importance (Test)")
|
97 |
+
|
98 |
+
btn.click(fn=app_fn, outputs=[fig_mdi, fig_perm_train, fig_perm_test], inputs=[seed, n_cat, n_estimators, min_samples_leaf])
|
99 |
+
demo.load(fn=app_fn, outputs=[fig_mdi, fig_perm_train, fig_perm_test], inputs=[seed, n_cat, n_estimators, min_samples_leaf])
|
100 |
+
|
101 |
+
demo.launch()
|