import streamlit as st import pandas as pd import numpy as np import random from backend.utils import make_grid, load_dataset, load_model, load_images from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch from matplotlib.backends.backend_agg import RendererAgg _lock = RendererAgg.lock st.set_page_config(layout='wide') BACKGROUND_COLOR = '#bcd0e7' st.title('Feature attribution visualization with SmoothGrad') st.write("""> **Which features are responsible for the current prediction of ConvNeXt?** In machine learning, it is helpful to identify the significant features of the input (e.g., pixels for images) that affect the model's prediction. If the model makes an incorrect prediction, we might want to determine which features contributed to the mistake. To do this, we can generate a feature importance mask, which is a grayscale image with the same size as the original image. The brightness of each pixel in the mask represents the importance of that feature to the model's prediction. There are various methods to calculate an image sensitivity mask for a specific prediction. One simple way is to use the gradient of a class prediction neuron concerning the input pixels, indicating how the prediction is affected by small pixel changes. However, this method usually produces a noisy mask. To reduce the noise, the SmoothGrad technique as described in [SmoothGrad: Removing noise by adding noise](https://arxiv.org/abs/1706.03825) by Daniel _et al_ is used, which adds Gaussian noise to multiple copies of the image and averages the resulting gradients. """) instruction_text = """Users need to input the model(s), type of image set and image set setting to use this functionality. 1. Choose model: Users can choose one or more models for comparison. There are 3 models supported: [ConvNeXt](https://huggingface.co/facebook/convnext-tiny-224), [ResNet](https://huggingface.co/microsoft/resnet-50) and [MobileNet](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/). These 3 models have similar number of parameters. 2. Choose type of Image set: There are 2 types of Image set. They are _User-defined set_ and _Random set_. 3. Image set setting: If users choose _User-defined set_ in Image set, users need to enter a list of image IDs separated by commas (,). For example, `0,1,4,7` is a valid input. Check the page [ImageNet1k](/ImageNet1k) to see all the Image IDs. If users choose _Random set_ in Image set, users just need to choose the number of random images to display here. """ with st.expander("See more instruction", expanded=False): st.write(instruction_text) imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv') # --------------------------- LOAD function ----------------------------- images = [] image_ids = [] # INPUT ------------------------------ st.header('Input') with st.form('smooth_grad_form'): st.markdown('**Model and Input Setting**') selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet']) selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set']) summit_button = st.form_submit_button('Set') if summit_button: setting_container = st.container() # for id in image_ids: # images = load_images(image_ids) with st.form('2nd_form'): st.markdown('**Image set setting**') if selected_image_set == 'Random set': no_images = st.slider('Number of images', 1, 50, value=10) image_ids = random.sample(list(range(50_000)), k=no_images) else: text = st.text_area('Specific Image IDs', value='0') image_ids = list(map(lambda x: int(x.strip()), text.split(','))) run_button = st.form_submit_button('Display output') if run_button: for id in image_ids: images = load_images(image_ids) st.header('Output') models = {} feature_extractors = {} for i, model_name in enumerate(selected_models): models[model_name], feature_extractors[model_name] = load_model(model_name) # DISPLAY ---------------------------------- if run_button: header_cols = st.columns([1, 1] + [2]*len(selected_models)) header_cols[0].markdown(f'