"""
Demo is based on https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html
"""
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn.feature_selection import RFE
import matplotlib.pyplot as plt
# Load the digits dataset
digits = load_digits()
X = digits.images.reshape((len(digits.images), -1))
y = digits.target
# Create the RFE object and rank each pixel
svc = SVC(kernel="linear", C=1)
def recursive_feature_elimination(n_features_to_select, step, esimator=svc):
# Plot the results
fig = plt.figure()
rfe = RFE(estimator=esimator, n_features_to_select=1, step=1)
# step : Number of feature to remove at each iteration, least important are removed
# n_features_to_select : Number of features to be selected after repeated elimination
rfe.fit(X, y)
ranking = rfe.ranking_.reshape(digits.images[0].shape)
# Plot pixel ranking
plt.matshow(ranking, cmap=plt.cm.Blues)
plt.colorbar()
plt.title("Ranking of pixels with RFE")
# plt.show()
return plt
import gradio as gr
title = " Illustration of Recursive feature elimination.🌲 "
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(
" This example the feature importnace by using Recursive feature elimination
"
" Dataset is load_digits() which is images of size 8 X 8 hand-written digits
"
" **Parameters**
**Number of features to select** : Represents the features left at the end of feature selection process.
"
" **Step** : Number of feature to remove at each iteration, least important are removed.
"
)
gr.Markdown(
" Support Vector classifier is used as estimator to rank features.
"
)
gr.Markdown(
" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html)**"
)
with gr.Row():
n_features_to_select = gr.Slider(
minimum=0, maximum=20, step=1, value=1, label="Number of features to select"
)
step = gr.Slider(minimum=0, maximum=20, step=1, value=1, label="Step")
btn = gr.Button(value="Submit")
btn.click(
recursive_feature_elimination,
inputs=[n_features_to_select, step],
outputs=gr.Plot(
label="Recursive feature elimination of pixels in digit classification"
),
) #
gr.Markdown(
" Plot shows the importance of each pixel in the classification of the digits.
"
)
demo.launch()