Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # -*- coding: utf-8 -*- | |
| """ | |
| @File : visualizer.py | |
| @Time : 2022/04/05 11:39:33 | |
| @Author : Shilong Liu | |
| @Contact : slongliu86@gmail.com | |
| """ | |
| import datetime | |
| import os | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from matplotlib import transforms | |
| from matplotlib.collections import PatchCollection | |
| from matplotlib.patches import Polygon | |
| from pycocotools import mask as maskUtils | |
| def renorm( | |
| img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ) -> torch.FloatTensor: | |
| # img: tensor(3,H,W) or tensor(B,3,H,W) | |
| # return: same as img | |
| assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() | |
| if img.dim() == 3: | |
| assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( | |
| img.size(0), | |
| str(img.size()), | |
| ) | |
| img_perm = img.permute(1, 2, 0) | |
| mean = torch.Tensor(mean) | |
| std = torch.Tensor(std) | |
| img_res = img_perm * std + mean | |
| return img_res.permute(2, 0, 1) | |
| else: # img.dim() == 4 | |
| assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( | |
| img.size(1), | |
| str(img.size()), | |
| ) | |
| img_perm = img.permute(0, 2, 3, 1) | |
| mean = torch.Tensor(mean) | |
| std = torch.Tensor(std) | |
| img_res = img_perm * std + mean | |
| return img_res.permute(0, 3, 1, 2) | |
| class ColorMap: | |
| def __init__(self, basergb=[255, 255, 0]): | |
| self.basergb = np.array(basergb) | |
| def __call__(self, attnmap): | |
| # attnmap: h, w. np.uint8. | |
| # return: h, w, 4. np.uint8. | |
| assert attnmap.dtype == np.uint8 | |
| h, w = attnmap.shape | |
| res = self.basergb.copy() | |
| res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3 | |
| attn1 = attnmap.copy()[..., None] # h, w, 1 | |
| res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) | |
| return res | |
| def rainbow_text(x, y, ls, lc, **kw): | |
| """ | |
| Take a list of strings ``ls`` and colors ``lc`` and place them next to each | |
| other, with text ls[i] being shown in color lc[i]. | |
| This example shows how to do both vertical and horizontal text, and will | |
| pass all keyword arguments to plt.text, so you can set the font size, | |
| family, etc. | |
| """ | |
| t = plt.gca().transData | |
| fig = plt.gcf() | |
| plt.show() | |
| # horizontal version | |
| for s, c in zip(ls, lc): | |
| text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw) | |
| text.draw(fig.canvas.get_renderer()) | |
| ex = text.get_window_extent() | |
| t = transforms.offset_copy(text._transform, x=ex.width, units="dots") | |
| # #vertical version | |
| # for s,c in zip(ls,lc): | |
| # text = plt.text(x,y," "+s+" ",color=c, transform=t, | |
| # rotation=90,va='bottom',ha='center',**kw) | |
| # text.draw(fig.canvas.get_renderer()) | |
| # ex = text.get_window_extent() | |
| # t = transforms.offset_copy(text._transform, y=ex.height, units='dots') | |
| class COCOVisualizer: | |
| def __init__(self, coco=None, tokenlizer=None) -> None: | |
| self.coco = coco | |
| def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"): | |
| """ | |
| img: tensor(3, H, W) | |
| tgt: make sure they are all on cpu. | |
| must have items: 'image_id', 'boxes', 'size' | |
| """ | |
| plt.figure(dpi=dpi) | |
| plt.rcParams["font.size"] = "5" | |
| ax = plt.gca() | |
| img = renorm(img).permute(1, 2, 0) | |
| # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
| # import ipdb; ipdb.set_trace() | |
| ax.imshow(img) | |
| self.addtgt(tgt) | |
| if tgt is None: | |
| image_id = 0 | |
| elif "image_id" not in tgt: | |
| image_id = 0 | |
| else: | |
| image_id = tgt["image_id"] | |
| if caption is None: | |
| savename = "{}/{}-{}.png".format( | |
| savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-") | |
| ) | |
| else: | |
| savename = "{}/{}-{}-{}.png".format( | |
| savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") | |
| ) | |
| print("savename: {}".format(savename)) | |
| os.makedirs(os.path.dirname(savename), exist_ok=True) | |
| plt.savefig(savename) | |
| plt.close() | |
| def addtgt(self, tgt): | |
| """ """ | |
| if tgt is None or not "boxes" in tgt: | |
| ax = plt.gca() | |
| if "caption" in tgt: | |
| ax.set_title(tgt["caption"], wrap=True) | |
| ax.set_axis_off() | |
| return | |
| ax = plt.gca() | |
| H, W = tgt["size"] | |
| numbox = tgt["boxes"].shape[0] | |
| color = [] | |
| polygons = [] | |
| boxes = [] | |
| for box in tgt["boxes"].cpu(): | |
| unnormbbox = box * torch.Tensor([W, H, W, H]) | |
| unnormbbox[:2] -= unnormbbox[2:] / 2 | |
| [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() | |
| boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) | |
| poly = [ | |
| [bbox_x, bbox_y], | |
| [bbox_x, bbox_y + bbox_h], | |
| [bbox_x + bbox_w, bbox_y + bbox_h], | |
| [bbox_x + bbox_w, bbox_y], | |
| ] | |
| np_poly = np.array(poly).reshape((4, 2)) | |
| polygons.append(Polygon(np_poly)) | |
| c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] | |
| color.append(c) | |
| p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) | |
| ax.add_collection(p) | |
| p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) | |
| ax.add_collection(p) | |
| if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0: | |
| assert ( | |
| len(tgt["strings_positive"]) == numbox | |
| ), f"{len(tgt['strings_positive'])} = {numbox}, " | |
| for idx, strlist in enumerate(tgt["strings_positive"]): | |
| cate_id = int(tgt["labels"][idx]) | |
| _string = str(cate_id) + ":" + " ".join(strlist) | |
| bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] | |
| # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) | |
| ax.text( | |
| bbox_x, | |
| bbox_y, | |
| _string, | |
| color="black", | |
| bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, | |
| ) | |
| if "box_label" in tgt: | |
| assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " | |
| for idx, bl in enumerate(tgt["box_label"]): | |
| _string = str(bl) | |
| bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] | |
| # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1}) | |
| ax.text( | |
| bbox_x, | |
| bbox_y, | |
| _string, | |
| color="black", | |
| bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, | |
| ) | |
| if "caption" in tgt: | |
| ax.set_title(tgt["caption"], wrap=True) | |
| # plt.figure() | |
| # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(), | |
| # ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black']) | |
| if "attn" in tgt: | |
| # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
| # import ipdb; ipdb.set_trace() | |
| if isinstance(tgt["attn"], tuple): | |
| tgt["attn"] = [tgt["attn"]] | |
| for item in tgt["attn"]: | |
| attn_map, basergb = item | |
| attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) | |
| attn_map = (attn_map * 255).astype(np.uint8) | |
| cm = ColorMap(basergb) | |
| heatmap = cm(attn_map) | |
| ax.imshow(heatmap) | |
| ax.set_axis_off() | |
| def showAnns(self, anns, draw_bbox=False): | |
| """ | |
| Display the specified annotations. | |
| :param anns (array of object): annotations to display | |
| :return: None | |
| """ | |
| if len(anns) == 0: | |
| return 0 | |
| if "segmentation" in anns[0] or "keypoints" in anns[0]: | |
| datasetType = "instances" | |
| elif "caption" in anns[0]: | |
| datasetType = "captions" | |
| else: | |
| raise Exception("datasetType not supported") | |
| if datasetType == "instances": | |
| ax = plt.gca() | |
| ax.set_autoscale_on(False) | |
| polygons = [] | |
| color = [] | |
| for ann in anns: | |
| c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] | |
| if "segmentation" in ann: | |
| if type(ann["segmentation"]) == list: | |
| # polygon | |
| for seg in ann["segmentation"]: | |
| poly = np.array(seg).reshape((int(len(seg) / 2), 2)) | |
| polygons.append(Polygon(poly)) | |
| color.append(c) | |
| else: | |
| # mask | |
| t = self.imgs[ann["image_id"]] | |
| if type(ann["segmentation"]["counts"]) == list: | |
| rle = maskUtils.frPyObjects( | |
| [ann["segmentation"]], t["height"], t["width"] | |
| ) | |
| else: | |
| rle = [ann["segmentation"]] | |
| m = maskUtils.decode(rle) | |
| img = np.ones((m.shape[0], m.shape[1], 3)) | |
| if ann["iscrowd"] == 1: | |
| color_mask = np.array([2.0, 166.0, 101.0]) / 255 | |
| if ann["iscrowd"] == 0: | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| for i in range(3): | |
| img[:, :, i] = color_mask[i] | |
| ax.imshow(np.dstack((img, m * 0.5))) | |
| if "keypoints" in ann and type(ann["keypoints"]) == list: | |
| # turn skeleton into zero-based index | |
| sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1 | |
| kp = np.array(ann["keypoints"]) | |
| x = kp[0::3] | |
| y = kp[1::3] | |
| v = kp[2::3] | |
| for sk in sks: | |
| if np.all(v[sk] > 0): | |
| plt.plot(x[sk], y[sk], linewidth=3, color=c) | |
| plt.plot( | |
| x[v > 0], | |
| y[v > 0], | |
| "o", | |
| markersize=8, | |
| markerfacecolor=c, | |
| markeredgecolor="k", | |
| markeredgewidth=2, | |
| ) | |
| plt.plot( | |
| x[v > 1], | |
| y[v > 1], | |
| "o", | |
| markersize=8, | |
| markerfacecolor=c, | |
| markeredgecolor=c, | |
| markeredgewidth=2, | |
| ) | |
| if draw_bbox: | |
| [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] | |
| poly = [ | |
| [bbox_x, bbox_y], | |
| [bbox_x, bbox_y + bbox_h], | |
| [bbox_x + bbox_w, bbox_y + bbox_h], | |
| [bbox_x + bbox_w, bbox_y], | |
| ] | |
| np_poly = np.array(poly).reshape((4, 2)) | |
| polygons.append(Polygon(np_poly)) | |
| color.append(c) | |
| # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) | |
| # ax.add_collection(p) | |
| p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) | |
| ax.add_collection(p) | |
| elif datasetType == "captions": | |
| for ann in anns: | |
| print(ann["caption"]) | |