Spaces:
Runtime error
Runtime error
File size: 2,118 Bytes
0dd57bb ced60bc 806234a ced60bc 0dd57bb ced60bc 0dd57bb ced60bc 0dd57bb ced60bc f03296b 3c1eaee f03296b 3c1eaee f03296b 3c1eaee f03296b 3c1eaee f03296b ced60bc 1c951f1 ced60bc 12e48bb ced60bc 12e48bb ced60bc f03296b ced60bc f03296b ced60bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from huggingface_hub.keras_mixin import from_pretrained_keras
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
import utils
_RESOLUTION = 224
def get_model() -> tf.keras.Model:
"""Initiates a tf.keras.Model from HF Hub."""
inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")
logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)
return tf.keras.Model(
inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
)
_MODEL = get_model()
def plot(attentions: np.ndarray):
"""Plots the attention maps from individual attention heads."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(attentions.shape[-1]):
if img_count < attentions.shape[-1]:
axes[i].imshow(attentions[:, :, img_count])
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def show_plot(image):
"""Function to be called when user hits submit on the UI."""
_, preprocessed_image = utils.preprocess_image(
image, _RESOLUTION
)
_, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
result_first_block = utils.get_cls_attention_map(
preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
)
result_second_block = utils.get_cls_attention_map(
preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
)
return plot(result_first_block), plot(result_second_block)
title = "Generate Class Attention Plots"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
iface = gr.Interface(
show_plot,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs="plot",
title=title,
article=article,
allow_flagging="never",
examples=[["./butterfly.jpg"]],
)
iface.launch()
|