shikunl commited on
Commit
087df0e
1 Parent(s): 61a0078

Add prismer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. prismer/dataset/__init__.py +32 -0
  2. prismer/dataset/ade_features.pt +3 -0
  3. prismer/dataset/background_features.pt +3 -0
  4. prismer/dataset/caption_dataset.py +63 -0
  5. prismer/dataset/clip_pca.pkl +3 -0
  6. prismer/dataset/coco_features.pt +3 -0
  7. prismer/dataset/detection_features.pt +3 -0
  8. prismer/dataset/randaugment.py +265 -0
  9. prismer/dataset/utils.py +192 -0
  10. prismer/dataset/vqa_dataset.py +53 -0
  11. prismer/demo.py +77 -0
  12. prismer/demo_vis.py +161 -0
  13. prismer/download_checkpoints.py +124 -0
  14. prismer/experts/depth/base_model.py +16 -0
  15. prismer/experts/depth/blocks.py +383 -0
  16. prismer/experts/depth/generate_dataset.py +32 -0
  17. prismer/experts/depth/models.py +124 -0
  18. prismer/experts/depth/vit.py +576 -0
  19. prismer/experts/edge/generate_dataset.py +32 -0
  20. prismer/experts/edge/images.py +50 -0
  21. prismer/experts/edge/model.py +286 -0
  22. prismer/experts/generate_depth.py +56 -0
  23. prismer/experts/generate_edge.py +57 -0
  24. prismer/experts/generate_normal.py +58 -0
  25. prismer/experts/generate_objdet.py +115 -0
  26. prismer/experts/generate_ocrdet.py +86 -0
  27. prismer/experts/generate_segmentation.py +56 -0
  28. prismer/experts/model_bank.py +139 -0
  29. prismer/experts/normal/generate_dataset.py +34 -0
  30. prismer/experts/normal/models/NNET.py +22 -0
  31. prismer/experts/normal/models/baseline.py +85 -0
  32. prismer/experts/normal/models/submodules/decoder.py +202 -0
  33. prismer/experts/normal/models/submodules/encoder.py +32 -0
  34. prismer/experts/normal/models/submodules/submodules.py +140 -0
  35. prismer/experts/normal/utils/losses.py +178 -0
  36. prismer/experts/normal/utils/utils.py +191 -0
  37. prismer/experts/obj_detection/configs/Base-CRCNN-COCO.yaml +48 -0
  38. prismer/experts/obj_detection/configs/O365_CRFR50_CAS_2x.yaml +15 -0
  39. prismer/experts/obj_detection/configs/OID_CRFR50_CAS_2x.yaml +22 -0
  40. prismer/experts/obj_detection/configs/Partitioned_COIM_R50_6x+2x.yaml +28 -0
  41. prismer/experts/obj_detection/configs/Partitioned_COI_R50_2x.yaml +29 -0
  42. prismer/experts/obj_detection/configs/Partitioned_COI_R50_6x.yaml +28 -0
  43. prismer/experts/obj_detection/configs/Partitioned_COI_R50_8x.yaml +28 -0
  44. prismer/experts/obj_detection/configs/Partitioned_COI_RS101_2x.yaml +46 -0
  45. prismer/experts/obj_detection/configs/Unified_human_OCI_R50_2x.yaml +29 -0
  46. prismer/experts/obj_detection/configs/Unified_learned_OCIM_R50_6x+2x.yaml +35 -0
  47. prismer/experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml +46 -0
  48. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_2x.yaml +29 -0
  49. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_6x.yaml +29 -0
  50. prismer/experts/obj_detection/configs/Unified_learned_OCI_R50_8x.yaml +29 -0
