import os import json import random import torch import matplotlib.pyplot as plt import matplotlib import numpy as np import shutil from prismer.utils import create_ade20k_label_colormap matplotlib.use('agg') obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels'] coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels'] ade_color = create_ade20k_label_colormap() def islight(rgb): r, g, b = rgb hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b)) if hsp > 127.5: return True else: return False def depth_prettify(file_path): pretty_path = file_path.replace('.png', '_p.png') if not os.path.exists(pretty_path): depth = plt.imread(file_path) plt.imsave(pretty_path, depth, cmap='rainbow') def obj_detection_prettify(rgb_path, path_name): pretty_path = path_name.replace('.png', '_p.png') if not os.path.exists(pretty_path): rgb = plt.imread(rgb_path) obj_labels = plt.imread(path_name) obj_labels_dict = json.load(open(path_name.replace('.png', '.json'))) plt.imshow(rgb) if len(np.unique(obj_labels)) == 1: plt.axis('off') plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0) plt.close() else: num_objs = np.unique(obj_labels)[:-1].max() plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8) cmap = matplotlib.colormaps.get_cmap('terrain') for i in np.unique(obj_labels)[:-1]: obj_idx_all = np.where(obj_labels == i) x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean() obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]] obj_name = obj_name.split(',')[0] if islight([c*255 for c in cmap(i / num_objs)[:3]]): plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) else: plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) plt.axis('off') plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) plt.close() def seg_prettify(rgb_path, file_name): pretty_path = file_name.replace('.png', '_p.png') if not os.path.exists(pretty_path): rgb = plt.imread(rgb_path) seg_labels = plt.imread(file_name) plt.imshow(rgb) seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16) for i in np.unique(seg_labels): seg_map[seg_labels == i] = ade_color[int(i * 255)] plt.imshow(seg_map, alpha=0.8) for i in np.unique(seg_labels): obj_idx_all = np.where(seg_labels == i) if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20 obj_idx = random.randint(0, len(obj_idx_all[0]) - 1) x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx] obj_name = coco_label_map[int(i * 255)] obj_name = obj_name.split(',')[0] if islight(seg_map[int(y), int(x)]): plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) else: plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) plt.axis('off') plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) plt.close() def ocr_detection_prettify(rgb_path, file_name): pretty_path = file_name.replace('.png', '_p.png') if not os.path.exists(pretty_path): if os.path.exists(file_name): rgb = plt.imread(rgb_path) ocr_labels = plt.imread(file_name) ocr_labels_dict = torch.load(file_name.replace('.png', '.pt')) plt.imshow(rgb) plt.imshow(ocr_labels, cmap='gray', alpha=0.8) for i in np.unique(ocr_labels)[:-1]: text_idx_all = np.where(ocr_labels == i) x, y = text_idx_all[1].mean(), text_idx_all[0].mean() text = ocr_labels_dict[int(i * 255)]['text'] plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) plt.axis('off') plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) plt.close() else: rgb = plt.imread(rgb_path) ocr_labels = np.ones_like(rgb, dtype=np.float32()) plt.imshow(rgb) plt.imshow(ocr_labels, cmap='gray', alpha=0.8) x, y = rgb.shape[1] / 2, rgb.shape[0] / 2 plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) plt.axis('off') os.makedirs(os.path.dirname(file_name), exist_ok=True) plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0) plt.close() def label_prettify(rgb_path, expert_paths): for expert_path in expert_paths: if 'depth' in expert_path: depth_prettify(expert_path) elif 'seg' in expert_path: seg_prettify(rgb_path, expert_path) elif 'ocr' in expert_path: ocr_detection_prettify(rgb_path, expert_path) elif 'obj' in expert_path: obj_detection_prettify(rgb_path, expert_path) else: pretty_path = expert_path.replace('.png', '_p.png') if not os.path.exists(pretty_path): shutil.copyfile(expert_path, pretty_path)