ariG23498 commited on
Commit
e1c2e43
1 Parent(s): 1f029d6

chore: add utils

Browse files
Files changed (1) hide show
  1. utils.py +48 -1
utils.py CHANGED
@@ -1,14 +1,17 @@
1
  # import the necessary packages
2
  import tensorflow as tf
 
3
  from tensorflow.keras import layers
4
 
5
  from PIL import Image
6
  from io import BytesIO
7
  import requests
8
  import numpy as np
 
9
 
10
 
11
  RESOLUTION = 224
 
12
 
13
  crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
14
  norm_layer = layers.Normalization(
@@ -50,4 +53,48 @@ def load_image_from_url(url, model_type):
50
  response = requests.get(url)
51
  image = Image.open(BytesIO(response.content))
52
  preprocessed_image = preprocess_image(image, model_type)
53
- return image, preprocessed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # import the necessary packages
2
  import tensorflow as tf
3
+ from tensorflow import keras
4
  from tensorflow.keras import layers
5
 
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
9
  import numpy as np
10
+ from matplotlib import pyplot as plt
11
 
12
 
13
  RESOLUTION = 224
14
+ PATCH_SIZE = 16
15
 
16
  crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
17
  norm_layer = layers.Normalization(
 
53
  response = requests.get(url)
54
  image = Image.open(BytesIO(response.content))
55
  preprocessed_image = preprocess_image(image, model_type)
56
+ return image, preprocessed_image
57
+
58
+
59
+ def attention_heatmap(attention_score_dict, image, model_type="dino", num_heads=12):
60
+ num_tokens = 2 if "distilled" in model_type else 1
61
+
62
+ # Sort the transformer blocks in order of their depth.
63
+ attention_score_list = list(attention_score_dict.keys())
64
+ attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)
65
+
66
+ # Process the attention maps for overlay.
67
+ w_featmap = image.shape[2] // PATCH_SIZE
68
+ h_featmap = image.shape[1] // PATCH_SIZE
69
+ attention_scores = attention_score_dict[attention_score_list[0]]
70
+
71
+ # Taking the representations from CLS token.
72
+ attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)
73
+
74
+ # Reshape the attention scores to resemble mini patches.
75
+ attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
76
+ attentions = attentions.transpose((1, 2, 0))
77
+
78
+ # Resize the attention patches to 224x224 (224: 14x16).
79
+ attentions = tf.image.resize(attentions, size=(
80
+ h_featmap * PATCH_SIZE,
81
+ w_featmap * PATCH_SIZE)
82
+ )
83
+ return attentions
84
+
85
+
86
+ def plot(attentions, image):
87
+ fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
88
+ img_count = 0
89
+
90
+ for i in range(3):
91
+ for j in range(4):
92
+ if img_count < len(attentions):
93
+ axes[i, j].imshow(image[0])
94
+ axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
95
+ axes[i, j].title.set_text(f"Attention head: {img_count}")
96
+ axes[i, j].axis("off")
97
+ img_count += 1
98
+
99
+ plt.tight_layout()
100
+ plt.savefig("heat_map.png")