import gradio as gr import matplotlib.pyplot as plt import numpy as np import timm import torch from timm import create_model from timm.models.layers import PatchEmbed from torchvision.models.feature_extraction import create_feature_extractor from torchvision.transforms import functional as F import glob CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval() TRANSFORM = timm.data.create_transform( **timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg) ) PATCH_SIZE = 16 def create_attn_extractor(block_id=0): """Creates a model that produces the softmax attention scores. References: https://github.com/huggingface/pytorch-image-models/discussions/926 """ feature_extractor = create_feature_extractor( CAIT_MODEL, return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"], tracer_kwargs={"leaf_modules": [PatchEmbed]}, ) return feature_extractor def get_cls_attention_map( image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax" ): """Prepares attention maps so that they can be visualized.""" w_featmap = image.shape[3] // PATCH_SIZE h_featmap = image.shape[2] // PATCH_SIZE attention_scores = attn_score_dict[block_key] nh = attention_scores.shape[1] # Number of attention heads. # Taking the representations from CLS token. attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1) print(attentions.shape) # Reshape the attention scores to resemble mini patches. attentions = attentions.reshape(nh, w_featmap, h_featmap) print(attentions.shape) # Resize the attention patches to 224x224 (224: 14x16) attentions = F.resize( attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE), interpolation=3, ) print(attentions.shape) return attentions def generate_plot(processed_map): """Generates a class attention map plot.""" fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13)) img_count = 0 for i in range(processed_map.shape[0]): if img_count < processed_map.shape[0]: axes[i].imshow(processed_map[img_count].numpy()) axes[i].title.set_text(f"Attention head: {img_count}") axes[i].axis("off") img_count += 1 fig.tight_layout() return fig def serialize_images(processed_map): """Serializes attention maps.""" for i in range(processed_map.shape[0]): plt.imshow(processed_map[i].numpy()) plt.title(f"Attention head: {i}", fontsize=14) plt.savefig(fname="attention_map_{i}.png") def generate_class_attn_map(image, block_id=0): """Collates the above utilities together for generating a class attention map.""" image_tensor = TRANSFORM(image).unsqueeze(0) feature_extractor = create_attn_extractor(block_id) with torch.no_grad(): out = feature_extractor(image_tensor) block_key = f"blocks_token_only.{block_id}.attn.softmax" processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key) serialize_images(processed_cls_attn_map) all_attn_img_paths = sorted(glob.glob("attention_map_*.png")) return all_attn_img_paths title = "Class Attention Maps" article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)." iface = gr.Interface( generate_class_attn_map, inputs=[ gr.inputs.Image(type="pil", label="Input Image"), gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"), ], outputs=gr.Gallery().style(grid=[2], height="auto"), title=title, article=article, allow_flagging="never", cache_examples=True, examples=[["./bird.png", 0]], ) iface.launch()