shikunl commited on
Commit
ad3ee60
β€’
1 Parent(s): 1dd8a60

Add label prettify

Browse files
Files changed (3) hide show
  1. label_prettify.py +126 -0
  2. prismer_model.py +3 -1
  3. requirements.txt +1 -0
label_prettify.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import numpy as np
9
+
10
+ from prismer.utils import create_ade20k_label_colormap
11
+
12
+ obj_label_map = torch.load('prismer/dataset/detection_features.pt')['labels']
13
+ coco_label_map = torch.load('prismer/dataset/coco_features.pt')['labels']
14
+ ade_color = create_ade20k_label_colormap()
15
+
16
+
17
+ def islight(rgb):
18
+ r, g, b = rgb
19
+ hsp = np.sqrt(0.299 * (r * r) + 0.587 * (g * g) + 0.114 * (b * b))
20
+ if hsp > 127.5:
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ def depth_prettify(file_path):
27
+ depth = plt.imread(file_path)
28
+ plt.imsave(file_path, depth, cmap='rainbow')
29
+
30
+
31
+ def obj_detection_prettify(rgb_path, path_name):
32
+ rgb = plt.imread(rgb_path)
33
+ obj_labels = plt.imread(path_name)
34
+ obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
35
+
36
+ plt.imshow(rgb)
37
+
38
+ num_objs = np.unique(obj_labels)[:-1].max()
39
+ plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
40
+ cmap = matplotlib.colormaps.get_cmap('terrain')
41
+ for i in np.unique(obj_labels)[:-1]:
42
+ obj_idx_all = np.where(obj_labels == i)
43
+ x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
44
+ obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
45
+ obj_name = obj_name.split(',')[0]
46
+ if islight([c*255 for c in cmap(i / num_objs)[:3]]):
47
+ plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
48
+ else:
49
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
50
+
51
+ plt.axis('off')
52
+ plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
53
+ plt.close()
54
+
55
+
56
+ def seg_prettify(rgb_path, file_name):
57
+ rgb = plt.imread(rgb_path)
58
+ seg_labels = plt.imread(file_name)
59
+
60
+ plt.imshow(rgb)
61
+
62
+ seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
63
+ for i in np.unique(seg_labels):
64
+ seg_map[seg_labels == i] = ade_color[int(i * 255)]
65
+
66
+ plt.imshow(seg_map, alpha=0.8)
67
+
68
+ for i in np.unique(seg_labels):
69
+ obj_idx_all = np.where(seg_labels == i)
70
+ x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
71
+ obj_name = coco_label_map[int(i * 255)]
72
+ obj_name = obj_name.split(',')[0]
73
+ if islight(seg_map[int(y), int(x)]):
74
+ plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
75
+ else:
76
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
77
+
78
+ plt.axis('off')
79
+ plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
80
+ plt.close()
81
+
82
+
83
+ def ocr_detection_prettify(rgb_path, file_name):
84
+ if os.path.exists(file_name):
85
+ rgb = plt.imread(rgb_path)
86
+ ocr_labels = plt.imread(file_name)
87
+ ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
88
+
89
+ plt.imshow(rgb)
90
+ plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
91
+
92
+ for i in np.unique(ocr_labels)[:-1]:
93
+ text_idx_all = np.where(ocr_labels == i)
94
+ x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
95
+ text = ocr_labels_dict[int(i * 255)]['text']
96
+ plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
97
+
98
+ plt.axis('off')
99
+ plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
100
+ plt.close()
101
+ else:
102
+ rgb = plt.imread(rgb_path)
103
+ ocr_labels = np.ones_like(rgb, dtype=np.float32())
104
+
105
+ plt.imshow(rgb)
106
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
107
+
108
+ x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
109
+ plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
110
+
111
+ plt.axis('off')
112
+ plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
113
+ plt.close()
114
+
115
+
116
+ def label_prettify(rgb_path, expert_paths):
117
+ for expert_path in expert_paths:
118
+ if 'depth' in expert_path:
119
+ depth_prettify(expert_path)
120
+ elif 'seg' in expert_path:
121
+ seg_prettify(rgb_path, expert_path)
122
+ elif 'ocr' in expert_path:
123
+ ocr_detection_prettify(rgb_path, expert_path)
124
+ elif 'obj' in expert_path:
125
+ obj_detection_prettify(rgb_path, expert_path)
126
+
prismer_model.py CHANGED
@@ -9,6 +9,7 @@ import sys
9
 
10
  import cv2
11
  import torch
 
12
 
13
  repo_dir = pathlib.Path(__file__).parent
14
  submodule_dir = repo_dir / 'prismer'
@@ -53,7 +54,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
53
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
54
  else:
55
  env['PYTHONPATH'] = submodule_dir.as_posix()
56
- subprocess.run(shlex.split(f'accelerate experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
57
 
58
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
59
  results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
@@ -108,4 +109,5 @@ class Model:
108
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
109
  out_paths = run_experts(image_path)
110
  # caption = self.run_caption_model(model_name)
 
111
  return None, *out_paths
 
9
 
10
  import cv2
11
  import torch
12
+ from label_prettify import label_prettify
13
 
14
  repo_dir = pathlib.Path(__file__).parent
15
  submodule_dir = repo_dir / 'prismer'
 
54
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
55
  else:
56
  env['PYTHONPATH'] = submodule_dir.as_posix()
57
+ subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
58
 
59
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
60
  results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
 
109
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
110
  out_paths = run_experts(image_path)
111
  # caption = self.run_caption_model(model_name)
112
+ label_prettify(image_path, out_paths)
113
  return None, *out_paths
requirements.txt CHANGED
@@ -21,3 +21,4 @@ torch==1.13.1
21
  torchvision==0.14.1
22
  transformers==4.26.1
23
  yacs==0.1.8
 
 
21
  torchvision==0.14.1
22
  transformers==4.26.1
23
  yacs==0.1.8
24
+ matplotlib=3.7.0