|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
from sklearn.inspection import permutation_importance |
|
|
|
def plot_rf_importance(clf): |
|
feature_names = clf[:-1].get_feature_names_out() |
|
mdi_importances = pd.Series( |
|
clf[-1].feature_importances_, index=feature_names |
|
).sort_values(ascending=True) |
|
|
|
|
|
fig = px.bar(mdi_importances, orientation="h", title="Random Forest Feature Importances (MDI)") |
|
fig.update_layout(showlegend=False, xaxis_title="Importance", yaxis_title="Feature") |
|
|
|
return fig |
|
|
|
def plot_permutation_boxplot(clf, X: np.ndarray, y: np.array, set_: str=None): |
|
|
|
result = permutation_importance( |
|
clf, X, y, n_repeats=10, random_state=42, n_jobs=2 |
|
) |
|
|
|
sorted_importances_idx = result.importances_mean.argsort() |
|
importances = pd.DataFrame( |
|
result.importances[sorted_importances_idx].T, |
|
columns=X.columns[sorted_importances_idx], |
|
) |
|
|
|
fig = px.box( |
|
importances.melt(), |
|
y="variable", |
|
x="value" |
|
) |
|
|
|
|
|
fig.add_shape( |
|
type="line", |
|
x0=0, |
|
y0=-1, |
|
x1=0, |
|
y1=len(importances.columns), |
|
opacity=0.5, |
|
line=dict( |
|
dash="dash" |
|
), |
|
) |
|
|
|
x_min = importances.min().min() |
|
x_min = x_min - 0.005 if x_min < 0 else -0.005 |
|
x_max = importances.max().max() + 0.005 |
|
fig.update_xaxes(range=[x_min, x_max]) |
|
fig.update_layout( |
|
title=f"Permutation Importances {set_ if set_ else ''}", |
|
xaxis_title="Importance", |
|
yaxis_title="Feature", |
|
showlegend=False |
|
) |
|
|
|
return fig |