CurHarsh commited on
Commit
34fbab0
·
1 Parent(s): 5bb056f

Upload cluster_visualize.py

Browse files
Files changed (1) hide show
  1. cluster_visualize.py +176 -0
cluster_visualize.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # This file is a modified version of https://github.com/ma-xu/Context-Cluster/blob/main/cluster_visualize.py
3
+ # It is modified in order to make it compatible with Gradio.
4
+ # --------------------------------------------------------
5
+
6
+ import context_cluster.models as models
7
+ import timm
8
+ import os
9
+ import torch
10
+ import argparse
11
+ import cv2
12
+ import numpy as np
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms.functional as TransF
15
+ from torchvision import transforms
16
+ from einops import rearrange
17
+ import random
18
+ from timm.models import load_checkpoint
19
+ from torchvision.utils import draw_segmentation_masks
20
+
21
+ object_categories = []
22
+ with open("./context_cluster/imagenet1k_id_to_label.txt", "r") as f:
23
+ for line in f:
24
+ _, val = line.strip().split(":")
25
+ object_categories.append(val)
26
+
27
+
28
+ class PredictionArgs:
29
+ def __init__(self,
30
+ model,
31
+ checkpoint,
32
+ image,
33
+ shape=224,
34
+ stage=0,
35
+ block=0,
36
+ head=1,
37
+ resize_img=False,
38
+ alpha=0.5):
39
+ """
40
+ This class contains all the arguments required for model prediction.
41
+
42
+ Args:
43
+ model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'.
44
+ checkpoint: `str` denoting the path of model checkpoint.
45
+ image: `np.array` denoting the path of image.
46
+ shape: `int` denoting the dimension of square image.
47
+ stage: `int` denoting index of visualized stage, 0-3.
48
+ block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1.
49
+ head: `int` denoting index of visualized head, 0-3 or 0-7.
50
+ resize_img: Boolean denoting whether to resize img to feature-map size.
51
+ alpha: `float` denoting transparency, 0-1.
52
+ """
53
+ self.model = model
54
+ self.checkpoint = checkpoint
55
+ self.image = image
56
+ self.shape = shape
57
+ self.stage = stage
58
+ self.block = block
59
+ self.head = head
60
+ self.resize_img = resize_img
61
+ self.alpha = alpha
62
+ assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()"
63
+
64
+ # Preprocessing
65
+ def _preprocess(raw_image):
66
+ raw_image = cv2.resize(raw_image, (224,) * 2)
67
+ image = transforms.Compose(
68
+ [
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
71
+ ]
72
+ )(raw_image[..., ::-1].copy())
73
+ return image, raw_image
74
+
75
+
76
+ def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
77
+ """
78
+ return pair-wise similarity matrix between two tensors
79
+ :param x1: [B,M,D]
80
+ :param x2: [B,N,D]
81
+ :return: similarity matrix [B,M,N]
82
+ """
83
+ x1 = F.normalize(x1, dim=-1)
84
+ x2 = F.normalize(x2, dim=-1)
85
+ sim = torch.matmul(x1, x2.permute(0, 2, 1))
86
+ return sim
87
+
88
+
89
+ # forward hook function
90
+ def get_attention_score(self, input, output):
91
+ x = input[0] # input tensor in a tuple
92
+ value = self.v(x)
93
+ x = self.f(x)
94
+ x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
95
+ value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
96
+ if self.fold_w > 1 and self.fold_h > 1:
97
+ b0, c0, w0, h0 = x.shape
98
+ assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
99
+ f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
100
+ x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
101
+ f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
102
+ value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
103
+ b, c, w, h = x.shape
104
+ centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
105
+ value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
106
+ b, c, ww, hh = centers.shape
107
+ sim = torch.sigmoid(self.sim_beta +
108
+ self.sim_alpha * pairwise_cos_sim(
109
+ centers.reshape(b, c, -1).permute(0, 2, 1),
110
+ x.reshape(b, c, -1).permute(0, 2,1)
111
+ )
112
+ ) # [B,M,N]
113
+ # sololy assign each point to one center
114
+ sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
115
+ mask = torch.zeros_like(sim) # binary #[B,M,N]
116
+ mask.scatter_(1, sim_max_idx, 1.) # binary #[B,M,N]
117
+ # changed, for plotting mask.
118
+ mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) # [(head*fold*fold),m, w,h]
119
+ mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h",
120
+ h0=self.heads, f1=self.fold_w, f2=self.fold_h) # [head, (fold*fold),m, w,h]
121
+ mask_list = []
122
+ for i in range(self.fold_w):
123
+ for j in range(self.fold_h):
124
+ for k in range(mask.shape[2]):
125
+ temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h)
126
+ temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :]
127
+ mask_list.append(temp.unsqueeze(dim=0)) # [1, heads, w, h]
128
+
129
+ mask2 = torch.concat(mask_list, dim=0) # [ n, heads, w, h]
130
+ global attention
131
+ attention = mask2.detach()
132
+
133
+
134
+ def generate_visualization(args):
135
+ global attention
136
+ image, raw_image = _preprocess(args.image)
137
+ image = image.unsqueeze(dim=0)
138
+ model = timm.create_model(model_name=args.model, pretrained=True)
139
+ if args.checkpoint:
140
+ load_checkpoint(model, args.checkpoint, True)
141
+ print(f"\n\n==> Loaded checkpoint")
142
+ else:
143
+ raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint))
144
+ print(f"\n\n==> NO checkpoint is loaded")
145
+ model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score)
146
+ out = model(image)
147
+ if type(out) is tuple:
148
+ out = out[0]
149
+ possibility = torch.softmax(out, dim=1).max() * 100
150
+ possibility = "{:.3f}".format(possibility)
151
+ value, index = torch.max(out, dim=1)
152
+
153
+ from torchvision.io import read_image
154
+ img = torch.tensor(raw_image).permute(2, 0, 1)
155
+
156
+ # process the attention map
157
+ attention = attention[:, args.head, :, :]
158
+ mask = attention.unsqueeze(dim=0)
159
+ mask = F.interpolate(mask, (img.shape[-2], img.shape[-1]))
160
+ mask = mask.squeeze(dim=0)
161
+ mask = mask > 0.5
162
+ # randomly selected some good colors.
163
+ colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue",
164
+ "white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"]
165
+ if mask.shape[0] == 4:
166
+ colors = colors[0:4]
167
+ if mask.shape[0] > 4:
168
+ colors = colors * (mask.shape[0] // 16)
169
+ random.seed(123)
170
+ random.shuffle(colors)
171
+
172
+ img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors)
173
+ img_with_masks = img_with_masks.detach()
174
+ img_with_masks = TransF.to_pil_image(img_with_masks)
175
+ img_with_masks = np.asarray(img_with_masks)
176
+ return img_with_masks, possibility