|
""" |
|
Demo is based on https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html |
|
""" |
|
from sklearn.svm import SVC |
|
from sklearn.datasets import load_digits |
|
from sklearn.feature_selection import RFE |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
digits = load_digits() |
|
X = digits.images.reshape((len(digits.images), -1)) |
|
y = digits.target |
|
|
|
|
|
svc = SVC(kernel="linear", C=1) |
|
|
|
|
|
def recursive_feature_elimination(n_features_to_select, step, esimator=svc): |
|
|
|
fig = plt.figure() |
|
rfe = RFE(estimator=esimator, n_features_to_select=1, step=1) |
|
|
|
|
|
rfe.fit(X, y) |
|
ranking = rfe.ranking_.reshape(digits.images[0].shape) |
|
|
|
|
|
plt.matshow(ranking, cmap=plt.cm.Blues) |
|
plt.colorbar() |
|
plt.title("Ranking of pixels with RFE") |
|
|
|
return plt |
|
|
|
|
|
import gradio as gr |
|
|
|
title = " Illustration of Recursive feature elimination.🌲 " |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown( |
|
"This example demonstrates recursive feature elimination. <br>" |
|
"Dataset is `load_digits()` which is images of size 8x8 images of hand-written digits. <br>" |
|
"**Parameters** <br> **Number of features to select**: Represents the features left at the end of feature selection process. <br>" |
|
"**Step**: Number of feature to remove at each iteration, least important are removed. <br>" |
|
) |
|
|
|
gr.Markdown( |
|
"Support vector classifier is used as estimator to rank features. <br>" |
|
) |
|
|
|
gr.Markdown( |
|
"Demo is based on [sklearn docs](https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html)." |
|
) |
|
with gr.Row(): |
|
n_features_to_select = gr.Slider( |
|
minimum=0, maximum=20, step=1, value=1, label="Number of features to select" |
|
) |
|
step = gr.Slider(minimum=0, maximum=20, step=1, value=1, label="Step") |
|
|
|
btn = gr.Button(value="Submit") |
|
|
|
btn.click( |
|
recursive_feature_elimination, |
|
inputs=[n_features_to_select, step], |
|
outputs=gr.Plot( |
|
label="Recursive feature elimination of pixels in digit classification" |
|
), |
|
) |
|
|
|
gr.Markdown( |
|
"Plot shows the importance of each pixel in the classification of the digits. <br>" |
|
) |
|
|
|
demo.launch() |
|
|