Spaces:
Runtime error
Runtime error
import os | |
import json | |
import random | |
import torch | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import numpy as np | |
from prismer.utils import create_ade20k_label_colormap | |
obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels'] | |
coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels'] | |
ade_color = create_ade20k_label_colormap() | |
def islight(rgb): | |
r, g, b = rgb | |
hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b)) | |
if hsp > 127.5: | |
return True | |
else: | |
return False | |
def depth_prettify(file_path): | |
depth = plt.imread(file_path) | |
plt.imsave(file_path, depth, cmap='rainbow') | |
def obj_detection_prettify(rgb_path, path_name): | |
rgb = plt.imread(rgb_path) | |
obj_labels = plt.imread(path_name) | |
obj_labels_dict = json.load(open(path_name.replace('.png', '.json'))) | |
plt.imshow(rgb) | |
if len(np.unique(obj_labels)) == 1: | |
plt.axis('off') | |
plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
else: | |
num_objs = np.unique(obj_labels)[:-1].max() | |
plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8) | |
cmap = matplotlib.colormaps.get_cmap('terrain') | |
for i in np.unique(obj_labels)[:-1]: | |
obj_idx_all = np.where(obj_labels == i) | |
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean() | |
obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]] | |
obj_name = obj_name.split(',')[0] | |
if islight([c*255 for c in cmap(i / num_objs)[:3]]): | |
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
else: | |
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
plt.axis('off') | |
plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
def seg_prettify(rgb_path, file_name): | |
rgb = plt.imread(rgb_path) | |
seg_labels = plt.imread(file_name) | |
plt.imshow(rgb) | |
seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16) | |
for i in np.unique(seg_labels): | |
seg_map[seg_labels == i] = ade_color[int(i * 255)] | |
plt.imshow(seg_map, alpha=0.8) | |
for i in np.unique(seg_labels): | |
obj_idx_all = np.where(seg_labels == i) | |
if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20 | |
obj_idx = random.randint(0, len(obj_idx_all[0]) - 1) | |
x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx] | |
obj_name = coco_label_map[int(i * 255)] | |
obj_name = obj_name.split(',')[0] | |
if islight(seg_map[int(y), int(x)]): | |
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
else: | |
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
plt.axis('off') | |
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
def ocr_detection_prettify(rgb_path, file_name): | |
if os.path.exists(file_name): | |
rgb = plt.imread(rgb_path) | |
ocr_labels = plt.imread(file_name) | |
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt')) | |
plt.imshow(rgb) | |
plt.imshow(ocr_labels, cmap='gray', alpha=0.8) | |
for i in np.unique(ocr_labels)[:-1]: | |
text_idx_all = np.where(ocr_labels == i) | |
x, y = text_idx_all[1].mean(), text_idx_all[0].mean() | |
text = ocr_labels_dict[int(i * 255)]['text'] | |
plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
plt.axis('off') | |
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
else: | |
rgb = plt.imread(rgb_path) | |
ocr_labels = np.ones_like(rgb, dtype=np.float32()) | |
plt.imshow(rgb) | |
plt.imshow(ocr_labels, cmap='gray', alpha=0.8) | |
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2 | |
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True) | |
plt.axis('off') | |
os.makedirs(os.path.dirname(file_name), exist_ok=True) | |
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
def label_prettify(rgb_path, expert_paths): | |
for expert_path in expert_paths: | |
if 'depth' in expert_path: | |
depth_prettify(expert_path) | |
elif 'seg' in expert_path: | |
seg_prettify(rgb_path, expert_path) | |
elif 'ocr' in expert_path: | |
ocr_detection_prettify(rgb_path, expert_path) | |
elif 'obj' in expert_path: | |
obj_detection_prettify(rgb_path, expert_path) | |