File size: 7,112 Bytes
34fbab0 4a57585 34fbab0 d69ba73 34fbab0 9822ff6 34fbab0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# --------------------------------------------------------
# This file is a modified version of https://github.com/ma-xu/Context-Cluster/blob/main/cluster_visualize.py
# It is modified in order to make it compatible with Gradio.
# --------------------------------------------------------
import models
import timm
import os
import torch
import argparse
import cv2
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as TransF
from torchvision import transforms
from einops import rearrange
import random
from timm.models import load_checkpoint
from torchvision.utils import draw_segmentation_masks
object_categories = []
with open("imagenet1k_id_to_label.txt", "r") as f:
for line in f:
_, val = line.strip().split(":")
object_categories.append(val)
class PredictionArgs:
def __init__(self,
model,
checkpoint,
image,
shape=224,
stage=0,
block=0,
head=1,
resize_img=False,
alpha=0.5):
"""
This class contains all the arguments required for model prediction.
Args:
model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'.
checkpoint: `str` denoting the path of model checkpoint.
image: `np.array` denoting the path of image.
shape: `int` denoting the dimension of square image.
stage: `int` denoting index of visualized stage, 0-3.
block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1.
head: `int` denoting index of visualized head, 0-3 or 0-7.
resize_img: Boolean denoting whether to resize img to feature-map size.
alpha: `float` denoting transparency, 0-1.
"""
self.model = model
self.checkpoint = checkpoint
self.image = image
self.shape = shape
self.stage = stage
self.block = block
self.head = head
self.resize_img = resize_img
self.alpha = alpha
assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()"
# Preprocessing
def _preprocess(raw_image):
raw_image = cv2.resize(raw_image, (224,) * 2)
image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)(raw_image[..., ::-1].copy())
return image, raw_image
def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
"""
return pair-wise similarity matrix between two tensors
:param x1: [B,M,D]
:param x2: [B,N,D]
:return: similarity matrix [B,M,N]
"""
x1 = F.normalize(x1, dim=-1)
x2 = F.normalize(x2, dim=-1)
sim = torch.matmul(x1, x2.permute(0, 2, 1))
return sim
# forward hook function
def get_attention_score(self, input, output):
x = input[0] # input tensor in a tuple
value = self.v(x)
x = self.f(x)
x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
if self.fold_w > 1 and self.fold_h > 1:
b0, c0, w0, h0 = x.shape
assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
b, c, w, h = x.shape
centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
b, c, ww, hh = centers.shape
sim = torch.sigmoid(self.sim_beta +
self.sim_alpha * pairwise_cos_sim(
centers.reshape(b, c, -1).permute(0, 2, 1),
x.reshape(b, c, -1).permute(0, 2,1)
)
) # [B,M,N]
# sololy assign each point to one center
sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
mask = torch.zeros_like(sim) # binary #[B,M,N]
mask.scatter_(1, sim_max_idx, 1.) # binary #[B,M,N]
# changed, for plotting mask.
mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) # [(head*fold*fold),m, w,h]
mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h",
h0=self.heads, f1=self.fold_w, f2=self.fold_h) # [head, (fold*fold),m, w,h]
mask_list = []
for i in range(self.fold_w):
for j in range(self.fold_h):
for k in range(mask.shape[2]):
temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h)
temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :]
mask_list.append(temp.unsqueeze(dim=0)) # [1, heads, w, h]
mask2 = torch.concat(mask_list, dim=0) # [ n, heads, w, h]
global attention
attention = mask2.detach()
def generate_visualization(args):
global attention
image, raw_image = _preprocess(args.image)
image = image.unsqueeze(dim=0)
model = timm.create_model(model_name=args.model, pretrained=True)
if args.checkpoint:
load_checkpoint(model, args.checkpoint, True)
print(f"\n\n==> Loaded checkpoint")
else:
raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint))
print(f"\n\n==> NO checkpoint is loaded")
model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score)
out = model(image)
if type(out) is tuple:
out = out[0]
possibility = torch.softmax(out, dim=1).max() * 100
possibility = "{:.3f}".format(possibility)
value, index = torch.max(out, dim=1)
from torchvision.io import read_image
img = torch.tensor(raw_image).permute(2, 0, 1)
# process the attention map
attention = attention[:, args.head, :, :]
mask = attention.unsqueeze(dim=0)
mask = F.interpolate(mask, (img.shape[-2], img.shape[-1]))
mask = mask.squeeze(dim=0)
mask = mask > 0.5
# randomly selected some good colors.
colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue",
"white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"]
if mask.shape[0] == 4:
colors = colors[0:4]
if mask.shape[0] > 4:
colors = colors * (mask.shape[0] // 16)
random.seed(123)
random.shuffle(colors)
img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors)
img_with_masks = img_with_masks.detach()
img_with_masks = TransF.to_pil_image(img_with_masks)
img_with_masks = np.asarray(img_with_masks)
return img_with_masks, possibility
|