# import the necessary packages import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from PIL import Image from io import BytesIO import requests import numpy as np from matplotlib import pyplot as plt RESOLUTION = 224 PATCH_SIZE = 16 crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION) norm_layer = layers.Normalization( mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2], ) rescale_layer = layers.Rescaling(scale=1./127.5, offset=-1) def preprocess_image(image, model_type, size=RESOLUTION): # Turn the image into a numpy array and add batch dim. image = np.array(image) image = tf.expand_dims(image, 0) # If model type is vit rescale the image to [-1, 1]. if model_type == "original_vit": image = rescale_layer(image) # Resize the image using bicubic interpolation. resize_size = int((256 / 224) * size) image = tf.image.resize( image, (resize_size, resize_size), method="bicubic" ) # Crop the image. image = crop_layer(image) # If model type is DeiT or DINO normalize the image. if model_type != "original_vit": image = norm_layer(image) return image.numpy() def load_image_from_url(url, model_type): # Credit: Willi Gierke response = requests.get(url) image = Image.open(BytesIO(response.content)) preprocessed_image = preprocess_image(image, model_type) return image, preprocessed_image def attention_heatmap(attention_score_dict, image, model_type="dino", num_heads=12): num_tokens = 2 if "distilled" in model_type else 1 # Sort the transformer blocks in order of their depth. attention_score_list = list(attention_score_dict.keys()) attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True) # Process the attention maps for overlay. w_featmap = image.shape[2] // PATCH_SIZE h_featmap = image.shape[1] // PATCH_SIZE attention_scores = attention_score_dict[attention_score_list[0]] # Taking the representations from CLS token. attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1) # Reshape the attention scores to resemble mini patches. attentions = attentions.reshape(num_heads, w_featmap, h_featmap) attentions = attentions.transpose((1, 2, 0)) # Resize the attention patches to 224x224 (224: 14x16). attentions = tf.image.resize(attentions, size=( h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE) ) return attentions def plot(attentions, image): fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13)) img_count = 0 for i in range(3): for j in range(4): if img_count < len(attentions): axes[i, j].imshow(image[0]) axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6) axes[i, j].title.set_text(f"Attention head: {img_count}") axes[i, j].axis("off") img_count += 1 plt.tight_layout() plt.savefig("heat_map.png")