Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import timm | |
import torch | |
from timm import create_model | |
from timm.models.layers import PatchEmbed | |
from torchvision.models.feature_extraction import create_feature_extractor | |
from torchvision.transforms import functional as F | |
import glob | |
CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval() | |
TRANSFORM = timm.data.create_transform( | |
**timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg) | |
) | |
PATCH_SIZE = 16 | |
def create_attn_extractor(block_id=0): | |
"""Creates a model that produces the softmax attention scores. | |
References: | |
https://github.com/huggingface/pytorch-image-models/discussions/926 | |
""" | |
feature_extractor = create_feature_extractor( | |
CAIT_MODEL, | |
return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"], | |
tracer_kwargs={"leaf_modules": [PatchEmbed]}, | |
) | |
return feature_extractor | |
def get_cls_attention_map( | |
image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax" | |
): | |
"""Prepares attention maps so that they can be visualized.""" | |
w_featmap = image.shape[3] // PATCH_SIZE | |
h_featmap = image.shape[2] // PATCH_SIZE | |
attention_scores = attn_score_dict[block_key] | |
nh = attention_scores.shape[1] # Number of attention heads. | |
# Taking the representations from CLS token. | |
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1) | |
print(attentions.shape) | |
# Reshape the attention scores to resemble mini patches. | |
attentions = attentions.reshape(nh, w_featmap, h_featmap) | |
print(attentions.shape) | |
# Resize the attention patches to 224x224 (224: 14x16) | |
attentions = F.resize( | |
attentions, | |
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE), | |
interpolation=3, | |
) | |
print(attentions.shape) | |
return attentions | |
def generate_plot(processed_map): | |
"""Generates a class attention map plot.""" | |
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13)) | |
img_count = 0 | |
for i in range(processed_map.shape[0]): | |
if img_count < processed_map.shape[0]: | |
axes[i].imshow(processed_map[img_count].numpy()) | |
axes[i].title.set_text(f"Attention head: {img_count}") | |
axes[i].axis("off") | |
img_count += 1 | |
fig.tight_layout() | |
return fig | |
def serialize_images(processed_map): | |
"""Serializes attention maps.""" | |
for i in range(processed_map.shape[0]): | |
plt.imshow(processed_map[i].numpy()) | |
plt.tile(f"Attention head: {i}") | |
plt.imsave(fname="attention_map_{i}.png") | |
def generate_class_attn_map(image, block_id=0): | |
"""Collates the above utilities together for generating | |
a class attention map.""" | |
image_tensor = TRANSFORM(image).unsqueeze(0) | |
feature_extractor = create_attn_extractor(block_id) | |
with torch.no_grad(): | |
out = feature_extractor(image_tensor) | |
block_key = f"blocks_token_only.{block_id}.attn.softmax" | |
processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key) | |
serialize_images(processed_cls_attn_map) | |
all_attn_img_paths = sorted(glob.glob("attention_map_*.png")) | |
return all_attn_img_paths | |
title = "Class Attention Maps" | |
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)." | |
iface = gr.Interface( | |
generate_class_attn_map, | |
inputs=[ | |
gr.inputs.Image(type="pil", label="Input Image"), | |
gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"), | |
], | |
outputs=gr.Gallery().style(grid=[2], height="auto"), | |
title=title, | |
article=article, | |
allow_flagging="never", | |
cache_examples=True, | |
examples=[["./bird.png", 0]], | |
) | |
iface.launch() | |