prismer/dataset/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from dataset.vqa_dataset import VQA
10
+ from dataset.caption_dataset import Caption
11
+
12
+
13
+ def create_dataset(dataset, config):
14
+ if dataset == 'vqa':
15
+ train_dataset = VQA(config, train=True)
16
+ test_dataset = VQA(config, train=False)
17
+ return train_dataset, test_dataset
18
+
19
+ elif dataset == 'caption':
20
+ train_dataset = Caption(config, train=True)
21
+ test_dataset = Caption(config, train=False)
22
+ return train_dataset, test_dataset
23
+
24
+
25
+ def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
26
+ data_loader = DataLoader(dataset,
27
+ batch_size=batch_size,
28
+ num_workers=num_workers,
29
+ collate_fn=collate_fn,
30
+ shuffle=True if train else False,
31
+ drop_last=True if train else False)
32
+ return data_loader
prismer/dataset/ade_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb5c802ef9eaf9c9213219bdc5f8f2d6579400275ab443c43905fdcdd805e5a
3
+ size 41899
prismer/dataset/background_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f5bf8e158715633a9e0f307be1947ca74c294e522b705e1e61bd49361088092
3
+ size 1003
prismer/dataset/caption_dataset.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from dataset.utils import *
11
+ from PIL import ImageFile
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+
15
+ class Caption(Dataset):
16
+ def __init__(self, config, train=True):
17
+ self.data_path = config['data_path']
18
+ self.label_path = config['label_path']
19
+ self.experts = config['experts']
20
+ self.prefix = config['prefix']
21
+ self.dataset = config['dataset']
22
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
23
+ self.train = train
24
+
25
+ if train:
26
+ self.data_list = []
27
+ if self.dataset in ['coco', 'nocaps']:
28
+ self.data_list += json.load(open(os.path.join(self.data_path, 'coco_karpathy_train.json'), 'r'))
29
+ else:
30
+ if self.dataset == 'coco':
31
+ self.data_list = json.load(open(os.path.join(self.data_path, 'coco_karpathy_test.json'), 'r'))
32
+ elif self.dataset == 'nocaps':
33
+ self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
34
+ elif self.dataset == 'demo':
35
+ data_folders = glob.glob(f'{self.data_path}/*/')
36
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
37
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
38
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
39
+
40
+ def __len__(self):
41
+ return len(self.data_list)
42
+
43
+ def __getitem__(self, index):
44
+ data = self.data_list[index]
45
+
46
+ if self.dataset == 'coco':
47
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
48
+ elif self.dataset == 'nocaps':
49
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'nocaps', self.experts)
50
+ elif self.dataset == 'demo':
51
+ img_path_split = self.data_list[index]['image'].split('/')
52
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
54
+
55
+ experts = self.transform(image, labels)
56
+ experts = post_label_process(experts, labels_info)
57
+
58
+ if self.train:
59
+ caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
60
+ return experts, caption
61
+ else:
62
+ return experts, index
63
+
prismer/dataset/clip_pca.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:064950869bce7d622ff7e733dc64c8e1be18be926c3f64e66ae57164df3f8bc5
3
+ size 401612
prismer/dataset/coco_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccf18221afe8dddef3ffb9daad31d5c7a92cdc2f2f434d77cbeb48031bc75756
3
+ size 36651
prismer/dataset/detection_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c38ba9352b2a9f832b14fdc19ac407527ffeaa2903958a73f6eb649f78119c76
3
+ size 198443
prismer/dataset/randaugment.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+
11
+ from PIL import Image, ImageOps, ImageEnhance, ImageDraw
12
+
13
+
14
+ fillmask = {'depth': 0, 'normal': 0, 'edge': 0, 'seg_coco': 255, 'seg_ade': 255,
15
+ 'obj_detection': 255, 'ocr_detection': 255}
16
+ fillcolor = (0, 0, 0)
17
+
18
+
19
+ def affine_transform(pair, affine_params):
20
+ img, label = pair
21
+ img = img.transform(img.size, Image.AFFINE, affine_params,
22
+ resample=Image.BILINEAR, fillcolor=fillcolor)
23
+ if label is not None:
24
+ for exp in label:
25
+ label[exp] = label[exp].transform(label[exp].size, Image.AFFINE, affine_params,
26
+ resample=Image.NEAREST, fillcolor=fillmask[exp])
27
+ return img, label
28
+
29
+
30
+ def ShearX(pair, v): # [-0.3, 0.3]
31
+ assert -0.3 <= v <= 0.3
32
+ if random.random() > 0.5:
33
+ v = -v
34
+ return affine_transform(pair, (1, v, 0, 0, 1, 0))
35
+
36
+
37
+ def ShearY(pair, v): # [-0.3, 0.3]
38
+ assert -0.3 <= v <= 0.3
39
+ if random.random() > 0.5:
40
+ v = -v
41
+ return affine_transform(pair, (1, 0, 0, v, 1, 0))
42
+
43
+
44
+ def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
45
+ assert -0.45 <= v <= 0.45
46
+ if random.random() > 0.5:
47
+ v = -v
48
+ img, _ = pair
49
+ v = v * img.size[0]
50
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
51
+
52
+
53
+ def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
54
+ assert -0.45 <= v <= 0.45
55
+ if random.random() > 0.5:
56
+ v = -v
57
+ img, _ = pair
58
+ v = v * img.size[1]
59
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
60
+
61
+
62
+ def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
63
+ assert 0 <= v <= 10
64
+ if random.random() > 0.5:
65
+ v = -v
66
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
67
+
68
+
69
+ def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
70
+ assert 0 <= v <= 10
71
+ if random.random() > 0.5:
72
+ v = -v
73
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
74
+
75
+
76
+ def Rotate(pair, v): # [-30, 30]
77
+ assert -30 <= v <= 30
78
+ if random.random() > 0.5:
79
+ v = -v
80
+ img, label = pair
81
+ img = img.rotate(v, fillcolor=fillcolor)
82
+ if label is not None:
83
+ for exp in label:
84
+ label[exp] = label[exp].rotate(v, resample=Image.NEAREST, fillcolor=fillmask[exp])
85
+ return img, label
86
+
87
+
88
+ def AutoContrast(pair, _):
89
+ img, label = pair
90
+ return ImageOps.autocontrast(img), label
91
+
92
+
93
+ def Invert(pair, _):
94
+ img, label = pair
95
+ return ImageOps.invert(img), label
96
+
97
+
98
+ def Equalize(pair, _):
99
+ img, label = pair
100
+ return ImageOps.equalize(img), label
101
+
102
+
103
+ def Flip(pair, _): # not from the paper
104
+ img, label = pair
105
+ return ImageOps.mirror(img), ImageOps.mirror(label)
106
+
107
+
108
+ def Solarize(pair, v): # [0, 256]
109
+ img, label = pair
110
+ assert 0 <= v <= 256
111
+ return ImageOps.solarize(img, v), label
112
+
113
+
114
+ def Posterize(pair, v): # [4, 8]
115
+ img, label = pair
116
+ assert 4 <= v <= 8
117
+ v = int(v)
118
+ return ImageOps.posterize(img, v), label
119
+
120
+
121
+ def Posterize2(pair, v): # [0, 4]
122
+ img, label = pair
123
+ assert 0 <= v <= 4
124
+ v = int(v)
125
+ return ImageOps.posterize(img, v), label
126
+
127
+
128
+ def Contrast(pair, v): # [0.1,1.9]
129
+ img, label = pair
130
+ assert 0.1 <= v <= 1.9
131
+ return ImageEnhance.Contrast(img).enhance(v), label
132
+
133
+
134
+ def Color(pair, v): # [0.1,1.9]
135
+ img, label = pair
136
+ assert 0.1 <= v <= 1.9
137
+ return ImageEnhance.Color(img).enhance(v), label
138
+
139
+
140
+ def Brightness(pair, v): # [0.1,1.9]
141
+ img, label = pair
142
+ assert 0.1 <= v <= 1.9
143
+ return ImageEnhance.Brightness(img).enhance(v), label
144
+
145
+
146
+ def Sharpness(pair, v): # [0.1,1.9]
147
+ img, label = pair
148
+ assert 0.1 <= v <= 1.9
149
+ return ImageEnhance.Sharpness(img).enhance(v), label
150
+
151
+
152
+ def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2]
153
+ assert 0.0 <= v <= 0.2
154
+ if v <= 0.:
155
+ return pair
156
+ img, label = pair
157
+ v = v * img.size[0]
158
+ return CutoutAbs(img, v), label
159
+
160
+
161
+ def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
162
+ # assert 0 <= v <= 20
163
+ if v < 0:
164
+ return img
165
+ w, h = img.size
166
+ x0 = np.random.uniform(w)
167
+ y0 = np.random.uniform(h)
168
+
169
+ x0 = int(max(0, x0 - v / 2.))
170
+ y0 = int(max(0, y0 - v / 2.))
171
+ x1 = min(w, x0 + v)
172
+ y1 = min(h, y0 + v)
173
+
174
+ xy = (x0, y0, x1, y1)
175
+ color = (125, 123, 114)
176
+ # color = (0, 0, 0)
177
+ img = img.copy()
178
+ ImageDraw.Draw(img).rectangle(xy, color)
179
+ return img
180
+
181
+
182
+ def Identity(pair, v):
183
+ return pair
184
+
185
+
186
+ def augment_list(): # 16 oeprations and their ranges
187
+ # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
188
+ l = [
189
+ (Identity, 0., 1.0),
190
+ (ShearX, 0., 0.3), # 0
191
+ (ShearY, 0., 0.3), # 1
192
+ (TranslateX, 0., 0.33), # 2
193
+ (TranslateY, 0., 0.33), # 3
194
+ (Rotate, 0, 30), # 4
195
+ (AutoContrast, 0, 1), # 5
196
+ # (Invert, 0, 1), # 6
197
+ (Equalize, 0, 1), # 7
198
+ # (Solarize, 0, 110), # 8
199
+ # (Posterize, 4, 8), # 9
200
+ # (Color, 0.1, 1.9), # 11
201
+ (Brightness, 0.1, 1.9), # 12
202
+ (Sharpness, 0.1, 1.9), # 13
203
+ ]
204
+ return l
205
+
206
+
207
+ class Lighting(object):
208
+ """Lighting noise(AlexNet - style PCA - based noise)"""
209
+
210
+ def __init__(self, alphastd, eigval, eigvec):
211
+ self.alphastd = alphastd
212
+ self.eigval = torch.Tensor(eigval)
213
+ self.eigvec = torch.Tensor(eigvec)
214
+
215
+ def __call__(self, img):
216
+ if self.alphastd == 0:
217
+ return img
218
+
219
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
220
+ rgb = self.eigvec.type_as(img).clone() \
221
+ .mul(alpha.view(1, 3).expand(3, 3)) \
222
+ .mul(self.eigval.view(1, 3).expand(3, 3)) \
223
+ .sum(1).squeeze()
224
+
225
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
226
+
227
+
228
+ class CutoutDefault(object):
229
+ """
230
+ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
231
+ """
232
+ def __init__(self, length):
233
+ self.length = length
234
+
235
+ def __call__(self, img):
236
+ h, w = img.size(1), img.size(2)
237
+ mask = np.ones((h, w), np.float32)
238
+ y = np.random.randint(h)
239
+ x = np.random.randint(w)
240
+
241
+ y1 = np.clip(y - self.length // 2, 0, h)
242
+ y2 = np.clip(y + self.length // 2, 0, h)
243
+ x1 = np.clip(x - self.length // 2, 0, w)
244
+ x2 = np.clip(x + self.length // 2, 0, w)
245
+
246
+ mask[y1: y2, x1: x2] = 0.
247
+ mask = torch.from_numpy(mask)
248
+ mask = mask.expand_as(img)
249
+ img *= mask
250
+ return img
251
+
252
+
253
+ class RandAugment:
254
+ def __init__(self, n, m):
255
+ self.n = n
256
+ self.m = m # [0, 10]
257
+ self.augment_list = augment_list()
258
+
259
+ def __call__(self, img, label):
260
+ pair = img, label
261
+ ops = random.choices(self.augment_list, k=self.n)
262
+ for op, minval, maxval in ops:
263
+ val = (float(self.m) / 10) * float(maxval - minval) + minval
264
+ pair = op(pair, val)
265
+ return pair
prismer/dataset/utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import os
8
+ import re
9
+ import json
10
+ import torch
11
+ import PIL.Image as Image
12
+ import numpy as np
13
+ import torchvision.transforms as transforms
14
+ import torchvision.transforms.functional as transforms_f
15
+ import pathlib
16
+ from dataset.randaugment import RandAugment
17
+
18
+
19
+ cur_dir = pathlib.Path(__file__).parent
20
+
21
+ COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
22
+ ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
23
+ DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
24
+ BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
25
+
26
+
27
+ class Transform:
28
+ def __init__(self, resize_resolution=384, scale_size=[0.5, 1.0], train=False):
29
+ self.resize_size = [resize_resolution, resize_resolution]
30
+ self.scale_size = scale_size
31
+ self.train = train
32
+ self.randaugment = RandAugment(2, 5)
33
+
34
+ def __call__(self, image, labels):
35
+ if self.train:
36
+ # random resize crop
37
+ i, j, h, w = transforms.RandomResizedCrop.get_params(img=image, scale=self.scale_size, ratio=[3. / 4, 4. / 3])
38
+ image = transforms_f.crop(image, i, j, h, w)
39
+ if labels is not None:
40
+ for exp in labels:
41
+ labels[exp] = transforms_f.crop(labels[exp], i, j, h, w)
42
+
43
+ # resize to the defined shape
44
+ image = transforms_f.resize(image, self.resize_size, transforms_f.InterpolationMode.BICUBIC)
45
+ if labels is not None:
46
+ for exp in labels:
47
+ labels[exp] = transforms_f.resize(labels[exp], [224, 224], transforms_f.InterpolationMode.NEAREST)
48
+
49
+ if self.train:
50
+ # random flipping
51
+ if torch.rand(1) > 0.5:
52
+ image = transforms_f.hflip(image)
53
+ if labels is not None:
54
+ for exp in labels:
55
+ labels[exp] = transforms_f.hflip(labels[exp])
56
+
57
+ # random augmentation
58
+ image, labels = self.randaugment(image, labels)
59
+
60
+ # transform to tensor
61
+ image = transforms_f.to_tensor(image)
62
+ if labels is not None:
63
+ for exp in labels:
64
+ if exp in ['depth', 'normal', 'edge']:
65
+ labels[exp] = transforms_f.to_tensor(labels[exp])
66
+ else:
67
+ labels[exp] = (transforms_f.to_tensor(labels[exp]) * 255).long()
68
+
69
+ # apply normalisation:
70
+ image = transforms_f.normalize(image, mean=[0.48145466, 0.4578275, 0.40821073],
71
+ std=[0.26862954, 0.26130258, 0.27577711])
72
+ if labels is not None:
73
+ return {'rgb': image, **labels}
74
+ else:
75
+ return{'rgb': image}
76
+
77
+
78
+ def get_expert_labels(data_path, label_path, image_path, dataset, experts):
79
+ image_full_path = os.path.join(data_path, dataset, image_path)
80
+ image = Image.open(image_full_path).convert('RGB')
81
+ if experts != 'none':
82
+ labels = {}
83
+ labels_info = {}
84
+ ps = image_path.split('.')[-1]
85
+ for exp in experts:
86
+ if exp in ['seg_coco', 'seg_ade', 'edge', 'depth']:
87
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
88
+ if os.stat(label_full_path).st_size > 0:
89
+ labels[exp] = Image.open(label_full_path).convert('L')
90
+ else:
91
+ labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0]])).convert('L')
92
+ elif exp == 'normal':
93
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
94
+ if os.stat(label_full_path).st_size > 0:
95
+ labels[exp] = Image.open(label_full_path).convert('RGB')
96
+ else:
97
+ labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0], 3])).convert('RGB')
98
+ elif exp == 'obj_detection':
99
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
100
+ if os.stat(label_full_path).st_size > 0:
101
+ labels[exp] = Image.open(label_full_path).convert('L')
102
+ else:
103
+ labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
104
+ label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.json'))
105
+ labels_info[exp] = json.load(open(label_info_path, 'r'))
106
+ elif exp == 'ocr_detection':
107
+ label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
108
+ label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.pt'))
109
+ if os.path.exists(label_info_path):
110
+ labels[exp] = Image.open(label_full_path).convert('L')
111
+ labels_info[exp] = torch.load(label_info_path)
112
+ else:
113
+ labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
114
+ labels_info[exp] = None
115
+
116
+ else:
117
+ labels, labels_info = None, None
118
+ return image, labels, labels_info
119
+
120
+
121
+ def post_label_process(inputs, labels_info):
122
+ eps = 1e-6
123
+ for exp in inputs:
124
+ if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
125
+ inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
126
+
127
+ elif exp == 'seg_coco': # in-paint with CLIP features
128
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
129
+ for l in inputs[exp].unique():
130
+ if l == 255:
131
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
132
+ else:
133
+ text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
134
+ inputs[exp] = text_emb
135
+
136
+ elif exp == 'seg_ade': # in-paint with CLIP features
137
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
138
+ for l in inputs[exp].unique():
139
+ if l == 255:
140
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
141
+ else:
142
+ text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
143
+ inputs[exp] = text_emb
144
+
145
+ elif exp == 'obj_detection': # in-paint with CLIP features
146
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
147
+ label_map = labels_info[exp]
148
+ for l in inputs[exp].unique():
149
+ if l == 255:
150
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
151
+ else:
152
+ text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
153
+ inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}
154
+
155
+ elif exp == 'ocr_detection': # in-paint with CLIP features
156
+ text_emb = torch.empty([64, *inputs[exp].shape[1:]])
157
+ label_map = labels_info[exp]
158
+ for l in inputs[exp].unique():
159
+ if l == 255:
160
+ text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
161
+ else:
162
+ text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
163
+ inputs[exp] = text_emb
164
+ return inputs
165
+
166
+
167
+ def pre_caption(caption, max_words=50):
168
+ caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.capitalize()) # remove special characters
169
+ caption = re.sub(r"\s{2,}", ' ', caption) # remove two white spaces
170
+
171
+ caption = caption.rstrip('\n') # remove \num_ans_per_q symbol
172
+ caption = caption.strip(' ') # remove leading and trailing white spaces
173
+
174
+ # truncate caption to the max words
175
+ caption_words = caption.split(' ')
176
+ if len(caption_words) > max_words:
177
+ caption = ' '.join(caption_words[:max_words])
178
+ return caption
179
+
180
+
181
+ def pre_question(question, max_words=50):
182
+ question = re.sub(r"([.!\"()*#:;~])", ' ', question.capitalize()) # remove special characters
183
+ question = question.strip()
184
+
185
+ # truncate question
186
+ question_words = question.split(' ')
187
+ if len(question_words) > max_words:
188
+ question = ' '.join(question_words[:max_words])
189
+ if question[-1] != '?':
190
+ question += '?'
191
+ return question
192
+
prismer/dataset/vqa_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ from torch.utils.data import Dataset
8
+ from dataset.utils import *
9
+ from PIL import ImageFile
10
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
11
+
12
+
13
+ class VQA(Dataset):
14
+ def __init__(self, config, train=True):
15
+ self.data_path = config['data_path']
16
+ self.label_path = config['label_path']
17
+ self.experts = config['experts']
18
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
19
+ self.train = train
20
+
21
+ if train:
22
+ self.data_list = []
23
+ if 'vqav2' in config['datasets']:
24
+ self.data_list += json.load(open(os.path.join(self.data_path, 'vqav2_train_val.json'), 'r'))
25
+ if 'vg' in config['datasets']:
26
+ self.data_list += json.load(open(os.path.join(self.data_path, 'vg_qa.json'), 'r'))
27
+ else:
28
+ self.data_list = json.load(open(os.path.join(self.data_path, 'vqav2_test.json'), 'r'))
29
+ self.answer_list = json.load(open(os.path.join(self.data_path, 'answer_list.json'), 'r'))
30
+
31
+ def __len__(self):
32
+ return len(self.data_list)
33
+
34
+ def __getitem__(self, index):
35
+ data = self.data_list[index]
36
+
37
+ if data['dataset'] == 'vqa':
38
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
39
+ elif data['dataset'] == 'vg':
40
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vg', self.experts)
41
+
42
+ experts = self.transform(image, labels)
43
+ experts = post_label_process(experts, labels_info)
44
+
45
+ if self.train:
46
+ question = pre_question(data['question'], max_words=30)
47
+ answers = data['answer']
48
+ weights = torch.tensor(data['weight']) if data['dataset'] != 'vg' else torch.tensor(0.2)
49
+ return experts, question, answers, weights
50
+ else:
51
+ question = pre_question(data['question'], max_words=30)
52
+ question_id = data['question_id']
53
+ return experts, index, question, question_id
prismer/demo.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ try:
5
+ import ruamel_yaml as yaml
6
+ except ModuleNotFoundError:
7
+ import ruamel.yaml as yaml
8
+
9
+
10
+ from model.prismer_caption import PrismerCaption
11
+ from dataset import create_dataset, create_loader
12
+ from tqdm import tqdm
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--mode', default='')
16
+ parser.add_argument('--port', default='')
17
+
18
+ parser.add_argument('--exp_name', default='', type=str)
19
+ args = parser.parse_args()
20
+
21
+ # load config
22
+ config = yaml.load(open('configs/caption.yaml', 'r'), Loader=yaml.Loader)['demo']
23
+
24
+ # generate expert labels
25
+ if len(config['experts']) > 0:
26
+ script_name = f'python experts/generate_depth.py'
27
+ os.system(script_name)
28
+ print('***** Generated Depth *****')
29
+
30
+ script_name = f'python experts/generate_edge.py'
31
+ os.system(script_name)
32
+ print('***** Generated Edge *****')
33
+
34
+ script_name = f'python experts/generate_normal.py'
35
+ os.system(script_name)
36
+ print('***** Generated Surface Normals *****')
37
+
38
+ script_name = f'python experts/generate_objdet.py'
39
+ os.system(script_name)
40
+ print('***** Generated Object Detection Labels *****')
41
+
42
+ script_name = f'python experts/generate_ocrdet.py'
43
+ os.system(script_name)
44
+ print('***** Generated OCR Detection Labels *****')
45
+
46
+ script_name = f'python experts/generate_segmentation.py'
47
+ os.system(script_name)
48
+ print('***** Generated Segmentation Labels *****')
49
+
50
+ # load datasets
51
+ _, test_dataset = create_dataset('caption', config)
52
+ test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
53
+
54
+ # load pre-trained model
55
+ model = PrismerCaption(config)
56
+ state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cuda:0')
57
+ model.load_state_dict(state_dict)
58
+ tokenizer = model.tokenizer
59
+
60
+ # inference
61
+ model.eval()
62
+ with torch.no_grad():
63
+ for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
64
+ captions = model(experts, train=False, prefix=config['prefix'])
65
+
66
+ captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
67
+ caption = captions.to(experts['rgb'].device)[0]
68
+
69
+ caption = tokenizer.decode(caption, skip_special_tokens=True)
70
+ caption = caption.capitalize() + '.'
71
+
72
+ # save caption
73
+ save_path = test_loader.dataset.data_list[data_ids[0]]['image'].replace('jpg', 'txt')
74
+ with open(save_path, 'w') as f:
75
+ f.write(caption)
76
+
77
+ print('All Done.')
prismer/demo_vis.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ from utils import create_ade20k_label_colormap
10
+
11
+ obj_label_map = torch.load('dataset/detection_features.pt')['labels']
12
+ coco_label_map = torch.load('dataset/coco_features.pt')['labels']
13
+ ade_color = create_ade20k_label_colormap()
14
+
15
+ file_path = 'helpers/images'
16
+ expert_path = 'helpers/labels'
17
+ plt.ioff()
18
+
19
+
20
+ def get_label_path(file_name, expert_name, with_suffix=False):
21
+ file_suffix = '.png' if not with_suffix else '_.png'
22
+ label_name = ''.join(file_name.split('.')[:-1] + [file_suffix])
23
+ label_path = os.path.join(expert_path, expert_name, label_name)
24
+ return label_path
25
+
26
+
27
+ def depth_prettify(file_name):
28
+ label_path = get_label_path(file_name, 'depth')
29
+ save_path = get_label_path(file_name, 'depth', True)
30
+ depth = plt.imread(label_path)
31
+ plt.imsave(save_path, depth, cmap='rainbow')
32
+
33
+
34
+ def obj_detection_prettify(file_name):
35
+ label_path = get_label_path(file_name, 'obj_detection')
36
+ save_path = get_label_path(file_name, 'obj_detection', True)
37
+
38
+ rgb = plt.imread(file_name)
39
+ obj_labels = plt.imread(label_path)
40
+ obj_labels_dict = json.load(open(label_path.replace('.png', '.json')))
41
+
42
+ plt.imshow(rgb)
43
+
44
+ num_objs = np.unique(obj_labels)[:-1].max()
45
+ plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.5)
46
+
47
+ for i in np.unique(obj_labels)[:-1]:
48
+ obj_idx_all = np.where(obj_labels == i)
49
+ obj_idx = random.randint(0, len(obj_idx_all[0]))
50
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
51
+ obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
52
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
53
+
54
+ plt.axis('off')
55
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
56
+ plt.close()
57
+
58
+
59
+ def seg_prettify(file_name):
60
+ label_path = get_label_path(file_name, 'seg_coco')
61
+ save_path = get_label_path(file_name, 'seg_coco', True)
62
+
63
+ rgb = plt.imread(file_name)
64
+ seg_labels = plt.imread(label_path)
65
+
66
+ plt.imshow(rgb)
67
+
68
+ seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
69
+ for i in np.unique(seg_labels):
70
+ seg_map[seg_labels == i] = ade_color[int(i * 255)]
71
+
72
+ plt.imshow(seg_map, alpha=0.5)
73
+
74
+ for i in np.unique(seg_labels):
75
+ obj_idx_all = np.where(seg_labels == i)
76
+ obj_idx = random.randint(0, len(obj_idx_all[0]))
77
+ x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
78
+ obj_name = coco_label_map[int(i * 255)]
79
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center')
80
+
81
+ plt.axis('off')
82
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
83
+ plt.close()
84
+
85
+
86
+ def ocr_detection_prettify(file_name):
87
+ label_path = get_label_path(file_name, 'ocr_detection')
88
+ save_path = get_label_path(file_name, 'ocr_detection', True)
89
+
90
+ if os.path.exists(label_path):
91
+ rgb = plt.imread(file_name)
92
+ ocr_labels = plt.imread(label_path)
93
+ ocr_labels_dict = torch.load(label_path.replace('.png', '.pt'))
94
+
95
+ plt.imshow(rgb)
96
+ plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
97
+
98
+ for i in np.unique(ocr_labels)[:-1]:
99
+ text_idx_all = np.where(ocr_labels == i)
100
+ x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
101
+ text = ocr_labels_dict[int(i * 255)]['text']
102
+ plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center')
103
+
104
+ plt.axis('off')
105
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
106
+ plt.close()
107
+ else:
108
+ rgb = plt.imread(file_name)
109
+ ocr_labels = np.ones_like(rgb, dtype=np.float32())
110
+
111
+ plt.imshow(rgb)
112
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
113
+
114
+ x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
115
+ plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center')
116
+
117
+ plt.axis('off')
118
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=0)
119
+ plt.close()
120
+
121
+
122
+ im_list = glob.glob(file_path + '/*.jpg') + glob.glob(file_path + '/*.png') + glob.glob(file_path + '/*.jpeg')
123
+
124
+ # prettify labels first:
125
+ for i in range(len(im_list)):
126
+ depth_prettify(im_list[i])
127
+ seg_prettify(im_list[i])
128
+ ocr_detection_prettify(im_list[i])
129
+ obj_detection_prettify(im_list[i])
130
+
131
+ pretty = {'depth': True, 'normal': False, 'edge': False,
132
+ 'obj_detection': True, 'ocr_detection': True, 'seg_coco': True}
133
+
134
+ # plot expert labels
135
+ for im_path in im_list:
136
+ fig, axs = plt.subplots(1, 7, figsize=(20, 4))
137
+ rgb = plt.imread(im_path)
138
+ axs[0].imshow(rgb)
139
+ axs[0].axis('off')
140
+ axs[0].set_title('RGB')
141
+
142
+ for j in range(6):
143
+ label_name = list(pretty.keys())[j]
144
+ label_path = get_label_path(im_path, label_name, with_suffix=pretty[label_name])
145
+ label = plt.imread(label_path)
146
+ if label_name != 'edge':
147
+ axs[j + 1].imshow(label)
148
+ else:
149
+ axs[j + 1].imshow(label, cmap='gray')
150
+
151
+ axs[j + 1].axis('off')
152
+ axs[j + 1].set_title(label_name)
153
+
154
+ caption_path = ''.join(im_path.split('.')[:-1] + ['.txt'])
155
+ with open(caption_path) as f:
156
+ caption = f.readlines()[0]
157
+
158
+ plt.suptitle(caption)
159
+ plt.tight_layout()
160
+
161
+ plt.show()
prismer/download_checkpoints.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata
2
+ from huggingface_hub.utils import disable_progress_bars
3
+ from pathlib import Path
4
+ from rich.progress import Progress
5
+ from fire import Fire
6
+ from typing import Union, List
7
+
8
+ _EXPERTS = [
9
+ "10_model.pth",
10
+ "Unified_learned_OCIM_RS200_6x+2x.pth",
11
+ "dpt_hybrid-midas-501f0c75.pt",
12
+ "icdar2015_hourglass88.pth",
13
+ "model_final_e0c58e.pkl",
14
+ "model_final_f07440.pkl",
15
+ "scannet.pt",
16
+ ]
17
+
18
+ _MODELS = [
19
+ "vqa_prismer_base",
20
+ "vqa_prismer_large",
21
+ "vqa_prismerz_base",
22
+ "vqa_prismerz_large",
23
+ "caption_prismerz_base",
24
+ "caption_prismerz_large",
25
+ "caption_prismer_base",
26
+ "caption_prismer_large",
27
+ "pretrain_prismer_base",
28
+ "pretrain_prismer_large",
29
+ "pretrain_prismerz_base",
30
+ "pretrain_prismerz_large",
31
+ ]
32
+
33
+ _REPO_ID = "lorenmt/prismer"
34
+
35
+
36
+ def download_checkpoints(
37
+ download_experts: bool = False,
38
+ download_models: Union[bool, List] = False,
39
+ hide_tqdm: bool = False,
40
+ force_redownload: bool = False,
41
+ ):
42
+ if hide_tqdm:
43
+ disable_progress_bars()
44
+ # Convert to list and check for invalid names
45
+ download_experts = _EXPERTS if download_experts else []
46
+ if download_models:
47
+ # only download single model
48
+ if isinstance(download_models, str):
49
+ download_models = [download_models]
50
+
51
+ assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}"
52
+ download_models = _MODELS if isinstance(download_models, bool) else download_models
53
+ else:
54
+ download_models = []
55
+
56
+ # Check if files already exist
57
+ if not force_redownload:
58
+ download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()]
59
+ download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()]
60
+
61
+ assert download_experts or download_models, "Nothing to download."
62
+
63
+ with Progress() as progress:
64
+ # Calculate total download size
65
+ progress.print("[blue]Calculating download size...")
66
+ total_size = 0
67
+ for expert in download_experts:
68
+ url = hf_hub_url(
69
+ filename=expert,
70
+ repo_id=_REPO_ID,
71
+ subfolder="expert_weights"
72
+ )
73
+ total_size += get_hf_file_metadata(url).size
74
+
75
+ for model in download_models:
76
+ url = hf_hub_url(
77
+ filename=f"pytorch_model.bin",
78
+ repo_id=_REPO_ID,
79
+ subfolder=model
80
+ )
81
+ total_size += get_hf_file_metadata(url).size
82
+ progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB")
83
+
84
+ # Download files
85
+ total_files = len(download_experts) + len(download_models)
86
+ total_task = progress.add_task(f"[green]Downloading files", total=total_files)
87
+ if download_experts:
88
+ expert_task = progress.add_task(
89
+ f"[green]Downloading experts...", total=len(download_experts)
90
+ )
91
+ out_folder = Path("experts/expert_weights")
92
+ out_folder.mkdir(parents=True, exist_ok=True)
93
+ for expert in download_experts:
94
+ path = Path(hf_hub_download(
95
+ filename=expert,
96
+ repo_id=_REPO_ID,
97
+ subfolder="expert_weights"
98
+ ))
99
+ path.resolve().rename(out_folder/path.name)
100
+ path.unlink()
101
+ progress.advance(expert_task)
102
+ progress.advance(total_task)
103
+
104
+ if download_models:
105
+ model_task = progress.add_task(
106
+ f"[green]Downloading models...", total=len(download_models)
107
+ )
108
+ for model in download_models:
109
+ path = Path(hf_hub_download(
110
+ filename=f"pytorch_model.bin",
111
+ repo_id=_REPO_ID,
112
+ subfolder=model
113
+ ))
114
+ out_folder = Path("./logging")/model
115
+ out_folder.mkdir(parents=True, exist_ok=True)
116
+ path.resolve().rename(out_folder/"pytorch_model.bin")
117
+ path.unlink()
118
+ progress.advance(model_task)
119
+ progress.advance(total_task)
120
+ progress.print("[green]Done!")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ Fire(download_checkpoints)
prismer/experts/depth/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device("cpu"))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
prismer/experts/depth/blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vit import (
5
+ _make_pretrained_vitb_rn50_384,
6
+ _make_pretrained_vitl16_384,
7
+ _make_pretrained_vitb16_384,
8
+ forward_vit,
9
+ )
10
+
11
+
12
+ def _make_encoder(
13
+ backbone,
14
+ features,
15
+ use_pretrained,
16
+ groups=1,
17
+ expand=False,
18
+ exportable=True,
19
+ hooks=None,
20
+ use_vit_only=False,
21
+ use_readout="ignore",
22
+ enable_attention_hooks=False,
23
+ ):
24
+ if backbone == "vitl16_384":
25
+ pretrained = _make_pretrained_vitl16_384(
26
+ use_pretrained,
27
+ hooks=hooks,
28
+ use_readout=use_readout,
29
+ enable_attention_hooks=enable_attention_hooks,
30
+ )
31
+ scratch = _make_scratch(
32
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
33
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
34
+ elif backbone == "vitb_rn50_384":
35
+ pretrained = _make_pretrained_vitb_rn50_384(
36
+ use_pretrained,
37
+ hooks=hooks,
38
+ use_vit_only=use_vit_only,
39
+ use_readout=use_readout,
40
+ enable_attention_hooks=enable_attention_hooks,
41
+ )
42
+ scratch = _make_scratch(
43
+ [256, 512, 768, 768], features, groups=groups, expand=expand
44
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
45
+ elif backbone == "vitb16_384":
46
+ pretrained = _make_pretrained_vitb16_384(
47
+ use_pretrained,
48
+ hooks=hooks,
49
+ use_readout=use_readout,
50
+ enable_attention_hooks=enable_attention_hooks,
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
55
+ elif backbone == "resnext101_wsl":
56
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57
+ scratch = _make_scratch(
58
+ [256, 512, 1024, 2048], features, groups=groups, expand=expand
59
+ ) # efficientnet_lite3
60
+ else:
61
+ print(f"Backbone '{backbone}' not implemented")
62
+ assert False
63
+
64
+ return pretrained, scratch
65
+
66
+
67
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68
+ scratch = nn.Module()
69
+
70
+ out_shape1 = out_shape
71
+ out_shape2 = out_shape
72
+ out_shape3 = out_shape
73
+ out_shape4 = out_shape
74
+ if expand == True:
75
+ out_shape1 = out_shape
76
+ out_shape2 = out_shape * 2
77
+ out_shape3 = out_shape * 4
78
+ out_shape4 = out_shape * 8
79
+
80
+ scratch.layer1_rn = nn.Conv2d(
81
+ in_shape[0],
82
+ out_shape1,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=False,
87
+ groups=groups,
88
+ )
89
+ scratch.layer2_rn = nn.Conv2d(
90
+ in_shape[1],
91
+ out_shape2,
92
+ kernel_size=3,
93
+ stride=1,
94
+ padding=1,
95
+ bias=False,
96
+ groups=groups,
97
+ )
98
+ scratch.layer3_rn = nn.Conv2d(
99
+ in_shape[2],
100
+ out_shape3,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups,
106
+ )
107
+ scratch.layer4_rn = nn.Conv2d(
108
+ in_shape[3],
109
+ out_shape4,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1,
113
+ bias=False,
114
+ groups=groups,
115
+ )
116
+
117
+ return scratch
118
+
119
+
120
+ def _make_resnet_backbone(resnet):
121
+ pretrained = nn.Module()
122
+ pretrained.layer1 = nn.Sequential(
123
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124
+ )
125
+
126
+ pretrained.layer2 = resnet.layer2
127
+ pretrained.layer3 = resnet.layer3
128
+ pretrained.layer4 = resnet.layer4
129
+
130
+ return pretrained
131
+
132
+
133
+ def _make_pretrained_resnext101_wsl(use_pretrained):
134
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135
+ return _make_resnet_backbone(resnet)
136
+
137
+
138
+ class Interpolate(nn.Module):
139
+ """Interpolation module."""
140
+
141
+ def __init__(self, scale_factor, mode, align_corners=False):
142
+ """Init.
143
+
144
+ Args:
145
+ scale_factor (float): scaling
146
+ mode (str): interpolation mode
147
+ """
148
+ super(Interpolate, self).__init__()
149
+
150
+ self.interp = nn.functional.interpolate
151
+ self.scale_factor = scale_factor
152
+ self.mode = mode
153
+ self.align_corners = align_corners
154
+
155
+ def forward(self, x):
156
+ """Forward pass.
157
+
158
+ Args:
159
+ x (tensor): input
160
+
161
+ Returns:
162
+ tensor: interpolated data_list
163
+ """
164
+
165
+ x = self.interp(
166
+ x,
167
+ scale_factor=self.scale_factor,
168
+ mode=self.mode,
169
+ align_corners=self.align_corners,
170
+ )
171
+
172
+ return x
173
+
174
+
175
+ class ResidualConvUnit(nn.Module):
176
+ """Residual convolution module."""
177
+
178
+ def __init__(self, features):
179
+ """Init.
180
+
181
+ Args:
182
+ features (int): number of features
183
+ """
184
+ super().__init__()
185
+
186
+ self.conv1 = nn.Conv2d(
187
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
188
+ )
189
+
190
+ self.conv2 = nn.Conv2d(
191
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
192
+ )
193
+
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ def forward(self, x):
197
+ """Forward pass.
198
+
199
+ Args:
200
+ x (tensor): input
201
+
202
+ Returns:
203
+ tensor: output
204
+ """
205
+ out = self.relu(x)
206
+ out = self.conv1(out)
207
+ out = self.relu(out)
208
+ out = self.conv2(out)
209
+
210
+ return out + x
211
+
212
+
213
+ class FeatureFusionBlock(nn.Module):
214
+ """Feature fusion block."""
215
+
216
+ def __init__(self, features):
217
+ """Init.
218
+
219
+ Args:
220
+ features (int): number of features
221
+ """
222
+ super(FeatureFusionBlock, self).__init__()
223
+
224
+ self.resConfUnit1 = ResidualConvUnit(features)
225
+ self.resConfUnit2 = ResidualConvUnit(features)
226
+
227
+ def forward(self, *xs):
228
+ """Forward pass.
229
+
230
+ Returns:
231
+ tensor: output
232
+ """
233
+ output = xs[0]
234
+
235
+ if len(xs) == 2:
236
+ output += self.resConfUnit1(xs[1])
237
+
238
+ output = self.resConfUnit2(output)
239
+
240
+ output = nn.functional.interpolate(
241
+ output, scale_factor=2, mode="bilinear", align_corners=True
242
+ )
243
+
244
+ return output
245
+
246
+
247
+ class ResidualConvUnit_custom(nn.Module):
248
+ """Residual convolution module."""
249
+
250
+ def __init__(self, features, activation, bn):
251
+ """Init.
252
+
253
+ Args:
254
+ features (int): number of features
255
+ """
256
+ super().__init__()
257
+
258
+ self.bn = bn
259
+
260
+ self.groups = 1
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ features,
264
+ features,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ bias=not self.bn,
269
+ groups=self.groups,
270
+ )
271
+
272
+ self.conv2 = nn.Conv2d(
273
+ features,
274
+ features,
275
+ kernel_size=3,
276
+ stride=1,
277
+ padding=1,
278
+ bias=not self.bn,
279
+ groups=self.groups,
280
+ )
281
+
282
+ if self.bn == True:
283
+ self.bn1 = nn.BatchNorm2d(features)
284
+ self.bn2 = nn.BatchNorm2d(features)
285
+
286
+ self.activation = activation
287
+
288
+ self.skip_add = nn.quantized.FloatFunctional()
289
+
290
+ def forward(self, x):
291
+ """Forward pass.
292
+
293
+ Args:
294
+ x (tensor): input
295
+
296
+ Returns:
297
+ tensor: output
298
+ """
299
+
300
+ out = self.activation(x)
301
+ out = self.conv1(out)
302
+ if self.bn == True:
303
+ out = self.bn1(out)
304
+
305
+ out = self.activation(out)
306
+ out = self.conv2(out)
307
+ if self.bn == True:
308
+ out = self.bn2(out)
309
+
310
+ if self.groups > 1:
311
+ out = self.conv_merge(out)
312
+
313
+ return self.skip_add.add(out, x)
314
+
315
+ # return out + x
316
+
317
+
318
+ class FeatureFusionBlock_custom(nn.Module):
319
+ """Feature fusion block."""
320
+
321
+ def __init__(
322
+ self,
323
+ features,
324
+ activation,
325
+ deconv=False,
326
+ bn=False,
327
+ expand=False,
328
+</