shikunl commited on
Commit
45a5416
·
1 Parent(s): d1c3a3a

Update with md5sum and half precision inference

Browse files
app.py CHANGED
@@ -3,18 +3,8 @@
3
  from __future__ import annotations
4
 
5
  import os
6
- import shutil
7
- import subprocess
8
-
9
  import gradio as gr
10
 
11
- if os.getenv('SYSTEM') == 'spaces':
12
- with open('patch') as f:
13
- subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
- shutil.copytree('prismer/helpers/images',
15
- 'prismer/images',
16
- dirs_exist_ok=True)
17
-
18
  from app_caption import create_demo as create_demo_caption
19
  from app_vqa import create_demo as create_demo_vqa
20
  from prismer_model import build_deformable_conv, download_models
@@ -36,7 +26,7 @@ if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
36
  description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
37
 
38
 
39
- with gr.Blocks() as demo:
40
  gr.Markdown(description)
41
  with gr.Tabs():
42
  with gr.TabItem('Zero-shot Image Captioning'):
 
3
  from __future__ import annotations
4
 
5
  import os
 
 
 
6
  import gradio as gr
7
 
 
 
 
 
 
 
 
8
  from app_caption import create_demo as create_demo_caption
9
  from app_vqa import create_demo as create_demo_vqa
10
  from prismer_model import build_deformable_conv, download_models
 
26
  description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
27
 
28
 
29
+ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
30
  gr.Markdown(description)
31
  with gr.Tabs():
32
  with gr.TabItem('Zero-shot Image Captioning'):
app_vqa.py CHANGED
@@ -35,7 +35,7 @@ def create_demo() -> gr.Blocks:
35
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
36
  ex_questions = ['What is the man on the left doing?',
37
  'What is this person doing?',
38
- 'How many cows in this image?',
39
  'What is the type of animal in this image?',
40
  'What toy is it?']
41
  examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
 
35
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
36
  ex_questions = ['What is the man on the left doing?',
37
  'What is this person doing?',
38
+ 'How many cows are in this image?',
39
  'What is the type of animal in this image?',
40
  'What toy is it?']
41
  examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
label_prettify.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import matplotlib.pyplot as plt
6
  import matplotlib
7
  import numpy as np
 
8
 
9
  from prismer.utils import create_ade20k_label_colormap
10
 
@@ -23,101 +24,109 @@ def islight(rgb):
23
 
24
 
25
  def depth_prettify(file_path):
26
- depth = plt.imread(file_path)
27
- plt.imsave(file_path, depth, cmap='rainbow')
 
 
28
 
29
 
30
  def obj_detection_prettify(rgb_path, path_name):
31
- rgb = plt.imread(rgb_path)
32
- obj_labels = plt.imread(path_name)
33
- obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
 
 
34
 
35
- plt.imshow(rgb)
36
 
37
- if len(np.unique(obj_labels)) == 1:
38
- plt.axis('off')
39
- plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
40
- plt.close()
41
- else:
42
- num_objs = np.unique(obj_labels)[:-1].max()
43
- plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
44
- cmap = matplotlib.colormaps.get_cmap('terrain')
45
- for i in np.unique(obj_labels)[:-1]:
46
- obj_idx_all = np.where(obj_labels == i)
47
- x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
48
- obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
49
- obj_name = obj_name.split(',')[0]
50
- if islight([c*255 for c in cmap(i / num_objs)[:3]]):
51
- plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
52
- else:
53
- plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
54
-
55
- plt.axis('off')
56
- plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
57
- plt.close()
58
 
59
 
60
  def seg_prettify(rgb_path, file_name):
61
- rgb = plt.imread(rgb_path)
62
- seg_labels = plt.imread(file_name)
 
 
63
 
64
- plt.imshow(rgb)
65
 
66
- seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
67
- for i in np.unique(seg_labels):
68
- seg_map[seg_labels == i] = ade_color[int(i * 255)]
69
 
70
- plt.imshow(seg_map, alpha=0.8)
71
 
72
- for i in np.unique(seg_labels):
73
- obj_idx_all = np.where(seg_labels == i)
74
- if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20
75
- obj_idx = random.randint(0, len(obj_idx_all[0]) - 1)
76
- x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
77
- obj_name = coco_label_map[int(i * 255)]
78
- obj_name = obj_name.split(',')[0]
79
- if islight(seg_map[int(y), int(x)]):
80
- plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
81
- else:
82
- plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
83
 
84
- plt.axis('off')
85
- plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
86
- plt.close()
87
 
88
 
