shikunl commited on
Commit
087df0e
β€’
1 Parent(s): 61a0078

Add prismer

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. prismer/dataset/__init__.py +32 -0
  2. prismer/dataset/ade_features.pt +3 -0
  3. prismer/dataset/background_features.pt +3 -0
  4. prismer/dataset/caption_dataset.py +63 -0
  5. prismer/dataset/clip_pca.pkl +3 -0
  6. prismer/dataset/coco_features.pt +3 -0
  7. prismer/dataset/detection_features.pt +3 -0
  8. prismer/dataset/randaugment.py +265 -0
  9. prismer/dataset/utils.py +192 -0
  10. prismer/dataset/vqa_dataset.py +53 -0
  11. prismer/demo.py +77 -0
  12. prismer/demo_vis.py +161 -0
  13. prismer/download_checkpoints.py +124 -0
  14. prismer/experts/depth/base_model.py +16 -0
  15. prismer/experts/depth/blocks.py +383 -0
  16. prismer/experts/depth/generate_dataset.py +32 -0
  17. prismer/experts/depth/models.py +124 -0
  18. prismer/experts/depth/vit.py +576 -0
  19. prismer/experts/edge/generate_dataset.py +32 -0
  20. prismer/experts/edge/images.py +50 -0
  21. prismer/experts/edge/model.py +286 -0
  22. prismer/experts/generate_depth.py +56 -0
  23. prismer/experts/generate_edge.py +57 -0
  24. prismer/experts/generate_normal.py +58 -0
  25. prismer/experts/generate_objdet.py +115 -0
  26. prismer/experts/generate_ocrdet.py +86 -0
  27. prismer/experts/generate_segmentation.py +56 -0
  28. prismer/experts/model_bank.py +139 -0
  29. prismer/experts/normal/generate_dataset.py +34 -0
  30. prismer/experts/normal/models/NNET.py +22 -0
  31. prismer/experts/normal/models/baseline.py +85 -0
  32. prismer/experts/normal/models/submodules/decoder.py +202 -0
  33. prismer/experts/normal/models/submodules/encoder.py +32 -0
  34. prismer/experts/normal/models/submodules/submodules.py +140 -0
  35. prismer/experts/normal/utils/losses.py +178 -0
  36. prismer/experts/normal/utils/utils.py +191 -0
  37. prismer/experts/obj_detection/configs/Base-CRCNN-COCO.yaml +48 -0
  38. prismer/experts/obj_detection/configs/O365_CRFR50_CAS_2x.yaml +15 -0
  39. prismer/experts/obj_detection/configs/OID_CRFR50_CAS_2x.yaml +22 -0
  40. prismer/experts/obj_detection/configs/Partitioned_COIM_R50_6x+2x.yaml +28 -0
  41. prismer/experts/obj_detection/configs/Partitioned_COI_R50_2x.yaml +29 -0
  42. prismer/experts/obj_detection/configs/Partitioned_COI_R50_6x.yaml +28 -0
  43. prismer/experts/obj_detection/configs/Partitioned_COI_R50_8x.yaml +28 -0
  44. prismer/experts/obj_detection/configs/Partitioned_COI_RS101_2x.yaml +46 -0
  45. prismer/experts/obj_detection/configs/Unified_human_OCI_R50_2x.yaml +29 -0
  46. prismer/experts/obj_detection/configs/Unified_learned_OCIM_R50_6x+2x.yaml +35 -0
  47. prismer/experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml +46 -0
  48. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_2x.yaml +29 -0
  49. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_6x.yaml +29 -0
  50. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_8x.yaml +29 -0
