Upload cluster_visualize.py
Browse files- cluster_visualize.py +176 -0
cluster_visualize.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# This file is a modified version of https://github.com/ma-xu/Context-Cluster/blob/main/cluster_visualize.py
|
3 |
+
# It is modified in order to make it compatible with Gradio.
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import context_cluster.models as models
|
7 |
+
import timm
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import argparse
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torchvision.transforms.functional as TransF
|
15 |
+
from torchvision import transforms
|
16 |
+
from einops import rearrange
|
17 |
+
import random
|
18 |
+
from timm.models import load_checkpoint
|
19 |
+
from torchvision.utils import draw_segmentation_masks
|
20 |
+
|
21 |
+
object_categories = []
|
22 |
+
with open("./context_cluster/imagenet1k_id_to_label.txt", "r") as f:
|
23 |
+
for line in f:
|
24 |
+
_, val = line.strip().split(":")
|
25 |
+
object_categories.append(val)
|
26 |
+
|
27 |
+
|
28 |
+
class PredictionArgs:
|
29 |
+
def __init__(self,
|
30 |
+
model,
|
31 |
+
checkpoint,
|
32 |
+
image,
|
33 |
+
shape=224,
|
34 |
+
stage=0,
|
35 |
+
block=0,
|
36 |
+
head=1,
|
37 |
+
resize_img=False,
|
38 |
+
alpha=0.5):
|
39 |
+
"""
|
40 |
+
This class contains all the arguments required for model prediction.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'.
|
44 |
+
checkpoint: `str` denoting the path of model checkpoint.
|
45 |
+
image: `np.array` denoting the path of image.
|
46 |
+
shape: `int` denoting the dimension of square image.
|
47 |
+
stage: `int` denoting index of visualized stage, 0-3.
|
48 |
+
block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1.
|
49 |
+
head: `int` denoting index of visualized head, 0-3 or 0-7.
|
50 |
+
resize_img: Boolean denoting whether to resize img to feature-map size.
|
51 |
+
alpha: `float` denoting transparency, 0-1.
|
52 |
+
"""
|
53 |
+
self.model = model
|
54 |
+
self.checkpoint = checkpoint
|
55 |
+
self.image = image
|
56 |
+
self.shape = shape
|
57 |
+
self.stage = stage
|
58 |
+
self.block = block
|
59 |
+
self.head = head
|
60 |
+
self.resize_img = resize_img
|
61 |
+
self.alpha = alpha
|
62 |
+
assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()"
|
63 |
+
|
64 |
+
# Preprocessing
|
65 |
+
def _preprocess(raw_image):
|
66 |
+
raw_image = cv2.resize(raw_image, (224,) * 2)
|
67 |
+
image = transforms.Compose(
|
68 |
+
[
|
69 |
+
transforms.ToTensor(),
|
70 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
71 |
+
]
|
72 |
+
)(raw_image[..., ::-1].copy())
|
73 |
+
return image, raw_image
|
74 |
+
|
75 |
+
|
76 |
+
def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
|
77 |
+
"""
|
78 |
+
return pair-wise similarity matrix between two tensors
|
79 |
+
:param x1: [B,M,D]
|
80 |
+
:param x2: [B,N,D]
|
81 |
+
:return: similarity matrix [B,M,N]
|
82 |
+
"""
|
83 |
+
x1 = F.normalize(x1, dim=-1)
|
84 |
+
x2 = F.normalize(x2, dim=-1)
|
85 |
+
sim = torch.matmul(x1, x2.permute(0, 2, 1))
|
86 |
+
return sim
|
87 |
+
|
88 |
+
|
89 |
+
# forward hook function
|
90 |
+
def get_attention_score(self, input, output):
|
91 |
+
x = input[0] # input tensor in a tuple
|
92 |
+
value = self.v(x)
|
93 |
+
x = self.f(x)
|
94 |
+
x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
|
95 |
+
value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
|
96 |
+
if self.fold_w > 1 and self.fold_h > 1:
|
97 |
+
b0, c0, w0, h0 = x.shape
|
98 |
+
assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
|
99 |
+
f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
|
100 |
+
x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
|
101 |
+
f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
|
102 |
+
value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
|
103 |
+
b, c, w, h = x.shape
|
104 |
+
centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
|
105 |
+
value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
|
106 |
+
b, c, ww, hh = centers.shape
|
107 |
+
sim = torch.sigmoid(self.sim_beta +
|
108 |
+
self.sim_alpha * pairwise_cos_sim(
|
109 |
+
centers.reshape(b, c, -1).permute(0, 2, 1),
|
110 |
+
x.reshape(b, c, -1).permute(0, 2,1)
|
111 |
+
)
|
112 |
+
) # [B,M,N]
|
113 |
+
# sololy assign each point to one center
|
114 |
+
sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
|
115 |
+
mask = torch.zeros_like(sim) # binary #[B,M,N]
|
116 |
+
mask.scatter_(1, sim_max_idx, 1.) # binary #[B,M,N]
|
117 |
+
# changed, for plotting mask.
|
118 |
+
mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) # [(head*fold*fold),m, w,h]
|
119 |
+
mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h",
|
120 |
+
h0=self.heads, f1=self.fold_w, f2=self.fold_h) # [head, (fold*fold),m, w,h]
|
121 |
+
mask_list = []
|
122 |
+
for i in range(self.fold_w):
|
123 |
+
for j in range(self.fold_h):
|
124 |
+
for k in range(mask.shape[2]):
|
125 |
+
temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h)
|
126 |
+
temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :]
|
127 |
+
mask_list.append(temp.unsqueeze(dim=0)) # [1, heads, w, h]
|
128 |
+
|
129 |
+
mask2 = torch.concat(mask_list, dim=0) # [ n, heads, w, h]
|
130 |
+
global attention
|
131 |
+
attention = mask2.detach()
|
132 |
+
|
133 |
+
|
134 |
+
def generate_visualization(args):
|
135 |
+
global attention
|
136 |
+
image, raw_image = _preprocess(args.image)
|
137 |
+
image = image.unsqueeze(dim=0)
|
138 |
+
model = timm.create_model(model_name=args.model, pretrained=True)
|
139 |
+
if args.checkpoint:
|
140 |
+
load_checkpoint(model, args.checkpoint, True)
|
141 |
+
print(f"\n\n==> Loaded checkpoint")
|
142 |
+
else:
|
143 |
+
raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint))
|
144 |
+
print(f"\n\n==> NO checkpoint is loaded")
|
145 |
+
model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score)
|
146 |
+
out = model(image)
|
147 |
+
if type(out) is tuple:
|
148 |
+
out = out[0]
|
149 |
+
possibility = torch.softmax(out, dim=1).max() * 100
|
150 |
+
possibility = "{:.3f}".format(possibility)
|
151 |
+
value, index = torch.max(out, dim=1)
|
152 |
+
|
153 |
+
from torchvision.io import read_image
|
154 |
+
img = torch.tensor(raw_image).permute(2, 0, 1)
|
155 |
+
|
156 |
+
# process the attention map
|
157 |
+
attention = attention[:, args.head, :, :]
|
158 |
+
mask = attention.unsqueeze(dim=0)
|
159 |
+
mask = F.interpolate(mask, (img.shape[-2], img.shape[-1]))
|
160 |
+
mask = mask.squeeze(dim=0)
|
161 |
+
mask = mask > 0.5
|
162 |
+
# randomly selected some good colors.
|
163 |
+
colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue",
|
164 |
+
"white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"]
|
165 |
+
if mask.shape[0] == 4:
|
166 |
+
colors = colors[0:4]
|
167 |
+
if mask.shape[0] > 4:
|
168 |
+
colors = colors * (mask.shape[0] // 16)
|
169 |
+
random.seed(123)
|
170 |
+
random.shuffle(colors)
|
171 |
+
|
172 |
+
img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors)
|
173 |
+
img_with_masks = img_with_masks.detach()
|
174 |
+
img_with_masks = TransF.to_pil_image(img_with_masks)
|
175 |
+
img_with_masks = np.asarray(img_with_masks)
|
176 |
+
return img_with_masks, possibility
|