""" Dataset object for Panoptic Narrative Grounding. Paper: https://openaccess.thecvf.com/content/ICCV2021/papers/Gonzalez_Panoptic_Narrative_Grounding_ICCV_2021_paper.pdf """ import os from os.path import join, isdir, exists import torch from torch.utils.data import Dataset import cv2 from PIL import Image from skimage import io import numpy as np import textwrap import matplotlib.pyplot as plt from matplotlib import transforms from imgaug.augmentables.segmaps import SegmentationMapsOnImage import matplotlib.colors as mc from clip_grounding.utils.io import load_json from clip_grounding.datasets.png_utils import show_image_and_caption class PNG(Dataset): """Panoptic Narrative Grounding.""" def __init__(self, dataset_root, split) -> None: """ Initializer. Args: dataset_root (str): path to the folder containing PNG dataset split (str): MS-COCO split such as train2017/val2017 """ super().__init__() assert isdir(dataset_root) self.dataset_root = dataset_root assert split in ["val2017"], f"Split {split} not supported. "\ "Currently, only supports split `val2017`." self.split = split self.ann_dir = join(self.dataset_root, "annotations") # feat_dir = join(self.dataset_root, "features") panoptic = load_json(join(self.ann_dir, "panoptic_{:s}.json".format(split))) images = panoptic["images"] self.images_info = {i["id"]: i for i in images} panoptic_anns = panoptic["annotations"] self.panoptic_anns = {int(a["image_id"]): a for a in panoptic_anns} # self.panoptic_pred_path = join( # feat_dir, split, "panoptic_seg_predictions" # ) # assert isdir(self.panoptic_pred_path) panoptic_narratives_path = join(self.dataset_root, "annotations", f"png_coco_{split}.json") self.panoptic_narratives = load_json(panoptic_narratives_path) def __len__(self): return len(self.panoptic_narratives) def get_image_path(self, image_id: str): image_path = join(self.dataset_root, "images", self.split, f"{image_id.zfill(12)}.jpg") return image_path def __getitem__(self, idx: int): narr = self.panoptic_narratives[idx] image_id = narr["image_id"] image_path = self.get_image_path(image_id) assert exists(image_path) image = Image.open(image_path) caption = narr["caption"] # show_single_image(image, title=caption, titlesize=12) segments = narr["segments"] image_id = int(narr["image_id"]) panoptic_ann = self.panoptic_anns[image_id] panoptic_ann = self.panoptic_anns[image_id] segment_infos = {} for s in panoptic_ann["segments_info"]: idi = s["id"] segment_infos[idi] = s image_info = self.images_info[image_id] panoptic_segm = io.imread( join( self.ann_dir, "panoptic_segmentation", self.split, "{:012d}.png".format(image_id), ) ) panoptic_segm = ( panoptic_segm[:, :, 0] + panoptic_segm[:, :, 1] * 256 + panoptic_segm[:, :, 2] * 256 ** 2 ) panoptic_ann = self.panoptic_anns[image_id] # panoptic_pred = io.imread( # join(self.panoptic_pred_path, "{:012d}.png".format(image_id)) # )[:, :, 0] # # select a single utterance to visualize # segment = segments[7] # segment_ids = segment["segment_ids"] # segment_mask = np.zeros((image_info["height"], image_info["width"])) # for segment_id in segment_ids: # segment_id = int(segment_id) # segment_mask[panoptic_segm == segment_id] = 1. utterances = [s["utterance"] for s in segments] outputs = [] for i, segment in enumerate(segments): # create segmentation mask on image segment_ids = segment["segment_ids"] # if no annotation for this word, skip if not len(segment_ids): continue segment_mask = np.zeros((image_info["height"], image_info["width"])) for segment_id in segment_ids: segment_id = int(segment_id) segment_mask[panoptic_segm == segment_id] = 1. # store the outputs text_mask = np.zeros(len(utterances)) text_mask[i] = 1. segment_data = dict( image=image, text=utterances, image_mask=segment_mask, text_mask=text_mask, full_caption=caption, ) outputs.append(segment_data) # # visualize segmentation mask with associated text # segment_color = "red" # segmap = SegmentationMapsOnImage( # segment_mask.astype(np.uint8), shape=segment_mask.shape, # ) # image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, COLORS[segment_color]])[0] # image_with_segmap = Image.fromarray(image_with_segmap) # colors = ["black" for _ in range(len(utterances))] # colors[i] = segment_color # show_image_and_caption(image_with_segmap, utterances, colors) return outputs def overlay_segmask_on_image(image, image_mask, segment_color="red"): segmap = SegmentationMapsOnImage( image_mask.astype(np.uint8), shape=image_mask.shape, ) rgb_color = mc.to_rgb(segment_color) rgb_color = 255 * np.array(rgb_color) image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0] image_with_segmap = Image.fromarray(image_with_segmap) return image_with_segmap def get_text_colors(text, text_mask, segment_color="red"): colors = ["black" for _ in range(len(text))] colors[text_mask.nonzero()[0][0]] = segment_color return colors def overlay_relevance_map_on_image(image, heatmap): width, height = image.size # resize the heatmap to image size heatmap = cv2.resize(heatmap, (width, height)) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # create overlapped super image img = np.asarray(image) super_img = heatmap * 0.4 + img * 0.6 super_img = np.uint8(super_img) super_img = Image.fromarray(super_img) return super_img def visualize_item(image, text, image_mask, text_mask, segment_color="red"): segmap = SegmentationMapsOnImage( image_mask.astype(np.uint8), shape=image_mask.shape, ) rgb_color = mc.to_rgb(segment_color) rgb_color = 255 * np.array(rgb_color) image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0] image_with_segmap = Image.fromarray(image_with_segmap) colors = ["black" for _ in range(len(text))] text_idx = text_mask.argmax() colors[text_idx] = segment_color show_image_and_caption(image_with_segmap, text, colors) if __name__ == "__main__": from clip_grounding.utils.paths import REPO_PATH, DATASET_ROOTS PNG_ROOT = DATASET_ROOTS["PNG"] dataset = PNG(dataset_root=PNG_ROOT, split="val2017") item = dataset[0] sub_item = item[1] visualize_item( image=sub_item["image"], text=sub_item["text"], image_mask=sub_item["image_mask"], text_mask=sub_item["text_mask"], segment_color="red", )