89
  def ocr_detection_prettify(rgb_path, file_name):
90
- if os.path.exists(file_name):
91
- rgb = plt.imread(rgb_path)
92
- ocr_labels = plt.imread(file_name)
93
- ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
 
 
94
 
95
- plt.imshow(rgb)
96
- plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
97
 
98
- for i in np.unique(ocr_labels)[:-1]:
99
- text_idx_all = np.where(ocr_labels == i)
100
- x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
101
- text = ocr_labels_dict[int(i * 255)]['text']
102
- plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
103
 
104
- plt.axis('off')
105
- plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
106
- plt.close()
107
- else:
108
- rgb = plt.imread(rgb_path)
109
- ocr_labels = np.ones_like(rgb, dtype=np.float32())
110
 
111
- plt.imshow(rgb)
112
- plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
113
 
114
- x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
115
- plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
116
- plt.axis('off')
117
 
118
- os.makedirs(os.path.dirname(file_name), exist_ok=True)
119
- plt.savefig(file_name, bbox_inches='tight', transparent=True, pad_inches=0)
120
- plt.close()
121
 
122
 
123
  def label_prettify(rgb_path, expert_paths):
@@ -130,4 +139,7 @@ def label_prettify(rgb_path, expert_paths):
130
  ocr_detection_prettify(rgb_path, expert_path)
131
  elif 'obj' in expert_path:
132
  obj_detection_prettify(rgb_path, expert_path)
133
-
 
 
 
 
5
  import matplotlib.pyplot as plt
6
  import matplotlib
7
  import numpy as np
8
+ import shutil
9
 
10
  from prismer.utils import create_ade20k_label_colormap
11
 
 
24
 
25
 
26
  def depth_prettify(file_path):
27
+ pretty_path = file_path.replace('.png', '_p.png')
28
+ if not os.path.exists(pretty_path):
29
+ depth = plt.imread(file_path)
30
+ plt.imsave(pretty_path, depth, cmap='rainbow')
31
 
32
 
33
  def obj_detection_prettify(rgb_path, path_name):
34
+ pretty_path = path_name.replace('.png', '_p.png')
35
+ if not os.path.exists(pretty_path):
36
+ rgb = plt.imread(rgb_path)
37
+ obj_labels = plt.imread(path_name)
38
+ obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
39
 
40
+ plt.imshow(rgb)
41
 
42
+ if len(np.unique(obj_labels)) == 1:
43
+ plt.axis('off')
44
+ plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
45
+ plt.close()
46
+ else:
47
+ num_objs = np.unique(obj_labels)[:-1].max()
48
+ plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
49
+ cmap = matplotlib.colormaps.get_cmap('terrain')
50
+ for i in np.unique(obj_labels)[:-1]:
51
+ obj_idx_all = np.where(obj_labels == i)
52
+ x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
53
+ obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
54
+ obj_name = obj_name.split(',')[0]
55
+ if islight([c*255 for c in cmap(i / num_objs)[:3]]):
56
+ plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
57
+ else:
58
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
59
+
60
+ plt.axis('off')
61
+ plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
62
+ plt.close()
63
 
64
 
65
  def seg_prettify(rgb_path, file_name):
66
+ pretty_path = file_name.replace('.png', '_p.png')
67
+ if not os.path.exists(pretty_path):
68
+ rgb = plt.imread(rgb_path)
69
+ seg_labels = plt.imread(file_name)
70
 
71
+ plt.imshow(rgb)
72
 
73
+ seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
74
+ for i in np.unique(seg_labels):
75
+ seg_map[seg_labels == i] = ade_color[int(i * 255)]
76
 
77
+ plt.imshow(seg_map, alpha=0.8)
78
 
79
+ for i in np.unique(seg_labels):
80
+ obj_idx_all = np.where(seg_labels == i)
81
+ if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20
82
+ obj_idx = random.randint(0, len(obj_idx_all[0]) - 1)
83
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
84
+ obj_name = coco_label_map[int(i * 255)]
85
+ obj_name = obj_name.split(',')[0]
86
+ if islight(seg_map[int(y), int(x)]):
87
+ plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
88
+ else:
89
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
90
 
91
+ plt.axis('off')
92
+ plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
93
+ plt.close()
94
 
95
 
96
  def ocr_detection_prettify(rgb_path, file_name):
97
+ pretty_path = file_name.replace('.png', '_p.png')
98
+ if not os.path.exists(pretty_path):
99
+ if os.path.exists(file_name):
100
+ rgb = plt.imread(rgb_path)
101
+ ocr_labels = plt.imread(file_name)
102
+ ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
103
 
