sayakpaul's picture
sayakpaul HF staff
Upload utils.py
98e527d
raw
history blame
1.9 kB
from typing import Dict
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow import keras
RESOLUTION = 224
PATCH_SIZE = 16
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
def preprocess_image(orig_image: Image, size: int):
"""Image preprocessing utility."""
image = np.array(orig_image)
image_resized = tf.expand_dims(image, 0)
resize_size = int((256 / 224) * size)
image_resized = tf.image.resize(
image_resized, (resize_size, resize_size), method="bicubic"
)
image_resized = crop_layer(image_resized)
return orig_image, norm_layer(image_resized).numpy()
# Reference:
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
def get_cls_attention_map(
preprocessed_image: np.ndarray,
attn_score_dict: Dict[str, np.ndarray],
block_key="ca_ffn_block_0_att",
):
"""Utility to generate class-attention map modeling spatial-class relationships."""
w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = attentions.transpose((1, 2, 0))
# Resize the attention patches to 224x224 (224: 14x16)
attentions = tf.image.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
method="bicubic",
)
return attentions.numpy()