sayakpaul HF staff commited on
Commit
98e527d
1 Parent(s): 8b8dc74

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +62 -0
utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from PIL import Image
6
+ from tensorflow import keras
7
+
8
+ RESOLUTION = 224
9
+ PATCH_SIZE = 16
10
+
11
+
12
+ crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
13
+ norm_layer = keras.layers.Normalization(
14
+ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
15
+ variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
16
+ )
17
+
18
+
19
+ def preprocess_image(orig_image: Image, size: int):
20
+ """Image preprocessing utility."""
21
+ image = np.array(orig_image)
22
+ image_resized = tf.expand_dims(image, 0)
23
+ resize_size = int((256 / 224) * size)
24
+ image_resized = tf.image.resize(
25
+ image_resized, (resize_size, resize_size), method="bicubic"
26
+ )
27
+ image_resized = crop_layer(image_resized)
28
+ return orig_image, norm_layer(image_resized).numpy()
29
+
30
+
31
+ # Reference:
32
+ # https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
33
+
34
+
35
+ def get_cls_attention_map(
36
+ preprocessed_image: np.ndarray,
37
+ attn_score_dict: Dict[str, np.ndarray],
38
+ block_key="ca_ffn_block_0_att",
39
+ ):
40
+ """Utility to generate class-attention map modeling spatial-class relationships."""
41
+ w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
42
+ h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
43
+
44
+ attention_scores = attn_score_dict[block_key]
45
+ nh = attention_scores.shape[1] # Number of attention heads.
46
+
47
+ # Taking the representations from CLS token.
48
+ attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
49
+
50
+ # Reshape the attention scores to resemble mini patches.
51
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
52
+
53
+ attentions = attentions.transpose((1, 2, 0))
54
+
55
+ # Resize the attention patches to 224x224 (224: 14x16)
56
+ attentions = tf.image.resize(
57
+ attentions,
58
+ size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
59
+ method="bicubic",
60
+ )
61
+
62
+ return attentions.numpy()