Spaces:
Runtime error
Runtime error
"""Helper functions for Panoptic Narrative Grounding.""" | |
import os | |
from os.path import join, isdir, exists | |
from typing import List | |
import torch | |
from PIL import Image | |
from skimage import io | |
import numpy as np | |
import textwrap | |
import matplotlib.pyplot as plt | |
from matplotlib import transforms | |
from imgaug.augmentables.segmaps import SegmentationMapsOnImage | |
def rainbow_text(x,y,ls,lc,fig, ax,**kw): | |
""" | |
Take a list of strings ``ls`` and colors ``lc`` and place them next to each | |
other, with text ls[i] being shown in color lc[i]. | |
Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib | |
""" | |
t = ax.transAxes | |
for s,c in zip(ls,lc): | |
text = ax.text(x,y,s+" ",color=c, transform=t, **kw) | |
text.draw(fig.canvas.get_renderer()) | |
ex = text.get_window_extent() | |
t = transforms.offset_copy(text._transform, x=ex.width, units='dots') | |
def find_first_index_greater_than(elements, key): | |
return next(x[0] for x in enumerate(elements) if x[1] > key) | |
def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50): | |
char_lengths = np.cumsum([len(x) for x in caption_phrases]) | |
thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)] | |
utt_per_line = [] | |
col_per_line = [] | |
start_index = 0 | |
for t in thresholds: | |
index = find_first_index_greater_than(char_lengths, t) | |
utt_per_line.append(caption_phrases[start_index:index]) | |
col_per_line.append(colors[start_index:index]) | |
start_index = index | |
return utt_per_line, col_per_line | |
def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None): | |
if colors is None: | |
colors = ["black" for _ in range(len(caption_phrases))] | |
fig, axes = plt.subplots(1, 2, figsize=(15, 4)) | |
ax = axes[0] | |
ax.imshow(image) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax = axes[1] | |
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50) | |
y = 0.7 | |
for U, C in zip(utt_per_line, col_per_line): | |
rainbow_text( | |
0., y, | |
U, | |
C, | |
size=15, ax=ax, fig=fig, | |
horizontalalignment='left', | |
verticalalignment='center', | |
) | |
y -= 0.11 | |
ax.axis("off") | |
fig.tight_layout() | |
plt.show() | |
def show_images_and_caption( | |
images: List, | |
caption_phrases: list, | |
colors: list = None, | |
image_xlabels: List=[], | |
figsize=None, | |
show=False, | |
xlabelsize=14, | |
): | |
if colors is None: | |
colors = ["black" for _ in range(len(caption_phrases))] | |
caption_phrases[0] = caption_phrases[0].capitalize() | |
if figsize is None: | |
figsize = (5 * len(images) + 8, 4) | |
if image_xlabels is None: | |
image_xlabels = ["" for _ in range(len(images))] | |
fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize) | |
for i, image in enumerate(images): | |
ax = axes[i] | |
ax.imshow(image) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize) | |
ax = axes[-1] | |
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40) | |
y = 0.7 | |
for U, C in zip(utt_per_line, col_per_line): | |
rainbow_text( | |
0., y, | |
U, | |
C, | |
size=23, ax=ax, fig=fig, | |
horizontalalignment='left', | |
verticalalignment='center', | |
# weight='bold' | |
) | |
y -= 0.11 | |
ax.axis("off") | |
fig.tight_layout() | |
if show: | |
plt.show() | |