Spaces:
Runtime error
Runtime error
File size: 5,107 Bytes
f949b3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from collections import defaultdict
import torch
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import Normalize
from matplotlib import cm
def pil_concat_v(images):
width = images[0].width
height = sum([image.height for image in images])
dst = Image.new('RGB', (width, height))
h = 0
for image_idx, image in enumerate(images):
dst.paste(image, (0, h))
h += image.height
return dst
def pil_concat_h(images):
width = sum([image.width for image in images])
height = images[0].height
dst = Image.new('RGB', (width, height))
w = 0
for image_idx, image in enumerate(images):
dst.paste(image, (w, 0))
w += image.width
return dst
def add_label(image, text, fontsize=12):
dst = Image.new('RGB', (image.width, image.height + fontsize*3))
dst.paste(image, (0, 0))
draw = ImageDraw.Draw(dst)
font = ImageFont.truetype("../misc/fonts/OpenSans.ttf", fontsize)
draw.text((fontsize, image.height + fontsize),text,(255,255,255),font=font)
return dst
def pil_concat(images, labels=None, col=8, fontsize=12):
col = min(col, len(images))
if labels is not None:
labeled_images = [add_label(image, labels[image_idx], fontsize=fontsize) for image_idx, image in enumerate(images)]
else:
labeled_images = images
labeled_images_rows = []
for row_idx in range(int(np.ceil(len(labeled_images) / col))):
labeled_images_rows.append(pil_concat_h(labeled_images[col*row_idx:col*(row_idx+1)]))
return pil_concat_v(labeled_images_rows)
def draw_panoptic_segmentation(model, segmentation, segments_info):
# get the used color map
viridis = cm.get_cmap('viridis')
norm = Normalize(vmin=segmentation.min().item(), vmax=segmentation.max().item())
fig, ax = plt.subplots()
ax.imshow(segmentation, cmap=viridis, norm=norm)
instances_counter = defaultdict(int)
handles = []
for segment in segments_info:
segment_id = segment['id']
segment_label_id = segment['label_id']
segment_label = model.config.id2label[segment_label_id]
label = f"{segment_label}-{instances_counter[segment_label_id]}"
instances_counter[segment_label_id] += 1
color = viridis(norm(segment_id))
handles.append(mpatches.Patch(color=color, label=label))
ax.legend(handles=handles)
rescale_ = lambda x: (x + 1.) / 2.
def pil_grid_display(x, mask=None, nrow=4, rescale=True):
if rescale:
x = rescale_(x)
if mask is not None:
mask = mask_to_3_channel(mask)
x = torch.concat([mask, x])
grid = make_grid(torch.clip(x, 0, 1), nrow=nrow)
return ToPILImage()(grid)
def pil_display(x, rescale=True):
if rescale:
x = rescale_(x)
image = torch.clip(rescale_(x), 0, 1)
return ToPILImage()(image)
def mask_to_3_channel(mask):
if mask.dim() == 3:
mask_c_idx = 0
elif mask.dim() == 4:
mask_c_idx = 1
else:
raise Exception("mask should be a 3d or 4d tensor")
if mask.shape[mask_c_idx] == 3:
return mask
elif mask.shape[mask_c_idx] == 1:
sizes = [1] * mask.dim()
sizes[mask_c_idx] = 3
mask = mask.repeat(*sizes)
else:
raise Exception("mask should have size 1 in channel dim")
return mask
def get_first_k_token_head_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False):
n_heads = atts_normed.shape[0]
att_images = []
for head_idx in range(n_heads):
atts_head = atts_normed[head_idx, :, :k].reshape(h, w, k).movedim(2, 0)
for token_idx in range(k):
att_head_np = atts_head[token_idx].detach().cpu().numpy()
if max_scale:
att_head_np = att_head_np / att_head_np.max()
att_image = Image.fromarray((att_head_np * 255).astype(np.uint8))
att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST)
att_images.append(att_image)
return pil_concat(att_images, col=k, labels=None)
def get_first_k_token_att_maps(atts_normed, k, h, w, output_h=256, output_w=256, labels=None, max_scale=False):
att_images = []
atts_head = atts_normed.mean(0)[:, :k].reshape(h, w, k).movedim(2, 0)
for token_idx in range(k):
att_head_np = atts_head[token_idx].detach().cpu().numpy()
if max_scale:
att_head_np = att_head_np / att_head_np.max()
att_image = Image.fromarray((att_head_np * 255).astype(np.uint8))
att_image = att_image.resize((output_h, output_w), Image.Resampling.NEAREST)
att_images.append(att_image)
return pil_concat(att_images, col=k, labels=None)
def draw_bbox(image, bbox):
image = image.copy()
left, top, right, bottom = bbox
image_draw = ImageDraw.Draw(image)
image_draw.rectangle(((left, top),(right, bottom)), outline='Red')
return image |