sayakpaul's picture
sayakpaul HF staff
fix: output grid and caching.
ccecbb2
raw
history blame contribute delete
No virus
4.16 kB
import glob
import gradio as gr
import matplotlib.pyplot as plt
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F
CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
**timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)
PATCH_SIZE = 16
def create_attn_extractor(block_id=0):
"""Creates a model that produces the softmax attention scores.
References:
https://github.com/huggingface/pytorch-image-models/discussions/926
"""
feature_extractor = create_feature_extractor(
CAIT_MODEL,
return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
tracer_kwargs={"leaf_modules": [PatchEmbed]},
)
return feature_extractor
def get_cls_attention_map(
image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
"""Prepares attention maps so that they can be visualized."""
w_featmap = image.shape[3] // PATCH_SIZE
h_featmap = image.shape[2] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
print(attentions.shape)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
print(attentions.shape)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = F.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
interpolation=3,
)
print(attentions.shape)
return attentions
def generate_plot(processed_map):
"""Generates a class attention map plot."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(processed_map.shape[0]):
if img_count < processed_map.shape[0]:
axes[i].imshow(processed_map[img_count].numpy())
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def serialize_images(processed_map):
"""Serializes attention maps."""
print(f"Number of maps: {processed_map.shape[0]}")
for i in range(processed_map.shape[0]):
plt.imshow(processed_map[i].numpy())
plt.title(f"Attention head: {i}", fontsize=14)
plt.axis("off")
plt.savefig(fname=f"attention_map_{i}.png")
def generate_class_attn_map(image, block_id=0):
"""Collates the above utilities together for generating
a class attention map."""
image_tensor = TRANSFORM(image).unsqueeze(0)
feature_extractor = create_attn_extractor(block_id)
with torch.no_grad():
out = feature_extractor(image_tensor)
block_key = f"blocks_token_only.{block_id}.attn.softmax"
processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
serialize_images(processed_cls_attn_map)
all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
print(f"Number of images: {len(all_attn_img_paths)}")
return all_attn_img_paths
title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."
iface = gr.Interface(
generate_class_attn_map,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
],
outputs=gr.Gallery().style(columns=2, height="auto", object_fit="scale-down"),
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[["./bird.png", 0]],
)
iface.launch(debug=True)