import gradio as gr from huggingface_hub.keras_mixin import from_pretrained_keras from PIL import Image import utils _MODEL = from_pretrained_keras("probing-vits/vit_b16_patch16_224_i21k_i1k") def show_rollout(image): _, preprocessed_image = utils.preprocess_image(image, "original_vit") _, attention_scores_dict = _MODEL.predict(preprocessed_image) result = utils.attention_rollout_map( image, attention_scores_dict, "original_vit" ) return Image.fromarray(result) title = "Generate Attention Rollout Plots" article = "Attention Rollout was proposed by [Abnar et al.](https://arxiv.org/abs/2005.00928) to quantify the information that flows through self-attention layers. In the original ViT paper ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)), the authors use it to investigate the representations learned by ViTs. The model used in the backend is a ViT B-16 model. For more details about it, refer to [this notebook](https://github.com/sayakpaul/probing-vits/blob/main/notebooks/load-jax-weights-vitb16.ipynb)." iface = gr.Interface( show_rollout, gr.inputs.Image(type="pil", label="Input Image"), "image", title=title, article=article, allow_flagging="never", examples=[["alligator.jpg", "bulbul.png"]] ) iface.launch()