Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import json | |
import torch | |
import random | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from utils import create_ade20k_label_colormap | |
obj_label_map = torch.load('dataset/detection_features.pt')['labels'] | |
coco_label_map = torch.load('dataset/coco_features.pt')['labels'] | |
ade_color = create_ade20k_label_colormap() | |
file_path = 'helpers/images' | |
expert_path = 'helpers/labels' | |
plt.ioff() | |
def get_label_path(file_name, expert_name, with_suffix=False): | |
file_suffix = '.png' if not with_suffix else '_.png' | |
label_name = ''.join(file_name.split('.')[:-1] + [file_suffix]) | |
label_path = os.path.join(expert_path, expert_name, label_name) | |
return label_path | |
def depth_prettify(file_name): | |
label_path = get_label_path(file_name, 'depth') | |
save_path = get_label_path(file_name, 'depth', True) | |
depth = plt.imread(label_path) | |
plt.imsave(save_path, depth, cmap='rainbow') | |
def obj_detection_prettify(file_name): | |
label_path = get_label_path(file_name, 'obj_detection') | |
save_path = get_label_path(file_name, 'obj_detection', True) | |
rgb = plt.imread(file_name) | |
obj_labels = plt.imread(label_path) | |
obj_labels_dict = json.load(open(label_path.replace('.png', '.json'))) | |
plt.imshow(rgb) | |
num_objs = np.unique(obj_labels)[:-1].max() | |
plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.5) | |
for i in np.unique(obj_labels)[:-1]: | |
obj_idx_all = np.where(obj_labels == i) | |
obj_idx = random.randint(0, len(obj_idx_all[0])) | |
x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx] | |
obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]] | |
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center') | |
plt.axis('off') | |
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
def seg_prettify(file_name): | |
label_path = get_label_path(file_name, 'seg_coco') | |
save_path = get_label_path(file_name, 'seg_coco', True) | |
rgb = plt.imread(file_name) | |
seg_labels = plt.imread(label_path) | |
plt.imshow(rgb) | |
seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16) | |
for i in np.unique(seg_labels): | |
seg_map[seg_labels == i] = ade_color[int(i * 255)] | |
plt.imshow(seg_map, alpha=0.5) | |
for i in np.unique(seg_labels): | |
obj_idx_all = np.where(seg_labels == i) | |
obj_idx = random.randint(0, len(obj_idx_all[0])) | |
x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx] | |
obj_name = coco_label_map[int(i * 255)] | |
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center') | |
plt.axis('off') | |
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
def ocr_detection_prettify(file_name): | |
label_path = get_label_path(file_name, 'ocr_detection') | |
save_path = get_label_path(file_name, 'ocr_detection', True) | |
if os.path.exists(label_path): | |
rgb = plt.imread(file_name) | |
ocr_labels = plt.imread(label_path) | |
ocr_labels_dict = torch.load(label_path.replace('.png', '.pt')) | |
plt.imshow(rgb) | |
plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8) | |
for i in np.unique(ocr_labels)[:-1]: | |
text_idx_all = np.where(ocr_labels == i) | |
x, y = text_idx_all[1].mean(), text_idx_all[0].mean() | |
text = ocr_labels_dict[int(i * 255)]['text'] | |
plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center') | |
plt.axis('off') | |
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
else: | |
rgb = plt.imread(file_name) | |
ocr_labels = np.ones_like(rgb, dtype=np.float32()) | |
plt.imshow(rgb) | |
plt.imshow(ocr_labels, cmap='gray', alpha=0.8) | |
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2 | |
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center') | |
plt.axis('off') | |
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0) | |
plt.close() | |
im_list = glob.glob(file_path + '/*.jpg') + glob.glob(file_path + '/*.png') + glob.glob(file_path + '/*.jpeg') | |
# prettify labels first: | |
for i in range(len(im_list)): | |
depth_prettify(im_list[i]) | |
seg_prettify(im_list[i]) | |
ocr_detection_prettify(im_list[i]) | |
obj_detection_prettify(im_list[i]) | |
pretty = {'depth': True, 'normal': False, 'edge': False, | |
'obj_detection': True, 'ocr_detection': True, 'seg_coco': True} | |
# plot expert labels | |
for im_path in im_list: | |
fig, axs = plt.subplots(1, 7, figsize=(20, 4)) | |
rgb = plt.imread(im_path) | |
axs[0].imshow(rgb) | |
axs[0].axis('off') | |
axs[0].set_title('RGB') | |
for j in range(6): | |
label_name = list(pretty.keys())[j] | |
label_path = get_label_path(im_path, label_name, with_suffix=pretty[label_name]) | |
label = plt.imread(label_path) | |
if label_name != 'edge': | |
axs[j + 1].imshow(label) | |
else: | |
axs[j + 1].imshow(label, cmap='gray') | |
axs[j + 1].axis('off') | |
axs[j + 1].set_title(label_name) | |
caption_path = ''.join(im_path.split('.')[:-1] + ['.txt']) | |
with open(caption_path) as f: | |
caption = f.readlines()[0] | |
plt.suptitle(caption) | |
plt.tight_layout() | |
plt.show() | |