prismer/dataset/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from dataset.vqa_dataset import VQA
10
+ from dataset.caption_dataset import Caption
11
+
12
+
13
+ def create_dataset(dataset, config):
14
+ if dataset == 'vqa':
15
+ train_dataset = VQA(config, train=True)
16
+ test_dataset = VQA(config, train=False)
17
+ return train_dataset, test_dataset
18
+
19
+ elif dataset == 'caption':
20
+ train_dataset = Caption(config, train=True)
21
+ test_dataset = Caption(config, train=False)
22
+ return train_dataset, test_dataset
23
+
24
+
25
+ def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
26
+ data_loader = DataLoader(dataset,
27
+ batch_size=batch_size,
28
+ num_workers=num_workers,
29
+ collate_fn=collate_fn,
30
+ shuffle=True if train else False,
31
+ drop_last=True if train else False)
32
+ return data_loader
prismer/dataset/ade_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb5c802ef9eaf9c9213219bdc5f8f2d6579400275ab443c43905fdcdd805e5a
3
+ size 41899
prismer/dataset/background_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f5bf8e158715633a9e0f307be1947ca74c294e522b705e1e61bd49361088092
3
+ size 1003
prismer/dataset/caption_dataset.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from dataset.utils import *
11
+ from PIL import ImageFile
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+
15
+ class Caption(Dataset):
16
+ def __init__(self, config, train=True):
17
+ self.data_path = config['data_path']
18
+ self.label_path = config['label_path']
19
+ self.experts = config['experts']
20
+ self.prefix = config['prefix']
21
+ self.dataset = config['dataset']
22
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
23
+ self.train = train
24
+
25
+ if train:
26
+ self.data_list = []
27
+ if self.dataset in ['coco', 'nocaps']:
28
+ self.data_list += json.load(open(os.path.join(self.data_path, 'coco_karpathy_train.json'), 'r'))
29
+ else:
30
+ if self.dataset == 'coco':
31
+ self.data_list = json.load(open(os.path.join(self.data_path, 'coco_karpathy_test.json'), 'r'))
32
+ elif self.dataset == 'nocaps':
33
+ self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
34
+ elif self.dataset == 'demo':
35
+ data_folders = glob.glob(f'{self.data_path}/*/')
36
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
37
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
38
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
39
+
40
+ def __len__(self):
41
+ return len(self.data_list)
42
+
43
+ def __getitem__(self, index):
44
+ data = self.data_list[index]
45
+
46
+ if self.dataset == 'coco':
47
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
48
+ elif self.dataset == 'nocaps':
49
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'nocaps', self.experts)
50
+ elif self.dataset == 'demo':
51
+ img_path_split = self.data_list[index]['image'].split('/')
52
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
54
+
55
+ experts = self.transform(image, labels)
56
+ experts = post_label_process(experts, labels_info)
57
+
58
+ if self.train:
59
+ caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
60
+ return experts, caption
61
+ else:
62
+ return experts, index
63
+
prismer/dataset/clip_pca.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:064950869bce7d622ff7e733dc64c8e1be18be926c3f64e66ae57164df3f8bc5
3
+ size 401612
prismer/dataset/coco_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf18221afe8dddef3ffb9daad31d5c7a92cdc2f2f434d77cbeb48031bc75756
3
+ size 36651
prismer/dataset/detection_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c38ba9352b2a9f832b14fdc19ac407527ffeaa2903958a73f6eb649f78119c76
3
+ size 198443
prismer/dataset/randaugment.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+
11
+ from PIL import Image, ImageOps, ImageEnhance, ImageDraw
12
+
13
+
14
+ fillmask = {'depth': 0, 'normal': 0, 'edge': 0, 'seg_coco': 255, 'seg_ade': 255,
15
+ 'obj_detection': 255, 'ocr_detection': 255}
16
+ fillcolor = (0, 0, 0)
17
+
18
+
19
+ def affine_transform(pair, affine_params):
20
+ img, label = pair
21
+ img = img.transform(img.size, Image.AFFINE, affine_params,
22
+ resample=Image.BILINEAR, fillcolor=fillcolor)
23
+ if label is not None:
24
+ for exp in label:
25
+ label[exp] = label[exp].transform(label[exp].size, Image.AFFINE, affine_params,
26
+ resample=Image.NEAREST, fillcolor=fillmask[exp])
27
+ return img, label
28
+
29
+
30
+ def ShearX(pair, v): # [-0.3, 0.3]
31
+ assert -0.3 <= v <= 0.3
32
+ if random.random() > 0.5:
33
+ v = -v
34
+ return affine_transform(pair, (1, v, 0, 0, 1, 0))
35
+
36
+
37
+ def ShearY(pair, v): # [-0.3, 0.3]
38
+ assert -0.3 <= v <= 0.3
39
+ if random.random() > 0.5:
40
+ v = -v
41
+ return affine_transform(pair, (1, 0, 0, v, 1, 0))
42
+
43
+
44
+ def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
45
+ assert -0.45 <= v <= 0.45
46
+ if random.random() > 0.5:
47
+ v = -v
48
+ img, _ = pair
49
+ v = v * img.size[0]
50
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
51
+
52
+
53
+ def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
54
+ assert -0.45 <= v <= 0.45
55
+ if random.random() > 0.5:
56
+ v = -v
57
+ img, _ = pair
58
+ v = v * img.size[1]
59
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
60
+
61
+
62
+ def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
63
+ assert 0 <= v <= 10
64
+ if random.random() > 0.5:
65
+ v = -v
66
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
67
+
68
+
69
+ def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
70
+ assert 0 <= v <= 10
71
+ if random.random() > 0.5:
72
+ v = -v
73
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
74
+
75
+
76
+ def Rotate(pair, v): # [-30, 30]
77
+ assert -30 <= v <= 30
78
+ if random.random() > 0.5:
79
+ v = -v
80
+ img, label = pair
81
+ img = img.rotate(v, fillcolor=fillcolor)
82
+ if label is not None:
83
+ for exp in label:
84
+ label[exp] = label[exp].rotate(v, resample=Image.NEAREST, fillcolor=fillmask[exp])
85
+ return img, label
86
+
87
+
88
+ def AutoContrast(pair, _):
89
+ img, label = pair
90
+ return ImageOps.autocontrast(img), label
91
+
92
+
93
+ def Invert(pair, _):
94
+ img, label = pair
95
+ return ImageOps.invert(img), label
96
+
97
+
98
+ def Equalize(pair, _):
99
+ img, label = pair
100
+ return ImageOps.equalize(img), label
101
+
102
+
103
+ def Flip(pair, _): # not from the paper
104
+ img, label = pair
105
+ return ImageOps.mirror(img), ImageOps.mirror(label)
106
+
107
+
108
+ def Solarize(pair, v): # [0, 256]
109
+ img, label = pair
110
+ assert 0 <= v <= 256
111
+ return ImageOps.solarize(img, v), label
112
+
113
+
114
+ def Posterize(pair, v): # [4, 8]
115
+ img, label = pair
116
+ assert 4 <= v <= 8
117
+ v = int(v)
118
+ return ImageOps.posterize(img, v), label
119
+
120
+
121
+ def Posterize2(pair, v): # [0, 4]
122
+ img, label = pair
123
+ assert 0 <= v <= 4
124
+ v = int(v)
125
+ return ImageOps.posterize(img, v), label
126
+
127
+
128
+ def Contrast(pair, v): # [0.1,1.9]
129
+ img, label = pair
130
+ assert 0.1 <= v <= 1.9
131
+ return ImageEnhance.Contrast(img).enhance(v), label
132
+
133
+
134
+ def Color(pair, v): # [0.1,1.9]
135
+ img, label = pair
136
+ assert 0.1 <= v <= 1.9
137
+ return ImageEnhance.Color(img).enhance(v), label
138
+
139
+
140
+ def Brightness(pair, v): # [0.1,1.9]
141
+ img, label = pair
142
+ assert 0.1 <= v <= 1.9
143
+ return ImageEnhance.Brightness(img).enhance(v), label
144
+
145
+
146
+ def Sharpness(pair, v): # [0.1,1.9]
147
+ img, label = pair
148
+ assert 0.1 <= v <= 1.9
149
+ return ImageEnhance.Sharpness(img).enhance(v), label
150
+
151
+
152
+ def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2]
153
+ assert 0.0 <= v <= 0.2
154
+ if v <= 0.:
155
+ return pair
156
+ img, label = pair
157
+ v = v * img.size[0]
158
+ return CutoutAbs(img, v), label
159
+
160
+
161
+ def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
162
+ # assert 0 <= v <= 20
163
+ if v < 0:
164
+ return img
165
+ w, h = img.size
166
+ x0 = np.random.uniform(w)
167
+ y0 = np.random.uniform(h)
168
+
169
+ x0 = int(max(0, x0 - v / 2.))
170
+ y0 = int(max(0, y0 - v / 2.))
171
+ x1 = min(w, x0 + v)
172
+ y1 = min(h, y0 + v)
173
+
174
+ xy = (x0, y0, x1, y1)
175
+ color = (125, 123, 114)
176
+ # color = (0, 0, 0)
177
+ img = img.copy()
178
+ ImageDraw.Draw(img).rectangle(xy, color)
179
+ return img
180
+
181
+
182
+ def Identity(pair, v):
183
+ return pair
184
+
185
+
186
+ def augment_list(): # 16 oeprations and their ranges
187
+ # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
188
+ l = [
189
+ (Identity, 0., 1.0),
190
+ (ShearX, 0., 0.3), # 0
191
+ (ShearY, 0., 0.3), # 1
192
+ (TranslateX, 0., 0.33), # 2
193
+ (TranslateY, 0., 0.33), # 3
194
+ (Rotate, 0, 30), # 4
195
+ (AutoContrast, 0, 1), # 5
196
+ # (Invert, 0, 1), # 6
197
+ (Equalize, 0, 1), # 7
198
+ # (Solarize, 0, 110), # 8
199
+ # (Posterize, 4, 8), # 9
200
+ # (Color, 0.1, 1.9), # 11
201
+ (Brightness, 0.1, 1.9), # 12
202
+ (Sharpness, 0.1, 1.9), # 13
203
+ ]
204
+ return l
205
+
206
+
207
+ class Lighting(object):
208
+ """Lighting noise(AlexNet - style PCA - based noise)"""
209
+
210
+ def __init__(self, alphastd, eigval, eigvec):
211
+ self.alphastd = alphastd
212
+ self.eigval = torch.Tensor(eigval)
213
+ self.eigvec = torch.Tensor(eigvec)
214
+
215
+ def __call__(self, img):
216
+ if self.alphastd == 0:
217
+ return img
218
+
219
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
220
+ rgb = self.eigvec.type_as(img).clone() \
221
+ .mul(alpha.view(1, 3).expand(3, 3)) \
222
+ .mul(self.eigval.view(1, 3).expand(3, 3)) \
223
+ .sum(1).squeeze()
224
+
225
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
226
+
227
+
228
+ class CutoutDefault(object):
229
+ """
230
+ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
231
+ """
232
+ def __init__(self, length):
233
+ self.length = length
234
+
235
+ def __call__(self, img):
236
+ h, w = img.size(1), img.size(2)
237
+ mask = np.ones((h, w), np.float32)
238
+ y = np.random.randint(h)
239
+ x = np.random.randint(w)
240
+
241
+ y1 = np.clip(y - self.length // 2, 0, h)
242
+ y2 = np.clip(y + self.length // 2, 0, h)
243
+ x1 = np.clip(x - self.length // 2, 0, w)
244
+ x2 = np.clip(x + self.length // 2, 0, w)
245
+
246
+ mask[y1: y2, x1: x2] = 0.
247
+ mask = torch.from_numpy(mask)
248
+ mask = mask.expand_as(img)
249
+ img *= mask
250
+ return img
251
+
252
+
253
+ class RandAugment:
254
+ def __init__(self, n, m):
255
+ self.n = n
256
+ self.m = m # [0, 10]
257
+ self.augment_list = augment_list()
258
+
259
+ def __call__(self, img, label):
260
+ pair = img, label
261
+ ops = random.choices(self.augment_list, k=self.n)
262
+ for op, minval, maxval in ops:
263
+ val = (float(self.m) / 10) * float(maxval - minval) + minval
264
+ pair = op(pair, val)
265
+ return pair
prismer/dataset/utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import os
8
+ import re
9
+ import json
10
+ import torch
11
+ import PIL.Image as Image
12
+ import numpy as np
13
+ import torchvision.transforms as transforms
14
+ import torchvision.transforms.functional as transforms_f
15
+ import pathlib
16
+ from dataset.randaugment import RandAugment
17
+
18
+
19
+ cur_dir = pathlib.Path(__file__).parent
20
+
21
+ COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
22
+ ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
23
+ DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
24
+ BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
25
+
26
+
27
+ class Transform:
28
+ def __init__(self, resize_resolution=384, scale_size=[0.5, 1.0], train=False):
29
+ self.resize_size = [resize_resolution, resize_resolution]
30
+ self.scale_size = scale_size
31
+ self.train = train
32
+ self.randaugment = RandAugment(2, 5)
33
+
34
+ def __call__(self, image, labels):
35
+ if self.train:
36
+ # random resize crop
37
+ i, j, h, w = transforms.RandomResizedCrop.get_params(img=image, scale=self.scale_size, ratio=[3. / 4, 4. / 3])
38
+ image = transforms_f.crop(image, i, j, h, w)
39
+ if labels is not None:
40
+ for exp in labels:
41
+ labels[exp] = transforms_f.crop(labels[exp], i, j, h, w)
42
+
43
+ # resize to the defined shape
44
+ image = transforms_f.resize(image, self.resize_size, transforms_f.InterpolationMode.BICUBIC)
45
+ if labels is not None:
46
+ for exp in labels:
47
+ labels[exp] = transforms_f.resize(labels[exp], [224, 224], transforms_f.InterpolationMode.NEAREST)
48
+
49
+ if self.train:
50
+ # random flipping
51
+ if torch.rand(1) > 0.5:
52
+ image = transforms_f.hflip(image)
53
+ if labels is not None:
54
+ for exp in labels:
55
+ labels[exp] = transforms_f.hflip(labels[exp])
56
+
57
+ # random augmentation
58
+ image, labels = self.randaugment(image, labels)
59
+
60
+ # transform to tensor
61
+ image = transforms_f.to_tensor(image)
62
+ if labels is not None:
63
+ for exp in labels:
64
+ if exp in ['depth', 'normal', 'edge']:
65
+ labels[exp] = transforms_f.to_tensor(labels[exp])
66
+ else:
67
+ labels[exp] = (transforms_f.to_tensor(labels[exp]) * 255).long()
68
+
69
+ # apply normalisation:
70
+ image = transforms_f.normalize(image, mean=[0.48145466, 0.4578275, 0.40821073],
71
+ std=[0.26862954, 0.26130258, 0.27577711])
72
+ if labels is not None:
73
+ return {'rgb': image, **labels}
74
+ else:
75
+ return{'rgb': image}
76
+
77
+
78
+ def get_expert_labels(data_path, label_path, image_path, dataset, experts):
79
+ image_full_path = os.path.join(data_path, dataset, image_path)
80
+ image = Image.open(image_full_path).convert('RGB')
81
+ if experts != 'none':
82
+ labels = {}
83
+ labels_info = {}
84
+ ps = image_path.split('.')[-1]
85
+ for exp in experts:
86
+ if exp in ['seg_coco', 'seg_ade', 'edge', 'depth']:
87
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
88
+ if os.stat(label_full_path).st_size > 0:
89
+ labels[exp] = Image.open(label_full_path).convert('L')
90
+ else:
91
+ labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0]])).convert('L')
92
+ elif exp == 'normal':
93
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
94
+ if os.stat(label_full_path).st_size > 0:
95
+ labels[exp] = Image.open(label_full_path).convert('RGB')
96
+ else:
97
+ labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0], 3])).convert('RGB')
98
+ elif exp == 'obj_detection':
99
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
100
+ if os.stat(label_full_path).st_size > 0:
101
+ labels[exp] = Image.open(label_full_path).convert('L')
102
+ else:
103
+ labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
104
+ label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.json'))
105
+ labels_info[exp] = json.load(open(label_info_path, 'r'))
106
+ elif exp == 'ocr_detection':
107
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
108
+ label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.pt'))
109
+ if os.path.exists(label_info_path):
110
+ labels[exp] = Image.open(label_full_path).convert('L')
111
+ labels_info[exp] = torch.load(label_info_path)
112
+ else:
113
+ labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
114
+ labels_info[exp] = None
115
+
116
+ else:
117
+ labels, labels_info = None, None
118
+ return image, labels, labels_info
119
+
120
+
121
+ def post_label_process(inputs, labels_info):
122
+ eps = 1e-6
123
+ for exp in inputs:
124
+ if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
125
+ inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
126
+
127
+ elif exp == 'seg_coco': # in-paint with CLIP features
128
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
129
+ for l in inputs[exp].unique():
130
+ if l == 255:
131
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
132
+ else:
133
+ text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
134
+ inputs[exp] = text_emb
135
+
136
+ elif exp == 'seg_ade': # in-paint with CLIP features
137
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
138
+ for l in inputs[exp].unique():
139
+ if l == 255:
140
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
141
+ else:
142
+ text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
143
+ inputs[exp] = text_emb
144
+
145
+ elif exp == 'obj_detection': # in-paint with CLIP features
146
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
147
+ label_map = labels_info[exp]
148
+ for l in inputs[exp].unique():
149
+ if l == 255:
150
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
151
+ else:
152
+ text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
153
+ inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}
154
+
155
+ elif exp == 'ocr_detection': # in-paint with CLIP features
156
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
157
+ label_map = labels_info[exp]
158
+ for l in inputs[exp].unique():
159
+ if l == 255:
160
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
161
+ else:
162
+ text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
163
+ inputs[exp] = text_emb
164
+ return inputs
165
+
166
+
167
+ def pre_caption(caption, max_words=50):
168
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.capitalize()) # remove special characters
169
+ caption = re.sub(r"\s{2,}", ' ', caption) # remove two white spaces
170
+
171
+ caption = caption.rstrip('\n') # remove \num_ans_per_q symbol
172
+ caption = caption.strip(' ') # remove leading and trailing white spaces
173
+
174
+ # truncate caption to the max words
175
+ caption_words = caption.split(' ')
176
+ if len(caption_words) > max_words:
177
+ caption = ' '.join(caption_words[:max_words])
178
+ return caption
179
+
180
+
181
+ def pre_question(question, max_words=50):
182
+ question = re.sub(r"([.!\"()*#:;~])", ' ', question.capitalize()) # remove special characters
183
+ question = question.strip()
184
+
185
+ # truncate question
186
+ question_words = question.split(' ')
187
+ if len(question_words) > max_words:
188
+ question = ' '.join(question_words[:max_words])
189
+ if question[-1] != '?':
190
+ question += '?'
191
+ return question
192
+
prismer/dataset/vqa_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ from torch.utils.data import Dataset
8
+ from dataset.utils import *
9
+ from PIL import ImageFile
10
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
11
+
12
+
13
+ class VQA(Dataset):
14
+ def __init__(self, config, train=True):
15
+ self.data_path = config['data_path']
16
+ self.label_path = config['label_path']
17
+ self.experts = config['experts']
18
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
19
+ self.train = train
20
+
21
+ if train:
22
+ self.data_list = []
23
+ if 'vqav2' in config['datasets']:
24
+ self.data_list += json.load(open(os.path.join(self.data_path, 'vqav2_train_val.json'), 'r'))
25
+ if 'vg' in config['datasets']:
26
+ self.data_list += json.load(open(os.path.join(self.data_path, 'vg_qa.json'), 'r'))
27
+ else:
28
+ self.data_list = json.load(open(os.path.join(self.data_path, 'vqav2_test.json'), 'r'))
29
+ self.answer_list = json.load(open(os.path.join(self.data_path, 'answer_list.json'), 'r'))
30
+
31
+ def __len__(self):
32
+ return len(self.data_list)
33
+
34
+ def __getitem__(self, index):
35
+ data = self.data_list[index]
36
+
37
+ if data['dataset'] == 'vqa':
38
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
39
+ elif data['dataset'] == 'vg':
40
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vg', self.experts)
41
+
42
+ experts = self.transform(image, labels)
43
+ experts = post_label_process(experts, labels_info)
44
+
45
+ if self.train:
46
+ question = pre_question(data['question'], max_words=30)
47
+ answers = data['answer']
48
+ weights = torch.tensor(data['weight']) if data['dataset'] != 'vg' else torch.tensor(0.2)
49
+ return experts, question, answers, weights
50
+ else:
51
+ question = pre_question(data['question'], max_words=30)
52
+ question_id = data['question_id']
53
+ return experts, index, question, question_id
prismer/demo.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ try:
5
+ import ruamel_yaml as yaml
6
+ except ModuleNotFoundError:
7
+ import ruamel.yaml as yaml
8
+
9
+
10
+ from model.prismer_caption import PrismerCaption
11
+ from dataset import create_dataset, create_loader
12
+ from tqdm import tqdm
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--mode', default='')
16
+ parser.add_argument('--port', default='')
17
+
18
+ parser.add_argument('--exp_name', default='', type=str)
19
+ args = parser.parse_args()
20
+
21
+ # load config
22
+ config = yaml.load(open('configs/caption.yaml', 'r'), Loader=yaml.Loader)['demo']
23
+
24
+ # generate expert labels
25
+ if len(config['experts']) > 0:
26
+ script_name = f'python experts/generate_depth.py'
27
+ os.system(script_name)
28
+ print('***** Generated Depth *****')
29
+
30
+ script_name = f'python experts/generate_edge.py'
31
+ os.system(script_name)
32
+ print('***** Generated Edge *****')
33
+
34
+ script_name = f'python experts/generate_normal.py'
35
+ os.system(script_name)
36
+ print('***** Generated Surface Normals *****')
37
+
38
+ script_name = f'python experts/generate_objdet.py'
39
+ os.system(script_name)
40
+ print('***** Generated Object Detection Labels *****')
41
+
42
+ script_name = f'python experts/generate_ocrdet.py'
43
+ os.system(script_name)
44
+ print('***** Generated OCR Detection Labels *****')
45
+
46
+ script_name = f'python experts/generate_segmentation.py'
47
+ os.system(script_name)
48
+ print('***** Generated Segmentation Labels *****')
49
+
50
+ # load datasets
51
+ _, test_dataset = create_dataset('caption', config)
52
+ test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
53
+
54
+ # load pre-trained model
55
+ model = PrismerCaption(config)
56
+ state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cuda:0')
57
+ model.load_state_dict(state_dict)
58
+ tokenizer = model.tokenizer
59
+
60
+ # inference
61
+ model.eval()
62
+ with torch.no_grad():
63
+ for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
64
+ captions = model(experts, train=False, prefix=config['prefix'])
65
+
66
+ captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
67
+ caption = captions.to(experts['rgb'].device)[0]
68
+
69
+ caption = tokenizer.decode(caption, skip_special_tokens=True)
70
+ caption = caption.capitalize() + '.'
71
+
72
+ # save caption
73
+ save_path = test_loader.dataset.data_list[data_ids[0]]['image'].replace('jpg', 'txt')
74
+ with open(save_path, 'w') as f:
75
+ f.write(caption)
76
+
77
+ print('All Done.')
prismer/demo_vis.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ from utils import create_ade20k_label_colormap
10
+
11
+ obj_label_map = torch.load('dataset/detection_features.pt')['labels']
12
+ coco_label_map = torch.load('dataset/coco_features.pt')['labels']
13
+ ade_color = create_ade20k_label_colormap()
14
+
15
+ file_path = 'helpers/images'
16
+ expert_path = 'helpers/labels'
17
+ plt.ioff()
18
+
19
+
20
+ def get_label_path(file_name, expert_name, with_suffix=False):
21
+ file_suffix = '.png' if not with_suffix else '_.png'
22
+ label_name = ''.join(file_name.split('.')[:-1] + [file_suffix])
23
+ label_path = os.path.join(expert_path, expert_name, label_name)
24
+ return label_path
25
+
26
+
27
+ def depth_prettify(file_name):
28
+ label_path = get_label_path(file_name, 'depth')
29
+ save_path = get_label_path(file_name, 'depth', True)
30
+ depth = plt.imread(label_path)
31
+ plt.imsave(save_path, depth, cmap='rainbow')
32
+
33
+
34
+ def obj_detection_prettify(file_name):
35
+ label_path = get_label_path(file_name, 'obj_detection')
36
+ save_path = get_label_path(file_name, 'obj_detection', True)
37
+
38
+ rgb = plt.imread(file_name)
39
+ obj_labels = plt.imread(label_path)
40
+ obj_labels_dict = json.load(open(label_path.replace('.png', '.json')))
41
+
42
+ plt.imshow(rgb)
43
+
44
+ num_objs = np.unique(obj_labels)[:-1].max()
45
+ plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.5)
46
+
47
+ for i in np.unique(obj_labels)[:-1]:
48
+ obj_idx_all = np.where(obj_labels == i)
49
+ obj_idx = random.randint(0, len(obj_idx_all[0]))
50
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
51
+ obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
52
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
53
+
54
+ plt.axis('off')
55
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
56
+ plt.close()
57
+
58
+
59
+ def seg_prettify(file_name):
60
+ label_path = get_label_path(file_name, 'seg_coco')
61
+ save_path = get_label_path(file_name, 'seg_coco', True)
62
+
63
+ rgb = plt.imread(file_name)
64
+ seg_labels = plt.imread(label_path)
65
+
66
+ plt.imshow(rgb)
67
+
68
+ seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
69
+ for i in np.unique(seg_labels):
70
+ seg_map[seg_labels == i] = ade_color[int(i * 255)]
71
+
72
+ plt.imshow(seg_map, alpha=0.5)
73
+
74
+ for i in np.unique(seg_labels):
75
+ obj_idx_all = np.where(seg_labels == i)
76
+ obj_idx = random.randint(0, len(obj_idx_all[0]))
77
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
78
+ obj_name = coco_label_map[int(i * 255)]
79
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
80
+
81
+ plt.axis('off')
82
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
83
+ plt.close()
84
+
85
+
86
+ def ocr_detection_prettify(file_name):
87
+ label_path = get_label_path(file_name, 'ocr_detection')
88
+ save_path = get_label_path(file_name, 'ocr_detection', True)
89
+
90
+ if os.path.exists(label_path):
91
+ rgb = plt.imread(file_name)
92
+ ocr_labels = plt.imread(label_path)
93
+ ocr_labels_dict = torch.load(label_path.replace('.png', '.pt'))
94
+
95
+ plt.imshow(rgb)
96
+ plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
97
+
98
+ for i in np.unique(ocr_labels)[:-1]:
99
+ text_idx_all = np.where(ocr_labels == i)
100
+ x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
101
+ text = ocr_labels_dict[int(i * 255)]['text']
102
+ plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center')
103
+
104
+ plt.axis('off')
105
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
106
+ plt.close()
107
+ else:
108
+ rgb = plt.imread(file_name)
109
+ ocr_labels = np.ones_like(rgb, dtype=np.float32())
110
+
111
+ plt.imshow(rgb)
112
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
113
+
114
+ x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
115
+ plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center')
116
+
117
+ plt.axis('off')
118
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
119
+ plt.close()
120
+
121
+
122
+ im_list = glob.glob(file_path + '/*.jpg') + glob.glob(file_path + '/*.png') + glob.glob(file_path + '/*.jpeg')
123
+
124
+ # prettify labels first:
125
+ for i in range(len(im_list)):
126
+ depth_prettify(im_list[i])
127
+ seg_prettify(im_list[i])
128
+ ocr_detection_prettify(im_list[i])
129
+ obj_detection_prettify(im_list[i])
130
+
131
+ pretty = {'depth': True, 'normal': False, 'edge': False,
132
+ 'obj_detection': True, 'ocr_detection': True, 'seg_coco': True}
133
+
134
+ # plot expert labels
135
+ for im_path in im_list:
136
+ fig, axs = plt.subplots(1, 7, figsize=(20, 4))
137
+ rgb = plt.imread(im_path)
138
+ axs[0].imshow(rgb)
139
+ axs[0].axis('off')
140
+ axs[0].set_title('RGB')
141
+
142
+ for j in range(6):
143
+ label_name = list(pretty.keys())[j]
144
+ label_path = get_label_path(im_path, label_name, with_suffix=pretty[label_name])
145
+ label = plt.imread(label_path)
146
+ if label_name != 'edge':
147
+ axs[j + 1].imshow(label)
148
+ else:
149
+ axs[j + 1].imshow(label, cmap='gray')
150
+
151
+ axs[j + 1].axis('off')
152
+ axs[j + 1].set_title(label_name)
153
+
154
+ caption_path = ''.join(im_path.split('.')[:-1] + ['.txt'])
155
+ with open(caption_path) as f:
156
+ caption = f.readlines()[0]
157
+
158
+ plt.suptitle(caption)
159
+ plt.tight_layout()
160
+
161
+ plt.show()
prismer/download_checkpoints.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata
2
+ from huggingface_hub.utils import disable_progress_bars
3
+ from pathlib import Path
4
+ from rich.progress import Progress
5
+ from fire import Fire
6
+ from typing import Union, List
7
+
8
+ _EXPERTS = [
9
+ "10_model.pth",
10
+ "Unified_learned_OCIM_RS200_6x+2x.pth",
11
+ "dpt_hybrid-midas-501f0c75.pt",
12
+ "icdar2015_hourglass88.pth",
13
+ "model_final_e0c58e.pkl",
14
+ "model_final_f07440.pkl",
15
+ "scannet.pt",
16
+ ]
17
+
18
+ _MODELS = [
19
+ "vqa_prismer_base",
20
+ "vqa_prismer_large",
21
+ "vqa_prismerz_base",
22
+ "vqa_prismerz_large",
23
+ "caption_prismerz_base",
24
+ "caption_prismerz_large",
25
+ "caption_prismer_base",
26
+ "caption_prismer_large",
27
+ "pretrain_prismer_base",
28
+ "pretrain_prismer_large",
29
+ "pretrain_prismerz_base",
30
+ "pretrain_prismerz_large",
31
+ ]
32
+
33
+ _REPO_ID = "lorenmt/prismer"
34
+
35
+
36
+ def download_checkpoints(
37
+ download_experts: bool = False,
38
+ download_models: Union[bool, List] = False,
39
+ hide_tqdm: bool = False,
40
+ force_redownload: bool = False,
41
+ ):
42
+ if hide_tqdm:
43
+ disable_progress_bars()
44
+ # Convert to list and check for invalid names
45
+ download_experts = _EXPERTS if download_experts else []
46
+ if download_models:
47
+ # only download single model
48
+ if isinstance(download_models, str):
49
+ download_models = [download_models]
50
+
51
+ assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}"
52
+ download_models = _MODELS if isinstance(download_models, bool) else download_models
53
+ else:
54
+ download_models = []
55
+
56
+ # Check if files already exist
57
+ if not force_redownload:
58
+ download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()]
59
+ download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()]
60
+
61
+ assert download_experts or download_models, "Nothing to download."
62
+
63
+ with Progress() as progress:
64
+ # Calculate total download size
65
+ progress.print("[blue]Calculating download size...")
66
+ total_size = 0
67
+ for expert in download_experts:
68
+ url = hf_hub_url(
69
+ filename=expert,
70
+ repo_id=_REPO_ID,
71
+ subfolder="expert_weights"
72
+ )
73
+ total_size += get_hf_file_metadata(url).size
74
+
75
+ for model in download_models:
76
+ url = hf_hub_url(
77
+ filename=f"pytorch_model.bin",
78
+ repo_id=_REPO_ID,
79
+ subfolder=model
80
+ )
81
+ total_size += get_hf_file_metadata(url).size
82
+ progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB")
83
+
84
+ # Download files
85
+ total_files = len(download_experts) + len(download_models)
86
+ total_task = progress.add_task(f"[green]Downloading files", total=total_files)
87
+ if download_experts:
88
+ expert_task = progress.add_task(
89
+ f"[green]Downloading experts...", total=len(download_experts)
90
+ )
91
+ out_folder = Path("experts/expert_weights")
92
+ out_folder.mkdir(parents=True, exist_ok=True)
93
+ for expert in download_experts:
94
+ path = Path(hf_hub_download(
95
+ filename=expert,
96
+ repo_id=_REPO_ID,
97
+ subfolder="expert_weights"
98
+ ))
99
+ path.resolve().rename(out_folder/path.name)
100
+ path.unlink()
101
+ progress.advance(expert_task)
102
+ progress.advance(total_task)
103
+
104
+ if download_models:
105
+ model_task = progress.add_task(
106
+ f"[green]Downloading models...", total=len(download_models)
107
+ )
108
+ for model in download_models:
109
+ path = Path(hf_hub_download(
110
+ filename=f"pytorch_model.bin",
111
+ repo_id=_REPO_ID,
112
+ subfolder=model
113
+ ))
114
+ out_folder = Path("./logging")/model
115
+ out_folder.mkdir(parents=True, exist_ok=True)
116
+ path.resolve().rename(out_folder/"pytorch_model.bin")
117
+ path.unlink()
118
+ progress.advance(model_task)
119
+ progress.advance(total_task)
120
+ progress.print("[green]Done!")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ Fire(download_checkpoints)
prismer/experts/depth/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
prismer/experts/depth/blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ ):
24
+ if backbone == "vitl16_384":
25
+ pretrained = _make_pretrained_vitl16_384(
26
+ use_pretrained,
27
+ hooks=hooks,
28
+ use_readout=use_readout,
29
+ enable_attention_hooks=enable_attention_hooks,
30
+ )
31
+ scratch = _make_scratch(
32
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
34
+ elif backbone == "vitb_rn50_384":
35
+ pretrained = _make_pretrained_vitb_rn50_384(
36
+ use_pretrained,
37
+ hooks=hooks,
38
+ use_vit_only=use_vit_only,
39
+ use_readout=use_readout,
40
+ enable_attention_hooks=enable_attention_hooks,
41
+ )
42
+ scratch = _make_scratch(
43
+ [256, 512, 768, 768], features, groups=groups, expand=expand
44
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
45
+ elif backbone == "vitb16_384":
46
+ pretrained = _make_pretrained_vitb16_384(
47
+ use_pretrained,
48
+ hooks=hooks,
49
+ use_readout=use_readout,
50
+ enable_attention_hooks=enable_attention_hooks,
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
55
+ elif backbone == "resnext101_wsl":
56
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57
+ scratch = _make_scratch(
58
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
59
+ ) # efficientnet_lite3
60
+ else:
61
+ print(f"Backbone '{backbone}' not implemented")
62
+ assert False
63
+
64
+ return pretrained, scratch
65
+
66
+
67
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68
+ scratch = nn.Module()
69
+
70
+ out_shape1 = out_shape
71
+ out_shape2 = out_shape
72
+ out_shape3 = out_shape
73
+ out_shape4 = out_shape
74
+ if expand == True:
75
+ out_shape1 = out_shape
76
+ out_shape2 = out_shape * 2
77
+ out_shape3 = out_shape * 4
78
+ out_shape4 = out_shape * 8
79
+
80
+ scratch.layer1_rn = nn.Conv2d(
81
+ in_shape[0],
82
+ out_shape1,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=False,
87
+ groups=groups,
88
+ )
89
+ scratch.layer2_rn = nn.Conv2d(
90
+ in_shape[1],
91
+ out_shape2,
92
+ kernel_size=3,
93
+ stride=1,
94
+ padding=1,
95
+ bias=False,
96
+ groups=groups,
97
+ )
98
+ scratch.layer3_rn = nn.Conv2d(
99
+ in_shape[2],
100
+ out_shape3,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups,
106
+ )
107
+ scratch.layer4_rn = nn.Conv2d(
108
+ in_shape[3],
109
+ out_shape4,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1,
113
+ bias=False,
114
+ groups=groups,
115
+ )
116
+
117
+ return scratch
118
+
119
+
120
+ def _make_resnet_backbone(resnet):
121
+ pretrained = nn.Module()
122
+ pretrained.layer1 = nn.Sequential(
123
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124
+ )
125
+
126
+ pretrained.layer2 = resnet.layer2
127
+ pretrained.layer3 = resnet.layer3
128
+ pretrained.layer4 = resnet.layer4
129
+
130
+ return pretrained
131
+
132
+
133
+ def _make_pretrained_resnext101_wsl(use_pretrained):
134
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135
+ return _make_resnet_backbone(resnet)
136
+
137
+
138
+ class Interpolate(nn.Module):
139
+ """Interpolation module."""
140
+
141
+ def __init__(self, scale_factor, mode, align_corners=False):
142
+ """Init.
143
+
144
+ Args:
145
+ scale_factor (float): scaling
146
+ mode (str): interpolation mode
147
+ """
148
+ super(Interpolate, self).__init__()
149
+
150
+ self.interp = nn.functional.interpolate
151
+ self.scale_factor = scale_factor
152
+ self.mode = mode
153
+ self.align_corners = align_corners
154
+
155
+ def forward(self, x):
156
+ """Forward pass.
157
+
158
+ Args:
159
+ x (tensor): input
160
+
161
+ Returns:
162
+ tensor: interpolated data_list
163
+ """
164
+
165
+ x = self.interp(
166
+ x,
167
+ scale_factor=self.scale_factor,
168
+ mode=self.mode,
169
+ align_corners=self.align_corners,
170
+ )
171
+
172
+ return x
173
+
174
+
175
+ class ResidualConvUnit(nn.Module):
176
+ """Residual convolution module."""
177
+
178
+ def __init__(self, features):
179
+ """Init.
180
+
181
+ Args:
182
+ features (int): number of features
183
+ """
184
+ super().__init__()
185
+
186
+ self.conv1 = nn.Conv2d(
187
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
188
+ )
189
+
190
+ self.conv2 = nn.Conv2d(
191
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
192
+ )
193
+
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ def forward(self, x):
197
+ """Forward pass.
198
+
199
+ Args:
200
+ x (tensor): input
201
+
202
+ Returns:
203
+ tensor: output
204
+ """
205
+ out = self.relu(x)
206
+ out = self.conv1(out)
207
+ out = self.relu(out)
208
+ out = self.conv2(out)
209
+
210
+ return out + x
211
+
212
+
213
+ class FeatureFusionBlock(nn.Module):
214
+ """Feature fusion block."""
215
+
216
+ def __init__(self, features):
217
+ """Init.
218
+
219
+ Args:
220
+ features (int): number of features
221
+ """
222
+ super(FeatureFusionBlock, self).__init__()
223
+
224
+ self.resConfUnit1 = ResidualConvUnit(features)
225
+ self.resConfUnit2 = ResidualConvUnit(features)
226
+
227
+ def forward(self, *xs):
228
+ """Forward pass.
229
+
230
+ Returns:
231
+ tensor: output
232
+ """
233
+ output = xs[0]
234
+
235
+ if len(xs) == 2:
236
+ output += self.resConfUnit1(xs[1])
237
+
238
+ output = self.resConfUnit2(output)
239
+
240
+ output = nn.functional.interpolate(
241
+ output, scale_factor=2, mode="bilinear", align_corners=True
242
+ )
243
+
244
+ return output
245
+
246
+
247
+ class ResidualConvUnit_custom(nn.Module):
248
+ """Residual convolution module."""
249
+
250
+ def __init__(self, features, activation, bn):
251
+ """Init.
252
+
253
+ Args:
254
+ features (int): number of features
255
+ """
256
+ super().__init__()
257
+
258
+ self.bn = bn
259
+
260
+ self.groups = 1
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ features,
264
+ features,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ bias=not self.bn,
269
+ groups=self.groups,
270
+ )
271
+
272
+ self.conv2 = nn.Conv2d(
273
+ features,
274
+ features,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=not self.bn,
279
+ groups=self.groups,
280
+ )
281
+
282
+ if self.bn == True:
283
+ self.bn1 = nn.BatchNorm2d(features)
284
+ self.bn2 = nn.BatchNorm2d(features)
285
+
286
+ self.activation = activation
287
+
288
+ self.skip_add = nn.quantized.FloatFunctional()
289
+
290
+ def forward(self, x):
291
+ """Forward pass.
292
+
293
+ Args:
294
+ x (tensor): input
295
+
296
+ Returns:
297
+ tensor: output
298
+ """
299
+
300
+ out = self.activation(x)
301
+ out = self.conv1(out)
302
+ if self.bn == True:
303
+ out = self.bn1(out)
304
+
305
+ out = self.activation(out)
306
+ out = self.conv2(out)
307
+ if self.bn == True:
308
+ out = self.bn2(out)
309
+
310
+ if self.groups > 1:
311
+ out = self.conv_merge(out)
312
+
313
+ return self.skip_add.add(out, x)
314
+
315
+ # return out + x
316
+
317
+
318
+ class FeatureFusionBlock_custom(nn.Module):
319
+ """Feature fusion block."""
320
+
321
+ def __init__(
322
+ self,
323
+ features,
324
+ activation,
325
+ deconv=False,
326
+ bn=False,
327
+ expand=False,
328
+ align_corners=True,
329
+ ):
330
+ """Init.
331
+
332
+ Args:
333
+ features (int): number of features
334
+ """
335
+ super(FeatureFusionBlock_custom, self).__init__()
336
+
337
+ self.deconv = deconv
338
+ self.align_corners = align_corners
339
+
340
+ self.groups = 1
341
+
342
+ self.expand = expand
343
+ out_features = features
344
+ if self.expand == True:
345
+ out_features = features // 2
346
+
347
+ self.out_conv = nn.Conv2d(
348
+ features,
349
+ out_features,
350
+ kernel_size=1,
351
+ stride=1,
352
+ padding=0,
353
+ bias=True,
354
+ groups=1,
355
+ )
356
+
357
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
358
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
359
+
360
+ self.skip_add = nn.quantized.FloatFunctional()
361
+
362
+ def forward(self, *xs):
363
+ """Forward pass.
364
+
365
+ Returns:
366
+ tensor: output
367
+ """
368
+ output = xs[0]
369
+
370
+ if len(xs) == 2:
371
+ res = self.resConfUnit1(xs[1])
372
+ output = self.skip_add.add(output, res)
373
+ # output += res
374
+
375
+ output = self.resConfUnit2(output)
376
+
377
+ output = nn.functional.interpolate(
378
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
379
+ )
380
+
381
+ output = self.out_conv(output)
382
+
383
+ return output
prismer/experts/depth/generate_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from PIL import ImageFile
11
+ from dataset.utils import *
12
+
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+
15
+
16
+ class Dataset(Dataset):
17
+ def __init__(self, data_path, transform):
18
+ self.data_path = data_path
19
+ self.transform = transform
20
+ data_folders = glob.glob(f'{data_path}/*/')
21
+ self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
+ self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
+
24
+ def __len__(self):
25
+ return len(self.data_list)
26
+
27
+ def __getitem__(self, index):
28
+ image_path = self.data_list[index]
29
+ image = Image.open(image_path).convert('RGB')
30
+ img_size = [image.size[0], image.size[1]]
31
+ image = self.transform(image)
32
+ return image, image_path, img_size
prismer/experts/depth/models.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .base_model import BaseModel
6
+ from .blocks import (
7
+ FeatureFusionBlock,
8
+ FeatureFusionBlock_custom,
9
+ Interpolate,
10
+ _make_encoder,
11
+ forward_vit,
12
+ )
13
+
14
+
15
+ def _make_fusion_block(features, use_bn):
16
+ return FeatureFusionBlock_custom(
17
+ features,
18
+ nn.ReLU(False),
19
+ deconv=False,
20
+ bn=use_bn,
21
+ expand=False,
22
+ align_corners=True,
23
+ )
24
+
25
+
26
+ class DPT(BaseModel):
27
+ def __init__(
28
+ self,
29
+ head,
30
+ features=256,
31
+ backbone="vitb_rn50_384",
32
+ readout="project",
33
+ channels_last=False,
34
+ use_bn=False,
35
+ enable_attention_hooks=False,
36
+ ):
37
+
38
+ super(DPT, self).__init__()
39
+
40
+ self.channels_last = channels_last
41
+
42
+ hooks = {
43
+ "vitb_rn50_384": [0, 1, 8, 11],
44
+ "vitb16_384": [2, 5, 8, 11],
45
+ "vitl16_384": [5, 11, 17, 23],
46
+ }
47
+
48
+ # Instantiate backbone and reassemble blocks
49
+ self.pretrained, self.scratch = _make_encoder(
50
+ backbone,
51
+ features,
52
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
53
+ groups=1,
54
+ expand=False,
55
+ exportable=False,
56
+ hooks=hooks[backbone],
57
+ use_readout=readout,
58
+ enable_attention_hooks=enable_attention_hooks,
59
+ )
60
+
61
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
62
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
63
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
64
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
65
+
66
+ self.scratch.output_conv = head
67
+
68
+ def forward(self, x):
69
+ if self.channels_last == True:
70
+ x.contiguous(memory_format=torch.channels_last)
71
+
72
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
73
+
74
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
75
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
76
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
77
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
78
+
79
+ path_4 = self.scratch.refinenet4(layer_4_rn)
80
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
81
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
82
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
83
+
84
+ out = self.scratch.output_conv(path_1)
85
+
86
+ return out
87
+
88
+
89
+ class DPTDepthModel(DPT):
90
+ def __init__(
91
+ self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
92
+ ):
93
+ features = kwargs["features"] if "features" in kwargs else 256
94
+
95
+ self.scale = scale
96
+ self.shift = shift
97
+ self.invert = invert
98
+
99
+ head = nn.Sequential(
100
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
101
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
102
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103
+ nn.ReLU(True),
104
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105
+ nn.ReLU(True) if non_negative else nn.Identity(),
106
+ nn.Identity(),
107
+ )
108
+
109
+ super().__init__(head, **kwargs)
110
+
111
+ if path is not None:
112
+ self.load(path)
113
+
114
+ def forward(self, x):
115
+ inv_depth = super().forward(x).squeeze(dim=1)
116
+
117
+ if self.invert:
118
+ depth = self.scale * inv_depth + self.shift
119
+ depth[depth < 1e-8] = 1e-8
120
+ depth = 1.0 / depth
121
+ return depth
122
+ else:
123
+ return inv_depth
124
+
prismer/experts/depth/vit.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+
9
+ activations = {}
10
+
11
+
12
+ def get_activation(name):
13
+ def hook(model, input, output):
14
+ activations[name] = output
15
+
16
+ return hook
17
+
18
+
19
+ attention = {}
20
+
21
+
22
+ def get_attention(name):
23
+ def hook(module, input, output):
24
+ x = input[0]
25
+ B, N, C = x.shape
26
+ qkv = (
27
+ module.qkv(x)
28
+ .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29
+ .permute(2, 0, 3, 1, 4)
30
+ )
31
+ q, k, v = (
32
+ qkv[0],
33
+ qkv[1],
34
+ qkv[2],
35
+ ) # make torchscript happy (cannot use tensor as tuple)
36
+
37
+ attn = (q @ k.transpose(-2, -1)) * module.scale
38
+
39
+ attn = attn.softmax(dim=-1) # [:,:,1,1:]
40
+ attention[name] = attn
41
+
42
+ return hook
43
+
44
+
45
+ def get_mean_attention_map(attn, token, shape):
46
+ attn = attn[:, :, token, 1:]
47
+ attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48
+ attn = torch.nn.functional.interpolate(
49
+ attn, size=shape[2:], mode="bicubic", align_corners=False
50
+ ).squeeze(0)
51
+
52
+ all_attn = torch.mean(attn, 0)
53
+
54
+ return all_attn
55
+
56
+
57
+ class Slice(nn.Module):
58
+ def __init__(self, start_index=1):
59
+ super(Slice, self).__init__()
60
+ self.start_index = start_index
61
+
62
+ def forward(self, x):
63
+ return x[:, self.start_index :]
64
+
65
+
66
+ class AddReadout(nn.Module):
67
+ def __init__(self, start_index=1):
68
+ super(AddReadout, self).__init__()
69
+ self.start_index = start_index
70
+
71
+ def forward(self, x):
72
+ if self.start_index == 2:
73
+ readout = (x[:, 0] + x[:, 1]) / 2
74
+ else:
75
+ readout = x[:, 0]
76
+ return x[:, self.start_index :] + readout.unsqueeze(1)
77
+
78
+
79
+ class ProjectReadout(nn.Module):
80
+ def __init__(self, in_features, start_index=1):
81
+ super(ProjectReadout, self).__init__()
82
+ self.start_index = start_index
83
+
84
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85
+
86
+ def forward(self, x):
87
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88
+ features = torch.cat((x[:, self.start_index :], readout), -1)
89
+
90
+ return self.project(features)
91
+
92
+
93
+ class Transpose(nn.Module):
94
+ def __init__(self, dim0, dim1):
95
+ super(Transpose, self).__init__()
96
+ self.dim0 = dim0
97
+ self.dim1 = dim1
98
+
99
+ def forward(self, x):
100
+ x = x.transpose(self.dim0, self.dim1)
101
+ return x
102
+
103
+
104
+ def forward_vit(pretrained, x):
105
+ b, c, h, w = x.shape
106
+
107
+ glob = pretrained.model.forward_flex(x)
108
+
109
+ layer_1 = pretrained.activations["1"]
110
+ layer_2 = pretrained.activations["2"]
111
+ layer_3 = pretrained.activations["3"]
112
+ layer_4 = pretrained.activations["4"]
113
+
114
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118
+
119
+ unflatten = nn.Sequential(
120
+ nn.Unflatten(
121
+ 2,
122
+ torch.Size(
123
+ [
124
+ h // pretrained.model.patch_size[1],
125
+ w // pretrained.model.patch_size[0],
126
+ ]
127
+ ),
128
+ )
129
+ )
130
+
131
+ if layer_1.ndim == 3:
132
+ layer_1 = unflatten(layer_1)
133
+ if layer_2.ndim == 3:
134
+ layer_2 = unflatten(layer_2)
135
+ if layer_3.ndim == 3:
136
+ layer_3 = unflatten(layer_3)
137
+ if layer_4.ndim == 3:
138
+ layer_4 = unflatten(layer_4)
139
+
140
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144
+
145
+ return layer_1, layer_2, layer_3, layer_4
146
+
147
+
148
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
149
+ posemb_tok, posemb_grid = (
150
+ posemb[:, : self.start_index],
151
+ posemb[0, self.start_index :],
152
+ )
153
+
154
+ gs_old = int(math.sqrt(len(posemb_grid)))
155
+
156
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159
+
160
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161
+
162
+ return posemb
163
+
164
+
165
+ def forward_flex(self, x):
166
+ b, c, h, w = x.shape
167
+
168
+ pos_embed = self._resize_pos_embed(
169
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170
+ )
171
+
172
+ B = x.shape[0]
173
+
174
+ if hasattr(self.patch_embed, "backbone"):
175
+ x = self.patch_embed.backbone(x)
176
+ if isinstance(x, (list, tuple)):
177
+ x = x[-1] # last feature if backbone outputs list/tuple of features
178
+
179
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180
+
181
+ if getattr(self, "dist_token", None) is not None:
182
+ cls_tokens = self.cls_token.expand(
183
+ B, -1, -1
184
+ ) # stole cls_tokens impl from Phil Wang, thanks
185
+ dist_token = self.dist_token.expand(B, -1, -1)
186
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
187
+ else:
188
+ cls_tokens = self.cls_token.expand(
189
+ B, -1, -1
190
+ ) # stole cls_tokens impl from Phil Wang, thanks
191
+ x = torch.cat((cls_tokens, x), dim=1)
192
+
193
+ x = x + pos_embed
194
+ x = self.pos_drop(x)
195
+
196
+ for blk in self.blocks:
197
+ x = blk(x)
198
+
199
+ x = self.norm(x)
200
+
201
+ return x
202
+
203
+
204
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
205
+ if use_readout == "ignore":
206
+ readout_oper = [Slice(start_index)] * len(features)
207
+ elif use_readout == "add":
208
+ readout_oper = [AddReadout(start_index)] * len(features)
209
+ elif use_readout == "project":
210
+ readout_oper = [
211
+ ProjectReadout(vit_features, start_index) for out_feat in features
212
+ ]
213
+ else:
214
+ assert (
215
+ False
216
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217
+
218
+ return readout_oper
219
+
220
+
221
+ def _make_vit_b16_backbone(
222
+ model,
223
+ features=[96, 192, 384, 768],
224
+ size=[384, 384],
225
+ hooks=[2, 5, 8, 11],
226
+ vit_features=768,
227
+ use_readout="ignore",
228
+ start_index=1,
229
+ enable_attention_hooks=False,
230
+ ):
231
+ pretrained = nn.Module()
232
+
233
+ pretrained.model = model
234
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238
+
239
+ pretrained.activations = activations
240
+
241
+ if enable_attention_hooks:
242
+ pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243
+ get_attention("attn_1")
244
+ )
245
+ pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246
+ get_attention("attn_2")
247
+ )
248
+ pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249
+ get_attention("attn_3")
250
+ )
251
+ pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252
+ get_attention("attn_4")
253
+ )
254
+ pretrained.attention = attention
255
+
256
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257
+
258
+ # 32, 48, 136, 384
259
+ pretrained.act_postprocess1 = nn.Sequential(
260
+ readout_oper[0],
261
+ Transpose(1, 2),
262
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263
+ nn.Conv2d(
264
+ in_channels=vit_features,
265
+ out_channels=features[0],
266
+ kernel_size=1,
267
+ stride=1,
268
+ padding=0,
269
+ ),
270
+ nn.ConvTranspose2d(
271
+ in_channels=features[0],
272
+ out_channels=features[0],
273
+ kernel_size=4,
274
+ stride=4,
275
+ padding=0,
276
+ bias=True,
277
+ dilation=1,
278
+ groups=1,
279
+ ),
280
+ )
281
+
282
+ pretrained.act_postprocess2 = nn.Sequential(
283
+ readout_oper[1],
284
+ Transpose(1, 2),
285
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286
+ nn.Conv2d(
287
+ in_channels=vit_features,
288
+ out_channels=features[1],
289
+ kernel_size=1,
290
+ stride=1,
291
+ padding=0,
292
+ ),
293
+ nn.ConvTranspose2d(
294
+ in_channels=features[1],
295
+ out_channels=features[1],
296
+ kernel_size=2,
297
+ stride=2,
298
+ padding=0,
299
+ bias=True,
300
+ dilation=1,
301
+ groups=1,
302
+ ),
303
+ )
304
+
305
+ pretrained.act_postprocess3 = nn.Sequential(
306
+ readout_oper[2],
307
+ Transpose(1, 2),
308
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309
+ nn.Conv2d(
310
+ in_channels=vit_features,
311
+ out_channels=features[2],
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0,
315
+ ),
316
+ )
317
+
318
+ pretrained.act_postprocess4 = nn.Sequential(
319
+ readout_oper[3],
320
+ Transpose(1, 2),
321
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322
+ nn.Conv2d(
323
+ in_channels=vit_features,
324
+ out_channels=features[3],
325
+ kernel_size=1,
326
+ stride=1,
327
+ padding=0,
328
+ ),
329
+ nn.Conv2d(
330
+ in_channels=features[3],
331
+ out_channels=features[3],
332
+ kernel_size=3,
333
+ stride=2,
334
+ padding=1,
335
+ ),
336
+ )
337
+
338
+ pretrained.model.start_index = start_index
339
+ pretrained.model.patch_size = [16, 16]
340
+
341
+ # We inject this function into the VisionTransformer instances so that
342
+ # we can use it with interpolated position embeddings without modifying the library source.
343
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344
+ pretrained.model._resize_pos_embed = types.MethodType(
345
+ _resize_pos_embed, pretrained.model
346
+ )
347
+
348
+ return pretrained
349
+
350
+
351
+ def _make_vit_b_rn50_backbone(
352
+ model,
353
+ features=[256, 512, 768, 768],
354
+ size=[384, 384],
355
+ hooks=[0, 1, 8, 11],
356
+ vit_features=768,
357
+ use_vit_only=False,
358
+ use_readout="ignore",
359
+ start_index=1,
360
+ enable_attention_hooks=False,
361
+ ):
362
+ pretrained = nn.Module()
363
+
364
+ pretrained.model = model
365
+
366
+ if use_vit_only == True:
367
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369
+ else:
370
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371
+ get_activation("1")
372
+ )
373
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374
+ get_activation("2")
375
+ )
376
+
377
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379
+
380
+ if enable_attention_hooks:
381
+ pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382
+ pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383
+ pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384
+ pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385
+ pretrained.attention = attention
386
+
387
+ pretrained.activations = activations
388
+
389
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390
+
391
+ if use_vit_only == True:
392
+ pretrained.act_postprocess1 = nn.Sequential(
393
+ readout_oper[0],
394
+ Transpose(1, 2),
395
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396
+ nn.Conv2d(
397
+ in_channels=vit_features,
398
+ out_channels=features[0],
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0,
402
+ ),
403
+ nn.ConvTranspose2d(
404
+ in_channels=features[0],
405
+ out_channels=features[0],
406
+ kernel_size=4,
407
+ stride=4,
408
+ padding=0,
409
+ bias=True,
410
+ dilation=1,
411
+ groups=1,
412
+ ),
413
+ )
414
+
415
+ pretrained.act_postprocess2 = nn.Sequential(
416
+ readout_oper[1],
417
+ Transpose(1, 2),
418
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419
+ nn.Conv2d(
420
+ in_channels=vit_features,
421
+ out_channels=features[1],
422
+ kernel_size=1,
423
+ stride=1,
424
+ padding=0,
425
+ ),
426
+ nn.ConvTranspose2d(
427
+ in_channels=features[1],
428
+ out_channels=features[1],
429
+ kernel_size=2,
430
+ stride=2,
431
+ padding=0,
432
+ bias=True,
433
+ dilation=1,
434
+ groups=1,
435
+ ),
436
+ )
437
+ else:
438
+ pretrained.act_postprocess1 = nn.Sequential(
439
+ nn.Identity(), nn.Identity(), nn.Identity()
440
+ )
441
+ pretrained.act_postprocess2 = nn.Sequential(
442
+ nn.Identity(), nn.Identity(), nn.Identity()
443
+ )
444
+
445
+ pretrained.act_postprocess3 = nn.Sequential(
446
+ readout_oper[2],
447
+ Transpose(1, 2),
448
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449
+ nn.Conv2d(
450
+ in_channels=vit_features,
451
+ out_channels=features[2],
452
+ kernel_size=1,
453
+ stride=1,
454
+ padding=0,
455
+ ),
456
+ )
457
+
458
+ pretrained.act_postprocess4 = nn.Sequential(
459
+ readout_oper[3],
460
+ Transpose(1, 2),
461
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462
+ nn.Conv2d(
463
+ in_channels=vit_features,
464
+ out_channels=features[3],
465
+ kernel_size=1,
466
+ stride=1,
467
+ padding=0,
468
+ ),
469
+ nn.Conv2d(
470
+ in_channels=features[3],
471
+ out_channels=features[3],
472
+ kernel_size=3,
473
+ stride=2,
474
+ padding=1,
475
+ ),
476
+ )
477
+
478
+ pretrained.model.start_index = start_index
479
+ pretrained.model.patch_size = [16, 16]
480
+
481
+ # We inject this function into the VisionTransformer instances so that
482
+ # we can use it with interpolated position embeddings without modifying the library source.
483
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484
+
485
+ # We inject this function into the VisionTransformer instances so that
486
+ # we can use it with interpolated position embeddings without modifying the library source.
487
+ pretrained.model._resize_pos_embed = types.MethodType(
488
+ _resize_pos_embed, pretrained.model
489
+ )
490
+
491
+ return pretrained
492
+
493
+
494
+ def _make_pretrained_vitb_rn50_384(
495
+ pretrained,
496
+ use_readout="ignore",
497
+ hooks=None,
498
+ use_vit_only=False,
499
+ enable_attention_hooks=False,
500
+ ):
501
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
502
+
503
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
504
+ return _make_vit_b_rn50_backbone(
505
+ model,
506
+ features=[256, 512, 768, 768],
507
+ size=[384, 384],
508
+ hooks=hooks,
509
+ use_vit_only=use_vit_only,
510
+ use_readout=use_readout,
511
+ enable_attention_hooks=enable_attention_hooks,
512
+ )
513
+
514
+
515
+ def _make_pretrained_vitl16_384(
516
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517
+ ):
518
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519
+
520
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
521
+ return _make_vit_b16_backbone(
522
+ model,
523
+ features=[256, 512, 1024, 1024],
524
+ hooks=hooks,
525
+ vit_features=1024,
526
+ use_readout=use_readout,
527
+ enable_attention_hooks=enable_attention_hooks,
528
+ )
529
+
530
+
531
+ def _make_pretrained_vitb16_384(
532
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533
+ ):
534
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535
+
536
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
537
+ return _make_vit_b16_backbone(
538
+ model,
539
+ features=[96, 192, 384, 768],
540
+ hooks=hooks,
541
+ use_readout=use_readout,
542
+ enable_attention_hooks=enable_attention_hooks,
543
+ )
544
+
545
+
546
+ def _make_pretrained_deitb16_384(
547
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548
+ ):
549
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550
+
551
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
552
+ return _make_vit_b16_backbone(
553
+ model,
554
+ features=[96, 192, 384, 768],
555
+ hooks=hooks,
556
+ use_readout=use_readout,
557
+ enable_attention_hooks=enable_attention_hooks,
558
+ )
559
+
560
+
561
+ def _make_pretrained_deitb16_distil_384(
562
+ pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563
+ ):
564
+ model = timm.create_model(
565
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566
+ )
567
+
568
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
569
+ return _make_vit_b16_backbone(
570
+ model,
571
+ features=[96, 192, 384, 768],
572
+ hooks=hooks,
573
+ use_readout=use_readout,
574
+ start_index=2,
575
+ enable_attention_hooks=enable_attention_hooks,
576
+ )
prismer/experts/edge/generate_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from dataset.utils import *
11
+ from PIL import ImageFile
12
+
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+
15
+
16
+ class Dataset(Dataset):
17
+ def __init__(self, data_path, transform):
18
+ self.data_path = data_path
19
+ self.transform = transform
20
+ data_folders = glob.glob(f'{data_path}/*/')
21
+ self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
+ self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
+
24
+ def __len__(self):
25
+ return len(self.data_list)
26
+
27
+ def __getitem__(self, index):
28
+ image_path = self.data_list[index]
29
+ image = Image.open(image_path).convert('RGB')
30
+ img_size = [image.size[0], image.size[1]]
31
+ image = self.transform(image)
32
+ return torch.flip(image, dims=(0, )) * 255., image_path, img_size
prismer/experts/edge/images.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def image_normalization(img, img_min=0, img_max=255,
7
+ epsilon=1e-12):
8
+ """This is a typical image normalization function
9
+ where the minimum and maximum of the image is needed
10
+ source: https://en.wikipedia.org/wiki/Normalization_(image_processing)
11
+
12
+ :param img: an image could be gray scale or color
13
+ :param img_min: for default is 0
14
+ :param img_max: for default is 255
15
+
16
+ :return: a normalized image, if max is 255 the dtype is uint8
17
+ """
18
+
19
+ img = np.float32(img)
20
+ # whenever an inconsistent image
21
+ img = (img - np.min(img)) * (img_max - img_min) / \
22
+ ((np.max(img) - np.min(img)) + epsilon) + img_min
23
+ return img
24
+
25
+
26
+ def fuse_edge(pred):
27
+ edge_maps = []
28
+ for i in pred:
29
+ tmp = torch.sigmoid(i).cpu().detach().numpy()
30
+ edge_maps.append(tmp)
31
+ tensor = np.array(edge_maps)
32
+
33
+ fuses = []
34
+ for idx in range(tensor.shape[1]):
35
+ tmp = tensor[:, idx, ...]
36
+ tmp = np.squeeze(tmp)
37
+
38
+ # Iterate our all 7 NN outputs for a particular image
39
+ for i in range(tmp.shape[0]):
40
+ tmp_img = tmp[i]
41
+ tmp_img = np.uint8(image_normalization(tmp_img))
42
+ tmp_img = cv2.bitwise_not(tmp_img)
43
+
44
+ if i == 6:
45
+ fuse = tmp_img
46
+ fuse = fuse.astype(np.uint8)
47
+ fuses.append(fuse)
48
+ return fuses
49
+
50
+
prismer/experts/edge/model.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def weight_init(m):
7
+ if isinstance(m, (nn.Conv2d,)):
8
+ # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
9
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
10
+ # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
11
+ if m.weight.data.shape[1] == torch.Size([1]):
12
+ torch.nn.init.normal_(m.weight, mean=0.0)
13
+
14
+ if m.bias is not None:
15
+ torch.nn.init.zeros_(m.bias)
16
+
17
+ # for fusion layer
18
+ if isinstance(m, (nn.ConvTranspose2d,)):
19
+ # torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
20
+ torch.nn.init.xavier_normal_(m.weight, gain=1.0)
21
+ # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01)
22
+
23
+ if m.weight.data.shape[1] == torch.Size([1]):
24
+ torch.nn.init.normal_(m.weight, std=0.1)
25
+ if m.bias is not None:
26
+ torch.nn.init.zeros_(m.bias)
27
+
28
+
29
+ class CoFusion(nn.Module):
30
+
31
+ def __init__(self, in_ch, out_ch):
32
+ super(CoFusion, self).__init__()
33
+ self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3,
34
+ stride=1, padding=1)
35
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
36
+ stride=1, padding=1)
37
+ self.conv3 = nn.Conv2d(64, out_ch, kernel_size=3,
38
+ stride=1, padding=1)
39
+ self.relu = nn.ReLU()
40
+
41
+ self.norm_layer1 = nn.GroupNorm(4, 64)
42
+ self.norm_layer2 = nn.GroupNorm(4, 64)
43
+
44
+ def forward(self, x):
45
+ # fusecat = torch.cat(x, dim=1)
46
+ attn = self.relu(self.norm_layer1(self.conv1(x)))
47
+ attn = self.relu(self.norm_layer2(self.conv2(attn)))
48
+ attn = F.softmax(self.conv3(attn), dim=1)
49
+
50
+ # return ((fusecat * attn).sum(1)).unsqueeze(1)
51
+ return ((x * attn).sum(1)).unsqueeze(1)
52
+
53
+ class _DenseLayer(nn.Sequential):
54
+ def __init__(self, input_features, out_features):
55
+ super(_DenseLayer, self).__init__()
56
+
57
+ # self.add_module('relu2', nn.ReLU(inplace=True)),
58
+ self.add_module('conv1', nn.Conv2d(input_features, out_features,
59
+ kernel_size=3, stride=1, padding=2, bias=True)),
60
+ self.add_module('norm1', nn.BatchNorm2d(out_features)),
61
+ self.add_module('relu1', nn.ReLU(inplace=True)),
62
+ self.add_module('conv2', nn.Conv2d(out_features, out_features,
63
+ kernel_size=3, stride=1, bias=True)),
64
+ self.add_module('norm2', nn.BatchNorm2d(out_features))
65
+
66
+ def forward(self, x):
67
+ x1, x2 = x
68
+
69
+ new_features = super(_DenseLayer, self).forward(F.relu(x1)) # F.relu()
70
+ # if new_features.shape[-1]!=x2.shape[-1]:
71
+ # new_features =F.interpolate(new_features,size=(x2.shape[2],x2.shape[-1]), mode='bicubic',
72
+ # align_corners=False)
73
+ return 0.5 * (new_features + x2), x2
74
+
75
+
76
+ class _DenseBlock(nn.Sequential):
77
+ def __init__(self, num_layers, input_features, out_features):
78
+ super(_DenseBlock, self).__init__()
79
+ for i in range(num_layers):
80
+ layer = _DenseLayer(input_features, out_features)
81
+ self.add_module('denselayer%d' % (i + 1), layer)
82
+ input_features = out_features
83
+
84
+
85
+ class UpConvBlock(nn.Module):
86
+ def __init__(self, in_features, up_scale):
87
+ super(UpConvBlock, self).__init__()
88
+ self.up_factor = 2
89
+ self.constant_features = 16
90
+
91
+ layers = self.make_deconv_layers(in_features, up_scale)
92
+ assert layers is not None, layers
93
+ self.features = nn.Sequential(*layers)
94
+
95
+ def make_deconv_layers(self, in_features, up_scale):
96
+ layers = []
97
+ all_pads=[0,0,1,3,7]
98
+ for i in range(up_scale):
99
+ kernel_size = 2 ** up_scale
100
+ pad = all_pads[up_scale] # kernel_size-1
101
+ out_features = self.compute_out_features(i, up_scale)
102
+ layers.append(nn.Conv2d(in_features, out_features, 1))
103
+ layers.append(nn.ReLU(inplace=True))
104
+ layers.append(nn.ConvTranspose2d(
105
+ out_features, out_features, kernel_size, stride=2, padding=pad))
106
+ in_features = out_features
107
+ return layers
108
+
109
+ def compute_out_features(self, idx, up_scale):
110
+ return 1 if idx == up_scale - 1 else self.constant_features
111
+
112
+ def forward(self, x):
113
+ return self.features(x)
114
+
115
+
116
+ class SingleConvBlock(nn.Module):
117
+ def __init__(self, in_features, out_features, stride,
118
+ use_bs=True
119
+ ):
120
+ super(SingleConvBlock, self).__init__()
121
+ self.use_bn = use_bs
122
+ self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride,
123
+ bias=True)
124
+ self.bn = nn.BatchNorm2d(out_features)
125
+
126
+ def forward(self, x):
127
+ x = self.conv(x)
128
+ if self.use_bn:
129
+ x = self.bn(x)
130
+ return x
131
+
132
+
133
+ class DoubleConvBlock(nn.Module):
134
+ def __init__(self, in_features, mid_features,
135
+ out_features=None,
136
+ stride=1,
137
+ use_act=True):
138
+ super(DoubleConvBlock, self).__init__()
139
+
140
+ self.use_act = use_act
141
+ if out_features is None:
142
+ out_features = mid_features
143
+ self.conv1 = nn.Conv2d(in_features, mid_features,
144
+ 3, padding=1, stride=stride)
145
+ self.bn1 = nn.BatchNorm2d(mid_features)
146
+ self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1)
147
+ self.bn2 = nn.BatchNorm2d(out_features)
148
+ self.relu = nn.ReLU(inplace=True)
149
+
150
+ def forward(self, x):
151
+ x = self.conv1(x)
152
+ x = self.bn1(x)
153
+ x = self.relu(x)
154
+ x = self.conv2(x)
155
+ x = self.bn2(x)
156
+ if self.use_act:
157
+ x = self.relu(x)
158
+ return x
159
+
160
+
161
+ class DexiNed(nn.Module):
162
+ """ Definition of the DXtrem network. """
163
+
164
+ def __init__(self):
165
+ super(DexiNed, self).__init__()
166
+ self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
167
+ self.block_2 = DoubleConvBlock(64, 128, use_act=False)
168
+ self.dblock_3 = _DenseBlock(2, 128, 256) # [128,256,100,100]
169
+ self.dblock_4 = _DenseBlock(3, 256, 512)
170
+ self.dblock_5 = _DenseBlock(3, 512, 512)
171
+ self.dblock_6 = _DenseBlock(3, 512, 256)
172
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
173
+
174
+ # left skip connections, figure in Journal
175
+ self.side_1 = SingleConvBlock(64, 128, 2)
176
+ self.side_2 = SingleConvBlock(128, 256, 2)
177
+ self.side_3 = SingleConvBlock(256, 512, 2)
178
+ self.side_4 = SingleConvBlock(512, 512, 1)
179
+ self.side_5 = SingleConvBlock(512, 256, 1) # Sory I forget to comment this line :(
180
+
181
+ # right skip connections, figure in Journal paper
182
+ self.pre_dense_2 = SingleConvBlock(128, 256, 2)
183
+ self.pre_dense_3 = SingleConvBlock(128, 256, 1)
184
+ self.pre_dense_4 = SingleConvBlock(256, 512, 1)
185
+ self.pre_dense_5 = SingleConvBlock(512, 512, 1)
186
+ self.pre_dense_6 = SingleConvBlock(512, 256, 1)
187
+
188
+
189
+ self.up_block_1 = UpConvBlock(64, 1)
190
+ self.up_block_2 = UpConvBlock(128, 1)
191
+ self.up_block_3 = UpConvBlock(256, 2)
192
+ self.up_block_4 = UpConvBlock(512, 3)
193
+ self.up_block_5 = UpConvBlock(512, 4)
194
+ self.up_block_6 = UpConvBlock(256, 4)
195
+ self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False) # hed fusion method
196
+ # self.block_cat = CoFusion(6,6)# cats fusion method
197
+
198
+
199
+ self.apply(weight_init)
200
+
201
+ def slice(self, tensor, slice_shape):
202
+ t_shape = tensor.shape
203
+ height, width = slice_shape
204
+ if t_shape[-1]!=slice_shape[-1]:
205
+ new_tensor = F.interpolate(
206
+ tensor, size=(height, width), mode='bicubic',align_corners=False)
207
+ else:
208
+ new_tensor=tensor
209
+ # tensor[..., :height, :width]
210
+ return new_tensor
211
+
212
+ def forward(self, x):
213
+ assert x.ndim == 4, x.shape
214
+
215
+ # Block 1
216
+ block_1 = self.block_1(x)
217
+ block_1_side = self.side_1(block_1)
218
+
219
+ # Block 2
220
+ block_2 = self.block_2(block_1)
221
+ block_2_down = self.maxpool(block_2)
222
+ block_2_add = block_2_down + block_1_side
223
+ block_2_side = self.side_2(block_2_add)
224
+
225
+ # Block 3
226
+ block_3_pre_dense = self.pre_dense_3(block_2_down)
227
+ block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
228
+ block_3_down = self.maxpool(block_3) # [128,256,50,50]
229
+ block_3_add = block_3_down + block_2_side
230
+ block_3_side = self.side_3(block_3_add)
231
+
232
+ # Block 4
233
+ block_2_resize_half = self.pre_dense_2(block_2_down)
234
+ block_4_pre_dense = self.pre_dense_4(block_3_down+block_2_resize_half)
235
+ block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
236
+ block_4_down = self.maxpool(block_4)
237
+ block_4_add = block_4_down + block_3_side
238
+ block_4_side = self.side_4(block_4_add)
239
+
240
+ # Block 5
241
+ block_5_pre_dense = self.pre_dense_5(
242
+ block_4_down) #block_5_pre_dense_512 +block_4_down
243
+ block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
244
+ block_5_add = block_5 + block_4_side
245
+
246
+ # Block 6
247
+ block_6_pre_dense = self.pre_dense_6(block_5)
248
+ block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
249
+
250
+ # upsampling blocks
251
+ out_1 = self.up_block_1(block_1)
252
+ out_2 = self.up_block_2(block_2)
253
+ out_3 = self.up_block_3(block_3)
254
+ out_4 = self.up_block_4(block_4)
255
+ out_5 = self.up_block_5(block_5)
256
+ out_6 = self.up_block_6(block_6)
257
+ results = [out_1, out_2, out_3, out_4, out_5, out_6]
258
+
259
+ # concatenate multiscale outputs
260
+ block_cat = torch.cat(results, dim=1) # Bx6xHxW
261
+ block_cat = self.block_cat(block_cat) # Bx1xHxW
262
+
263
+ # return results
264
+ results.append(block_cat)
265
+ return results
266
+
267
+
268
+ if __name__ == '__main__':
269
+ batch_size = 8
270
+ img_height = 352
271
+ img_width = 352
272
+
273
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
274
+ device = "cpu"
275
+ input = torch.rand(batch_size, 3, img_height, img_width).to(device)
276
+ # target = torch.rand(batch_size, 1, img_height, img_width).to(device)
277
+ print(f"input shape: {input.shape}")
278
+ model = DexiNed().to(device)
279
+ output = model(input)
280
+ print(f"output shapes: {[t.shape for t in output]}")
281
+
282
+ # for i in range(20000):
283
+ # print(i)
284
+ # output = model(input)
285
+ # loss = nn.MSELoss()(output[-1], target)
286
+ # loss.backward()
prismer/experts/generate_depth.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ try:
10
+ import ruamel_yaml as yaml
11
+ except ModuleNotFoundError:
12
+ import ruamel.yaml as yaml
13
+
14
+ from experts.model_bank import load_expert_model
15
+ from experts.depth.generate_dataset import Dataset
16
+ import PIL.Image as Image
17
+ from accelerate import Accelerator
18
+ from tqdm import tqdm
19
+
20
+ model, transform = load_expert_model(task='depth')
21
+ accelerator = Accelerator(mixed_precision='fp16')
22
+
23
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
+ data_path = config['data_path']
25
+ save_path = os.path.join(config['save_path'], 'depth')
26
+
27
+ batch_size = 64
28
+ dataset = Dataset(data_path, transform)
29
+ data_loader = torch.utils.data.DataLoader(
30
+ dataset=dataset,
31
+ batch_size=batch_size,
32
+ shuffle=False,
33
+ num_workers=4,
34
+ pin_memory=True
35
+ )
36
+
37
+ model, data_loader = accelerator.prepare(model, data_loader)
38
+
39
+ with torch.no_grad():
40
+ for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
41
+ test_pred = model(test_data)
42
+
43
+ for k in range(len(test_pred)):
44
+ img_path_split = img_path[k].split('/')
45
+ ps = img_path[k].split('.')[-1]
46
+ im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
47
+ os.makedirs(im_save_path, exist_ok=True)
48
+
49
+ im_size = img_size[0][k].item(), img_size[1][k].item()
50
+ depth = test_pred[k]
51
+ depth = (depth - depth.min()) / (depth.max() - depth.min())
52
+ depth = torch.nn.functional.interpolate(depth.unsqueeze(0).unsqueeze(1), size=(im_size[1], im_size[0]), mode='bilinear', align_corners=True)
53
+ depth_im = Image.fromarray(255 * depth[0, 0].detach().cpu().numpy()).convert('L')
54
+ depth_im.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
55
+
56
+
prismer/experts/generate_edge.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ try:
10
+ import ruamel_yaml as yaml
11
+ except ModuleNotFoundError:
12
+ import ruamel.yaml as yaml
13
+
14
+ from experts.model_bank import load_expert_model
15
+ from experts.edge.generate_dataset import Dataset
16
+ from experts.edge.images import fuse_edge
17
+ import PIL.Image as Image
18
+ from accelerate import Accelerator
19
+ from tqdm import tqdm
20
+
21
+
22
+ model, transform = load_expert_model(task='edge')
23
+ accelerator = Accelerator(mixed_precision='fp16')
24
+
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
+ data_path = config['data_path']
27
+ save_path = os.path.join(config['save_path'], 'edge')
28
+
29
+ batch_size = 64
30
+ dataset = Dataset(data_path, transform)
31
+ data_loader = torch.utils.data.DataLoader(
32
+ dataset=dataset,
33
+ batch_size=batch_size,
34
+ shuffle=False,
35
+ num_workers=4,
36
+ pin_memory=True
37
+ )
38
+
39
+ model, data_loader = accelerator.prepare(model, data_loader)
40
+
41
+ with torch.no_grad():
42
+ for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
43
+ test_pred = model(test_data)
44
+ fuses = fuse_edge(test_pred)
45
+ for k in range(len(fuses)):
46
+ edge = fuses[k]
47
+ img_path_split = img_path[k].split('/')
48
+ ps = img_path[k].split('.')[-1]
49
+ im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
50
+ os.makedirs(im_save_path, exist_ok=True)
51
+
52
+ im_size = img_size[0][k].item(), img_size[1][k].item()
53
+ edge = Image.fromarray(edge).convert('L')
54
+ edge = edge.resize((im_size[0], im_size[1]), resample=Image.Resampling.BILINEAR)
55
+ edge.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
56
+
57
+
prismer/experts/generate_normal.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ try:
10
+ import ruamel_yaml as yaml
11
+ except ModuleNotFoundError:
12
+ import ruamel.yaml as yaml
13
+
14
+ from experts.model_bank import load_expert_model
15
+ from experts.normal.generate_dataset import CustomDataset
16
+ import PIL.Image as Image
17
+ from accelerate import Accelerator
18
+ from tqdm import tqdm
19
+ import numpy as np
20
+
21
+
22
+ model, transform = load_expert_model(task='normal')
23
+ accelerator = Accelerator(mixed_precision='fp16')
24
+
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
+ data_path = config['data_path']
27
+ save_path = os.path.join(config['save_path'], 'normal')
28
+
29
+ batch_size = 64
30
+ dataset = CustomDataset(data_path, transform)
31
+ data_loader = torch.utils.data.DataLoader(
32
+ dataset=dataset,
33
+ batch_size=batch_size,
34
+ shuffle=False,
35
+ num_workers=4,
36
+ pin_memory=True
37
+ )
38
+
39
+ model, data_loader = accelerator.prepare(model, data_loader)
40
+
41
+ with torch.no_grad():
42
+ for i, (test_data, img_path, img_size) in enumerate(tqdm(data_loader)):
43
+ test_pred = model(test_data)
44
+ pred_norm = test_pred[0][-1][:, :3]
45
+ for k in range(len(pred_norm)):
46
+ img_path_split = img_path[k].split('/')
47
+ ps = img_path[k].split('.')[-1]
48
+ im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
49
+ os.makedirs(im_save_path, exist_ok=True)
50
+
51
+ im_size = img_size[0][k].item(), img_size[1][k].item()
52
+ norm = pred_norm[k]
53
+ norm = ((norm + 1) * 0.5).clip(0, 1)
54
+ norm = torch.nn.functional.interpolate(norm.unsqueeze(0), size=(im_size[1], im_size[0]), mode='bilinear', align_corners=True)
55
+ norm_im = Image.fromarray((norm[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)).convert('RGB')
56
+ norm_im.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
57
+
58
+
prismer/experts/generate_objdet.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ import json
10
+ import copy
11
+ import PIL.Image as Image
12
+ try:
13
+ import ruamel_yaml as yaml
14
+ except ModuleNotFoundError:
15
+ import ruamel.yaml as yaml
16
+
17
+ from experts.model_bank import load_expert_model
18
+ from experts.obj_detection.generate_dataset import Dataset, collate_fn
19
+ from accelerate import Accelerator
20
+ from tqdm import tqdm
21
+
22
+ model, transform = load_expert_model(task='obj_detection')
23
+ accelerator = Accelerator(mixed_precision='fp16')
24
+
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
+ data_path = config['data_path']
27
+ save_path = config['save_path']
28
+
29
+ depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1])
30
+ batch_size = 32
31
+ dataset = Dataset(data_path, depth_path, transform)
32
+ data_loader = torch.utils.data.DataLoader(
33
+ dataset=dataset,
34
+ batch_size=batch_size,
35
+ shuffle=False,
36
+ num_workers=4,
37
+ pin_memory=True,
38
+ collate_fn=collate_fn,
39
+ )
40
+
41
+ model, data_loader = accelerator.prepare(model, data_loader)
42
+
43
+
44
+ def get_mask_labels(depth, instance_boxes, instance_id):
45
+ obj_masks = []
46
+ obj_ids = []
47
+ for i in range(len(instance_boxes)):
48
+ is_duplicate = False
49
+ mask = torch.zeros_like(depth)
50
+ x1, y1, x2, y2 = instance_boxes[i][0].item(), instance_boxes[i][1].item(), \
51
+ instance_boxes[i][2].item(), instance_boxes[i][3].item()
52
+ mask[int(y1):int(y2), int(x1):int(x2)] = 1
53
+ for j in range(len(obj_masks)):
54
+ if ((mask + obj_masks[j]) == 2).sum() / ((mask + obj_masks[j]) > 0).sum() > 0.95:
55
+ is_duplicate = True
56
+ break
57
+ if not is_duplicate:
58
+ obj_masks.append(mask)
59
+ obj_ids.append(instance_id[i])
60
+
61
+ obj_masked_modified = copy.deepcopy(obj_masks[:])
62
+ for i in range(len(obj_masks) - 1):
63
+ mask1 = obj_masks[i]
64
+ mask1_ = obj_masked_modified[i]
65
+ for j in range(i + 1, len(obj_masks)):
66
+ mask2 = obj_masks[j]
67
+ mask2_ = obj_masked_modified[j]
68
+ # case 1: if they don't intersect we don't touch them
69
+ if ((mask1 + mask2) == 2).sum() == 0:
70
+ continue
71
+ # case 2: the entire object 1 is inside of object 2, we say object 1 is in front of object 2:
72
+ elif (((mask1 + mask2) == 2).float() - mask1).sum() == 0:
73
+ mask2_ -= mask1_
74
+ # case 3: the entire object 2 is inside of object 1, we say object 2 is in front of object 1:
75
+ elif (((mask1 + mask2) == 2).float() - mask2).sum() == 0:
76
+ mask1_ -= mask2_
77
+ # case 4: use depth to check object order:
78
+ else:
79
+ # object 1 is closer
80
+ if (depth * mask1).sum() / mask1.sum() > (depth * mask2).sum() / mask2.sum():
81
+ mask2_ -= ((mask1 + mask2) == 2).float()
82
+ # object 2 is closer
83
+ if (depth * mask1).sum() / mask1.sum() < (depth * mask2).sum() / mask2.sum():
84
+ mask1_ -= ((mask1 + mask2) == 2).float()
85
+
86
+ final_mask = torch.ones_like(depth) * 255
87
+ instance_labels = {}
88
+ for i in range(len(obj_masked_modified)):
89
+ final_mask = final_mask.masked_fill(obj_masked_modified[i] > 0, i)
90
+ instance_labels[i] = obj_ids[i].item()
91
+ return final_mask, instance_labels
92
+
93
+
94
+ with torch.no_grad():
95
+ for i, test_data in enumerate(tqdm(data_loader)):
96
+ test_pred = model(test_data)
97
+ for k in range(len(test_pred)):
98
+ instance_boxes = test_pred[k]['instances'].get_fields()['pred_boxes'].tensor
99
+ instance_id = test_pred[k]['instances'].get_fields()['pred_classes']
100
+ depth = test_data[k]['depth']
101
+
102
+ final_mask, instance_labels = get_mask_labels(depth, instance_boxes, instance_id)
103
+
104
+ img_path_split = test_data[k]['image_path'].split('/')
105
+ im_save_path = os.path.join(save_path, 'obj_detection', img_path_split[-3], img_path_split[-2])
106
+ ps = test_data[k]['image_path'].split('.')[-1]
107
+ os.makedirs(im_save_path, exist_ok=True)
108
+
109
+ height, width = test_data[k]['true_height'], test_data[k]['true_width']
110
+ final_mask = Image.fromarray(final_mask.cpu().numpy()).convert('L')
111
+ final_mask = final_mask.resize((height, width), resample=Image.Resampling.NEAREST)
112
+ final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
113
+
114
+ with open(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.json')), 'w') as fp:
115
+ json.dump(instance_labels, fp)
prismer/experts/generate_ocrdet.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ import PIL.Image as Image
10
+ import numpy as np
11
+ import cv2
12
+ import clip
13
+ import pickle as pk
14
+ try:
15
+ import ruamel_yaml as yaml
16
+ except ModuleNotFoundError:
17
+ import ruamel.yaml as yaml
18
+
19
+ from experts.model_bank import load_expert_model
20
+ from experts.ocr_detection.generate_dataset import Dataset
21
+ from accelerate import Accelerator
22
+ from tqdm import tqdm
23
+
24
+
25
+ model, transform = load_expert_model(task='ocr_detection')
26
+ accelerator = Accelerator(mixed_precision='fp16')
27
+ pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
+
29
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
30
+ data_path = config['data_path']
31
+ save_path = os.path.join(config['save_path'], 'ocr_detection')
32
+
33
+ batch_size = 32
34
+ dataset = Dataset(data_path, transform)
35
+ data_loader = torch.utils.data.DataLoader(
36
+ dataset=dataset,
37
+ batch_size=batch_size,
38
+ shuffle=False,
39
+ num_workers=4,
40
+ pin_memory=True,
41
+ )
42
+
43
+ clip_model, _ = clip.load("ViT-L/14", device=accelerator.device)
44
+ model, data_loader = accelerator.prepare(model, data_loader)
45
+
46
+
47
+ def get_label(w, h, word_instances):
48
+ word_lists = []
49
+ final_mask = np.ones([h, w], dtype=np.uint8) * 255
50
+ counter = 0
51
+ for word_instance in word_instances[::-1]:
52
+ mask = np.zeros([h ,w])
53
+ mask = cv2.fillPoly(mask, [np.int32(word_instance.word_bbox.reshape(-1, 2))], 1)
54
+ text = word_instance.text.lower()
55
+ final_mask[mask > 0] = counter
56
+ word_lists.append(text)
57
+ counter += 1
58
+ return final_mask, word_lists
59
+
60
+
61
+ with torch.no_grad():
62
+ for i, (test_data, image_path, scale_w, scale_h, original_w, original_h) in enumerate(tqdm(data_loader)):
63
+ word_instance_lists = model(test_data, scale_w, scale_h, original_w, original_h)
64
+ for k in range(len(word_instance_lists)):
65
+ word_instance = word_instance_lists[k]
66
+ if len(word_instance) == 0:
67
+ continue
68
+ else:
69
+ final_mask, word_lists = get_label(original_w[k], original_h[k], word_instance)
70
+
71
+ final_mask = Image.fromarray(final_mask)
72
+ img_path_split = image_path[k].split('/')
73
+ ps = image_path[k].split('.')[-1]
74
+ im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
75
+ os.makedirs(im_save_path, exist_ok=True)
76
+
77
+ final_mask.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
78
+
79
+ if len(word_lists) > 0:
80
+ word_embed = clip.tokenize(word_lists).to(accelerator.device)
81
+ word_features = pca_clip.transform(clip_model.encode_text(word_embed).float().cpu())
82
+ word_lists = {j: {'features': torch.from_numpy(word_features[j]).float(),
83
+ 'text': word_lists[j]} for j in range(len(word_lists))}
84
+ torch.save(word_lists, os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.pt')))
85
+
86
+
prismer/experts/generate_segmentation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import os
9
+ import PIL.Image as Image
10
+ try:
11
+ import ruamel_yaml as yaml
12
+ except ModuleNotFoundError:
13
+ import ruamel.yaml as yaml
14
+
15
+ from experts.model_bank import load_expert_model
16
+ from experts.segmentation.generate_dataset import Dataset, collate_fn
17
+ from accelerate import Accelerator
18
+ from tqdm import tqdm
19
+
20
+ model, transform = load_expert_model(task='seg_coco')
21
+ accelerator = Accelerator(mixed_precision='fp16')
22
+
23
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
+ data_path = config['data_path']
25
+ save_path = os.path.join(config['save_path'], 'seg_coco')
26
+
27
+ batch_size = 4
28
+ dataset = Dataset(data_path, transform)
29
+ data_loader = torch.utils.data.DataLoader(
30
+ dataset=dataset,
31
+ batch_size=batch_size,
32
+ shuffle=False,
33
+ num_workers=4,
34
+ pin_memory=True,
35
+ collate_fn=collate_fn,
36
+ )
37
+
38
+
39
+ model, data_loader = accelerator.prepare(model, data_loader)
40
+
41
+ with torch.no_grad():
42
+ for i, test_data in enumerate(tqdm(data_loader)):
43
+ test_pred = model(test_data)
44
+
45
+ for k in range(len(test_pred)):
46
+ pred = test_pred[k]['sem_seg']
47
+ labels = torch.argmax(pred, dim=0)
48
+
49
+ img_path_split = test_data[k]['image_path'].split('/')
50
+ ps = test_data[k]['image_path'].split('.')[-1]
51
+ im_save_path = os.path.join(save_path, img_path_split[-3], img_path_split[-2])
52
+ os.makedirs(im_save_path, exist_ok=True)
53
+
54
+ seg = Image.fromarray(labels.float().detach().cpu().numpy()).convert('L')
55
+ seg.save(os.path.join(im_save_path, img_path_split[-1].replace(f'.{ps}', '.png')))
56
+
prismer/experts/model_bank.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+
10
+
11
+ def load_expert_model(task=None):
12
+ if task == 'depth':
13
+ # DPT model is a standard pytorch model class
14
+ from experts.depth.models import DPTDepthModel
15
+
16
+ model = DPTDepthModel(path='experts/expert_weights/dpt_hybrid-midas-501f0c75.pt',
17
+ backbone="vitb_rn50_384",
18
+ non_negative=True,
19
+ enable_attention_hooks=False)
20
+ transform = transforms.Compose([
21
+ transforms.Resize([480, 480]),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=0.5, std=0.5)]
24
+ )
25
+
26
+ elif task == 'seg_coco':
27
+ # Mask2Former is wrapped in detection2,
28
+ # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
29
+ import argparse
30
+ from detectron2.engine.defaults import DefaultPredictor
31
+ from experts.segmentation.utils import setup_cfg
32
+
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--mode", default="client")
35
+ parser.add_argument("--port", default=2)
36
+ args = parser.parse_args()
37
+
38
+ args.config_file = 'experts/segmentation/configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml'
39
+ args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_f07440.pkl']
40
+ cfg = setup_cfg(args)
41
+ model = DefaultPredictor(cfg).model
42
+ transform = transforms.Compose([
43
+ transforms.Resize(size=479, max_size=480)
44
+ ])
45
+
46
+ elif task == 'seg_ade':
47
+ # Mask2Former is wrapped in detection2,
48
+ # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
49
+ import argparse
50
+ from detectron2.engine.defaults import DefaultPredictor
51
+ from experts.segmentation.utils import setup_cfg
52
+
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--mode", default="client")
55
+ parser.add_argument("--port", default=2)
56
+ args = parser.parse_args()
57
+
58
+ args.config_file = 'experts/segmentation/configs/ade20k/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_160k.yaml'
59
+ args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_e0c58e.pkl']
60
+ cfg = setup_cfg(args)
61
+ model = DefaultPredictor(cfg).model
62
+ transform = transforms.Compose([
63
+ transforms.Resize(size=479, max_size=480)
64
+ ])
65
+
66
+ elif task == 'obj_detection':
67
+ # UniDet is wrapped in detection2,
68
+ # the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
69
+ import argparse
70
+ from detectron2.engine.defaults import DefaultPredictor
71
+ from experts.obj_detection.utils import setup_cfg
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument("--mode", default="client")
74
+ parser.add_argument("--port", default=2)
75
+ parser.add_argument("--confidence-threshold", type=float, default=0.5)
76
+ args = parser.parse_args()
77
+
78
+ args.config_file = 'experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml'
79
+ args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/Unified_learned_OCIM_RS200_6x+2x.pth']
80
+
81
+ cfg = setup_cfg(args)
82
+ model = DefaultPredictor(cfg).model
83
+ transform = transforms.Compose([
84
+ transforms.Resize(size=479, max_size=480)
85
+ ])
86
+
87
+ elif task == 'ocr_detection':
88
+ from experts.ocr_detection.charnet.modeling.model import CharNet
89
+ model = CharNet()
90
+ model.load_state_dict(torch.load('experts/expert_weights/icdar2015_hourglass88.pth'))
91
+ transform = transforms.Compose([
92
+ transforms.ToTensor(),
93
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
+ ])
95
+
96
+ elif task == 'normal':
97
+ # NLL-AngMF model is a standard pytorch model class
98
+ import argparse
99
+ from experts.normal.models.NNET import NNET
100
+ from experts.normal.utils import utils
101
+
102
+ parser = argparse.ArgumentParser()
103
+ parser.add_argument("--mode", default="client")
104
+ parser.add_argument("--port", default=2)
105
+ parser.add_argument('--architecture', default='BN', type=str, help='{BN, GN}')
106
+ parser.add_argument("--pretrained", default='scannet', type=str, help="{nyu, scannet}")
107
+ parser.add_argument('--sampling_ratio', type=float, default=0.4)
108
+ parser.add_argument('--importance_ratio', type=float, default=0.7)
109
+ args = parser.parse_args()
110
+ model = NNET(args)
111
+ model = utils.load_checkpoint('experts/expert_weights/scannet.pt', model)
112
+
113
+ transform = transforms.Compose([
114
+ transforms.Resize([480, 480]),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
117
+ ])
118
+
119
+ elif task == 'edge':
120
+ # NLL-AngMF model is a standard pytorch model class
121
+ from experts.edge.model import DexiNed
122
+ model = DexiNed()
123
+ model.load_state_dict(torch.load('experts/expert_weights/10_model.pth', map_location='cpu'))
124
+ transform = transforms.Compose([
125
+ transforms.Resize([480, 480]),
126
+ transforms.ToTensor(),
127
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.0, 1.0, 1.0])
128
+ ])
129
+ else:
130
+ print('Task not supported')
131
+ model = None
132
+ transform = None
133
+
134
+ model.eval()
135
+ return model, transform
136
+
137
+
138
+
139
+
prismer/experts/normal/generate_dataset.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+ from PIL import ImageFile
12
+
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+
15
+
16
+ class CustomDataset(Dataset):
17
+ def __init__(self, data_path, transform):
18
+ self.data_path = data_path
19
+ self.transform = transform
20
+ data_folders = glob.glob(f'{data_path}/*/')
21
+ self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
22
+ self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
23
+
24
+ def __len__(self):
25
+ return len(self.data_list)
26
+
27
+ def __getitem__(self, index):
28
+ image_path = self.data_list[index]
29
+ image = Image.open(image_path).convert('RGB')
30
+ img_size = [image.size[0], image.size[1]]
31
+ image = self.transform(image)
32
+ return image, image_path, img_size
33
+
34
+
prismer/experts/normal/models/NNET.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from experts.normal.models.submodules.encoder import Encoder
6
+ from experts.normal.models.submodules.decoder import Decoder
7
+
8
+
9
+ class NNET(nn.Module):
10
+ def __init__(self, args):
11
+ super(NNET, self).__init__()
12
+ self.encoder = Encoder()
13
+ self.decoder = Decoder(args)
14
+
15
+ def get_1x_lr_params(self): # lr/10 learning rate
16
+ return self.encoder.parameters()
17
+
18
+ def get_10x_lr_params(self): # lr learning rate
19
+ return self.decoder.parameters()
20
+
21
+ def forward(self, img, **kwargs):
22
+ return self.decoder(self.encoder(img), **kwargs)
prismer/experts/normal/models/baseline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from experts.normal.models.submodules.submodules import UpSampleBN, norm_normalize
6
+
7
+
8
+ # This is the baseline encoder-decoder we used in the ablation study
9
+ class NNET(nn.Module):
10
+ def __init__(self, args=None):
11
+ super(NNET, self).__init__()
12
+ self.encoder = Encoder()
13
+ self.decoder = Decoder(num_classes=4)
14
+
15
+ def forward(self, x, **kwargs):
16
+ out = self.decoder(self.encoder(x), **kwargs)
17
+
18
+ # Bilinearly upsample the output to match the input resolution
19
+ up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
20
+
21
+ # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
22
+ up_out = norm_normalize(up_out)
23
+ return up_out
24
+
25
+ def get_1x_lr_params(self): # lr/10 learning rate
26
+ return self.encoder.parameters()
27
+
28
+ def get_10x_lr_params(self): # lr learning rate
29
+ modules = [self.decoder]
30
+ for m in modules:
31
+ yield from m.parameters()
32
+
33
+
34
+ # Encoder
35
+ class Encoder(nn.Module):
36
+ def __init__(self):
37
+ super(Encoder, self).__init__()
38
+
39
+ basemodel_name = 'tf_efficientnet_b5_ap'
40
+ basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
41
+
42
+ # Remove last layer
43
+ basemodel.global_pool = nn.Identity()
44
+ basemodel.classifier = nn.Identity()
45
+
46
+ self.original_model = basemodel
47
+
48
+ def forward(self, x):
49
+ features = [x]
50
+ for k, v in self.original_model._modules.items():
51
+ if (k == 'blocks'):
52
+ for ki, vi in v._modules.items():
53
+ features.append(vi(features[-1]))
54
+ else:
55
+ features.append(v(features[-1]))
56
+ return features
57
+
58
+
59
+ # Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
60
+ class Decoder(nn.Module):
61
+ def __init__(self, num_classes=4):
62
+ super(Decoder, self).__init__()
63
+ self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
64
+ self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
65
+ self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
66
+ self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
67
+ self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
68
+ self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, features):
71
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
72
+ x_d0 = self.conv2(x_block4)
73
+ x_d1 = self.up1(x_d0, x_block3)
74
+ x_d2 = self.up2(x_d1, x_block2)
75
+ x_d3 = self.up3(x_d2, x_block1)
76
+ x_d4 = self.up4(x_d3, x_block0)
77
+ out = self.conv3(x_d4)
78
+ return out
79
+
80
+
81
+ if __name__ == '__main__':
82
+ model = Baseline()
83
+ x = torch.rand(2, 3, 480, 640)
84
+ out = model(x)
85
+ print(out.shape)
prismer/experts/normal/models/submodules/decoder.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from experts.normal.models.submodules.submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
5
+
6
+
7
+ class Decoder(nn.Module):
8
+ def __init__(self, args):
9
+ super(Decoder, self).__init__()
10
+
11
+ # hyper-parameter for sampling
12
+ self.sampling_ratio = args.sampling_ratio
13
+ self.importance_ratio = args.importance_ratio
14
+
15
+ # feature-map
16
+ self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
17
+ if args.architecture == 'BN':
18
+ self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
19
+ self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
20
+ self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
21
+ self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
22
+
23
+ elif args.architecture == 'GN':
24
+ self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
25
+ self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
26
+ self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
27
+ self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
28
+
29
+ else:
30
+ raise Exception('invalid architecture')
31
+
32
+ # produces 1/8 res output
33
+ self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
34
+
35
+ # produces 1/4 res output
36
+ self.out_conv_res4 = nn.Sequential(
37
+ nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
38
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
39
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
40
+ nn.Conv1d(128, 4, kernel_size=1),
41
+ )
42
+
43
+ # produces 1/2 res output
44
+ self.out_conv_res2 = nn.Sequential(
45
+ nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
46
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
47
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
48
+ nn.Conv1d(128, 4, kernel_size=1),
49
+ )
50
+
51
+ # produces 1/1 res output
52
+ self.out_conv_res1 = nn.Sequential(
53
+ nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
54
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
55
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
56
+ nn.Conv1d(128, 4, kernel_size=1),
57
+ )
58
+
59
+ def forward(self, features, gt_norm_mask=None, mode='test'):
60
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
61
+
62
+ # generate feature-map
63
+
64
+ x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
65
+ x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
66
+ x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
67
+ x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
68
+ x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
69
+
70
+ # 1/8 res output
71
+ out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
72
+ out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
73
+
74
+ ################################################################################################################
75
+ # out_res4
76
+ ################################################################################################################
77
+
78
+ if mode == 'train':
79
+ # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
80
+ out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
81
+ B, _, H, W = out_res8_res4.shape
82
+
83
+ # samples: [B, 1, N, 2]
84
+ point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
85
+ sampling_ratio=self.sampling_ratio,
86
+ beta=self.importance_ratio)
87
+
88
+ # output (needed for evaluation / visualization)
89
+ out_res4 = out_res8_res4
90
+
91
+ # grid_sample feature-map
92
+ feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
93
+ init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
94
+ feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
95
+
96
+ # prediction (needed to compute loss)
97
+ samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
98
+ samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
99
+
100
+ for i in range(B):
101
+ out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
102
+
103
+ else:
104
+ # grid_sample feature-map
105
+ feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
106
+ init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
107
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
108
+ B, _, H, W = feat_map.shape
109
+
110
+ # try all pixels
111
+ out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
112
+ out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
113
+ out_res4 = out_res4.view(B, 4, H, W)
114
+ samples_pred_res4 = point_coords_res4 = None
115
+
116
+ ################################################################################################################
117
+ # out_res2
118
+ ################################################################################################################
119
+
120
+ if mode == 'train':
121
+
122
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
123
+ out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
124
+ B, _, H, W = out_res4_res2.shape
125
+
126
+ # samples: [B, 1, N, 2]
127
+ point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
128
+ sampling_ratio=self.sampling_ratio,
129
+ beta=self.importance_ratio)
130
+
131
+ # output (needed for evaluation / visualization)
132
+ out_res2 = out_res4_res2
133
+
134
+ # grid_sample feature-map
135
+ feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
136
+ init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
137
+ feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
138
+
139
+ # prediction (needed to compute loss)
140
+ samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
141
+ samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
142
+
143
+ for i in range(B):
144
+ out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
145
+
146
+ else:
147
+ # grid_sample feature-map
148
+ feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
149
+ init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
150
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
151
+ B, _, H, W = feat_map.shape
152
+
153
+ out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
154
+ out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
155
+ out_res2 = out_res2.view(B, 4, H, W)
156
+ samples_pred_res2 = point_coords_res2 = None
157
+
158
+ ################################################################################################################
159
+ # out_res1
160
+ ################################################################################################################
161
+
162
+ if mode == 'train':
163
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
164
+ out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
165
+ B, _, H, W = out_res2_res1.shape
166
+
167
+ # samples: [B, 1, N, 2]
168
+ point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
169
+ sampling_ratio=self.sampling_ratio,
170
+ beta=self.importance_ratio)
171
+
172
+ # output (needed for evaluation / visualization)
173
+ out_res1 = out_res2_res1
174
+
175
+ # grid_sample feature-map
176
+ feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
177
+ init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
178
+ feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
179
+
180
+ # prediction (needed to compute loss)
181
+ samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
182
+ samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
183
+
184
+ for i in range(B):
185
+ out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
186
+
187
+ else:
188
+ # grid_sample feature-map
189
+ feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
190
+ init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
191
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
192
+ B, _, H, W = feat_map.shape
193
+
194
+ out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
195
+ out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
196
+ out_res1 = out_res1.view(B, 4, H, W)
197
+ samples_pred_res1 = point_coords_res1 = None
198
+
199
+ return [out_res8, out_res4, out_res2, out_res1], \
200
+ [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
201
+ [None, point_coords_res4, point_coords_res2, point_coords_res1]
202
+
prismer/experts/normal/models/submodules/encoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Encoder(nn.Module):
7
+ def __init__(self):
8
+ super(Encoder, self).__init__()
9
+
10
+ basemodel_name = 'tf_efficientnet_b5_ap'
11
+ print('Loading base model ()...'.format(basemodel_name), end='')
12
+ basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
13
+ print('Done.')
14
+
15
+ # Remove last layer
16
+ print('Removing last two layers (global_pool & classifier).')
17
+ basemodel.global_pool = nn.Identity()
18
+ basemodel.classifier = nn.Identity()
19
+
20
+ self.original_model = basemodel
21
+
22
+ def forward(self, x):
23
+ features = [x]
24
+ for k, v in self.original_model._modules.items():
25
+ if (k == 'blocks'):
26
+ for ki, vi in v._modules.items():
27
+ features.append(vi(features[-1]))
28
+ else:
29
+ features.append(v(features[-1]))
30
+ return features
31
+
32
+
prismer/experts/normal/models/submodules/submodules.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ ########################################################################################################################
7
+
8
+
9
+ # Upsample + BatchNorm
10
+ class UpSampleBN(nn.Module):
11
+ def __init__(self, skip_input, output_features):
12
+ super(UpSampleBN, self).__init__()
13
+
14
+ self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
15
+ nn.BatchNorm2d(output_features),
16
+ nn.LeakyReLU(),
17
+ nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
18
+ nn.BatchNorm2d(output_features),
19
+ nn.LeakyReLU())
20
+
21
+ def forward(self, x, concat_with):
22
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
23
+ f = torch.cat([up_x, concat_with], dim=1)
24
+ return self._net(f)
25
+
26
+
27
+ # Upsample + GroupNorm + Weight Standardization
28
+ class UpSampleGN(nn.Module):
29
+ def __init__(self, skip_input, output_features):
30
+ super(UpSampleGN, self).__init__()
31
+
32
+ self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
33
+ nn.GroupNorm(8, output_features),
34
+ nn.LeakyReLU(),
35
+ Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
36
+ nn.GroupNorm(8, output_features),
37
+ nn.LeakyReLU())
38
+
39
+ def forward(self, x, concat_with):
40
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
41
+ f = torch.cat([up_x, concat_with], dim=1)
42
+ return self._net(f)
43
+
44
+
45
+ # Conv2d with weight standardization
46
+ class Conv2d(nn.Conv2d):
47
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
48
+ padding=0, dilation=1, groups=1, bias=True):
49
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
50
+ padding, dilation, groups, bias)
51
+
52
+ def forward(self, x):
53
+ weight = self.weight
54
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
55
+ keepdim=True).mean(dim=3, keepdim=True)
56
+ weight = weight - weight_mean
57
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
58
+ weight = weight / std.expand_as(weight)
59
+ return F.conv2d(x, weight, self.bias, self.stride,
60
+ self.padding, self.dilation, self.groups)
61
+
62
+
63
+ # normalize
64
+ def norm_normalize(norm_out):
65
+ min_kappa = 0.01
66
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
67
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
68
+ kappa = F.elu(kappa) + 1.0 + min_kappa
69
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
70
+ return final_out
71
+
72
+
73
+ # uncertainty-guided sampling (only used during training)
74
+ @torch.no_grad()
75
+ def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
76
+ device = init_normal.device
77
+ B, _, H, W = init_normal.shape
78
+ N = int(sampling_ratio * H * W)
79
+ beta = beta
80
+
81
+ # uncertainty map
82
+ uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
83
+
84
+ # gt_invalid_mask (B, H, W)
85
+ if gt_norm_mask is not None:
86
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
87
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
88
+ uncertainty_map[gt_invalid_mask] = -1e4
89
+
90
+ # (B, H*W)
91
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
92
+
93
+ # importance sampling
94
+ if int(beta * N) > 0:
95
+ importance = idx[:, :int(beta * N)] # B, beta*N
96
+
97
+ # remaining
98
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
99
+
100
+ # coverage
101
+ num_coverage = N - int(beta * N)
102
+
103
+ if num_coverage <= 0:
104
+ samples = importance
105
+ else:
106
+ coverage_list = []
107
+ for i in range(B):
108
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
109
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
110
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
111
+ samples = torch.cat((importance, coverage), dim=1) # B, N
112
+
113
+ else:
114
+ # remaining
115
+ remaining = idx[:, :] # B, H*W
116
+
117
+ # coverage
118
+ num_coverage = N
119
+
120
+ coverage_list = []
121
+ for i in range(B):
122
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
123
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
124
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
125
+ samples = coverage
126
+
127
+ # point coordinates
128
+ rows_int = samples // W # 0 for first row, H-1 for last row
129
+ rows_float = rows_int / float(H-1) # 0 to 1.0
130
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
131
+
132
+ cols_int = samples % W # 0 for first column, W-1 for last column
133
+ cols_float = cols_int / float(W-1) # 0 to 1.0
134
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
135
+
136
+ point_coords = torch.zeros(B, 1, N, 2)
137
+ point_coords[:, 0, :, 0] = cols_float # x coord
138
+ point_coords[:, 0, :, 1] = rows_float # y coord
139
+ point_coords = point_coords.to(device)
140
+ return point_coords, rows_int, cols_int
prismer/experts/normal/utils/losses.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+
6
+
7
+ # compute loss
8
+ class compute_loss(nn.Module):
9
+ def __init__(self, args):
10
+ """args.loss_fn can be one of following:
11
+ - L1 - L1 loss (no uncertainty)
12
+ - L2 - L2 loss (no uncertainty)
13
+ - AL - Angular loss (no uncertainty)
14
+ - NLL_vMF - NLL of vonMF distribution
15
+ - NLL_ours - NLL of Angular vonMF distribution
16
+ - UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
17
+ - UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
18
+ """
19
+ super(compute_loss, self).__init__()
20
+ self.loss_type = args.loss_fn
21
+ if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']:
22
+ self.loss_fn = self.forward_R
23
+ elif self.loss_type in ['UG_NLL_vMF', 'UG_NLL_ours']:
24
+ self.loss_fn = self.forward_UG
25
+ else:
26
+ raise Exception('invalid loss type')
27
+
28
+ def forward(self, *args):
29
+ return self.loss_fn(*args)
30
+
31
+ def forward_R(self, norm_out, gt_norm, gt_norm_mask):
32
+ pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :]
33
+
34
+ if self.loss_type == 'L1':
35
+ l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
36
+ loss = torch.mean(l1[gt_norm_mask])
37
+
38
+ elif self.loss_type == 'L2':
39
+ l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
40
+ loss = torch.mean(l2[gt_norm_mask])
41
+
42
+ elif self.loss_type == 'AL':
43
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
44
+
45
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
46
+ * (dot.detach() < 0.999).float() \
47
+ * (dot.detach() > -0.999).float()
48
+ valid_mask = valid_mask > 0.0
49
+
50
+ al = torch.acos(dot[valid_mask])
51
+ loss = torch.mean(al)
52
+
53
+ elif self.loss_type == 'NLL_vMF':
54
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
55
+
56
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
57
+ * (dot.detach() < 0.999).float() \
58
+ * (dot.detach() > -0.999).float()
59
+ valid_mask = valid_mask > 0.0
60
+
61
+ dot = dot[valid_mask]
62
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
63
+
64
+ loss_pixelwise = - torch.log(kappa) \
65
+ - (kappa * (dot - 1)) \
66
+ + torch.log(1 - torch.exp(- 2 * kappa))
67
+ loss = torch.mean(loss_pixelwise)
68
+
69
+ elif self.loss_type == 'NLL_ours':
70
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
71
+
72
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
73
+ * (dot.detach() < 0.999).float() \
74
+ * (dot.detach() > -0.999).float()
75
+ valid_mask = valid_mask > 0.0
76
+
77
+ dot = dot[valid_mask]
78
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
79
+
80
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
81
+ + kappa * torch.acos(dot) \
82
+ + torch.log(1 + torch.exp(-kappa * np.pi))
83
+ loss = torch.mean(loss_pixelwise)
84
+
85
+ else:
86
+ raise Exception('invalid loss type')
87
+
88
+ return loss
89
+
90
+
91
+ def forward_UG(self, pred_list, coord_list, gt_norm, gt_norm_mask):
92
+ loss = 0.0
93
+ for (pred, coord) in zip(pred_list, coord_list):
94
+ if coord is None:
95
+ pred = F.interpolate(pred, size=[gt_norm.size(2), gt_norm.size(3)], mode='bilinear', align_corners=True)
96
+ pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :]
97
+
98
+ if self.loss_type == 'UG_NLL_vMF':
99
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
100
+
101
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
102
+ * (dot.detach() < 0.999).float() \
103
+ * (dot.detach() > -0.999).float()
104
+ valid_mask = valid_mask > 0.5
105
+
106
+ # mask
107
+ dot = dot[valid_mask]
108
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
109
+
110
+ loss_pixelwise = - torch.log(kappa) \
111
+ - (kappa * (dot - 1)) \
112
+ + torch.log(1 - torch.exp(- 2 * kappa))
113
+ loss = loss + torch.mean(loss_pixelwise)
114
+
115
+ elif self.loss_type == 'UG_NLL_ours':
116
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
117
+
118
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
119
+ * (dot.detach() < 0.999).float() \
120
+ * (dot.detach() > -0.999).float()
121
+ valid_mask = valid_mask > 0.5
122
+
123
+ dot = dot[valid_mask]
124
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
125
+
126
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
127
+ + kappa * torch.acos(dot) \
128
+ + torch.log(1 + torch.exp(-kappa * np.pi))
129
+ loss = loss + torch.mean(loss_pixelwise)
130
+
131
+ else:
132
+ raise Exception
133
+
134
+ else:
135
+ # coord: B, 1, N, 2
136
+ # pred: B, 4, N
137
+ gt_norm_ = F.grid_sample(gt_norm, coord, mode='nearest', align_corners=True) # (B, 3, 1, N)
138
+ gt_norm_mask_ = F.grid_sample(gt_norm_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
139
+ gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N)
140
+ gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N)
141
+
142
+ pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :]
143
+
144
+ if self.loss_type == 'UG_NLL_vMF':
145
+ dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
146
+
147
+ valid_mask = gt_norm_mask_[:, 0, :].float() \
148
+ * (dot.detach() < 0.999).float() \
149
+ * (dot.detach() > -0.999).float()
150
+ valid_mask = valid_mask > 0.5
151
+
152
+ dot = dot[valid_mask]
153
+ kappa = pred_kappa[:, 0, :][valid_mask]
154
+
155
+ loss_pixelwise = - torch.log(kappa) \
156
+ - (kappa * (dot - 1)) \
157
+ + torch.log(1 - torch.exp(- 2 * kappa))
158
+ loss = loss + torch.mean(loss_pixelwise)
159
+
160
+ elif self.loss_type == 'UG_NLL_ours':
161
+ dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
162
+
163
+ valid_mask = gt_norm_mask_[:, 0, :].float() \
164
+ * (dot.detach() < 0.999).float() \
165
+ * (dot.detach() > -0.999).float()
166
+ valid_mask = valid_mask > 0.5
167
+
168
+ dot = dot[valid_mask]
169
+ kappa = pred_kappa[:, 0, :][valid_mask]
170
+
171
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
172
+ + kappa * torch.acos(dot) \
173
+ + torch.log(1 + torch.exp(-kappa * np.pi))
174
+ loss = loss + torch.mean(loss_pixelwise)
175
+
176
+ else:
177
+ raise Exception
178
+ return loss
prismer/experts/normal/utils/utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+
15
+ # convert arg line to args
16
+ def convert_arg_line_to_args(arg_line):
17
+ for arg in arg_line.split():
18
+ if not arg.strip():
19
+ continue
20
+ yield str(arg)
21
+
22
+
23
+ # save args
24
+ def save_args(args, filename):
25
+ with open(filename, 'w') as f:
26
+ for arg in vars(args):
27
+ f.write('{}: {}\n'.format(arg, getattr(args, arg)))
28
+
29
+
30
+ # concatenate images
31
+ def concat_image(image_path_list, concat_image_path):
32
+ imgs = [Image.open(i).convert("RGB").resize((640, 480), resample=Image.BILINEAR) for i in image_path_list]
33
+ imgs_list = []
34
+ for i in range(len(imgs)):
35
+ img = imgs[i]
36
+ imgs_list.append(np.asarray(img))
37
+
38
+ H, W, _ = np.asarray(img).shape
39
+ imgs_list.append(255 * np.ones((H, 20, 3)).astype('uint8'))
40
+
41
+ imgs_comb = np.hstack(imgs_list[:-1])
42
+ imgs_comb = Image.fromarray(imgs_comb)
43
+ imgs_comb.save(concat_image_path)
44
+
45
+
46
+ # load model
47
+ def load_checkpoint(fpath, model):
48
+ ckpt = torch.load(fpath, map_location='cpu')['model']
49
+
50
+ load_dict = {}
51
+ for k, v in ckpt.items():
52
+ if k.startswith('module.'):
53
+ k_ = k.replace('module.', '')
54
+ load_dict[k_] = v
55
+ else:
56
+ load_dict[k] = v
57
+
58
+ model.load_state_dict(load_dict)
59
+ return model
60
+
61
+
62
+ # compute normal errors
63
+ def compute_normal_errors(total_normal_errors):
64
+ metrics = {
65
+ 'mean': np.average(total_normal_errors),
66
+ 'median': np.median(total_normal_errors),
67
+ 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / total_normal_errors.shape),
68
+ 'a1': 100.0 * (np.sum(total_normal_errors < 5) / total_normal_errors.shape[0]),
69
+ 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / total_normal_errors.shape[0]),
70
+ 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / total_normal_errors.shape[0]),
71
+ 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / total_normal_errors.shape[0]),
72
+ 'a5': 100.0 * (np.sum(total_normal_errors < 30) / total_normal_errors.shape[0])
73
+ }
74
+ return metrics
75
+
76
+
77
+ # log normal errors
78
+ def log_normal_errors(metrics, where_to_write, first_line):
79
+ print(first_line)
80
+ print("mean median rmse 5 7.5 11.25 22.5 30")
81
+ print("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f" % (
82
+ metrics['mean'], metrics['median'], metrics['rmse'],
83
+ metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5']))
84
+
85
+ with open(where_to_write, 'a') as f:
86
+ f.write('%s\n' % first_line)
87
+ f.write("mean median rmse 5 7.5 11.25 22.5 30\n")
88
+ f.write("%.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f\n\n" % (
89
+ metrics['mean'], metrics['median'], metrics['rmse'],
90
+ metrics['a1'], metrics['a2'], metrics['a3'], metrics['a4'], metrics['a5']))
91
+
92
+
93
+ # makedir
94
+ def makedir(dirpath):
95
+ if not os.path.exists(dirpath):
96
+ os.makedirs(dirpath)
97
+
98
+
99
+ # makedir from list
100
+ def make_dir_from_list(dirpath_list):
101
+ for dirpath in dirpath_list:
102
+ makedir(dirpath)
103
+
104
+
105
+
106
+ ########################################################################################################################
107
+ # Visualization
108
+ ########################################################################################################################
109
+
110
+
111
+ # unnormalize image
112
+ __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
113
+ def unnormalize(img_in):
114
+ img_out = np.zeros(img_in.shape)
115
+ for ich in range(3):
116
+ img_out[:, :, ich] = img_in[:, :, ich] * __imagenet_stats['std'][ich]
117
+ img_out[:, :, ich] += __imagenet_stats['mean'][ich]
118
+ img_out = (img_out * 255).astype(np.uint8)
119
+ return img_out
120
+
121
+
122
+ # kappa to exp error (only applicable to AngMF distribution)
123
+ def kappa_to_alpha(pred_kappa):
124
+ alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
125
+ + ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
126
+ alpha = np.degrees(alpha)
127
+ return alpha
128
+
129
+
130
+ # normal vector to rgb values
131
+ def norm_to_rgb(norm):
132
+ # norm: (B, H, W, 3)
133
+ norm_rgb = ((norm[0, ...] + 1) * 0.5) * 255
134
+ norm_rgb = np.clip(norm_rgb, a_min=0, a_max=255)
135
+ norm_rgb = norm_rgb.astype(np.uint8)
136
+ return norm_rgb
137
+
138
+
139
+ # visualize during training
140
+ def visualize(args, img, gt_norm, gt_norm_mask, norm_out_list, total_iter):
141
+ B, _, H, W = gt_norm.shape
142
+
143
+ pred_norm_list = []
144
+ pred_kappa_list = []
145
+ for norm_out in norm_out_list:
146
+ norm_out = F.interpolate(norm_out, size=[gt_norm.size(2), gt_norm.size(3)], mode='nearest')
147
+ pred_norm = norm_out[:, :3, :, :] # (B, 3, H, W)
148
+ pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
149
+ pred_norm_list.append(pred_norm)
150
+
151
+ pred_kappa = norm_out[:, 3:, :, :] # (B, 1, H, W)
152
+ pred_kappa = pred_kappa.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1)
153
+ pred_kappa_list.append(pred_kappa)
154
+
155
+ # to numpy arrays
156
+ img = img.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
157
+ gt_norm = gt_norm.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
158
+ gt_norm_mask = gt_norm_mask.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 1)
159
+
160
+ # input image
161
+ target_path = '%s/%08d_img.jpg' % (args.exp_vis_dir, total_iter)
162
+ img = unnormalize(img[0, ...])
163
+ plt.imsave(target_path, img)
164
+
165
+ # gt norm
166
+ gt_norm_rgb = ((gt_norm[0, ...] + 1) * 0.5) * 255
167
+ gt_norm_rgb = np.clip(gt_norm_rgb, a_min=0, a_max=255)
168
+ gt_norm_rgb = gt_norm_rgb.astype(np.uint8)
169
+
170
+ target_path = '%s/%08d_gt_norm.jpg' % (args.exp_vis_dir, total_iter)
171
+ plt.imsave(target_path, gt_norm_rgb * gt_norm_mask[0, ...])
172
+
173
+ # pred_norm
174
+ for i in range(len(pred_norm_list)):
175
+ pred_norm = pred_norm_list[i]
176
+ pred_norm_rgb = norm_to_rgb(pred_norm)
177
+ target_path = '%s/%08d_pred_norm_%d.jpg' % (args.exp_vis_dir, total_iter, i)
178
+ plt.imsave(target_path, pred_norm_rgb)
179
+
180
+ pred_kappa = pred_kappa_list[i]
181
+ pred_alpha = kappa_to_alpha(pred_kappa)
182
+ target_path = '%s/%08d_pred_alpha_%d.jpg' % (args.exp_vis_dir, total_iter, i)
183
+ plt.imsave(target_path, pred_alpha[0, :, :, 0], vmin=0, vmax=60, cmap='jet')
184
+
185
+ # error in angles
186
+ DP = np.sum(gt_norm * pred_norm, axis=3, keepdims=True) # (B, H, W, 1)
187
+ DP = np.clip(DP, -1, 1)
188
+ E = np.degrees(np.arccos(DP)) # (B, H, W, 1)
189
+ E = E * gt_norm_mask
190
+ target_path = '%s/%08d_pred_error_%d.jpg' % (args.exp_vis_dir, total_iter, i)
191
+ plt.imsave(target_path, E[0, :, :, 0], vmin=0, vmax=60, cmap='jet')
prismer/experts/obj_detection/configs/Base-CRCNN-COCO.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ BACKBONE:
4
+ NAME: "build_p67_resnet_fpn_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
6
+ RESNETS:
7
+ OUT_FEATURES: ["res3", "res4", "res5"]
8
+ DEPTH: 50
9
+ FPN:
10
+ IN_FEATURES: ["res3", "res4", "res5"]
11
+ ANCHOR_GENERATOR:
12
+ SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
13
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]]
14
+ RPN:
15
+ IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
16
+ PRE_NMS_TOPK_TRAIN: 2000
17
+ PRE_NMS_TOPK_TEST: 1000
18
+ POST_NMS_TOPK_TRAIN: 2000
19
+ POST_NMS_TOPK_TEST: 1000
20
+ ROI_HEADS:
21
+ NUM_CLASSES: 80
22
+ NAME: CustomCascadeROIHeads
23
+ IN_FEATURES: ["p3", "p4", "p5"]
24
+ SCORE_THRESH_TEST: 0.0001
25
+ ROI_BOX_HEAD:
26
+ NAME: "FastRCNNConvFCHead"
27
+ NUM_FC: 2
28
+ POOLER_RESOLUTION: 7
29
+ CLS_AGNOSTIC_BBOX_REG: True
30
+ DATASETS:
31
+ TRAIN: ("coco_2017_train",)
32
+ TEST: ("coco_2017_val",)
33
+ TEST:
34
+ DETECTIONS_PER_IMAGE: 300
35
+ SOLVER:
36
+ IMS_PER_BATCH: 16
37
+ BASE_LR: 0.01
38
+ STEPS: (60000, 80000)
39
+ MAX_ITER: 90000
40
+ CHECKPOINT_PERIOD: 1000000
41
+ WARMUP_ITERS: 4000
42
+ WARMUP_FACTOR: 0.00025
43
+ CLIP_GRADIENTS:
44
+ ENABLED: True
45
+ INPUT:
46
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
47
+ VERSION: 2
48
+ OUTPUT_DIR: "output/UniDet/auto"
prismer/experts/obj_detection/configs/O365_CRFR50_CAS_2x.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO"
2
+ MODEL:
3
+ ROI_HEADS:
4
+ NUM_CLASSES: 365
5
+ DATASETS:
6
+ TRAIN: ("objects365_train",)
7
+ TEST: ("objects365_val",)
8
+ SOLVER:
9
+ IMS_PER_BATCH: 16
10
+ BASE_LR: 0.01
11
+ STEPS: (120000, 160000,)
12
+ MAX_ITER: 180000
13
+ CHECKPOINT_PERIOD: 120000
14
+ DATALOADER:
15
+ SAMPLER_TRAIN: "ClassAwareSampler"
prismer/experts/obj_detection/configs/OID_CRFR50_CAS_2x.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO"
2
+ MODEL:
3
+ ROI_HEADS:
4
+ NUM_CLASSES: 500
5
+ ROI_BOX_HEAD:
6
+ USE_SIGMOID_CE: True
7
+ USE_EQL_LOSS: True
8
+ EQL_FREQ_CAT: 200
9
+ EQL_CAT_INFO: 'datasets/oid/annotations/openimages_challenge_2019_train_v2_cat_info.json'
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("oid_train",)
14
+ TEST: ("oid_val_expanded",)
15
+ SOLVER:
16
+ IMS_PER_BATCH: 16
17
+ BASE_LR: 0.01
18
+ STEPS: (120000, 160000,)
19
+ MAX_ITER: 180000
20
+ CHECKPOINT_PERIOD: 120000
21
+ DATALOADER:
22
+ SAMPLER_TRAIN: "ClassAwareSampler"
prismer/experts/obj_detection/configs/Partitioned_COIM_R50_6x+2x.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "SplitClassifierRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: -1
6
+ NAME: "MultiDatasetCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train","mapillary_960_train")
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',"mapillary_val")
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid', 'mapillary']
21
+ NUM_CLASSES: [80, 365, 500, 37]
22
+ DATA_RATIO: [1, 1, 1, 1]
23
+ USE_CAS: [False, True, True, False]
24
+ SOLVER:
25
+ IMS_PER_BATCH: 16
26
+ BASE_LR: 0.001
27
+ STEPS: (160000,)
28
+ MAX_ITER: 180000
prismer/experts/obj_detection/configs/Partitioned_COI_R50_2x.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "SplitClassifierRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: -1
6
+ NAME: "MultiDatasetCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ SOLVER:
25
+ IMS_PER_BATCH: 16
26
+ STEPS: (120000, 160000)
27
+ MAX_ITER: 180000
28
+ CHECKPOINT_PERIOD: 120000
29
+
prismer/experts/obj_detection/configs/Partitioned_COI_R50_6x.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "SplitClassifierRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: -1
6
+ NAME: "MultiDatasetCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ SOLVER:
25
+ IMS_PER_BATCH: 16
26
+ STEPS: (480000, 500000)
27
+ MAX_ITER: 540000
28
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Partitioned_COI_R50_8x.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "SplitClassifierRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: -1
6
+ NAME: "MultiDatasetCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ SOLVER:
25
+ IMS_PER_BATCH: 16
26
+ STEPS: (660000, 700000)
27
+ MAX_ITER: 720000
28
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Partitioned_COI_RS101_2x.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "build_p67_resnest_fpn_backbone"
5
+ WEIGHTS: "https://hangzh.s3-us-west-1.amazonaws.com/encoding/models/resnest101_detectron-486f69a8.pth"
6
+ PIXEL_MEAN: [123.68, 116.779, 103.939]
7
+ PIXEL_STD: [58.393, 57.12, 57.375]
8
+ RESNETS:
9
+ DEPTH: 101
10
+ STRIDE_IN_1X1: False
11
+ RADIX: 2
12
+ NORM: "SyncBN"
13
+ FPN:
14
+ NORM: "SyncBN"
15
+ META_ARCHITECTURE: "SplitClassifierRCNN"
16
+ ROI_HEADS:
17
+ NUM_CLASSES: -1
18
+ NAME: "MultiDatasetCascadeROIHeads"
19
+ ROI_BOX_HEAD:
20
+ USE_SIGMOID_CE: True
21
+ # USE_EQL_LOSS: True
22
+ HIERARCHY_IGNORE: True
23
+ HIERARCHY_POS_PARENTS: True
24
+ NAME: "FastRCNNConvFCHead"
25
+ NUM_CONV: 4
26
+ NUM_FC: 1
27
+ NORM: "SyncBN"
28
+ INPUT:
29
+ FORMAT: "RGB"
30
+ DATASETS:
31
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
32
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
33
+ DATALOADER:
34
+ SAMPLER_TRAIN: "MultiDatasetSampler"
35
+ NUM_WORKERS: 1
36
+ MULTI_DATASET:
37
+ ENABLED: True
38
+ DATASETS: ['coco', 'objects365', 'oid']
39
+ NUM_CLASSES: [80, 365, 500]
40
+ DATA_RATIO: [1, 1, 1]
41
+ USE_CAS: [False, True, True]
42
+ SOLVER:
43
+ IMS_PER_BATCH: 16
44
+ STEPS: (120000, 160000)
45
+ MAX_ITER: 180000
46
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Unified_human_OCI_R50_2x.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "UnifiedRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: 659
6
+ NAME: "UnifiedCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ UNIFIED_LABEL_FILE: 'datasets/label_spaces/manual.json'
25
+ SOLVER:
26
+ IMS_PER_BATCH: 16
27
+ STEPS: (120000,160000)
28
+ MAX_ITER: 180000
29
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Unified_learned_OCIM_R50_6x+2x.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "UnifiedRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: 722
6
+ NAME: "UnifiedCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("objects365_train","coco_2017_train","oid_train","mapillary_960_train")
14
+ TEST: ("coco_2017_val", "objects365_val", "oid_val_v2_expanded","mapillary_val")
15
+ # TEST: ('voc_cocoformat_test','viper_val', 'scannet_val','wilddash_public',
16
+ # 'kitti_train','crowdhuman_val', 'cityscapes_cocoformat_val',)
17
+ DATALOADER:
18
+ SAMPLER_TRAIN: "MultiDatasetSampler"
19
+ NUM_WORKERS: 1
20
+ MULTI_DATASET:
21
+ ENABLED: True
22
+ DATASETS: ['objects365', 'coco', 'oid', 'mapillary']
23
+ NUM_CLASSES: [365, 80, 500, 37]
24
+ DATA_RATIO: [1, 1, 1, 1]
25
+ USE_CAS: [True, False, True, False]
26
+ UNIFIED_LABEL_FILE: 'experts/obj_detection/datasets/label_spaces/learned_mAP+M.json'
27
+ # MATCH_NOVEL_CLASSES_FILE: 'datasets/label_spaces/mAP_val+M_722_4d_labelmap_test.json'
28
+ # UNIFIED_EVAL: True
29
+ # UNIFIED_NOVEL_CLASSES_EVAL: True
30
+ # UNIFY_LABEL_TEST: False
31
+ SOLVER:
32
+ IMS_PER_BATCH: 16
33
+ BASE_LR: 0.001
34
+ STEPS: (160000,)
35
+ MAX_ITER: 180000
prismer/experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "build_p67_resnest_fpn_backbone"
5
+ PIXEL_MEAN: [123.68, 116.779, 103.939]
6
+ PIXEL_STD: [58.393, 57.12, 57.375]
7
+ RESNETS:
8
+ DEPTH: 200
9
+ STRIDE_IN_1X1: False
10
+ RADIX: 2
11
+ NORM: "SyncBN"
12
+ FPN:
13
+ NORM: "SyncBN"
14
+ META_ARCHITECTURE: "UnifiedRCNN"
15
+ ROI_HEADS:
16
+ NUM_CLASSES: 722
17
+ NAME: "UnifiedCascadeROIHeads"
18
+ ROI_BOX_HEAD:
19
+ USE_SIGMOID_CE: True
20
+ # USE_EQL_LOSS: True
21
+ HIERARCHY_IGNORE: True
22
+ HIERARCHY_POS_PARENTS: True
23
+ NAME: "FastRCNNConvFCHead"
24
+ NUM_CONV: 4
25
+ NUM_FC: 1
26
+ NORM: "SyncBN"
27
+ INPUT:
28
+ FORMAT: "RGB"
29
+ DATASETS:
30
+ TRAIN: ("objects365_train","coco_2017_train","oid_train","mapillary_960_train")
31
+ TEST: ("coco_2017_val", "objects365_val", "oid_val_v2_expanded","mapillary_val")
32
+ DATALOADER:
33
+ SAMPLER_TRAIN: "MultiDatasetSampler"
34
+ NUM_WORKERS: 1
35
+ MULTI_DATASET:
36
+ ENABLED: True
37
+ DATASETS: ['objects365', 'coco', 'oid', 'mapillary']
38
+ NUM_CLASSES: [365, 80, 500, 37]
39
+ DATA_RATIO: [1, 1, 1, 1]
40
+ USE_CAS: [True, False, True, False]
41
+ UNIFIED_LABEL_FILE: 'experts/obj_detection/datasets/label_spaces/learned_mAP+M.json'
42
+ SOLVER:
43
+ IMS_PER_BATCH: 16
44
+ STEPS: (480000, 500000)
45
+ MAX_ITER: 540000
46
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_2x.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "UnifiedRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: 701
6
+ NAME: "UnifiedCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ UNIFIED_LABEL_FILE: 'datasets/label_spaces/learned_mAP.json'
25
+ SOLVER:
26
+ IMS_PER_BATCH: 16
27
+ STEPS: (120000,160000)
28
+ MAX_ITER: 180000
29
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_6x.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "UnifiedRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: 701
6
+ NAME: "UnifiedCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ UNIFIED_LABEL_FILE: 'datasets/label_spaces/learned_mAP.json'
25
+ SOLVER:
26
+ IMS_PER_BATCH: 16
27
+ STEPS: (480000, 500000)
28
+ MAX_ITER: 540000
29
+ CHECKPOINT_PERIOD: 120000
prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_8x.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-CRCNN-COCO.yaml"
2
+ MODEL:
3
+ META_ARCHITECTURE: "UnifiedRCNN"
4
+ ROI_HEADS:
5
+ NUM_CLASSES: 722
6
+ NAME: "UnifiedCascadeROIHeads"
7
+ ROI_BOX_HEAD:
8
+ USE_SIGMOID_CE: True
9
+ # USE_EQL_LOSS: True
10
+ HIERARCHY_IGNORE: True
11
+ HIERARCHY_POS_PARENTS: True
12
+ DATASETS:
13
+ TRAIN: ("coco_2017_train","objects365_train","oid_train",)
14
+ TEST: ('coco_2017_val','oid_val_expanded','objects365_val',)
15
+ DATALOADER:
16
+ SAMPLER_TRAIN: "MultiDatasetSampler"
17
+ NUM_WORKERS: 1
18
+ MULTI_DATASET:
19
+ ENABLED: True
20
+ DATASETS: ['coco', 'objects365', 'oid']
21
+ NUM_CLASSES: [80, 365, 500]
22
+ DATA_RATIO: [1, 1, 1]
23
+ USE_CAS: [False, True, True]
24
+ UNIFIED_LABEL_FILE: 'datasets/label_spaces/learned_mAP+M.json'
25
+ SOLVER:
26
+ IMS_PER_BATCH: 16
27
+ STEPS: (660000, 700000)
28
+ MAX_ITER: 720000
29
+ CHECKPOINT_PERIOD: 120000