import matplotlib import matplotlib.cm as cm import matplotlib.colors as mcolors import numpy as np import torch import torchvision from PIL import Image, ImageDraw, ImageFont from einops import rearrange from matplotlib import pyplot as plt def get_similarity(image_encodings, label_encodings, target_shape, interpolation="bilinear", do_argmax=False): """ Args: image_encodings: label_encodings: target_shape: interpolation: nearest, bilinear do_argmax: Returns: """ image_encodings = image_encodings.cpu() label_encodings = label_encodings.cpu() image_encodings = rearrange( image_encodings, "b (h w) d -> d b h w", h=int(np.sqrt(image_encodings.shape[-2])) ) # assuming square inputs & targets scale_ratio = (target_shape[-2] / image_encodings.shape[-2], target_shape[-1] / image_encodings.shape[-1],) temp_list = [] for i in image_encodings: i = i.unsqueeze(1) i = torch.nn.functional.interpolate( i, scale_factor=scale_ratio, mode=interpolation ) temp_list.append(i) image_encodings = torch.cat(temp_list, dim=1) image_encodings = rearrange(image_encodings, "b d h w -> b h w d") similarity = image_encodings @ label_encodings.T similarity = rearrange(similarity, "b h w d-> b d h w") if do_argmax: similarity = torch.argmax(similarity, dim=1, keepdim=True).to(torch.float64) return similarity def get_cmap(ncolors): if ncolors > 9: cmap = plt.cm.tab20 else: cmap = plt.cm.tab10 cmaplist = [cmap(i) for i in range(ncolors)] cmap = matplotlib.colors.LinearSegmentedColormap.from_list("custom", cmaplist, ncolors) mappable = cm.ScalarMappable(cmap=cmap) mappable.set_array([]) mappable.set_clim(-0.5, ncolors + 0.5) return cmap, mappable def vis_prediction(sample_text, img_arr, similarity): N = len(sample_text) cmap, mappable = get_cmap(N) fig, axs = plt.subplots(1, 2) _ = axs[0].imshow(img_arr) _ = axs[1].imshow(img_arr) _ = axs[1].imshow(similarity, cmap=cmap, interpolation="nearest", vmin=0, vmax=N, alpha=0.5) axs[0].axis("off") axs[1].axis("off") fig.subplots_adjust(bottom=0.2) cbar_ax = fig.add_axes([0.0, 0.85, 1.0, 0.05]) colorbar = plt.colorbar(mappable, cax=cbar_ax, cmap=cmap, orientation="horizontal") colorbar.set_ticks(np.linspace(0, N, N)) colorbar.set_ticklabels(sample_text) return fig class DummyArgs: def __init__(self, **kwargs): self.__dict__.update(kwargs) def get_transform(size=(224, 224)): transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) return transform def ade_palette(): """ADE20K palette that maps each class to RGB values.""" return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] def get_cmap_image(legend): # Define the size of the legend image width = 200 height = len(legend) * 20 # Create a new image with the desired size and background color img = Image.new('RGB', (width, height), (255, 255, 255)) # Create a drawing context draw = ImageDraw.Draw(img) # Define the font to use for the legend labels font = ImageFont.truetype('arial.ttf', 16) # Loop through the items in legend and draw a rectangle and label for each y = 0 for label, color in legend.items(): draw.rectangle((0, y, 20, y + 20), fill=color) draw.text((30, y), label, font=font, fill=(0, 0, 0)) y += 20 return img