Spaces:
Build error
Build error
| from collections import defaultdict | |
| import torch | |
| from torchvision.utils import make_grid | |
| from torchvision.transforms import ToPILImage | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.colors import Normalize | |
| from matplotlib import cm | |
| def pil_concat_v(images): | |
| width = images[0].width | |
| height = sum([image.height for image in images]) | |
| dst = Image.new('RGB', (width, height)) | |
| h = 0 | |
| for image_idx, image in enumerate(images): | |
| dst.paste(image, (0, h)) | |
| h += image.height | |
| return dst | |
| def pil_concat_h(images): | |
| width = sum([image.width for image in images]) | |
| height = images[0].height | |
| dst = Image.new('RGB', (width, height)) | |
| w = 0 | |
| for image_idx, image in enumerate(images): | |
| dst.paste(image, (w, 0)) | |
| w += image.width | |
| return dst | |
| def add_label(image, text, fontsize=12): | |
| dst = Image.new('RGB', (image.width, image.height + fontsize*3)) | |
| dst.paste(image, (0, 0)) | |
| draw = ImageDraw.Draw(dst) | |
| font = ImageFont.truetype("../misc/fonts/OpenSans.ttf", fontsize) | |
| draw.text((fontsize, image.height + fontsize),text,(255,255,255),font=font) | |
| return dst | |
| def pil_concat(images, labels=None, col=8, fontsize=12): | |
| col = min(col, len(images)) | |
| if labels is not None: | |
| labeled_images = [add_label(image, labels[image_idx], fontsize=fontsize) for image_idx, image in enumerate(images)] | |
| else: | |
| labeled_images = images | |
| labeled_images_rows = [] | |
| for row_idx in range(int(np.ceil(len(labeled_images) / col))): | |
| labeled_images_rows.append(pil_concat_h(labeled_images[col*row_idx:col*(row_idx+1)])) | |
| return pil_concat_v(labeled_images_rows) | |
| def draw_panoptic_segmentation(model, segmentation, segments_info): | |
| # get the used color map | |
| viridis = cm.get_cmap('viridis') | |
| norm = Normalize(vmin=segmentation.min().item(), vmax=segmentation.max().item()) | |
| fig, ax = plt.subplots() | |
| ax.imshow(segmentation, cmap=viridis, norm=norm) | |
| instances_counter = defaultdict(int) | |
| handles = [] | |
| for segment in segments_info: | |
| segment_id = segment['id'] | |
| segment_label_id = segment['label_id'] | |
| segment_label = model.config.id2label[segment_label_id] | |
| label = f"{segment_label}-{instances_counter[segment_label_id]}" | |
| instances_counter[segment_label_id] += 1 | |
| color = viridis(norm(segment_id)) | |
| handles.append(mpatches.Patch(color=color, label=label)) | |
| ax.legend(handles=handles) | |
| rescale_ = lambda x: (x + 1.) / 2. | |
| def pil_grid_display(x, mask=None, nrow=4, rescale=True): | |
| if rescale: | |
| x = rescale_(x) | |
| if mask is not None: | |
| mask = mask_to_3_channel(mask) | |
| x = torch.concat([mask, x]) | |
| grid = make_grid(torch.clip(x, 0, 1), nrow=nrow) | |
| return ToPILImage()(grid) | |
| def pil_display(x, rescale=True): | |
| if rescale: | |
| x = rescale_(x) | |
| image = torch.clip(rescale_(x), 0, 1) | |
| return ToPILImage()(image) | |
| def mask_to_3_channel(mask): | |
| if mask.dim() == 3: | |
| mask_c_idx = 0 | |
| elif mask.dim() == 4: | |
| mask_c_idx = 1 | |
| else: | |
| raise Exception("mask should be a 3d or 4d tensor") | |
| if mask.shape[mask_c_idx] == 3: | |
| return mask | |
| elif mask.shape[mask_c_idx] == 1: | |
| sizes = [1] * mask.dim() | |
| sizes[mask_c_idx] = 3 | |
| mask = mask.repeat(*sizes) | |
| else: | |
| raise Exception("mask should have size 1 in channel dim") | |
| return mask | |
| def get_first_k_token_head_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): | |
| n_heads = atts_normed.shape[0] | |
| att_images = [] | |
| for head_idx in range(n_heads): | |
| atts_head = atts_normed[head_idx, :, :k].reshape(h, w, k).movedim(2, 0) | |
| for token_idx in range(k): | |
| att_head_np = atts_head[token_idx].detach().cpu().numpy() | |
| if max_scale: | |
| att_head_np = att_head_np / att_head_np.max() | |
| att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) | |
| att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) | |
| att_images.append(att_image) | |
| return pil_concat(att_images, col=k, labels=None) | |
| def get_first_k_token_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False): | |
| att_images = [] | |
| atts_head = atts_normed.mean(0)[:, :k].reshape(h, w, k).movedim(2, 0) | |
| for token_idx in range(k): | |
| att_head_np = atts_head[token_idx].detach().cpu().numpy() | |
| if max_scale: | |
| att_head_np = att_head_np / att_head_np.max() | |
| att_image = Image.fromarray((att_head_np * 255).astype(np.uint8)) | |
| att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST) | |
| att_images.append(att_image) | |
| return pil_concat(att_images, col=k, labels=None) | |
| def draw_bbox(image, bbox): | |
| image = image.copy() | |
| left, top, right, bottom = bbox | |
| image_draw = ImageDraw.Draw(image) | |
| image_draw.rectangle(((left, top),(right, bottom)), outline='Red') | |
| return image |