sayakpaul's picture
sayakpaul HF staff
Update utils.py
0c48c4a
raw
history blame contribute delete
No virus
1.92 kB
from typing import Dict
import numpy as np
import tensorflow as tf
from PIL import Image
from tensorflow import keras
RESOLUTION = 224
PATCH_SIZE = 16
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
def preprocess_image(orig_image: Image, size: int):
"""Image preprocessing utility."""
image = np.array(orig_image)
image_resized = tf.expand_dims(image, 0)
resize_size = int((256 / 224) * size)
image_resized = tf.image.resize(
image_resized, (resize_size, resize_size), method="bicubic"
)
image_resized = crop_layer(image_resized)
return image_resized.numpy().squeeze(), norm_layer(image_resized).numpy()
# Reference:
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
def get_cls_attention_map(
preprocessed_image: np.ndarray,
attn_score_dict: Dict[str, np.ndarray],
block_key="ca_ffn_block_0_att",
):
"""Utility to generate class-attention map modeling spatial-class relationships."""
w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = attentions.transpose((1, 2, 0))
# Resize the attention patches to 224x224 (224: 14x16)
attentions = tf.image.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
method="bicubic",
)
return attentions.numpy()