104
+ plt.imshow(rgb)
105
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
106
 
107
+ for i in np.unique(ocr_labels)[:-1]:
108
+ text_idx_all = np.where(ocr_labels == i)
109
+ x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
110
+ text = ocr_labels_dict[int(i * 255)]['text']
111
+ plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
112
 
113
+ plt.axis('off')
114
+ plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
115
+ plt.close()
116
+ else:
117
+ rgb = plt.imread(rgb_path)
118
+ ocr_labels = np.ones_like(rgb, dtype=np.float32())
119
 
120
+ plt.imshow(rgb)
121
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
122
 
123
+ x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
124
+ plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
125
+ plt.axis('off')
126
 
127
+ os.makedirs(os.path.dirname(file_name), exist_ok=True)
128
+ plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
129
+ plt.close()
130
 
131
 
132
  def label_prettify(rgb_path, expert_paths):
 
139
  ocr_detection_prettify(rgb_path, expert_path)
140
  elif 'obj' in expert_path:
141
  obj_detection_prettify(rgb_path, expert_path)
142
+ else:
143
+ pretty_path = expert_path.replace('.png', '_p.png')
144
+ if not os.path.exists(pretty_path):
145
+ shutil.copyfile(expert_path, pretty_path)
prismer/configs/experts.yaml CHANGED
@@ -1,2 +1,3 @@
1
- data_path: 'helpers'
2
- save_path: 'helpers/labels'
 
 
1
+ data_path: helpers
2
+ im_name: 87dfaeb4978ce05aa7be5e5b4cc1273a
3
+ save_path: helpers/labels
prismer/dataset/caption_dataset.py CHANGED
@@ -32,10 +32,7 @@ class Caption(Dataset):
32
  elif self.dataset == 'nocaps':
33
  self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
34
  elif self.dataset == 'demo':
35
- data_folders = glob.glob(f'{self.data_path}/*/')
36
- self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
37
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
38
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
39
 
40
  def __len__(self):
41
  return len(self.data_list)
@@ -50,10 +47,11 @@ class Caption(Dataset):
50
  elif self.dataset == 'demo':
51
  img_path_split = self.data_list[index]['image'].split('/')
52
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
- image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
54
 
55
  experts = self.transform(image, labels)
56
  experts = post_label_process(experts, labels_info)
 
57
 
58
  if self.train:
59
  caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
 
32
  elif self.dataset == 'nocaps':
33
  self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
34
  elif self.dataset == 'demo':
35
+ self.data_list = [{'image': f'helpers/images/{config["im_name"]}.jpg'}]
 
 
 
36
 
37
  def __len__(self):
38
  return len(self.data_list)
 
47
  elif self.dataset == 'demo':
48
  img_path_split = self.data_list[index]['image'].split('/')
49
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
50
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
51
 
52
  experts = self.transform(image, labels)
53
  experts = post_label_process(experts, labels_info)
54
+ experts['rgb'] = experts['rgb'].half()
55
 
56
  if self.train:
57
  caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
prismer/dataset/utils.py CHANGED
@@ -5,6 +5,7 @@
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import os
 
8
  import re
9
  import json
10
  import torch
@@ -14,10 +15,12 @@ import torchvision.transforms as transforms
14
  import torchvision.transforms.functional as transforms_f
15
  from dataset.randaugment import RandAugment
16
 
17
- COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
18
- ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
19
- DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
20
- BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
 
 
21
 
22
 
23
  class Transform:
@@ -119,7 +122,8 @@ def post_label_process(inputs, labels_info):
119
  for exp in inputs:
120
  if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
121
  inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
122
-
 
123
  elif exp == 'seg_coco': # in-paint with CLIP features
124
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
125
  for l in inputs[exp].unique():
@@ -127,7 +131,7 @@ def post_label_process(inputs, labels_info):
127
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
128
  else:
129
  text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
130
- inputs[exp] = text_emb
131
 
132
  elif exp == 'seg_ade': # in-paint with CLIP features
133
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
@@ -136,7 +140,7 @@ def post_label_process(inputs, labels_info):
136
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
137
  else:
138
  text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
139
- inputs[exp] = text_emb
140
 
141
  elif exp == 'obj_detection': # in-paint with CLIP features
142
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
@@ -146,7 +150,7 @@ def post_label_process(inputs, labels_info):
146
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
147
  else:
148
  text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
149
- inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}
150
 
151
  elif exp == 'ocr_detection': # in-paint with CLIP features
152
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
@@ -156,7 +160,7 @@ def post_label_process(inputs, labels_info):
156
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
157
  else:
