File size: 2,216 Bytes
0dd57bb
c9bbc63
 
ced60bc
806234a
ced60bc
 
 
 
 
 
 
 
 
0dd57bb
ced60bc
0dd57bb
ced60bc
0dd57bb
ced60bc
 
 
 
 
 
 
 
 
f03296b
3c1eaee
a66a0af
3c1eaee
f03296b
3c1eaee
 
 
 
 
 
f03296b
3c1eaee
 
f03296b
 
ced60bc
 
 
1c951f1
ced60bc
 
 
 
12e48bb
ced60bc
 
12e48bb
ced60bc
f03296b
ced60bc
 
 
 
 
 
 
 
0959f35
ced60bc
 
 
 
 
c25735a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from huggingface_hub.keras_mixin import from_pretrained_keras

import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image

import utils

_RESOLUTION = 224


def get_model() -> tf.keras.Model:
    """Initiates a tf.keras.Model from HF Hub."""
    inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
    hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")

    logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)

    return tf.keras.Model(
        inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
    )


_MODEL = get_model()


def plot(attentions: np.ndarray):
    """Plots the attention maps from individual attention heads."""
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13)) 
    img_count = 0

    for i in range(attentions.shape[-1]):
        if img_count < attentions.shape[-1]:
            axes[i].imshow(attentions[:, :, img_count])
            axes[i].title.set_text(f"Attention head: {img_count}")
            axes[i].axis("off")
            img_count += 1

    fig.tight_layout()
    return fig


def show_plot(image):
    """Function to be called when user hits submit on the UI."""
    _, preprocessed_image = utils.preprocess_image(
        image, _RESOLUTION
    )
    _, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)

    result_first_block = utils.get_cls_attention_map(
        preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
    )
    result_second_block = utils.get_cls_attention_map(
        preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
    )
    return plot(result_first_block), plot(result_second_block)


title = "Generate Class Attention Plots"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."

iface = gr.Interface(
    show_plot,
    inputs=gr.inputs.Image(type="pil", label="Input Image"),
    outputs=[gr.outputs.Plot(type="auto"), gr.outputs.Plot(type="auto")],
    title=title,
    article=article,
    allow_flagging="never",
    examples=[["./butterfly.jpg"]],
)
iface.launch(debug=True)