class-saliency / utils.py
sayakpaul's picture
sayakpaul HF staff
Upload utils.py
3e32c41
raw
history blame
No virus
2.06 kB
from typing import Dict
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow import keras
RESOLUTION = 224
PATCH_SIZE = 16
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
def preprocess_image(orig_image: Image, size: int):
"""Image preprocessing utility."""
image = np.array(orig_image)
image_resized = tf.expand_dims(image, 0)
resize_size = int((256 / 224) * size)
image_resized = tf.image.resize(
image_resized, (resize_size, resize_size), method="bicubic"
)
image_resized = crop_layer(image_resized)
return image_resized.numpy().squeeze(), norm_layer(image_resized).numpy()
# Reference:
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
def get_cls_attention_map(
preprocessed_image: np.ndarray,
attn_score_dict: Dict[str, np.ndarray],
block_key="ca_ffn_block_0_att",
):
"""Utility to generate class saliency map modeling spatial-class relationships."""
w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = np.mean(attentions, axis=0)
attentions = (attentions - attentions.min()) / (
attentions.max() - attentions.min()
)
attentions = np.expand_dims(attentions, -1)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = tf.image.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
method="bicubic",
)
return attentions.numpy()