File size: 4,159 Bytes
ccecbb2
 
104a2dd
 
 
2a6b8e9
104a2dd
 
 
 
 
dad7fe2
 
 
104a2dd
 
dad7fe2
104a2dd
 
aac57a1
104a2dd
 
 
 
 
dad7fe2
104a2dd
 
 
 
 
 
 
045d37d
104a2dd
 
dad7fe2
 
104a2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dad7fe2
104a2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccecbb2
aac57a1
 
4191f84
aac57a1
 
0e614b3
83c742a
008e1c0
aac57a1
104a2dd
 
 
 
dad7fe2
 
104a2dd
 
 
 
 
 
ccecbb2
aac57a1
4191f84
83c742a
aac57a1
104a2dd
 
 
 
 
 
 
 
ccecbb2
104a2dd
 
ccecbb2
104a2dd
 
 
 
 
 
83c742a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import glob

import gradio as gr
import matplotlib.pyplot as plt
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F

CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
    **timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)

PATCH_SIZE = 16


def create_attn_extractor(block_id=0):
    """Creates a model that produces the softmax attention scores.
    References:
        https://github.com/huggingface/pytorch-image-models/discussions/926
    """
    feature_extractor = create_feature_extractor(
        CAIT_MODEL,
        return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
        tracer_kwargs={"leaf_modules": [PatchEmbed]},
    )
    return feature_extractor


def get_cls_attention_map(
    image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
    """Prepares attention maps so that they can be visualized."""
    w_featmap = image.shape[3] // PATCH_SIZE
    h_featmap = image.shape[2] // PATCH_SIZE

    attention_scores = attn_score_dict[block_key]
    nh = attention_scores.shape[1]  # Number of attention heads.

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
    print(attentions.shape)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = F.resize(
        attentions,
        size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
        interpolation=3,
    )
    print(attentions.shape)

    return attentions


def generate_plot(processed_map):
    """Generates a class attention map plot."""
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
    img_count = 0

    for i in range(processed_map.shape[0]):
        if img_count < processed_map.shape[0]:
            axes[i].imshow(processed_map[img_count].numpy())
            axes[i].title.set_text(f"Attention head: {img_count}")
            axes[i].axis("off")
            img_count += 1

    fig.tight_layout()
    return fig


def serialize_images(processed_map):
    """Serializes attention maps."""
    print(f"Number of maps: {processed_map.shape[0]}")
    for i in range(processed_map.shape[0]):
        plt.imshow(processed_map[i].numpy())
        plt.title(f"Attention head: {i}", fontsize=14)
        plt.axis("off")
        plt.savefig(fname=f"attention_map_{i}.png")


def generate_class_attn_map(image, block_id=0):
    """Collates the above utilities together for generating
    a class attention map."""
    image_tensor = TRANSFORM(image).unsqueeze(0)
    feature_extractor = create_attn_extractor(block_id)

    with torch.no_grad():
        out = feature_extractor(image_tensor)

    block_key = f"blocks_token_only.{block_id}.attn.softmax"
    processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)

    serialize_images(processed_cls_attn_map)
    all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
    print(f"Number of images: {len(all_attn_img_paths)}")
    return all_attn_img_paths


title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."

iface = gr.Interface(
    generate_class_attn_map,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
    ],
    outputs=gr.Gallery().style(columns=2, height="auto", object_fit="scale-down"),
    title=title,
    article=article,
    allow_flagging="never",
    cache_examples=True,
    examples=[["./bird.png", 0]],
)
iface.launch(debug=True)