Spaces:
Runtime error
Runtime error
File size: 4,159 Bytes
ccecbb2 104a2dd 2a6b8e9 104a2dd dad7fe2 104a2dd dad7fe2 104a2dd aac57a1 104a2dd dad7fe2 104a2dd 045d37d 104a2dd dad7fe2 104a2dd dad7fe2 104a2dd ccecbb2 aac57a1 4191f84 aac57a1 0e614b3 83c742a 008e1c0 aac57a1 104a2dd dad7fe2 104a2dd ccecbb2 aac57a1 4191f84 83c742a aac57a1 104a2dd ccecbb2 104a2dd ccecbb2 104a2dd 83c742a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import glob
import gradio as gr
import matplotlib.pyplot as plt
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F
CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
**timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)
PATCH_SIZE = 16
def create_attn_extractor(block_id=0):
"""Creates a model that produces the softmax attention scores.
References:
https://github.com/huggingface/pytorch-image-models/discussions/926
"""
feature_extractor = create_feature_extractor(
CAIT_MODEL,
return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
tracer_kwargs={"leaf_modules": [PatchEmbed]},
)
return feature_extractor
def get_cls_attention_map(
image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
"""Prepares attention maps so that they can be visualized."""
w_featmap = image.shape[3] // PATCH_SIZE
h_featmap = image.shape[2] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
print(attentions.shape)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
print(attentions.shape)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = F.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
interpolation=3,
)
print(attentions.shape)
return attentions
def generate_plot(processed_map):
"""Generates a class attention map plot."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(processed_map.shape[0]):
if img_count < processed_map.shape[0]:
axes[i].imshow(processed_map[img_count].numpy())
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def serialize_images(processed_map):
"""Serializes attention maps."""
print(f"Number of maps: {processed_map.shape[0]}")
for i in range(processed_map.shape[0]):
plt.imshow(processed_map[i].numpy())
plt.title(f"Attention head: {i}", fontsize=14)
plt.axis("off")
plt.savefig(fname=f"attention_map_{i}.png")
def generate_class_attn_map(image, block_id=0):
"""Collates the above utilities together for generating
a class attention map."""
image_tensor = TRANSFORM(image).unsqueeze(0)
feature_extractor = create_attn_extractor(block_id)
with torch.no_grad():
out = feature_extractor(image_tensor)
block_key = f"blocks_token_only.{block_id}.attn.softmax"
processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
serialize_images(processed_cls_attn_map)
all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
print(f"Number of images: {len(all_attn_img_paths)}")
return all_attn_img_paths
title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."
iface = gr.Interface(
generate_class_attn_map,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
],
outputs=gr.Gallery().style(columns=2, height="auto", object_fit="scale-down"),
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[["./bird.png", 0]],
)
iface.launch(debug=True)
|