158
  text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
159
- inputs[exp] = text_emb
160
  return inputs
161
 
162
 
 
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import os
8
+ import pathlib
9
  import re
10
  import json
11
  import torch
 
15
  import torchvision.transforms.functional as transforms_f
16
  from dataset.randaugment import RandAugment
17
 
18
+ cur_dir = pathlib.Path(__file__).parent
19
+
20
+ COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
21
+ ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
22
+ DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
23
+ BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
24
 
25
 
26
  class Transform:
 
122
  for exp in inputs:
123
  if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
124
  inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
125
+ inputs[exp] = inputs[exp].half()
126
+
127
  elif exp == 'seg_coco': # in-paint with CLIP features
128
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
129
  for l in inputs[exp].unique():
 
131
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
132
  else:
133
  text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
134
+ inputs[exp] = text_emb.half()
135
 
136
  elif exp == 'seg_ade': # in-paint with CLIP features
137
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
 
140
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
141
  else:
142
  text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
143
+ inputs[exp] = text_emb.half()
144
 
145
  elif exp == 'obj_detection': # in-paint with CLIP features
146
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
 
150
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
151
  else:
152
  text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
153
+ inputs[exp] = {'label': text_emb.half(), 'instance': inputs[exp].half()}
154
 
155
  elif exp == 'ocr_detection': # in-paint with CLIP features
156
  text_emb = torch.empty([64, *inputs[exp].shape[1:]])
 
160
  text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
161
  else:
162
  text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
163
+ inputs[exp] = text_emb.half()
164
  return inputs
165
 
166
 
prismer/experts/depth/generate_dataset.py CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
 
16
  class Dataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
  self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
 
24
  def __len__(self):
25
  return len(self.data_list)
@@ -29,4 +27,4 @@ class Dataset(Dataset):
29
  image = Image.open(image_path).convert('RGB')
30
  img_size = [image.size[0], image.size[1]]
31
  image = self.transform(image)
32
- return image, image_path, img_size
 
14
 
15
 
16
  class Dataset(Dataset):
17
+ def __init__(self, config, transform):
18
+ self.data_path = config['data_path']
19
  self.transform = transform
20
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
21
 
22
  def __len__(self):
23
  return len(self.data_list)
 
27
  image = Image.open(image_path).convert('RGB')
28
  img_size = [image.size[0], image.size[1]]
29
  image = self.transform(image)
30
+ return image.half(), image_path, img_size
prismer/experts/edge/generate_dataset.py CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
 
16
  class Dataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
  self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
 
24
  def __len__(self):
25
  return len(self.data_list)
@@ -29,4 +27,4 @@ class Dataset(Dataset):
29
  image = Image.open(image_path).convert('RGB')
30
  img_size = [image.size[0], image.size[1]]
31
  image = self.transform(image)
32
- return torch.flip(image, dims=(0, )) * 255., image_path, img_size
 
14
 
15
 
16
  class Dataset(Dataset):
17
+ def __init__(self, config, transform):
18
+ self.data_path = config['data_path']
19
  self.transform = transform
20
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
21
 
22
  def __len__(self):
23
  return len(self.data_list)
 
27
  image = Image.open(image_path).convert('RGB')
28
  img_size = [image.size[0], image.size[1]]
29
  image = self.transform(image)
30
+ return torch.flip(image.half(), dims=(0, )) * 255., image_path, img_size
prismer/experts/generate_depth.py CHANGED
@@ -21,11 +21,10 @@ model, transform = load_expert_model(task='depth')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
- data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'depth')
26
 
27
  batch_size = 64
