""" Demo is based on https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html """ from sklearn.svm import SVC from sklearn.datasets import load_digits from sklearn.feature_selection import RFE import matplotlib.pyplot as plt # Load the digits dataset digits = load_digits() X = digits.images.reshape((len(digits.images), -1)) y = digits.target # Create the RFE object and rank each pixel svc = SVC(kernel="linear", C=1) def recursive_feature_elimination(n_features_to_select, step, esimator=svc): # Plot the results fig = plt.figure() rfe = RFE(estimator=esimator, n_features_to_select=1, step=1) # step : Number of feature to remove at each iteration, least important are removed # n_features_to_select : Number of features to be selected after repeated elimination rfe.fit(X, y) ranking = rfe.ranking_.reshape(digits.images[0].shape) # Plot pixel ranking plt.matshow(ranking, cmap=plt.cm.Blues) plt.colorbar() plt.title("Ranking of pixels with RFE") # plt.show() return plt import gradio as gr title = " Illustration of Recursive feature elimination.🌲 " with gr.Blocks(title=title) as demo: gr.Markdown(f"# {title}") gr.Markdown( "This example demonstrates recursive feature elimination.
" "Dataset is `load_digits()` which is images of size 8x8 images of hand-written digits.
" "**Parameters**
**Number of features to select**: Represents the features left at the end of feature selection process.
" "**Step**: Number of feature to remove at each iteration, least important are removed.
" ) gr.Markdown( "Support vector classifier is used as estimator to rank features.
" ) gr.Markdown( "Demo is based on [sklearn docs](https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html)." ) with gr.Row(): n_features_to_select = gr.Slider( minimum=0, maximum=20, step=1, value=1, label="Number of features to select" ) step = gr.Slider(minimum=0, maximum=20, step=1, value=1, label="Step") btn = gr.Button(value="Submit") btn.click( recursive_feature_elimination, inputs=[n_features_to_select, step], outputs=gr.Plot( label="Recursive feature elimination of pixels in digit classification" ), ) # gr.Markdown( "Plot shows the importance of each pixel in the classification of the digits.
" ) demo.launch()