attention-rollout / utils.py
sayakpaul's picture
sayakpaul HF staff
Upload utils.py
2a9ec74
raw
history blame
2.99 kB
from typing import Dict
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow import keras
RESOLUTION = 224
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
def preprocess_image(orig_image: Image, model_type: str, size=RESOLUTION):
"""Image preprocessing utility."""
# Turn the image into a numpy array and add batch dim.
image = np.array(orig_image)
image = tf.expand_dims(image, 0)
# If model type is vit rescale the image to [-1, 1].
if model_type == "original_vit":
image = rescale_layer(image)
# Resize the image using bicubic interpolation.
resize_size = int((256 / 224) * size)
image = tf.image.resize(image, (resize_size, resize_size), method="bicubic")
# Crop the image.
preprocessed_image = crop_layer(image)
# If model type is DeiT or DINO normalize the image.
if model_type != "original_vit":
image = norm_layer(preprocessed_image)
return orig_image, preprocessed_image.numpy()
def attention_rollout_map(
image: Image, attention_score_dict: Dict[str, np.ndarray], model_type: str
):
"""Computes attention rollout results.
Reference:
https://arxiv.org/abs/2005.00928
Code copied and modified from here:
https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb
"""
num_cls_tokens = 2 if "distilled" in model_type else 1
# Stack the individual attention matrices from individual transformer blocks.
attn_mat = tf.stack(
[attention_score_dict[k] for k in attention_score_dict.keys()]
)
attn_mat = tf.squeeze(attn_mat, axis=1)
# Average the attention weights across all heads.
attn_mat = tf.reduce_mean(attn_mat, axis=1)
# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_attn = tf.eye(attn_mat.shape[1])
aug_attn_mat = attn_mat + residual_attn
aug_attn_mat = (
aug_attn_mat / tf.reduce_sum(aug_attn_mat, axis=-1)[..., None]
)
aug_attn_mat = aug_attn_mat.numpy()
# Recursively multiply the weight matrices.
joint_attentions = np.zeros(aug_attn_mat.shape)
joint_attentions[0] = aug_attn_mat[0]
for n in range(1, aug_attn_mat.shape[0]):
joint_attentions[n] = np.matmul(
aug_attn_mat[n], joint_attentions[n - 1]
)
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
result = (mask * image).astype("uint8")
return result