Spaces:
Runtime error
Runtime error
Add label prettify
Browse files- label_prettify.py +126 -0
- prismer_model.py +3 -1
- requirements.txt +1 -0
label_prettify.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from prismer.utils import create_ade20k_label_colormap
|
11 |
+
|
12 |
+
obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels']
|
13 |
+
coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels']
|
14 |
+
ade_color = create_ade20k_label_colormap()
|
15 |
+
|
16 |
+
|
17 |
+
def islight(rgb):
|
18 |
+
r, g, b = rgb
|
19 |
+
hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b))
|
20 |
+
if hsp > 127.5:
|
21 |
+
return True
|
22 |
+
else:
|
23 |
+
return False
|
24 |
+
|
25 |
+
|
26 |
+
def depth_prettify(file_path):
|
27 |
+
depth = plt.imread(file_path)
|
28 |
+
plt.imsave(file_path, depth, cmap='rainbow')
|
29 |
+
|
30 |
+
|
31 |
+
def obj_detection_prettify(rgb_path, path_name):
|
32 |
+
rgb = plt.imread(rgb_path)
|
33 |
+
obj_labels = plt.imread(path_name)
|
34 |
+
obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
|
35 |
+
|
36 |
+
plt.imshow(rgb)
|
37 |
+
|
38 |
+
num_objs = np.unique(obj_labels)[:-1].max()
|
39 |
+
plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
|
40 |
+
cmap = matplotlib.colormaps.get_cmap('terrain')
|
41 |
+
for i in np.unique(obj_labels)[:-1]:
|
42 |
+
obj_idx_all = np.where(obj_labels == i)
|
43 |
+
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
|
44 |
+
obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
|
45 |
+
obj_name = obj_name.split(',')[0]
|
46 |
+
if islight([c*255 for c in cmap(i / num_objs)[:3]]):
|
47 |
+
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
48 |
+
else:
|
49 |
+
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
50 |
+
|
51 |
+
plt.axis('off')
|
52 |
+
plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
53 |
+
plt.close()
|
54 |
+
|
55 |
+
|
56 |
+
def seg_prettify(rgb_path, file_name):
|
57 |
+
rgb = plt.imread(rgb_path)
|
58 |
+
seg_labels = plt.imread(file_name)
|
59 |
+
|
60 |
+
plt.imshow(rgb)
|
61 |
+
|
62 |
+
seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
|
63 |
+
for i in np.unique(seg_labels):
|
64 |
+
seg_map[seg_labels == i] = ade_color[int(i * 255)]
|
65 |
+
|
66 |
+
plt.imshow(seg_map, alpha=0.8)
|
67 |
+
|
68 |
+
for i in np.unique(seg_labels):
|
69 |
+
obj_idx_all = np.where(seg_labels == i)
|
70 |
+
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
|
71 |
+
obj_name = coco_label_map[int(i * 255)]
|
72 |
+
obj_name = obj_name.split(',')[0]
|
73 |
+
if islight(seg_map[int(y), int(x)]):
|
74 |
+
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
75 |
+
else:
|
76 |
+
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
77 |
+
|
78 |
+
plt.axis('off')
|
79 |
+
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
80 |
+
plt.close()
|
81 |
+
|
82 |
+
|
83 |
+
def ocr_detection_prettify(rgb_path, file_name):
|
84 |
+
if os.path.exists(file_name):
|
85 |
+
rgb = plt.imread(rgb_path)
|
86 |
+
ocr_labels = plt.imread(file_name)
|
87 |
+
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
88 |
+
|
89 |
+
plt.imshow(rgb)
|
90 |
+
plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
|
91 |
+
|
92 |
+
for i in np.unique(ocr_labels)[:-1]:
|
93 |
+
text_idx_all = np.where(ocr_labels == i)
|
94 |
+
x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
|
95 |
+
text = ocr_labels_dict[int(i * 255)]['text']
|
96 |
+
plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
97 |
+
|
98 |
+
plt.axis('off')
|
99 |
+
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
100 |
+
plt.close()
|
101 |
+
else:
|
102 |
+
rgb = plt.imread(rgb_path)
|
103 |
+
ocr_labels = np.ones_like(rgb, dtype=np.float32())
|
104 |
+
|
105 |
+
plt.imshow(rgb)
|
106 |
+
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
|
107 |
+
|
108 |
+
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
|
109 |
+
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
110 |
+
|
111 |
+
plt.axis('off')
|
112 |
+
plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
113 |
+
plt.close()
|
114 |
+
|
115 |
+
|
116 |
+
def label_prettify(rgb_path, expert_paths):
|
117 |
+
for expert_path in expert_paths:
|
118 |
+
if 'depth' in expert_path:
|
119 |
+
depth_prettify(expert_path)
|
120 |
+
elif 'seg' in expert_path:
|
121 |
+
seg_prettify(rgb_path, expert_path)
|
122 |
+
elif 'ocr' in expert_path:
|
123 |
+
ocr_detection_prettify(rgb_path, expert_path)
|
124 |
+
elif 'obj' in expert_path:
|
125 |
+
obj_detection_prettify(rgb_path, expert_path)
|
126 |
+
|
prismer_model.py
CHANGED
@@ -9,6 +9,7 @@ import sys
|
|
9 |
|
10 |
import cv2
|
11 |
import torch
|
|
|
12 |
|
13 |
repo_dir = pathlib.Path(__file__).parent
|
14 |
submodule_dir = repo_dir / 'prismer'
|
@@ -53,7 +54,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
53 |
env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
|
54 |
else:
|
55 |
env['PYTHONPATH'] = submodule_dir.as_posix()
|
56 |
-
subprocess.run(shlex.split(f'
|
57 |
|
58 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
59 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
@@ -108,4 +109,5 @@ class Model:
|
|
108 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
109 |
out_paths = run_experts(image_path)
|
110 |
# caption = self.run_caption_model(model_name)
|
|
|
111 |
return None, *out_paths
|
|
|
9 |
|
10 |
import cv2
|
11 |
import torch
|
12 |
+
from label_prettify import label_prettify
|
13 |
|
14 |
repo_dir = pathlib.Path(__file__).parent
|
15 |
submodule_dir = repo_dir / 'prismer'
|
|
|
54 |
env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
|
55 |
else:
|
56 |
env['PYTHONPATH'] = submodule_dir.as_posix()
|
57 |
+
subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
|
58 |
|
59 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
60 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
|
|
109 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
110 |
out_paths = run_experts(image_path)
|
111 |
# caption = self.run_caption_model(model_name)
|
112 |
+
label_prettify(image_path, out_paths)
|
113 |
return None, *out_paths
|
requirements.txt
CHANGED
@@ -21,3 +21,4 @@ torch==1.13.1
|
|
21 |
torchvision==0.14.1
|
22 |
transformers==4.26.1
|
23 |
yacs==0.1.8
|
|
|
|
21 |
torchvision==0.14.1
|
22 |
transformers==4.26.1
|
23 |
yacs==0.1.8
|
24 |
+
matplotlib=3.7.0
|