class-saliency / app.py
sayakpaul's picture
sayakpaul HF staff
Update app.py
4a0ced2
raw
history blame
No virus
1.72 kB
import gradio as gr
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from huggingface_hub.keras_mixin import from_pretrained_keras
from PIL import Image
import utils
_RESOLUTION = 224
def get_model() -> tf.keras.Model:
"""Initiates a tf.keras.Model from HF Hub."""
inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")
logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)
return tf.keras.Model(
inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
)
_MODEL = get_model()
def show_plot(image):
"""Function to be called when user hits submit on the UI."""
original_image, preprocessed_image = utils.preprocess_image(
image, _RESOLUTION
)
_, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
# Compute the saliency map and superimpose.
saliency_attention = utils.get_cls_attention_map(
preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
)
fig = plt.figure()
plt.imshow(original_image.astype("int32"))
plt.imshow(saliency_attention.squeeze(), cmap="cividis", alpha=0.9)
plt.axis("off")
return fig
title = "Generate Class Saliency Plots"
article = "Class saliency maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
iface = gr.Interface(
show_plot,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=gr.outputs.Plot(type="auto"),
title=title,
article=article,
allow_flagging="never",
examples=[["./butterfly.jpg"]],
)
iface.launch(debug=True)