merve's picture
merve HF staff
Update app.py
6c5ad74
"""
Demo is based on https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html
"""
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn.feature_selection import RFE
import matplotlib.pyplot as plt
# Load the digits dataset
digits = load_digits()
X = digits.images.reshape((len(digits.images), -1))
y = digits.target
# Create the RFE object and rank each pixel
svc = SVC(kernel="linear", C=1)
def recursive_feature_elimination(n_features_to_select, step, esimator=svc):
# Plot the results
fig = plt.figure()
rfe = RFE(estimator=esimator, n_features_to_select=1, step=1)
# step : Number of feature to remove at each iteration, least important are removed
# n_features_to_select : Number of features to be selected after repeated elimination
rfe.fit(X, y)
ranking = rfe.ranking_.reshape(digits.images[0].shape)
# Plot pixel ranking
plt.matshow(ranking, cmap=plt.cm.Blues)
plt.colorbar()
plt.title("Ranking of pixels with RFE")
# plt.show()
return plt
import gradio as gr
title = " Illustration of Recursive feature elimination.🌲 "
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(
"This example demonstrates recursive feature elimination. <br>"
"Dataset is `load_digits()` which is images of size 8x8 images of hand-written digits. <br>"
"**Parameters** <br> **Number of features to select**: Represents the features left at the end of feature selection process. <br>"
"**Step**: Number of feature to remove at each iteration, least important are removed. <br>"
)
gr.Markdown(
"Support vector classifier is used as estimator to rank features. <br>"
)
gr.Markdown(
"Demo is based on [sklearn docs](https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html)."
)
with gr.Row():
n_features_to_select = gr.Slider(
minimum=0, maximum=20, step=1, value=1, label="Number of features to select"
)
step = gr.Slider(minimum=0, maximum=20, step=1, value=1, label="Step")
btn = gr.Button(value="Submit")
btn.click(
recursive_feature_elimination,
inputs=[n_features_to_select, step],
outputs=gr.Plot(
label="Recursive feature elimination of pixels in digit classification"
),
) #
gr.Markdown(
"Plot shows the importance of each pixel in the classification of the digits. <br>"
)
demo.launch()