Spaces:
Runtime error
Runtime error
paul hilders
commited on
Commit
•
5587e3a
1
Parent(s):
bbedbdc
Add Utils files
Browse files- CLIP_explainability/utils.py +152 -0
- clip_grounding/datasets/png.py +231 -0
- clip_grounding/datasets/png_utils.py +135 -0
- clip_grounding/evaluation/clip_on_png.py +362 -0
- clip_grounding/evaluation/qualitative_results.py +93 -0
- clip_grounding/utils/image.py +46 -0
- clip_grounding/utils/io.py +116 -0
- clip_grounding/utils/log.py +57 -0
- clip_grounding/utils/paths.py +10 -0
- clip_grounding/utils/visualize.py +183 -0
CLIP_explainability/utils.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import CLIP.clip as clip
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from captum.attr import visualization
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
12 |
+
_tokenizer = _Tokenizer()
|
13 |
+
|
14 |
+
#@title Control context expansion (number of attention layers to consider)
|
15 |
+
#@title Number of layers for image Transformer
|
16 |
+
start_layer = 11#@param {type:"number"}
|
17 |
+
|
18 |
+
#@title Number of layers for text Transformer
|
19 |
+
start_layer_text = 11#@param {type:"number"}
|
20 |
+
|
21 |
+
|
22 |
+
def interpret(image, texts, model, device):
|
23 |
+
batch_size = texts.shape[0]
|
24 |
+
images = image.repeat(batch_size, 1, 1, 1)
|
25 |
+
logits_per_image, logits_per_text = model(images, texts)
|
26 |
+
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
|
27 |
+
index = [i for i in range(batch_size)]
|
28 |
+
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
|
29 |
+
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
|
30 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
31 |
+
one_hot = torch.sum(one_hot.to(device) * logits_per_image)
|
32 |
+
model.zero_grad()
|
33 |
+
|
34 |
+
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
|
35 |
+
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
|
36 |
+
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
|
37 |
+
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
|
38 |
+
for i, blk in enumerate(image_attn_blocks):
|
39 |
+
if i < start_layer:
|
40 |
+
continue
|
41 |
+
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
|
42 |
+
cam = blk.attn_probs.detach()
|
43 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
44 |
+
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
|
45 |
+
cam = grad * cam
|
46 |
+
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
|
47 |
+
cam = cam.clamp(min=0).mean(dim=1)
|
48 |
+
R = R + torch.bmm(cam, R)
|
49 |
+
image_relevance = R[:, 0, 1:]
|
50 |
+
|
51 |
+
|
52 |
+
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
|
53 |
+
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
|
54 |
+
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
|
55 |
+
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
|
56 |
+
for i, blk in enumerate(text_attn_blocks):
|
57 |
+
if i < start_layer_text:
|
58 |
+
continue
|
59 |
+
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
|
60 |
+
cam = blk.attn_probs.detach()
|
61 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
62 |
+
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
|
63 |
+
cam = grad * cam
|
64 |
+
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
|
65 |
+
cam = cam.clamp(min=0).mean(dim=1)
|
66 |
+
R_text = R_text + torch.bmm(cam, R_text)
|
67 |
+
text_relevance = R_text
|
68 |
+
|
69 |
+
return text_relevance, image_relevance
|
70 |
+
|
71 |
+
|
72 |
+
def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
73 |
+
# create heatmap from mask on image
|
74 |
+
def show_cam_on_image(img, mask):
|
75 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
76 |
+
heatmap = np.float32(heatmap) / 255
|
77 |
+
cam = heatmap + np.float32(img)
|
78 |
+
cam = cam / np.max(cam)
|
79 |
+
return cam
|
80 |
+
|
81 |
+
# plt.axis('off')
|
82 |
+
# f, axarr = plt.subplots(1,2)
|
83 |
+
# axarr[0].imshow(orig_image)
|
84 |
+
|
85 |
+
if show:
|
86 |
+
fig, axs = plt.subplots(1, 2)
|
87 |
+
axs[0].imshow(orig_image);
|
88 |
+
axs[0].axis('off');
|
89 |
+
|
90 |
+
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
91 |
+
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
92 |
+
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
93 |
+
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
|
94 |
+
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
95 |
+
image = (image - image.min()) / (image.max() - image.min())
|
96 |
+
vis = show_cam_on_image(image, image_relevance)
|
97 |
+
vis = np.uint8(255 * vis)
|
98 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
99 |
+
|
100 |
+
if show:
|
101 |
+
# axar[1].imshow(vis)
|
102 |
+
axs[1].imshow(vis);
|
103 |
+
axs[1].axis('off');
|
104 |
+
# plt.imshow(vis)
|
105 |
+
|
106 |
+
return image_relevance
|
107 |
+
|
108 |
+
|
109 |
+
def show_heatmap_on_text(text, text_encoding, R_text, show=True):
|
110 |
+
CLS_idx = text_encoding.argmax(dim=-1)
|
111 |
+
R_text = R_text[CLS_idx, 1:CLS_idx]
|
112 |
+
text_scores = R_text / R_text.sum()
|
113 |
+
text_scores = text_scores.flatten()
|
114 |
+
# print(text_scores)
|
115 |
+
text_tokens=_tokenizer.encode(text)
|
116 |
+
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
117 |
+
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
118 |
+
|
119 |
+
if show:
|
120 |
+
visualization.visualize_text(vis_data_records)
|
121 |
+
|
122 |
+
return text_scores, text_tokens_decoded
|
123 |
+
|
124 |
+
|
125 |
+
def show_img_heatmap(image_relevance, image, orig_image, device, show=True):
|
126 |
+
return show_image_relevance(image_relevance, image, orig_image, device, show=show)
|
127 |
+
|
128 |
+
|
129 |
+
def show_txt_heatmap(text, text_encoding, R_text, show=True):
|
130 |
+
return show_heatmap_on_text(text, text_encoding, R_text, show=show)
|
131 |
+
|
132 |
+
|
133 |
+
def load_dataset():
|
134 |
+
dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt')
|
135 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
136 |
+
|
137 |
+
data = torch.load(dataset_path, map_location=device)
|
138 |
+
|
139 |
+
return data
|
140 |
+
|
141 |
+
|
142 |
+
class color:
|
143 |
+
PURPLE = '\033[95m'
|
144 |
+
CYAN = '\033[96m'
|
145 |
+
DARKCYAN = '\033[36m'
|
146 |
+
BLUE = '\033[94m'
|
147 |
+
GREEN = '\033[92m'
|
148 |
+
YELLOW = '\033[93m'
|
149 |
+
RED = '\033[91m'
|
150 |
+
BOLD = '\033[1m'
|
151 |
+
UNDERLINE = '\033[4m'
|
152 |
+
END = '\033[0m'
|
clip_grounding/datasets/png.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset object for Panoptic Narrative Grounding.
|
3 |
+
|
4 |
+
Paper: https://openaccess.thecvf.com/content/ICCV2021/papers/Gonzalez_Panoptic_Narrative_Grounding_ICCV_2021_paper.pdf
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from os.path import join, isdir, exists
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import cv2
|
13 |
+
from PIL import Image
|
14 |
+
from skimage import io
|
15 |
+
import numpy as np
|
16 |
+
import textwrap
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from matplotlib import transforms
|
19 |
+
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
20 |
+
import matplotlib.colors as mc
|
21 |
+
|
22 |
+
from clip_grounding.utils.io import load_json
|
23 |
+
from clip_grounding.datasets.png_utils import show_image_and_caption
|
24 |
+
|
25 |
+
|
26 |
+
class PNG(Dataset):
|
27 |
+
"""Panoptic Narrative Grounding."""
|
28 |
+
|
29 |
+
def __init__(self, dataset_root, split) -> None:
|
30 |
+
"""
|
31 |
+
Initializer.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dataset_root (str): path to the folder containing PNG dataset
|
35 |
+
split (str): MS-COCO split such as train2017/val2017
|
36 |
+
"""
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
assert isdir(dataset_root)
|
40 |
+
self.dataset_root = dataset_root
|
41 |
+
|
42 |
+
assert split in ["val2017"], f"Split {split} not supported. "\
|
43 |
+
"Currently, only supports split `val2017`."
|
44 |
+
self.split = split
|
45 |
+
|
46 |
+
self.ann_dir = join(self.dataset_root, "annotations")
|
47 |
+
# feat_dir = join(self.dataset_root, "features")
|
48 |
+
|
49 |
+
panoptic = load_json(join(self.ann_dir, "panoptic_{:s}.json".format(split)))
|
50 |
+
images = panoptic["images"]
|
51 |
+
self.images_info = {i["id"]: i for i in images}
|
52 |
+
panoptic_anns = panoptic["annotations"]
|
53 |
+
self.panoptic_anns = {int(a["image_id"]): a for a in panoptic_anns}
|
54 |
+
|
55 |
+
# self.panoptic_pred_path = join(
|
56 |
+
# feat_dir, split, "panoptic_seg_predictions"
|
57 |
+
# )
|
58 |
+
# assert isdir(self.panoptic_pred_path)
|
59 |
+
|
60 |
+
panoptic_narratives_path = join(self.dataset_root, "annotations", f"png_coco_{split}.json")
|
61 |
+
self.panoptic_narratives = load_json(panoptic_narratives_path)
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return len(self.panoptic_narratives)
|
65 |
+
|
66 |
+
def get_image_path(self, image_id: str):
|
67 |
+
image_path = join(self.dataset_root, "images", self.split, f"{image_id.zfill(12)}.jpg")
|
68 |
+
return image_path
|
69 |
+
|
70 |
+
def __getitem__(self, idx: int):
|
71 |
+
narr = self.panoptic_narratives[idx]
|
72 |
+
|
73 |
+
image_id = narr["image_id"]
|
74 |
+
image_path = self.get_image_path(image_id)
|
75 |
+
assert exists(image_path)
|
76 |
+
|
77 |
+
image = Image.open(image_path)
|
78 |
+
caption = narr["caption"]
|
79 |
+
|
80 |
+
# show_single_image(image, title=caption, titlesize=12)
|
81 |
+
|
82 |
+
segments = narr["segments"]
|
83 |
+
|
84 |
+
image_id = int(narr["image_id"])
|
85 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
86 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
87 |
+
segment_infos = {}
|
88 |
+
for s in panoptic_ann["segments_info"]:
|
89 |
+
idi = s["id"]
|
90 |
+
segment_infos[idi] = s
|
91 |
+
|
92 |
+
image_info = self.images_info[image_id]
|
93 |
+
panoptic_segm = io.imread(
|
94 |
+
join(
|
95 |
+
self.ann_dir,
|
96 |
+
"panoptic_segmentation",
|
97 |
+
self.split,
|
98 |
+
"{:012d}.png".format(image_id),
|
99 |
+
)
|
100 |
+
)
|
101 |
+
panoptic_segm = (
|
102 |
+
panoptic_segm[:, :, 0]
|
103 |
+
+ panoptic_segm[:, :, 1] * 256
|
104 |
+
+ panoptic_segm[:, :, 2] * 256 ** 2
|
105 |
+
)
|
106 |
+
|
107 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
108 |
+
# panoptic_pred = io.imread(
|
109 |
+
# join(self.panoptic_pred_path, "{:012d}.png".format(image_id))
|
110 |
+
# )[:, :, 0]
|
111 |
+
|
112 |
+
|
113 |
+
# # select a single utterance to visualize
|
114 |
+
# segment = segments[7]
|
115 |
+
# segment_ids = segment["segment_ids"]
|
116 |
+
# segment_mask = np.zeros((image_info["height"], image_info["width"]))
|
117 |
+
# for segment_id in segment_ids:
|
118 |
+
# segment_id = int(segment_id)
|
119 |
+
# segment_mask[panoptic_segm == segment_id] = 1.
|
120 |
+
|
121 |
+
utterances = [s["utterance"] for s in segments]
|
122 |
+
outputs = []
|
123 |
+
for i, segment in enumerate(segments):
|
124 |
+
|
125 |
+
# create segmentation mask on image
|
126 |
+
segment_ids = segment["segment_ids"]
|
127 |
+
|
128 |
+
# if no annotation for this word, skip
|
129 |
+
if not len(segment_ids):
|
130 |
+
continue
|
131 |
+
|
132 |
+
segment_mask = np.zeros((image_info["height"], image_info["width"]))
|
133 |
+
for segment_id in segment_ids:
|
134 |
+
segment_id = int(segment_id)
|
135 |
+
segment_mask[panoptic_segm == segment_id] = 1.
|
136 |
+
|
137 |
+
# store the outputs
|
138 |
+
text_mask = np.zeros(len(utterances))
|
139 |
+
text_mask[i] = 1.
|
140 |
+
segment_data = dict(
|
141 |
+
image=image,
|
142 |
+
text=utterances,
|
143 |
+
image_mask=segment_mask,
|
144 |
+
text_mask=text_mask,
|
145 |
+
full_caption=caption,
|
146 |
+
)
|
147 |
+
outputs.append(segment_data)
|
148 |
+
|
149 |
+
# # visualize segmentation mask with associated text
|
150 |
+
# segment_color = "red"
|
151 |
+
# segmap = SegmentationMapsOnImage(
|
152 |
+
# segment_mask.astype(np.uint8), shape=segment_mask.shape,
|
153 |
+
# )
|
154 |
+
# image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, COLORS[segment_color]])[0]
|
155 |
+
# image_with_segmap = Image.fromarray(image_with_segmap)
|
156 |
+
|
157 |
+
# colors = ["black" for _ in range(len(utterances))]
|
158 |
+
# colors[i] = segment_color
|
159 |
+
# show_image_and_caption(image_with_segmap, utterances, colors)
|
160 |
+
|
161 |
+
return outputs
|
162 |
+
|
163 |
+
|
164 |
+
def overlay_segmask_on_image(image, image_mask, segment_color="red"):
|
165 |
+
segmap = SegmentationMapsOnImage(
|
166 |
+
image_mask.astype(np.uint8), shape=image_mask.shape,
|
167 |
+
)
|
168 |
+
rgb_color = mc.to_rgb(segment_color)
|
169 |
+
rgb_color = 255 * np.array(rgb_color)
|
170 |
+
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
|
171 |
+
image_with_segmap = Image.fromarray(image_with_segmap)
|
172 |
+
return image_with_segmap
|
173 |
+
|
174 |
+
|
175 |
+
def get_text_colors(text, text_mask, segment_color="red"):
|
176 |
+
colors = ["black" for _ in range(len(text))]
|
177 |
+
colors[text_mask.nonzero()[0][0]] = segment_color
|
178 |
+
return colors
|
179 |
+
|
180 |
+
|
181 |
+
def overlay_relevance_map_on_image(image, heatmap):
|
182 |
+
width, height = image.size
|
183 |
+
|
184 |
+
# resize the heatmap to image size
|
185 |
+
heatmap = cv2.resize(heatmap, (width, height))
|
186 |
+
heatmap = np.uint8(255 * heatmap)
|
187 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
188 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
189 |
+
|
190 |
+
# create overlapped super image
|
191 |
+
img = np.asarray(image)
|
192 |
+
super_img = heatmap * 0.4 + img * 0.6
|
193 |
+
super_img = np.uint8(super_img)
|
194 |
+
super_img = Image.fromarray(super_img)
|
195 |
+
|
196 |
+
return super_img
|
197 |
+
|
198 |
+
|
199 |
+
def visualize_item(image, text, image_mask, text_mask, segment_color="red"):
|
200 |
+
|
201 |
+
segmap = SegmentationMapsOnImage(
|
202 |
+
image_mask.astype(np.uint8), shape=image_mask.shape,
|
203 |
+
)
|
204 |
+
rgb_color = mc.to_rgb(segment_color)
|
205 |
+
rgb_color = 255 * np.array(rgb_color)
|
206 |
+
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
|
207 |
+
image_with_segmap = Image.fromarray(image_with_segmap)
|
208 |
+
|
209 |
+
colors = ["black" for _ in range(len(text))]
|
210 |
+
|
211 |
+
text_idx = text_mask.argmax()
|
212 |
+
colors[text_idx] = segment_color
|
213 |
+
show_image_and_caption(image_with_segmap, text, colors)
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
from clip_grounding.utils.paths import REPO_PATH, DATASET_ROOTS
|
219 |
+
|
220 |
+
PNG_ROOT = DATASET_ROOTS["PNG"]
|
221 |
+
dataset = PNG(dataset_root=PNG_ROOT, split="val2017")
|
222 |
+
|
223 |
+
item = dataset[0]
|
224 |
+
sub_item = item[1]
|
225 |
+
visualize_item(
|
226 |
+
image=sub_item["image"],
|
227 |
+
text=sub_item["text"],
|
228 |
+
image_mask=sub_item["image_mask"],
|
229 |
+
text_mask=sub_item["text_mask"],
|
230 |
+
segment_color="red",
|
231 |
+
)
|
clip_grounding/datasets/png_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions for Panoptic Narrative Grounding."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from os.path import join, isdir, exists
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from skimage import io
|
10 |
+
import numpy as np
|
11 |
+
import textwrap
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from matplotlib import transforms
|
14 |
+
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
15 |
+
|
16 |
+
|
17 |
+
def rainbow_text(x,y,ls,lc,fig, ax,**kw):
|
18 |
+
"""
|
19 |
+
Take a list of strings ``ls`` and colors ``lc`` and place them next to each
|
20 |
+
other, with text ls[i] being shown in color lc[i].
|
21 |
+
|
22 |
+
Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib
|
23 |
+
"""
|
24 |
+
t = ax.transAxes
|
25 |
+
|
26 |
+
for s,c in zip(ls,lc):
|
27 |
+
|
28 |
+
text = ax.text(x,y,s+" ",color=c, transform=t, **kw)
|
29 |
+
text.draw(fig.canvas.get_renderer())
|
30 |
+
ex = text.get_window_extent()
|
31 |
+
t = transforms.offset_copy(text._transform, x=ex.width, units='dots')
|
32 |
+
|
33 |
+
|
34 |
+
def find_first_index_greater_than(elements, key):
|
35 |
+
return next(x[0] for x in enumerate(elements) if x[1] > key)
|
36 |
+
|
37 |
+
|
38 |
+
def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50):
|
39 |
+
char_lengths = np.cumsum([len(x) for x in caption_phrases])
|
40 |
+
thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)]
|
41 |
+
|
42 |
+
utt_per_line = []
|
43 |
+
col_per_line = []
|
44 |
+
start_index = 0
|
45 |
+
for t in thresholds:
|
46 |
+
index = find_first_index_greater_than(char_lengths, t)
|
47 |
+
utt_per_line.append(caption_phrases[start_index:index])
|
48 |
+
col_per_line.append(colors[start_index:index])
|
49 |
+
start_index = index
|
50 |
+
|
51 |
+
return utt_per_line, col_per_line
|
52 |
+
|
53 |
+
|
54 |
+
def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None):
|
55 |
+
|
56 |
+
if colors is None:
|
57 |
+
colors = ["black" for _ in range(len(caption_phrases))]
|
58 |
+
|
59 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
|
60 |
+
|
61 |
+
ax = axes[0]
|
62 |
+
ax.imshow(image)
|
63 |
+
ax.set_xticks([])
|
64 |
+
ax.set_yticks([])
|
65 |
+
|
66 |
+
ax = axes[1]
|
67 |
+
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50)
|
68 |
+
y = 0.7
|
69 |
+
for U, C in zip(utt_per_line, col_per_line):
|
70 |
+
rainbow_text(
|
71 |
+
0., y,
|
72 |
+
U,
|
73 |
+
C,
|
74 |
+
size=15, ax=ax, fig=fig,
|
75 |
+
horizontalalignment='left',
|
76 |
+
verticalalignment='center',
|
77 |
+
)
|
78 |
+
y -= 0.11
|
79 |
+
|
80 |
+
ax.axis("off")
|
81 |
+
|
82 |
+
fig.tight_layout()
|
83 |
+
plt.show()
|
84 |
+
|
85 |
+
|
86 |
+
def show_images_and_caption(
|
87 |
+
images: List,
|
88 |
+
caption_phrases: list,
|
89 |
+
colors: list = None,
|
90 |
+
image_xlabels: List=[],
|
91 |
+
figsize=None,
|
92 |
+
show=False,
|
93 |
+
xlabelsize=14,
|
94 |
+
):
|
95 |
+
|
96 |
+
if colors is None:
|
97 |
+
colors = ["black" for _ in range(len(caption_phrases))]
|
98 |
+
caption_phrases[0] = caption_phrases[0].capitalize()
|
99 |
+
|
100 |
+
if figsize is None:
|
101 |
+
figsize = (5 * len(images) + 8, 4)
|
102 |
+
|
103 |
+
if image_xlabels is None:
|
104 |
+
image_xlabels = ["" for _ in range(len(images))]
|
105 |
+
|
106 |
+
fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize)
|
107 |
+
|
108 |
+
for i, image in enumerate(images):
|
109 |
+
ax = axes[i]
|
110 |
+
ax.imshow(image)
|
111 |
+
ax.set_xticks([])
|
112 |
+
ax.set_yticks([])
|
113 |
+
ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize)
|
114 |
+
|
115 |
+
ax = axes[-1]
|
116 |
+
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40)
|
117 |
+
y = 0.7
|
118 |
+
for U, C in zip(utt_per_line, col_per_line):
|
119 |
+
rainbow_text(
|
120 |
+
0., y,
|
121 |
+
U,
|
122 |
+
C,
|
123 |
+
size=23, ax=ax, fig=fig,
|
124 |
+
horizontalalignment='left',
|
125 |
+
verticalalignment='center',
|
126 |
+
# weight='bold'
|
127 |
+
)
|
128 |
+
y -= 0.11
|
129 |
+
|
130 |
+
ax.axis("off")
|
131 |
+
|
132 |
+
fig.tight_layout()
|
133 |
+
|
134 |
+
if show:
|
135 |
+
plt.show()
|
clip_grounding/evaluation/clip_on_png.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Evaluates cross-modal correspondence of CLIP on PNG images."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from os.path import join, exists
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
warnings.filterwarnings('ignore')
|
9 |
+
|
10 |
+
from clip_grounding.utils.paths import REPO_PATH
|
11 |
+
sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import CLIP.clip as clip
|
15 |
+
from PIL import Image
|
16 |
+
import numpy as np
|
17 |
+
import cv2
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from captum.attr import visualization
|
20 |
+
from torchmetrics import JaccardIndex
|
21 |
+
from collections import defaultdict
|
22 |
+
from IPython.core.display import display, HTML
|
23 |
+
from skimage import filters
|
24 |
+
|
25 |
+
from CLIP_explainability.utils import interpret, show_img_heatmap, show_txt_heatmap, color, _tokenizer
|
26 |
+
from clip_grounding.datasets.png import PNG
|
27 |
+
from clip_grounding.utils.image import pad_to_square
|
28 |
+
from clip_grounding.utils.visualize import show_grid_of_images
|
29 |
+
from clip_grounding.utils.log import tqdm_iterator, print_update
|
30 |
+
|
31 |
+
|
32 |
+
# global usage
|
33 |
+
# specify device
|
34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
35 |
+
|
36 |
+
# load CLIP model
|
37 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
38 |
+
|
39 |
+
|
40 |
+
def show_cam(mask):
|
41 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
42 |
+
heatmap = np.float32(heatmap) / 255
|
43 |
+
cam = heatmap
|
44 |
+
cam = cam / np.max(cam)
|
45 |
+
return cam
|
46 |
+
|
47 |
+
|
48 |
+
def interpret_and_generate(model, img, texts, orig_image, return_outputs=False, show=True):
|
49 |
+
text = clip.tokenize(texts).to(device)
|
50 |
+
R_text, R_image = interpret(model=model, image=img, texts=text, device=device)
|
51 |
+
batch_size = text.shape[0]
|
52 |
+
|
53 |
+
outputs = []
|
54 |
+
for i in range(batch_size):
|
55 |
+
text_scores, text_tokens_decoded = show_txt_heatmap(texts[i], text[i], R_text[i], show=show)
|
56 |
+
image_relevance = show_img_heatmap(R_image[i], img, orig_image=orig_image, device=device, show=show)
|
57 |
+
plt.show()
|
58 |
+
outputs.append({"text_scores": text_scores, "image_relevance": image_relevance, "tokens_decoded": text_tokens_decoded})
|
59 |
+
|
60 |
+
if return_outputs:
|
61 |
+
return outputs
|
62 |
+
|
63 |
+
|
64 |
+
def process_entry_text_to_image(entry, unimodal=False):
|
65 |
+
image = entry['image']
|
66 |
+
text_mask = entry['text_mask']
|
67 |
+
text = entry['text']
|
68 |
+
orig_image = pad_to_square(image)
|
69 |
+
|
70 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
71 |
+
text_index = text_mask.argmax()
|
72 |
+
texts = [text[text_index]] if not unimodal else ['']
|
73 |
+
|
74 |
+
return img, texts, orig_image
|
75 |
+
|
76 |
+
|
77 |
+
def preprocess_ground_truth_mask(mask, resize_shape):
|
78 |
+
mask = Image.fromarray(mask.astype(np.uint8) * 255)
|
79 |
+
mask = pad_to_square(mask, color=0)
|
80 |
+
mask = mask.resize(resize_shape)
|
81 |
+
mask = np.asarray(mask) / 255.
|
82 |
+
return mask
|
83 |
+
|
84 |
+
|
85 |
+
def apply_otsu_threshold(relevance_map):
|
86 |
+
threshold = filters.threshold_otsu(relevance_map)
|
87 |
+
otsu_map = (relevance_map > threshold).astype(np.uint8)
|
88 |
+
return otsu_map
|
89 |
+
|
90 |
+
|
91 |
+
def evaluate_text_to_image(method, dataset, debug=False):
|
92 |
+
|
93 |
+
instance_level_metrics = defaultdict(list)
|
94 |
+
entry_level_metrics = defaultdict(list)
|
95 |
+
|
96 |
+
jaccard = JaccardIndex(num_classes=2)
|
97 |
+
jaccard = jaccard.to(device)
|
98 |
+
|
99 |
+
num_iter = len(dataset)
|
100 |
+
if debug:
|
101 |
+
num_iter = 100
|
102 |
+
|
103 |
+
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
|
104 |
+
for idx in iterator:
|
105 |
+
instance = dataset[idx]
|
106 |
+
|
107 |
+
instance_iou = 0.
|
108 |
+
for entry in instance:
|
109 |
+
|
110 |
+
# preprocess the image and text
|
111 |
+
unimodal = True if method == "clip-unimodal" else False
|
112 |
+
test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=unimodal)
|
113 |
+
|
114 |
+
if method in ["clip", "clip-unimodal"]:
|
115 |
+
|
116 |
+
# compute the relevance scores
|
117 |
+
outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
|
118 |
+
|
119 |
+
# use the image relevance score to compute IoU w.r.t. ground truth segmentation masks
|
120 |
+
|
121 |
+
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
|
122 |
+
relevance_map = outputs[0]["image_relevance"]
|
123 |
+
elif method == "random":
|
124 |
+
relevance_map = np.random.uniform(low=0., high=1., size=tuple(test_img.shape[2:]))
|
125 |
+
|
126 |
+
otsu_relevance_map = apply_otsu_threshold(relevance_map)
|
127 |
+
|
128 |
+
ground_truth_mask = entry["image_mask"]
|
129 |
+
ground_truth_mask = preprocess_ground_truth_mask(ground_truth_mask, relevance_map.shape)
|
130 |
+
|
131 |
+
entry_iou = jaccard(
|
132 |
+
torch.from_numpy(otsu_relevance_map).to(device),
|
133 |
+
torch.from_numpy(ground_truth_mask.astype(np.uint8)).to(device),
|
134 |
+
)
|
135 |
+
entry_iou = entry_iou.item()
|
136 |
+
instance_iou += (entry_iou / len(entry))
|
137 |
+
|
138 |
+
entry_level_metrics["iou"].append(entry_iou)
|
139 |
+
|
140 |
+
# capture instance (image-sentence pair) level IoU
|
141 |
+
instance_level_metrics["iou"].append(instance_iou)
|
142 |
+
|
143 |
+
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
|
144 |
+
|
145 |
+
return (
|
146 |
+
average_metrics,
|
147 |
+
instance_level_metrics,
|
148 |
+
entry_level_metrics
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def process_entry_image_to_text(entry, unimodal=False):
|
153 |
+
|
154 |
+
if not unimodal:
|
155 |
+
if len(np.asarray(entry["image"]).shape) == 3:
|
156 |
+
mask = np.repeat(np.expand_dims(entry['image_mask'], -1), 3, axis=-1)
|
157 |
+
else:
|
158 |
+
mask = np.asarray(entry['image_mask'])
|
159 |
+
|
160 |
+
masked_image = (mask * np.asarray(entry['image'])).astype(np.uint8)
|
161 |
+
masked_image = Image.fromarray(masked_image)
|
162 |
+
orig_image = pad_to_square(masked_image)
|
163 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
164 |
+
else:
|
165 |
+
orig_image_shape = max(np.asarray(entry['image']).shape[:2])
|
166 |
+
orig_image = Image.fromarray(np.zeros((orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
|
167 |
+
# orig_image = Image.fromarray(np.random.randint(0, 256, (orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
|
168 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
169 |
+
|
170 |
+
texts = [' '.join(entry['text'])]
|
171 |
+
|
172 |
+
return img, texts, orig_image
|
173 |
+
|
174 |
+
|
175 |
+
def process_text_mask(text, text_mask, tokens):
|
176 |
+
|
177 |
+
token_level_mask = np.zeros(len(tokens))
|
178 |
+
|
179 |
+
for label, subtext in zip(text_mask, text):
|
180 |
+
|
181 |
+
subtext_tokens=_tokenizer.encode(subtext)
|
182 |
+
subtext_tokens_decoded=[_tokenizer.decode([a]) for a in subtext_tokens]
|
183 |
+
|
184 |
+
if label == 1:
|
185 |
+
start = tokens.index(subtext_tokens_decoded[0])
|
186 |
+
end = tokens.index(subtext_tokens_decoded[-1])
|
187 |
+
token_level_mask[start:end + 1] = 1
|
188 |
+
|
189 |
+
return token_level_mask
|
190 |
+
|
191 |
+
|
192 |
+
def evaluate_image_to_text(method, dataset, debug=False, clamp_sentence_len=70):
|
193 |
+
|
194 |
+
instance_level_metrics = defaultdict(list)
|
195 |
+
entry_level_metrics = defaultdict(list)
|
196 |
+
|
197 |
+
# skipped if text length > 77 which is CLIP limit
|
198 |
+
num_entries_skipped = 0
|
199 |
+
num_total_entries = 0
|
200 |
+
|
201 |
+
num_iter = len(dataset)
|
202 |
+
if debug:
|
203 |
+
num_iter = 100
|
204 |
+
|
205 |
+
jaccard_image_to_text = JaccardIndex(num_classes=2).to(device)
|
206 |
+
|
207 |
+
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
|
208 |
+
for idx in iterator:
|
209 |
+
instance = dataset[idx]
|
210 |
+
|
211 |
+
instance_iou = 0.
|
212 |
+
for entry in instance:
|
213 |
+
num_total_entries += 1
|
214 |
+
|
215 |
+
# preprocess the image and text
|
216 |
+
unimodal = True if method == "clip-unimodal" else False
|
217 |
+
img, texts, orig_image = process_entry_image_to_text(entry, unimodal=unimodal)
|
218 |
+
|
219 |
+
appx_total_sent_len = np.sum([len(x.split(" ")) for x in texts])
|
220 |
+
if appx_total_sent_len > clamp_sentence_len:
|
221 |
+
# print(f"Skipping an entry since it's text has appx"\
|
222 |
+
# " {appx_total_sent_len} while CLIP cannot process beyond {clamp_sentence_len}")
|
223 |
+
num_entries_skipped += 1
|
224 |
+
continue
|
225 |
+
|
226 |
+
# compute the relevance scores
|
227 |
+
if method in ["clip", "clip-unimodal"]:
|
228 |
+
try:
|
229 |
+
outputs = interpret_and_generate(model, img, texts, orig_image, return_outputs=True, show=False)
|
230 |
+
except:
|
231 |
+
num_entries_skipped += 1
|
232 |
+
continue
|
233 |
+
elif method == "random":
|
234 |
+
text = texts[0]
|
235 |
+
text_tokens = _tokenizer.encode(text)
|
236 |
+
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
237 |
+
outputs = [
|
238 |
+
{
|
239 |
+
"text_scores": np.random.uniform(low=0., high=1., size=len(text_tokens_decoded)),
|
240 |
+
"tokens_decoded": text_tokens_decoded,
|
241 |
+
}
|
242 |
+
]
|
243 |
+
|
244 |
+
# use the text relevance score to compute IoU w.r.t. ground truth text masks
|
245 |
+
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
|
246 |
+
token_relevance_scores = outputs[0]["text_scores"]
|
247 |
+
if isinstance(token_relevance_scores, torch.Tensor):
|
248 |
+
token_relevance_scores = token_relevance_scores.cpu().numpy()
|
249 |
+
token_relevance_scores = apply_otsu_threshold(token_relevance_scores)
|
250 |
+
token_ground_truth_mask = process_text_mask(entry["text"], entry["text_mask"], outputs[0]["tokens_decoded"])
|
251 |
+
|
252 |
+
entry_iou = jaccard_image_to_text(
|
253 |
+
torch.from_numpy(token_relevance_scores).to(device),
|
254 |
+
torch.from_numpy(token_ground_truth_mask.astype(np.uint8)).to(device),
|
255 |
+
)
|
256 |
+
entry_iou = entry_iou.item()
|
257 |
+
|
258 |
+
instance_iou += (entry_iou / len(entry))
|
259 |
+
entry_level_metrics["iou"].append(entry_iou)
|
260 |
+
|
261 |
+
# capture instance (image-sentence pair) level IoU
|
262 |
+
instance_level_metrics["iou"].append(instance_iou)
|
263 |
+
|
264 |
+
print(f"CAUTION: Skipped {(num_entries_skipped / num_total_entries) * 100} % since these had length > 77 (CLIP limit).")
|
265 |
+
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
|
266 |
+
|
267 |
+
return (
|
268 |
+
average_metrics,
|
269 |
+
instance_level_metrics,
|
270 |
+
entry_level_metrics
|
271 |
+
)
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
|
276 |
+
import argparse
|
277 |
+
parser = argparse.ArgumentParser("Evaluate Image-to-Text & Text-to-Image model")
|
278 |
+
parser.add_argument(
|
279 |
+
"--eval_method", type=str, default="clip",
|
280 |
+
choices=["clip", "random", "clip-unimodal"],
|
281 |
+
help="Evaluation method to use",
|
282 |
+
)
|
283 |
+
parser.add_argument(
|
284 |
+
"--ignore_cache", action="store_true",
|
285 |
+
help="Ignore cache and force re-generation of the results",
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--debug", action="store_true",
|
289 |
+
help="Run evaluation on a small subset of the dataset",
|
290 |
+
)
|
291 |
+
args = parser.parse_args()
|
292 |
+
|
293 |
+
print_update("Using evaluation method: {}".format(args.eval_method))
|
294 |
+
|
295 |
+
|
296 |
+
clip.clip._MODELS = {
|
297 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
298 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
299 |
+
}
|
300 |
+
|
301 |
+
# specify device
|
302 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
303 |
+
|
304 |
+
# load CLIP model
|
305 |
+
print_update("Loading CLIP model...")
|
306 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
307 |
+
print()
|
308 |
+
|
309 |
+
# load PNG dataset
|
310 |
+
print_update("Loading PNG dataset...")
|
311 |
+
dataset = PNG(dataset_root=join(REPO_PATH, "data", "panoptic_narrative_grounding"), split="val2017")
|
312 |
+
print()
|
313 |
+
|
314 |
+
# evaluate
|
315 |
+
|
316 |
+
# save metrics
|
317 |
+
metrics_dir = join(REPO_PATH, "outputs")
|
318 |
+
os.makedirs(metrics_dir, exist_ok=True)
|
319 |
+
|
320 |
+
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_text2image_metrics.pt")
|
321 |
+
if (not exists(metrics_path)) or args.ignore_cache:
|
322 |
+
print_update("Computing metrics for text-to-image grounding")
|
323 |
+
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_text_to_image(
|
324 |
+
args.eval_method, dataset, debug=args.debug,
|
325 |
+
)
|
326 |
+
metrics = {
|
327 |
+
"average_metrics": average_metrics,
|
328 |
+
"instance_level_metrics":instance_level_metrics,
|
329 |
+
"entry_level_metrics": entry_level_metrics
|
330 |
+
}
|
331 |
+
|
332 |
+
torch.save(metrics, metrics_path)
|
333 |
+
print("TEXT2IMAGE METRICS SAVED TO:", metrics_path)
|
334 |
+
else:
|
335 |
+
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
|
336 |
+
metrics = torch.load(metrics_path)
|
337 |
+
average_metrics = metrics["average_metrics"]
|
338 |
+
print("TEXT2IMAGE METRICS:", np.round(average_metrics["iou"], 4))
|
339 |
+
|
340 |
+
print()
|
341 |
+
|
342 |
+
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_image2text_metrics.pt")
|
343 |
+
if (not exists(metrics_path)) or args.ignore_cache:
|
344 |
+
print_update("Computing metrics for image-to-text grounding")
|
345 |
+
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_image_to_text(
|
346 |
+
args.eval_method, dataset, debug=args.debug,
|
347 |
+
)
|
348 |
+
|
349 |
+
torch.save(
|
350 |
+
{
|
351 |
+
"average_metrics": average_metrics,
|
352 |
+
"instance_level_metrics":instance_level_metrics,
|
353 |
+
"entry_level_metrics": entry_level_metrics
|
354 |
+
},
|
355 |
+
metrics_path,
|
356 |
+
)
|
357 |
+
print("IMAGE2TEXT METRICS SAVED TO:", metrics_path)
|
358 |
+
else:
|
359 |
+
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
|
360 |
+
metrics = torch.load(metrics_path)
|
361 |
+
average_metrics = metrics["average_metrics"]
|
362 |
+
print("IMAGE2TEXT METRICS:", np.round(average_metrics["iou"], 4))
|
clip_grounding/evaluation/qualitative_results.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Converts notebook for qualitative results to a python script."""
|
2 |
+
import sys
|
3 |
+
from os.path import join
|
4 |
+
|
5 |
+
from clip_grounding.utils.paths import REPO_PATH
|
6 |
+
sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
|
7 |
+
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
from matplotlib.patches import Patch
|
13 |
+
import CLIP.clip as clip
|
14 |
+
import cv2
|
15 |
+
from PIL import Image
|
16 |
+
from glob import glob
|
17 |
+
from natsort import natsorted
|
18 |
+
|
19 |
+
from clip_grounding.utils.paths import REPO_PATH
|
20 |
+
from clip_grounding.utils.io import load_json
|
21 |
+
from clip_grounding.utils.visualize import set_latex_fonts, show_grid_of_images
|
22 |
+
from clip_grounding.utils.image import pad_to_square
|
23 |
+
from clip_grounding.datasets.png_utils import show_images_and_caption
|
24 |
+
from clip_grounding.datasets.png import (
|
25 |
+
PNG,
|
26 |
+
visualize_item,
|
27 |
+
overlay_segmask_on_image,
|
28 |
+
overlay_relevance_map_on_image,
|
29 |
+
get_text_colors,
|
30 |
+
)
|
31 |
+
from clip_grounding.evaluation.clip_on_png import (
|
32 |
+
process_entry_image_to_text,
|
33 |
+
process_entry_text_to_image,
|
34 |
+
interpret_and_generate,
|
35 |
+
)
|
36 |
+
|
37 |
+
# load dataset
|
38 |
+
dataset = PNG(dataset_root=join(REPO_PATH, "data/panoptic_narrative_grounding"), split="val2017")
|
39 |
+
|
40 |
+
# load CLIP model
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
43 |
+
|
44 |
+
|
45 |
+
def visualize_entry_text_to_image(entry, pad_images=True, figsize=(18, 5)):
|
46 |
+
test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=False)
|
47 |
+
outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
|
48 |
+
relevance_map = outputs[0]["image_relevance"]
|
49 |
+
|
50 |
+
image_with_mask = overlay_segmask_on_image(entry["image"], entry["image_mask"])
|
51 |
+
if pad_images:
|
52 |
+
image_with_mask = pad_to_square(image_with_mask)
|
53 |
+
|
54 |
+
image_with_relevance_map = overlay_relevance_map_on_image(entry["image"], relevance_map)
|
55 |
+
if pad_images:
|
56 |
+
image_with_relevance_map = pad_to_square(image_with_relevance_map)
|
57 |
+
|
58 |
+
text_colors = get_text_colors(entry["text"], entry["text_mask"])
|
59 |
+
|
60 |
+
show_images_and_caption(
|
61 |
+
[image_with_mask, image_with_relevance_map],
|
62 |
+
entry["text"], text_colors, figsize=figsize,
|
63 |
+
image_xlabels=["Ground truth segmentation", "Predicted relevance map"]
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def create_and_save_gif(filenames, save_path, **kwargs):
|
68 |
+
import imageio
|
69 |
+
images = []
|
70 |
+
for filename in filenames:
|
71 |
+
images.append(imageio.imread(filename))
|
72 |
+
imageio.mimsave(save_path, images, **kwargs)
|
73 |
+
|
74 |
+
|
75 |
+
idx = 100
|
76 |
+
instance = dataset[idx]
|
77 |
+
|
78 |
+
instance_dir = join(REPO_PATH, "figures", f"instance-{idx}")
|
79 |
+
os.makedirs(instance_dir, exist_ok=True)
|
80 |
+
|
81 |
+
for i, entry in enumerate(instance):
|
82 |
+
del entry["full_caption"]
|
83 |
+
|
84 |
+
visualize_entry_text_to_image(entry, pad_images=False, figsize=(19, 4))
|
85 |
+
|
86 |
+
save_path = instance_dir
|
87 |
+
plt.savefig(join(instance_dir, f"viz-{i}.png"), bbox_inches="tight")
|
88 |
+
|
89 |
+
|
90 |
+
filenames = natsorted(glob(join(instance_dir, "viz-*.png")))
|
91 |
+
save_path = join(REPO_PATH, "media", "sample.gif")
|
92 |
+
|
93 |
+
create_and_save_gif(filenames, save_path, duration=3)
|
clip_grounding/utils/image.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Image operations."""
|
2 |
+
from copy import deepcopy
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def center_crop(im: Image):
|
7 |
+
width, height = im.size
|
8 |
+
new_width = width if width < height else height
|
9 |
+
new_height = height if height < width else width
|
10 |
+
|
11 |
+
left = (width - new_width)/2
|
12 |
+
top = (height - new_height)/2
|
13 |
+
right = (width + new_width)/2
|
14 |
+
bottom = (height + new_height)/2
|
15 |
+
|
16 |
+
# Crop the center of the image
|
17 |
+
im = im.crop((left, top, right, bottom))
|
18 |
+
|
19 |
+
return im
|
20 |
+
|
21 |
+
|
22 |
+
def pad_to_square(im: Image, color=(0, 0, 0)):
|
23 |
+
im = deepcopy(im)
|
24 |
+
width, height = im.size
|
25 |
+
|
26 |
+
vert_pad = (max(width, height) - height) // 2
|
27 |
+
hor_pad = (max(width, height) - width) // 2
|
28 |
+
|
29 |
+
if len(im.mode) == 3:
|
30 |
+
color = (0, 0, 0)
|
31 |
+
elif len(im.mode) == 1:
|
32 |
+
color = 0
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
|
35 |
+
|
36 |
+
return add_margin(im, vert_pad, hor_pad, vert_pad, hor_pad, color=color)
|
37 |
+
|
38 |
+
|
39 |
+
def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
|
40 |
+
"""Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/"""
|
41 |
+
width, height = pil_img.size
|
42 |
+
new_width = width + right + left
|
43 |
+
new_height = height + top + bottom
|
44 |
+
result = Image.new(pil_img.mode, (new_width, new_height), color)
|
45 |
+
result.paste(pil_img, (left, top))
|
46 |
+
return result
|
clip_grounding/utils/io.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for input-output loading/saving.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Any, List
|
6 |
+
import yaml
|
7 |
+
import pickle
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
class PrettySafeLoader(yaml.SafeLoader):
|
12 |
+
"""Custom loader for reading YAML files"""
|
13 |
+
def construct_python_tuple(self, node):
|
14 |
+
return tuple(self.construct_sequence(node))
|
15 |
+
|
16 |
+
|
17 |
+
PrettySafeLoader.add_constructor(
|
18 |
+
u'tag:yaml.org,2002:python/tuple',
|
19 |
+
PrettySafeLoader.construct_python_tuple
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def load_yml(path: str, loader_type: str = 'default'):
|
24 |
+
"""Read params from a yml file.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
path (str): path to the .yml file
|
28 |
+
loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Any: object (typically dict) loaded from .yml file
|
32 |
+
"""
|
33 |
+
assert loader_type in ['default', 'safe']
|
34 |
+
|
35 |
+
loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
|
36 |
+
|
37 |
+
with open(path, 'r') as f:
|
38 |
+
data = yaml.load(f, Loader=loader)
|
39 |
+
|
40 |
+
return data
|
41 |
+
|
42 |
+
|
43 |
+
def save_yml(data: dict, path: str):
|
44 |
+
"""Save params in the given yml file path.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
data (dict): data object to save
|
48 |
+
path (str): path to .yml file to be saved
|
49 |
+
"""
|
50 |
+
with open(path, 'w') as f:
|
51 |
+
yaml.dump(data, f, default_flow_style=False)
|
52 |
+
|
53 |
+
|
54 |
+
def load_pkl(path: str, encoding: str = "ascii") -> Any:
|
55 |
+
"""Loads a .pkl file.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
path (str): path to the .pkl file
|
59 |
+
encoding (str, optional): encoding to use for loading. Defaults to "ascii".
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Any: unpickled object
|
63 |
+
"""
|
64 |
+
return pickle.load(open(path, "rb"), encoding=encoding)
|
65 |
+
|
66 |
+
|
67 |
+
def save_pkl(data: Any, path: str) -> None:
|
68 |
+
"""Saves given object into .pkl file
|
69 |
+
|
70 |
+
Args:
|
71 |
+
data (Any): object to be saved
|
72 |
+
path (str): path to the location to be saved at
|
73 |
+
"""
|
74 |
+
with open(path, 'wb') as f:
|
75 |
+
pickle.dump(data, f)
|
76 |
+
|
77 |
+
|
78 |
+
def load_json(path: str) -> dict:
|
79 |
+
"""Helper to load json file"""
|
80 |
+
with open(path, 'rb') as f:
|
81 |
+
data = json.load(f)
|
82 |
+
return data
|
83 |
+
|
84 |
+
|
85 |
+
def save_json(data: dict, path: str):
|
86 |
+
"""Helper to save `dict` as .json file."""
|
87 |
+
with open(path, 'w') as f:
|
88 |
+
json.dump(data, f)
|
89 |
+
|
90 |
+
|
91 |
+
def load_txt(path: str) -> List:
|
92 |
+
"""Loads lines of a .txt file.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
path (str): path to the .txt file
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
List: lines of .txt file
|
99 |
+
"""
|
100 |
+
with open(path) as f:
|
101 |
+
lines = f.read().splitlines()
|
102 |
+
return lines
|
103 |
+
|
104 |
+
|
105 |
+
def save_txt(data: dict, path: str):
|
106 |
+
"""Writes data (lines) to a txt file.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
data (dict): List of strings
|
110 |
+
path (str): path to .txt file
|
111 |
+
"""
|
112 |
+
assert isinstance(data, list)
|
113 |
+
|
114 |
+
lines = "\n".join(data)
|
115 |
+
with open(path, "w") as f:
|
116 |
+
f.write(str(lines))
|
clip_grounding/utils/log.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utilities for logging"""
|
2 |
+
import logging
|
3 |
+
from tqdm import tqdm
|
4 |
+
from termcolor import colored
|
5 |
+
|
6 |
+
|
7 |
+
def color(string: str, color_name: str = 'yellow') -> str:
|
8 |
+
"""Returns colored string for output to terminal"""
|
9 |
+
return colored(string, color_name)
|
10 |
+
|
11 |
+
|
12 |
+
def print_update(message: str, width: int = 140, fillchar: str = ":", color="yellow") -> str:
|
13 |
+
"""Prints an update message
|
14 |
+
|
15 |
+
Args:
|
16 |
+
message (str): message
|
17 |
+
width (int): width of new update message
|
18 |
+
fillchar (str): character to be filled to L and R of message
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
str: print-ready update message
|
22 |
+
"""
|
23 |
+
message = message.center(len(message) + 2, " ")
|
24 |
+
print(colored(message.center(width, fillchar), color))
|
25 |
+
|
26 |
+
|
27 |
+
def set_logger(log_path):
|
28 |
+
"""Set the logger to log info in terminal and file `log_path`.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
log_path (str): path to the log file
|
32 |
+
"""
|
33 |
+
logger = logging.getLogger()
|
34 |
+
logger.setLevel(logging.INFO)
|
35 |
+
|
36 |
+
if not logger.handlers:
|
37 |
+
# Logging to a file
|
38 |
+
file_handler = logging.FileHandler(log_path)
|
39 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
40 |
+
logger.addHandler(file_handler)
|
41 |
+
|
42 |
+
# Logging to console
|
43 |
+
stream_handler = logging.StreamHandler()
|
44 |
+
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
45 |
+
logger.addHandler(stream_handler)
|
46 |
+
|
47 |
+
|
48 |
+
def tqdm_iterator(items, desc=None, bar_format=None, **kwargs):
|
49 |
+
tqdm._instances.clear()
|
50 |
+
iterator = tqdm(
|
51 |
+
items,
|
52 |
+
desc=desc,
|
53 |
+
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
|
54 |
+
**kwargs,
|
55 |
+
)
|
56 |
+
|
57 |
+
return iterator
|
clip_grounding/utils/paths.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Path helpers for the relfm project."""
|
2 |
+
from os.path import join, abspath, dirname
|
3 |
+
|
4 |
+
|
5 |
+
REPO_PATH = dirname(dirname(dirname(abspath(__file__))))
|
6 |
+
DATA_ROOT = join(REPO_PATH, "data")
|
7 |
+
|
8 |
+
DATASET_ROOTS = {
|
9 |
+
"PNG": join(DATA_ROOT, "panoptic_narrative_grounding"),
|
10 |
+
}
|
clip_grounding/utils/visualize.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helpers for visualization"""
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
# define predominanat colors
|
10 |
+
COLORS = {
|
11 |
+
"pink": (242, 116, 223),
|
12 |
+
"cyan": (46, 242, 203),
|
13 |
+
"red": (255, 0, 0),
|
14 |
+
"green": (0, 255, 0),
|
15 |
+
"blue": (0, 0, 255),
|
16 |
+
"yellow": (255, 255, 0),
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, titlesize=18, cmap: str = None, ticks=False, save=False, save_path=None):
|
21 |
+
"""Show a single image."""
|
22 |
+
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
23 |
+
|
24 |
+
if isinstance(image, Image.Image):
|
25 |
+
image = np.asarray(image)
|
26 |
+
|
27 |
+
ax.set_title(title, fontsize=titlesize)
|
28 |
+
ax.imshow(image, cmap=cmap)
|
29 |
+
|
30 |
+
if not ticks:
|
31 |
+
ax.set_xticks([])
|
32 |
+
ax.set_yticks([])
|
33 |
+
|
34 |
+
if save:
|
35 |
+
plt.savefig(save_path, bbox_inches='tight')
|
36 |
+
|
37 |
+
plt.show()
|
38 |
+
|
39 |
+
|
40 |
+
def show_grid_of_images(
|
41 |
+
images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8),
|
42 |
+
cmap=None, subtitles=None, title=None, subtitlesize=18,
|
43 |
+
save=False, save_path=None, titlesize=20,
|
44 |
+
):
|
45 |
+
"""Show a grid of images."""
|
46 |
+
n_cols = min(n_cols, len(images))
|
47 |
+
|
48 |
+
copy_of_images = images.copy()
|
49 |
+
for i, image in enumerate(copy_of_images):
|
50 |
+
if isinstance(image, Image.Image):
|
51 |
+
image = np.asarray(image)
|
52 |
+
images[i] = image
|
53 |
+
|
54 |
+
if subtitles is None:
|
55 |
+
subtitles = [None] * len(images)
|
56 |
+
|
57 |
+
n_rows = int(np.ceil(len(images) / n_cols))
|
58 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
59 |
+
for i, ax in enumerate(axes.flat):
|
60 |
+
if i < len(images):
|
61 |
+
if len(images[i].shape) == 2 and cmap is None:
|
62 |
+
cmap="gray"
|
63 |
+
ax.imshow(images[i], cmap=cmap)
|
64 |
+
ax.set_title(subtitles[i], fontsize=subtitlesize)
|
65 |
+
ax.axis('off')
|
66 |
+
fig.set_tight_layout(True)
|
67 |
+
plt.suptitle(title, y=0.8, fontsize=titlesize)
|
68 |
+
|
69 |
+
if save:
|
70 |
+
plt.savefig(save_path, bbox_inches='tight')
|
71 |
+
plt.close()
|
72 |
+
else:
|
73 |
+
plt.show()
|
74 |
+
|
75 |
+
|
76 |
+
def show_keypoint_matches(
|
77 |
+
img1, kp1, img2, kp2, matches,
|
78 |
+
K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)),
|
79 |
+
choose_matches="random",
|
80 |
+
):
|
81 |
+
"""Displays matches found in the pair of images"""
|
82 |
+
if choose_matches == "random":
|
83 |
+
selected_matches = np.random.choice(matches, K)
|
84 |
+
elif choose_matches == "all":
|
85 |
+
K = len(matches)
|
86 |
+
selected_matches = matches
|
87 |
+
elif choose_matches == "topk":
|
88 |
+
selected_matches = matches[:K]
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Unknown value for choose_matches: {choose_matches}")
|
91 |
+
|
92 |
+
# color each match with a different color
|
93 |
+
cmap = matplotlib.cm.get_cmap('gist_rainbow', K)
|
94 |
+
colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)]
|
95 |
+
drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)})
|
96 |
+
|
97 |
+
img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
|
98 |
+
show_single_image(
|
99 |
+
img3,
|
100 |
+
figsize=figsize,
|
101 |
+
title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.",
|
102 |
+
)
|
103 |
+
return img3
|
104 |
+
|
105 |
+
|
106 |
+
def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="numpy"):
|
107 |
+
"""
|
108 |
+
Draw keypoints on image.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
image: Image to draw keypoints on.
|
112 |
+
kps: Keypoints to draw. Note these should be in (x, y) format.
|
113 |
+
"""
|
114 |
+
if isinstance(image, Image.Image):
|
115 |
+
image = np.asarray(image)
|
116 |
+
|
117 |
+
for kp in kps:
|
118 |
+
image = cv2.circle(
|
119 |
+
image, (int(kp[0]), int(kp[1])), radius=radius, color=color, thickness=thickness)
|
120 |
+
|
121 |
+
if return_as == "PIL":
|
122 |
+
return Image.fromarray(image)
|
123 |
+
|
124 |
+
return image
|
125 |
+
|
126 |
+
|
127 |
+
def get_concat_h(im1, im2):
|
128 |
+
"""Concatenate two images horizontally"""
|
129 |
+
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
|
130 |
+
dst.paste(im1, (0, 0))
|
131 |
+
dst.paste(im2, (im1.width, 0))
|
132 |
+
return dst
|
133 |
+
|
134 |
+
|
135 |
+
def get_concat_v(im1, im2):
|
136 |
+
"""Concatenate two images vertically"""
|
137 |
+
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
138 |
+
dst.paste(im1, (0, 0))
|
139 |
+
dst.paste(im2, (0, im1.height))
|
140 |
+
return dst
|
141 |
+
|
142 |
+
|
143 |
+
def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8), return_images=False, save=False, save_path="sample.png"):
|
144 |
+
assert len(images) == len(kps)
|
145 |
+
|
146 |
+
# generate
|
147 |
+
images_with_kps = []
|
148 |
+
for i in range(len(images)):
|
149 |
+
img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL")
|
150 |
+
images_with_kps.append(img_with_kps)
|
151 |
+
|
152 |
+
# show
|
153 |
+
show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize, save=save, save_path=save_path)
|
154 |
+
|
155 |
+
if return_images:
|
156 |
+
return images_with_kps
|
157 |
+
|
158 |
+
|
159 |
+
def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
|
160 |
+
try:
|
161 |
+
plt.rcParams.update({
|
162 |
+
"text.usetex": usetex,
|
163 |
+
"font.family": "serif",
|
164 |
+
"font.serif": ["Computer Modern Roman"],
|
165 |
+
"font.size": fontsize,
|
166 |
+
**kwargs,
|
167 |
+
})
|
168 |
+
if show_sample:
|
169 |
+
plt.figure()
|
170 |
+
plt.title("Sample $y = x^2$")
|
171 |
+
plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
|
172 |
+
plt.grid()
|
173 |
+
plt.show()
|
174 |
+
except:
|
175 |
+
print("Failed to setup LaTeX fonts. Proceeding without.")
|
176 |
+
pass
|
177 |
+
|
178 |
+
|
179 |
+
def get_colors(num_colors, palette="jet"):
|
180 |
+
cmap = plt.get_cmap(palette)
|
181 |
+
colors = [cmap(i) for i in np.linspace(0, 1, num_colors)]
|
182 |
+
return colors
|
183 |
+
|