class-saliency / app.py
sayakpaul's picture
sayakpaul HF staff
Upload app.py
ed1b9ba
raw
history blame
No virus
1.84 kB
import cv2
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub.keras_mixin import from_pretrained_keras
from PIL import Image
import utils
_RESOLUTION = 224
def get_model() -> tf.keras.Model:
"""Initiates a tf.keras.Model from HF Hub."""
inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
hub_module = from_pretrained_keras(
"probing-vits/cait_xxs24_224_classification"
)
logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(
inputs, training=False
)
return tf.keras.Model(
inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
)
_MODEL = get_model()
def show_plot(image):
"""Function to be called when user hits submit on the UI."""
original_image, preprocessed_image = utils.preprocess_image(
image, _RESOLUTION
)
_, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
# Compute the saliency map and superimpose.
result_first_block = utils.get_cls_attention_map(
image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
)
heatmap = cv2.applyColorMap(
np.uint8(255 * result_first_block), cv2.COLORMAP_CIVIDIS
)
heatmap = np.float32(heatmap) / 255
original_image = original_image / 255.0
saliency_map = heatmap + original_image
saliency_map = saliency_map / np.max(saliency_map)
return Image.fromarray(saliency_map)
title = "Generate Class Saliency Plots"
article = "Class saliency maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
iface = gr.Interface(
show_plot,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs="image",
title=title,
article=article,
allow_flagging="never",
examples=[["./butterfly.jpg"]],
)
iface.launch()