prismer / label_prettify.py
shikunl's picture
Fix requirements
63bc825
import os
import json
import torch
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from prismer.utils import create_ade20k_label_colormap
obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels']
coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels']
ade_color = create_ade20k_label_colormap()
def islight(rgb):
r, g, b = rgb
hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b))
if hsp > 127.5:
return True
else:
return False
def depth_prettify(file_path):
depth = plt.imread(file_path)
plt.imsave(file_path, depth, cmap='rainbow')
def obj_detection_prettify(rgb_path, path_name):
rgb = plt.imread(rgb_path)
obj_labels = plt.imread(path_name)
obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
plt.imshow(rgb)
num_objs = np.unique(obj_labels)[:-1].max()
plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
cmap = matplotlib.colormaps.get_cmap('terrain')
for i in np.unique(obj_labels)[:-1]:
obj_idx_all = np.where(obj_labels == i)
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
obj_name = obj_name.split(',')[0]
if islight([c*255 for c in cmap(i / num_objs)[:3]]):
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
else:
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
plt.axis('off')
plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
plt.close()
def seg_prettify(rgb_path, file_name):
rgb = plt.imread(rgb_path)
seg_labels = plt.imread(file_name)
plt.imshow(rgb)
seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
for i in np.unique(seg_labels):
seg_map[seg_labels == i] = ade_color[int(i * 255)]
plt.imshow(seg_map, alpha=0.8)
for i in np.unique(seg_labels):
obj_idx_all = np.where(seg_labels == i)
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
obj_name = coco_label_map[int(i * 255)]
obj_name = obj_name.split(',')[0]
if islight(seg_map[int(y), int(x)]):
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
else:
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
plt.axis('off')
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
plt.close()
def ocr_detection_prettify(rgb_path, file_name):
if os.path.exists(file_name):
rgb = plt.imread(rgb_path)
ocr_labels = plt.imread(file_name)
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
plt.imshow(rgb)
plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
for i in np.unique(ocr_labels)[:-1]:
text_idx_all = np.where(ocr_labels == i)
x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
text = ocr_labels_dict[int(i * 255)]['text']
plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
plt.axis('off')
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
plt.close()
else:
rgb = plt.imread(rgb_path)
ocr_labels = np.ones_like(rgb, dtype=np.float32())
plt.imshow(rgb)
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
plt.axis('off')
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
plt.close()
def label_prettify(rgb_path, expert_paths):
for expert_path in expert_paths:
if 'depth' in expert_path:
depth_prettify(expert_path)
elif 'seg' in expert_path:
seg_prettify(rgb_path, expert_path)
elif 'ocr' in expert_path:
ocr_detection_prettify(rgb_path, expert_path)
elif 'obj' in expert_path:
obj_detection_prettify(rgb_path, expert_path)