"""Helper functions for Panoptic Narrative Grounding.""" import os from os.path import join, isdir, exists from typing import List import torch from PIL import Image from skimage import io import numpy as np import textwrap import matplotlib.pyplot as plt from matplotlib import transforms from imgaug.augmentables.segmaps import SegmentationMapsOnImage def rainbow_text(x,y,ls,lc,fig, ax,**kw): """ Take a list of strings ``ls`` and colors ``lc`` and place them next to each other, with text ls[i] being shown in color lc[i]. Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib """ t = ax.transAxes for s,c in zip(ls,lc): text = ax.text(x,y,s+" ",color=c, transform=t, **kw) text.draw(fig.canvas.get_renderer()) ex = text.get_window_extent() t = transforms.offset_copy(text._transform, x=ex.width, units='dots') def find_first_index_greater_than(elements, key): return next(x[0] for x in enumerate(elements) if x[1] > key) def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50): char_lengths = np.cumsum([len(x) for x in caption_phrases]) thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)] utt_per_line = [] col_per_line = [] start_index = 0 for t in thresholds: index = find_first_index_greater_than(char_lengths, t) utt_per_line.append(caption_phrases[start_index:index]) col_per_line.append(colors[start_index:index]) start_index = index return utt_per_line, col_per_line def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None): if colors is None: colors = ["black" for _ in range(len(caption_phrases))] fig, axes = plt.subplots(1, 2, figsize=(15, 4)) ax = axes[0] ax.imshow(image) ax.set_xticks([]) ax.set_yticks([]) ax = axes[1] utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50) y = 0.7 for U, C in zip(utt_per_line, col_per_line): rainbow_text( 0., y, U, C, size=15, ax=ax, fig=fig, horizontalalignment='left', verticalalignment='center', ) y -= 0.11 ax.axis("off") fig.tight_layout() plt.show() def show_images_and_caption( images: List, caption_phrases: list, colors: list = None, image_xlabels: List=[], figsize=None, show=False, xlabelsize=14, ): if colors is None: colors = ["black" for _ in range(len(caption_phrases))] caption_phrases[0] = caption_phrases[0].capitalize() if figsize is None: figsize = (5 * len(images) + 8, 4) if image_xlabels is None: image_xlabels = ["" for _ in range(len(images))] fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize) for i, image in enumerate(images): ax = axes[i] ax.imshow(image) ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize) ax = axes[-1] utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40) y = 0.7 for U, C in zip(utt_per_line, col_per_line): rainbow_text( 0., y, U, C, size=23, ax=ax, fig=fig, horizontalalignment='left', verticalalignment='center', # weight='bold' ) y -= 0.11 ax.axis("off") fig.tight_layout() if show: plt.show()