28
- dataset = Dataset(data_path, transform)
29
  data_loader = torch.utils.data.DataLoader(
30
  dataset=dataset,
31
  batch_size=batch_size,
 
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
 
24
  save_path = os.path.join(config['save_path'], 'depth')
25
 
26
  batch_size = 64
27
+ dataset = Dataset(config, transform)
28
  data_loader = torch.utils.data.DataLoader(
29
  dataset=dataset,
30
  batch_size=batch_size,
prismer/experts/generate_edge.py CHANGED
@@ -23,11 +23,10 @@ model, transform = load_expert_model(task='edge')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
- data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'edge')
28
 
29
  batch_size = 64
30
- dataset = Dataset(data_path, transform)
31
  data_loader = torch.utils.data.DataLoader(
32
  dataset=dataset,
33
  batch_size=batch_size,
 
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
 
26
  save_path = os.path.join(config['save_path'], 'edge')
27
 
28
  batch_size = 64
29
+ dataset = Dataset(config, transform)
30
  data_loader = torch.utils.data.DataLoader(
31
  dataset=dataset,
32
  batch_size=batch_size,
prismer/experts/generate_normal.py CHANGED
@@ -23,11 +23,10 @@ model, transform = load_expert_model(task='normal')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
- data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'normal')
28
 
29
  batch_size = 64
30
- dataset = CustomDataset(data_path, transform)
31
  data_loader = torch.utils.data.DataLoader(
32
  dataset=dataset,
33
  batch_size=batch_size,
 
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
 
26
  save_path = os.path.join(config['save_path'], 'normal')
27
 
28
  batch_size = 64
29
+ dataset = CustomDataset(config, transform)
30
  data_loader = torch.utils.data.DataLoader(
31
  dataset=dataset,
32
  batch_size=batch_size,
prismer/experts/generate_objdet.py CHANGED
@@ -26,9 +26,8 @@ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = config['save_path']
28
 
29
- depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1])
30
  batch_size = 32
31
- dataset = Dataset(data_path, depth_path, transform)
32
  data_loader = torch.utils.data.DataLoader(
33
  dataset=dataset,
34
  batch_size=batch_size,
 
26
  data_path = config['data_path']
27
  save_path = config['save_path']
28
 
 
29
  batch_size = 32
30
+ dataset = Dataset(config, transform)
31
  data_loader = torch.utils.data.DataLoader(
32
  dataset=dataset,
33
  batch_size=batch_size,
prismer/experts/generate_ocrdet.py CHANGED
@@ -27,11 +27,10 @@ accelerator = Accelerator(mixed_precision='fp16')
27
  pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
 
29
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
30
- data_path = config['data_path']
31
  save_path = os.path.join(config['save_path'], 'ocr_detection')
32
 
33
  batch_size = 32
34
- dataset = Dataset(data_path, transform)
35
  data_loader = torch.utils.data.DataLoader(
36
  dataset=dataset,
37
  batch_size=batch_size,
 
27
  pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
 
29
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
 
30
  save_path = os.path.join(config['save_path'], 'ocr_detection')
31
 
32
  batch_size = 32
33
+ dataset = Dataset(config, transform)
34
  data_loader = torch.utils.data.DataLoader(
35
  dataset=dataset,
36
  batch_size=batch_size,
prismer/experts/generate_segmentation.py CHANGED
@@ -21,11 +21,10 @@ model, transform = load_expert_model(task='seg_coco')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
- data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'seg_coco')
26
 
27
  batch_size = 4
28
- dataset = Dataset(data_path, transform)
29
  data_loader = torch.utils.data.DataLoader(
30
  dataset=dataset,
31
  batch_size=batch_size,
 
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
  config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
 
24
  save_path = os.path.join(config['save_path'], 'seg_coco')
25
 
26
  batch_size = 4
27
+ dataset = Dataset(config, transform)
28
  data_loader = torch.utils.data.DataLoader(
29
  dataset=dataset,
30
  batch_size=batch_size,
prismer/experts/model_bank.py CHANGED
@@ -131,6 +131,8 @@ def load_expert_model(task=None):
131
  model = None
132
  transform = None
133
 
 
 
134
  model.eval()
135
  return model, transform
136
 
 
131
  model = None
132
  transform = None
133
 
134
+ if 'seg' not in task:
135
+ model = model.half()
136
  model.eval()
137
  return model, transform
138
 
prismer/experts/normal/generate_dataset.py CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
 
16
  class CustomDataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
  self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
 
24
  def __len__(self):
25
  return len(self.data_list)
@@ -29,6 +27,6 @@ class CustomDataset(Dataset):
29
  image = Image.open(image_path).convert('RGB')
30
  img_size = [image.size[0], image.size[1]]
31
  image = self.transform(image)
32
- return image, image_path, img_size
33
 
34
 
 
14
 
15
 
16
  class CustomDataset(Dataset):
17
+ def __init__(self, config, transform):
18
+ self.data_path = config['data_path']
19
  self.transform = transform
20
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
21
 
22
  def __len__(self):
23
  return len(self.data_list)
 
27
  image = Image.open(image_path).convert('RGB')
28
  img_size = [image.size[0], image.size[1]]
29
  image = self.transform(image)
30
+ return image.half(), image_path, img_size
31
 
32
 
prismer/experts/obj_detection/generate_dataset.py CHANGED
@@ -5,6 +5,7 @@
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import glob
 
8
  import torch
9
 
10
  from torch.utils.data import Dataset
@@ -15,13 +16,11 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
 
16
 
17
  class Dataset(Dataset):
18
- def __init__(self, data_path, depth_path, transform):
19
- self.data_path = data_path
20
- self.depth_path = depth_path
21
  self.transform = transform
22
- data_folders = glob.glob(f'{data_path}/*/')
23
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
24
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
25
 
26
  def __len__(self):
27
  return len(self.data_list)
@@ -43,7 +42,7 @@ class Dataset(Dataset):
43
  depth = self.transform(depth)
44
  depth = torch.tensor(np.array(depth)).float() / 255.
45
  img_size = image.shape
46
- return {"image": image, "height": img_size[1], "width": img_size[2],
47
  "true_height": true_img_size[0], "true_width": true_img_size[1],
48
  'image_path': image_path, 'depth': depth}
49
 
 
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import glob
8
+ import os
9
  import torch
10
 
11
  from torch.utils.data import Dataset
 
16
 
17
 
18
  class Dataset(Dataset):
19
+ def __init__(self, config, transform):
20
+ self.data_path = config['data_path']
21
+ self.depth_path = os.path.join(config['save_path'], 'depth', self.data_path.split('/')[-1])
22
  self.transform = transform
23
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
24
 
25
  def __len__(self):
26
  return len(self.data_list)
 
42
  depth = self.transform(depth)
43
  depth = torch.tensor(np.array(depth)).float() / 255.
44
  img_size = image.shape
45
+ return {"image": image.half(), "height": img_size[1], "width": img_size[2],
46
  "true_height": true_img_size[0], "true_width": true_img_size[1],
47
  'image_path': image_path, 'depth': depth}
48
 
prismer/experts/ocr_detection/generate_dataset.py CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
 
16
  class Dataset(Dataset):
17
- def __init__(self, data_path, transform):
18
- self.data_path = data_path
19
  self.transform = transform
20
- data_folders = glob.glob(f'{data_path}/*/')
21
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
 
24
  def __len__(self):
25
  return len(self.data_list)
@@ -30,7 +28,7 @@ class Dataset(Dataset):
30
 
31
  image, scale_w, scale_h, original_w, original_h = resize(original_image)
32
  image = self.transform(image)
33
- return image, image_path, scale_w, scale_h, original_w, original_h
34
 
35
 
36
  def resize(im):
 
14
 
15
 
16
  class Dataset(Dataset):
17
+ def __init__(self, config, transform):
18
+ self.data_path = config['data_path']
19
  self.transform = transform
20
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
21
 
22
  def __len__(self):
23
  return len(self.data_list)
 
28
 
29
  image, scale_w, scale_h, original_w, original_h = resize(original_image)
30
  image = self.transform(image)
31
+ return image.half(), image_path, scale_w, scale_h, original_w, original_h
32
 
33
 
34
  def resize(im):
prismer/experts/segmentation/generate_dataset.py CHANGED
@@ -16,12 +16,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
16
 
17
 
18
  class Dataset(Dataset):
19
- def __init__(self, data_path, transform):
20
- self.data_path = data_path
21
  self.transform = transform
22
- data_folders = glob.glob(f'{data_path}/*/')
23
- self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
24
- self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
25
 
26
  def __len__(self):
27
  return len(self.data_list)
 
16
 
17
 
18
  class Dataset(Dataset):
19
+ def __init__(self, config, transform):
20
+ self.data_path = config['data_path']
21
  self.transform = transform
22
+ self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
 
 
23
 
24
  def __len__(self):
25
  return len(self.data_list)
prismer/helpers/images/COCO_test2015_000000000014.jpg DELETED
Binary file (169 kB)
 
prismer/helpers/images/COCO_test2015_000000000016.jpg DELETED
Binary file (231 kB)
 
prismer/helpers/images/COCO_test2015_000000000019.jpg DELETED
Binary file (285 kB)
 
prismer/helpers/images/COCO_test2015_000000000128.jpg DELETED
Binary file (212 kB)
 
prismer/helpers/images/COCO_test2015_000000000155.jpg DELETED
Binary file (79.7 kB)
 
prismer/helpers/intro.png DELETED
Binary file (405 kB)
 
prismer/model/prismer.py CHANGED
@@ -5,6 +5,7 @@
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
 
8
  import torch.nn as nn
9
 
10
  from model.modules.vit import load_encoder
@@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
12
  from transformers import RobertaTokenizer, RobertaConfig
13
 
14
 
 
 
 
15
  class Prismer(nn.Module):
16
  def __init__(self, config):
17
  super().__init__()
@@ -26,7 +30,7 @@ class Prismer(nn.Module):
26
  elif exp in ['obj_detection', 'ocr_detection']:
27
  self.experts[exp] = 64
28
 
29
- prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
30
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
31
 
32
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
@@ -35,7 +39,7 @@ class Prismer(nn.Module):
35
 
36
  self.prepare_to_train(config['freeze'])
37
  self.ignored_modules = self.get_ignored_modules(config['freeze'])
38
-
39
  def prepare_to_train(self, mode='none'):
40
  for name, params in self.named_parameters():
41
  if mode == 'freeze_lang':
 
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
8
+ import pathlib
9
  import torch.nn as nn
10
 
11
  from model.modules.vit import load_encoder
 
13
  from transformers import RobertaTokenizer, RobertaConfig
14
 
15
 
16
+ cur_dir = pathlib.Path(__file__).parent
17
+
18
+
19
  class Prismer(nn.Module):
20
  def __init__(self, config):
21
  super().__init__()
 
30
  elif exp in ['obj_detection', 'ocr_detection']:
31
  self.experts[exp] = 64
32
 
33
+ prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
34
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
35
 
36
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
 
39
 
40
  self.prepare_to_train(config['freeze'])
41
  self.ignored_modules = self.get_ignored_modules(config['freeze'])
42
+
43
  def prepare_to_train(self, mode='none'):
44
  for name, params in self.named_parameters():
45
  if mode == 'freeze_lang':
prismer_model.py CHANGED
@@ -7,6 +7,12 @@ import shlex
7
  import shutil
8
  import subprocess
9
  import sys
 
 
 
 
 
 
10
 
11
  import cv2
12
  import torch
@@ -55,27 +61,43 @@ def run_expert(expert_name: str):
55
  check=True)
56
 
57
 
58
- def run_experts(image_path: str) -> tuple[str | None, ...]:
59
- helper_dir = submodule_dir / 'helpers'
60
- shutil.rmtree(helper_dir, ignore_errors=True)
61
- image_dir = helper_dir / 'images'
62
- image_dir.mkdir(parents=True, exist_ok=True)
63
- out_path = image_dir / 'image.jpg'
64
- cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
65
 
66
- # expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
67
- # run_expert('depth')
68
- # with concurrent.futures.ProcessPoolExecutor() as executor:
69
- # executor.map(run_expert, expert_names)
70
-
71
- # no parallelization just to be safe
72
- expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
73
- for exp in expert_names:
74
- run_expert(exp)
75
 
 
 
 
76
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
77
- results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
78
- return tuple(path.as_posix() for path in results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  class Model:
@@ -126,20 +148,28 @@ class Model:
126
  len(model.expert_encoder.positional_embedding))
127
 
128
  model.load_state_dict(state_dict)
 
129
  model.eval()
130
 
131
  self.config = config
132
- self.model = model
133
  self.tokenizer = model.tokenizer
134
  self.exp_name = exp_name
135
  self.mode = mode
136
 
137
  @torch.inference_mode()
138
- def run_caption_model(self, exp_name: str) -> str:
139
  self.set_model(exp_name, 'caption')
 
140
  _, test_dataset = create_dataset('caption', self.config)
141
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
142
  experts, _ = next(iter(test_loader))
 
 
 
 
 
 
143
  captions = self.model(experts, train=False, prefix=self.config['prefix'])
144
  captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
145
  caption = captions.to(experts['rgb'].device)[0]
@@ -148,17 +178,23 @@ class Model:
148
  return caption
149
 
150
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
151
- out_paths = run_experts(image_path)
152
- caption = self.run_caption_model(model_name)
153
- label_prettify(image_path, out_paths)
154
- return caption, *out_paths
155
 
156
  @torch.inference_mode()
157
- def run_vqa_model(self, exp_name: str, question: str) -> str:
158
  self.set_model(exp_name, 'vqa')
 
159
  _, test_dataset = create_dataset('caption', self.config)
160
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
161
  experts, _ = next(iter(test_loader))
 
 
 
 
 
 
162
  question = pre_question(question)
163
  answer = self.model(experts, [question], train=False, inference='generate')
164
  answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
@@ -168,7 +204,6 @@ class Model:
168
  return answer
169
 
170
  def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
171
- out_paths = run_experts(image_path)
172
- answer = self.run_vqa_model(model_name, question)
173
- label_prettify(image_path, out_paths)
174
- return answer, *out_paths
 
7
  import shutil
8
  import subprocess
9
  import sys
10
+ import hashlib
11
+ from typing import Tuple
12
+ try:
13
+ import ruamel_yaml as yaml
14
+ except ModuleNotFoundError:
15
+ import ruamel.yaml as yaml
16
 
17
  import cv2
18
  import torch
 
61
  check=True)
62
 
63
 
64
+ def compute_md5(image_path: str) -> str:
65
+ with open(image_path, 'rb') as f:
66
+ s = f.read()
67
+ return hashlib.md5(s).hexdigest()
 
 
 
68
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def run_experts(image_path: str) -> Tuple[str, Tuple[str, ...]]:
71
+ im_name = compute_md5(image_path)
72
+ out_path = submodule_dir / 'helpers' / 'images' / f'{im_name}.jpg'
73
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
74
+ results = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}.png' for key in keys]
75
+ results_pretty = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}_p.png' for key in keys]
76
+ out_paths = tuple(path.as_posix() for path in results)
77
+ pretty_paths = tuple(path.as_posix() for path in results_pretty)
78
+
79
+ config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
80
+ config['im_name'] = im_name
81
+ with open('prismer/configs/experts.yaml', 'w') as yaml_file:
82
+ yaml.dump(config, yaml_file, default_flow_style=False)
83
+
84
+ if not os.path.exists(out_paths[0]):
85
+ cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
86
+
87
+ # paralleled inference
88
+ expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
89
+ run_expert('depth')
90
+ with concurrent.futures.ProcessPoolExecutor() as executor:
91
+ executor.map(run_expert, expert_names)
92
+ executor.shutdown(wait=True)
93
+
94
+ # no parallelization just to be safe
95
+ # expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
96
+ # for exp in expert_names:
97
+ # run_expert(exp)
98
+
99
+ label_prettify(image_path, out_paths)
100
+ return im_name, pretty_paths
101
 
102
 
103
  class Model:
 
148
  len(model.expert_encoder.positional_embedding))
149
 
150
  model.load_state_dict(state_dict)
151
+ model = model.half()
152
  model.eval()
153
 
154
  self.config = config
155
+ self.model = model.to('cuda:0')
156
  self.tokenizer = model.tokenizer
157
  self.exp_name = exp_name
158
  self.mode = mode
159
 
160
  @torch.inference_mode()
161
+ def run_caption_model(self, exp_name: str, im_name: str) -> str:
162
  self.set_model(exp_name, 'caption')
163
+ self.config['im_name'] = im_name
164
  _, test_dataset = create_dataset('caption', self.config)
165
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
166
  experts, _ = next(iter(test_loader))
167
+ for exp in experts:
168
+ if exp == 'obj_detection':
169
+ experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
170
+ experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
171
+ else:
172
+ experts[exp] = experts[exp].to('cuda:0')
173
  captions = self.model(experts, train=False, prefix=self.config['prefix'])
174
  captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
175
  caption = captions.to(experts['rgb'].device)[0]
 
178
  return caption
179
 
180
  def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
181
+ im_name, pretty_paths = run_experts(image_path)
182
+ caption = self.run_caption_model(model_name, im_name)
183
+ return caption, *pretty_paths
 
184
 
185
  @torch.inference_mode()
186
+ def run_vqa_model(self, exp_name: str, im_name: str, question: str) -> str:
187
  self.set_model(exp_name, 'vqa')
188
+ self.config['im_name'] = im_name
189
  _, test_dataset = create_dataset('caption', self.config)
190
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
191
  experts, _ = next(iter(test_loader))
192
+ for exp in experts:
193
+ if exp == 'obj_detection':
194
+ experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
195
+ experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
196
+ else:
197
+ experts[exp] = experts[exp].to('cuda:0')
198
  question = pre_question(question)
199
  answer = self.model(experts, [question], train=False, inference='generate')
200
  answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
 
204
  return answer
205
 
206
  def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
207
+ im_name, pretty_paths = run_experts(image_path)
208
+ answer = self.run_vqa_model(model_name, im_name, question)
209
+ return answer, *pretty_paths
 
requirements.txt CHANGED
@@ -6,7 +6,7 @@ fire==0.5.0
6
  geffnet==1.0.2
7
  git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13
8
  git+https://github.com/openai/CLIP.git@a9b1bf5
9
- gradio==3.20.1
10
  huggingface-hub==0.12.1
11
  opencv-python-headless==4.7.0.72
12
  pyclipper==1.3.0.post4
 
6
  geffnet==1.0.2
7
  git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13
8
  git+https://github.com/openai/CLIP.git@a9b1bf5
9
+ gradio==3.24.1
10
  huggingface-hub==0.12.1
11
  opencv-python-headless==4.7.0.72
12
  pyclipper==1.3.0.post4