from collections import defaultdict import torch from torchvision.utils import make_grid from torchvision.transforms import ToPILImage import numpy as np from PIL import Image, ImageDraw, ImageFont import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.colors import Normalize from matplotlib import cm def pil_concat_v(images): width = images[0].width height = sum([image.height for image in images]) dst = Image.new('RGB', (width, height)) h = 0 for image_idx, image in enumerate(images): dst.paste(image, (0, h)) h += image.height return dst def pil_concat_h(images): width = sum([image.width for image in images]) height = images[0].height dst = Image.new('RGB', (width, height)) w = 0 for image_idx, image in enumerate(images): dst.paste(image, (w, 0)) w += image.width return dst def add_label(image, text, fontsize=12): dst = Image.new('RGB', (image.width, image.height + fontsize*3)) dst.paste(image, (0, 0)) draw = ImageDraw.Draw(dst) font = ImageFont.truetype("../misc/fonts/OpenSans.ttf", fontsize) draw.text((fontsize, image.height + fontsize),text,(255,255,255),font=font) return dst def pil_concat(images, labels=None, col=8, fontsize=12): col = min(col, len(images)) if labels is not None: labeled_images = [add_label(image, labels[image_idx], fontsize=fontsize) for image_idx, image in enumerate(images)] else: labeled_images = images labeled_images_rows = [] for row_idx in range(int(np.ceil(len(labeled_images) / col))): labeled_images_rows.append(pil_concat_h(labeled_images[col*row_idx:col*(row_idx+1)])) return pil_concat_v(labeled_images_rows) def draw_panoptic_segmentation(model, segmentation, segments_info): # get the used color map viridis = cm.get_cmap('viridis') norm = Normalize(vmin=segmentation.min().item(), vmax=segmentation.max().item()) fig, ax = plt.subplots() ax.imshow(segmentation, cmap=viridis, norm=norm) instances_counter = defaultdict(int) handles = [] for segment in segments_info: segment_id = segment['id'] segment_label_id = segment['label_id'] segment_label = model.config.id2label[segment_label_id] label = f"{segment_label}-{instances_counter[segment_label_id]}" instances_counter[segment_label_id] += 1 color = viridis(norm(segment_id)) handles.append(mpatches.Patch(color=color, label=label)) ax.legend(handles=handles) rescale_ = lambda x: (x + 1.) / 2. def pil_grid_display(x, mask=None, nrow=4, rescale=True): if rescale: x = rescale_(x) if mask is not None: mask = mask_to_3_channel(mask) x = torch.concat([mask, x]) grid = make_grid(torch.clip(x, 0, 1), nrow=nrow) return ToPILImage()(grid) def pil_display(x, rescale=True): if rescale: x = rescale_(x) image = torch.clip(rescale_(x), 0, 1) return ToPILImage()(image) def mask_to_3_channel(mask): if mask.dim() == 3: mask_c_idx = 0 elif mask.dim() == 4: mask_c_idx = 1 else: raise Exception("mask should be a 3d or 4d tensor") if mask.shape[mask_c_idx] == 3: return mask elif mask.shape[mask_c_idx] == 1: sizes = [1] * mask.dim() sizes[mask_c_idx] = 3 mask = mask.repeat(*sizes) else: raise Exception("mask should have size 1 in channel dim") return mask def get_first_k_token_head_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): n_heads = atts_normed.shape[0] att_images = [] for head_idx in range(n_heads): atts_head = atts_normed[head_idx, :, :k].reshape(h, w, k).movedim(2, 0) for token_idx in range(k): att_head_np = atts_head[token_idx].detach().cpu().numpy() if max_scale: att_head_np = att_head_np / att_head_np.max() att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) att_images.append(att_image) return pil_concat(att_images, col=k, labels=None) def get_first_k_token_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): att_images = [] atts_head = atts_normed.mean(0)[:, :k].reshape(h, w, k).movedim(2, 0) for token_idx in range(k): att_head_np = atts_head[token_idx].detach().cpu().numpy() if max_scale: att_head_np = att_head_np / att_head_np.max() att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) att_images.append(att_image) return pil_concat(att_images, col=k, labels=None) def draw_bbox(image, bbox): image = image.copy() left, top, right, bottom = bbox image_draw = ImageDraw.Draw(image) image_draw.rectangle(((left, top),(right, bottom)), outline='Red') return image