shikunl commited on
Commit
bd0d673
β€’
1 Parent(s): 65af262

Add Dockerfile

Browse files
README.md CHANGED
@@ -5,7 +5,6 @@ colorFrom: yellow
5
  colorTo: purple
6
  sdk: docker
7
  sdk_version: 3.20.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: purple
6
  sdk: docker
7
  sdk_version: 3.20.1
 
8
  pinned: false
9
  ---
10
 
prismer/dataset/ade_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb5c802ef9eaf9c9213219bdc5f8f2d6579400275ab443c43905fdcdd805e5a
3
+ size 41899
prismer/dataset/background_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f5bf8e158715633a9e0f307be1947ca74c294e522b705e1e61bd49361088092
3
+ size 1003
prismer/dataset/classification_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+ from torch.utils.data import Dataset
9
+ from dataset.utils import *
10
+
11
+
12
+ class Classification(Dataset):
13
+ def __init__(self, config, train):
14
+ self.data_path = config['data_path']
15
+ self.label_path = config['label_path']
16
+ self.experts = config['experts']
17
+ self.dataset = config['dataset']
18
+ self.shots = config['shots']
19
+ self.prefix = config['prefix']
20
+
21
+ self.train = train
22
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=True)
23
+
24
+ if train:
25
+ data_folders = glob.glob(f'{self.data_path}/imagenet_train/*/')
26
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')[:self.shots]]
27
+ self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
28
+ self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
29
+ else:
30
+ data_folders = glob.glob(f'{self.data_path}/imagenet/*/')
31
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')]
32
+ self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
33
+ self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
34
+
35
+ def __len__(self):
36
+ return len(self.data_list)
37
+
38
+ def __getitem__(self, index):
39
+ img_path = self.data_list[index]['image']
40
+ if self.train:
41
+ img_path_split = img_path.split('/')
42
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
43
+ class_name = img_path_split[-2]
44
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet_train', self.experts)
45
+ else:
46
+ img_path_split = img_path.split('/')
47
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
48
+ class_name = img_path_split[-2]
49
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet', self.experts)
50
+
51
+ experts = self.transform(image, labels)
52
+ experts = post_label_process(experts, labels_info)
53
+
54
+ if self.train:
55
+ caption = self.prefix + ' ' + self.answer_list[int(self.class_list[class_name])].lower()
56
+ return experts, caption
57
+ else:
58
+ return experts, self.class_list[class_name]
59
+
60
+
61
+
62
+
63
+
64
+ # import os
65
+ # import glob
66
+ #
67
+ # data_path = '/Users/shikunliu/Documents/dataset/mscoco/mscoco'
68
+ #
69
+ # data_folders = glob.glob(f'{data_path}/*/')
70
+ # data_list = [data for f in data_folders for data in glob.glob(f + '*.jpg')]
71
+
72
+
prismer/dataset/clip_pca.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:064950869bce7d622ff7e733dc64c8e1be18be926c3f64e66ae57164df3f8bc5
3
+ size 401612
prismer/dataset/pretrain_dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from dataset.utils import *
11
+
12
+
13
+ class Pretrain(Dataset):
14
+ def __init__(self, config):
15
+ self.cc12m_data_path = config['cc12m_data_path']
16
+ self.cc3m_data_path = config['cc3m_data_path']
17
+ self.coco_data_path = config['coco_data_path']
18
+ self.vg_data_path = config['vg_data_path']
19
+ self.label_path = config['label_path']
20
+ self.experts = config['experts']
21
+
22
+ self.data_list = []
23
+ if 'cc12m' in config['datasets']:
24
+ data_folders = glob.glob(f'{self.cc12m_data_path}/cc12m/*/')
25
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
26
+ if 'cc3m_sgu' in config['datasets']:
27
+ data_folders = glob.glob(f'{self.cc3m_data_path}/cc3m_sgu/*/')
28
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
29
+ if 'coco' in config['datasets']:
30
+ self.data_list += json.load(open(os.path.join(self.coco_data_path, 'coco_karpathy_train.json'), 'r'))
31
+ if 'vg' in config['datasets']:
32
+ self.data_list += json.load(open(os.path.join(self.vg_data_path, 'vg_caption.json'), 'r'))
33
+
34
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.5], train=True)
35
+
36
+ def __len__(self):
37
+ return len(self.data_list)
38
+
39
+ def __getitem__(self, index):
40
+ img_path = self.data_list[index]['image']
41
+
42
+ if 'cc12m' in img_path:
43
+ img_path_split = img_path.split('/')
44
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
45
+ image, labels, labels_info = get_expert_labels(self.cc12m_data_path, self.label_path, img_name, 'cc12m', self.experts)
46
+
47
+ caption_path = img_path.replace('.jpg', '.txt')
48
+ with open(caption_path) as f:
49
+ caption = f.readlines()[0]
50
+
51
+ elif 'cc3m_sgu' in img_path:
52
+ img_path_split = img_path.split('/')
53
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
54
+ image, labels, labels_info = get_expert_labels(self.cc3m_data_path, self.label_path, img_name, 'cc3m_sgu', self.experts)
55
+
56
+ caption_path = img_path.replace('.jpg', '.txt')
57
+ with open(caption_path) as f:
58
+ caption = f.readlines()[0]
59
+
60
+ elif 'train2014' in img_path or 'val2014' in img_path:
61
+ image, labels, labels_info = get_expert_labels(self.coco_data_path, self.label_path, img_path, 'vqav2', self.experts)
62
+ caption = self.data_list[index]['caption']
63
+
64
+ elif 'visual-genome' in img_path:
65
+ img_path_split = img_path.split('/')
66
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
67
+ image, labels, labels_info = get_expert_labels(self.vg_data_path, self.label_path, img_name, 'vg', self.experts)
68
+ caption = self.data_list[index]['caption']
69
+
70
+ experts = self.transform(image, labels)
71
+ experts = post_label_process(experts, labels_info)
72
+ caption = pre_caption(caption, max_words=30)
73
+ return experts, caption
prismer/dataset/randaugment.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+
11
+ from PIL import Image, ImageOps, ImageEnhance, ImageDraw
12
+
13
+
14
+ fillmask = {'depth': 0, 'normal': 0, 'edge': 0, 'seg_coco': 255, 'seg_ade': 255,
15
+ 'obj_detection': 255, 'ocr_detection': 255}
16
+ fillcolor = (0, 0, 0)
17
+
18
+
19
+ def affine_transform(pair, affine_params):
20
+ img, label = pair
21
+ img = img.transform(img.size, Image.AFFINE, affine_params,
22
+ resample=Image.BILINEAR, fillcolor=fillcolor)
23
+ if label is not None:
24
+ for exp in label:
25
+ label[exp] = label[exp].transform(label[exp].size, Image.AFFINE, affine_params,
26
+ resample=Image.NEAREST, fillcolor=fillmask[exp])
27
+ return img, label
28
+
29
+
30
+ def ShearX(pair, v): # [-0.3, 0.3]
31
+ assert -0.3 <= v <= 0.3
32
+ if random.random() > 0.5:
33
+ v = -v
34
+ return affine_transform(pair, (1, v, 0, 0, 1, 0))
35
+
36
+
37
+ def ShearY(pair, v): # [-0.3, 0.3]
38
+ assert -0.3 <= v <= 0.3
39
+ if random.random() > 0.5:
40
+ v = -v
41
+ return affine_transform(pair, (1, 0, 0, v, 1, 0))
42
+
43
+
44
+ def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
45
+ assert -0.45 <= v <= 0.45
46
+ if random.random() > 0.5:
47
+ v = -v
48
+ img, _ = pair
49
+ v = v * img.size[0]
50
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
51
+
52
+
53
+ def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
54
+ assert -0.45 <= v <= 0.45
55
+ if random.random() > 0.5:
56
+ v = -v
57
+ img, _ = pair
58
+ v = v * img.size[1]
59
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
60
+
61
+
62
+ def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
63
+ assert 0 <= v <= 10
64
+ if random.random() > 0.5:
65
+ v = -v
66
+ return affine_transform(pair, (1, 0, v, 0, 1, 0))
67
+
68
+
69
+ def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45]
70
+ assert 0 <= v <= 10
71
+ if random.random() > 0.5:
72
+ v = -v
73
+ return affine_transform(pair, (1, 0, 0, 0, 1, v))
74
+
75
+
76
+ def Rotate(pair, v): # [-30, 30]
77
+ assert -30 <= v <= 30
78
+ if random.random() > 0.5:
79
+ v = -v
80
+ img, label = pair
81
+ img = img.rotate(v, fillcolor=fillcolor)
82
+ if label is not None:
83
+ for exp in label:
84
+ label[exp] = label[exp].rotate(v, resample=Image.NEAREST, fillcolor=fillmask[exp])
85
+ return img, label
86
+
87
+
88
+ def AutoContrast(pair, _):
89
+ img, label = pair
90
+ return ImageOps.autocontrast(img), label
91
+
92
+
93
+ def Invert(pair, _):
94
+ img, label = pair
95
+ return ImageOps.invert(img), label
96
+
97
+
98
+ def Equalize(pair, _):
99
+ img, label = pair
100
+ return ImageOps.equalize(img), label
101
+
102
+
103
+ def Flip(pair, _): # not from the paper
104
+ img, label = pair
105
+ return ImageOps.mirror(img), ImageOps.mirror(label)
106
+
107
+
108
+ def Solarize(pair, v): # [0, 256]
109
+ img, label = pair
110
+ assert 0 <= v <= 256
111
+ return ImageOps.solarize(img, v), label
112
+
113
+
114
+ def Posterize(pair, v): # [4, 8]
115
+ img, label = pair
116
+ assert 4 <= v <= 8
117
+ v = int(v)
118
+ return ImageOps.posterize(img, v), label
119
+
120
+
121
+ def Posterize2(pair, v): # [0, 4]
122
+ img, label = pair
123
+ assert 0 <= v <= 4
124
+ v = int(v)
125
+ return ImageOps.posterize(img, v), label
126
+
127
+
128
+ def Contrast(pair, v): # [0.1,1.9]
129
+ img, label = pair
130
+ assert 0.1 <= v <= 1.9
131
+ return ImageEnhance.Contrast(img).enhance(v), label
132
+
133
+
134
+ def Color(pair, v): # [0.1,1.9]
135
+ img, label = pair
136
+ assert 0.1 <= v <= 1.9
137
+ return ImageEnhance.Color(img).enhance(v), label
138
+
139
+
140
+ def Brightness(pair, v): # [0.1,1.9]
141
+ img, label = pair
142
+ assert 0.1 <= v <= 1.9
143
+ return ImageEnhance.Brightness(img).enhance(v), label
144
+
145
+
146
+ def Sharpness(pair, v): # [0.1,1.9]
147
+ img, label = pair
148
+ assert 0.1 <= v <= 1.9
149
+ return ImageEnhance.Sharpness(img).enhance(v), label
150
+
151
+
152
+ def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2]
153
+ assert 0.0 <= v <= 0.2
154
+ if v <= 0.:
155
+ return pair
156
+ img, label = pair
157
+ v = v * img.size[0]
158
+ return CutoutAbs(img, v), label
159
+
160
+
161
+ def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
162
+ # assert 0 <= v <= 20
163
+ if v < 0:
164
+ return img
165
+ w, h = img.size
166
+ x0 = np.random.uniform(w)
167
+ y0 = np.random.uniform(h)
168
+
169
+ x0 = int(max(0, x0 - v / 2.))
170
+ y0 = int(max(0, y0 - v / 2.))
171
+ x1 = min(w, x0 + v)
172
+ y1 = min(h, y0 + v)
173
+
174
+ xy = (x0, y0, x1, y1)
175
+ color = (125, 123, 114)
176
+ # color = (0, 0, 0)
177
+ img = img.copy()
178
+ ImageDraw.Draw(img).rectangle(xy, color)
179
+ return img
180
+
181
+
182
+ def Identity(pair, v):
183
+ return pair
184
+
185
+
186
+ def augment_list(): # 16 oeprations and their ranges
187
+ # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
188
+ l = [
189
+ (Identity, 0., 1.0),
190
+ (ShearX, 0., 0.3), # 0
191
+ (ShearY, 0., 0.3), # 1
192
+ (TranslateX, 0., 0.33), # 2
193
+ (TranslateY, 0., 0.33), # 3
194
+ (Rotate, 0, 30), # 4
195
+ (AutoContrast, 0, 1), # 5
196
+ # (Invert, 0, 1), # 6
197
+ (Equalize, 0, 1), # 7
198
+ # (Solarize, 0, 110), # 8
199
+ # (Posterize, 4, 8), # 9
200
+ # (Color, 0.1, 1.9), # 11
201
+ (Brightness, 0.1, 1.9), # 12
202
+ (Sharpness, 0.1, 1.9), # 13
203
+ ]
204
+ return l
205
+
206
+
207
+ class Lighting(object):
208
+ """Lighting noise(AlexNet - style PCA - based noise)"""
209
+
210
+ def __init__(self, alphastd, eigval, eigvec):
211
+ self.alphastd = alphastd
212
+ self.eigval = torch.Tensor(eigval)
213
+ self.eigvec = torch.Tensor(eigvec)
214
+
215
+ def __call__(self, img):
216
+ if self.alphastd == 0:
217
+ return img
218
+
219
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
220
+ rgb = self.eigvec.type_as(img).clone() \
221
+ .mul(alpha.view(1, 3).expand(3, 3)) \
222
+ .mul(self.eigval.view(1, 3).expand(3, 3)) \
223
+ .sum(1).squeeze()
224
+
225
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
226
+
227
+
228
+ class CutoutDefault(object):
229
+ """
230
+ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
231
+ """
232
+ def __init__(self, length):
233
+ self.length = length
234
+
235
+ def __call__(self, img):
236
+ h, w = img.size(1), img.size(2)
237
+ mask = np.ones((h, w), np.float32)
238
+ y = np.random.randint(h)
239
+ x = np.random.randint(w)
240
+
241
+ y1 = np.clip(y - self.length // 2, 0, h)
242
+ y2 = np.clip(y + self.length // 2, 0, h)
243
+ x1 = np.clip(x - self.length // 2, 0, w)
244
+ x2 = np.clip(x + self.length // 2, 0, w)
245
+
246
+ mask[y1: y2, x1: x2] = 0.
247
+ mask = torch.from_numpy(mask)
248
+ mask = mask.expand_as(img)
249
+ img *= mask
250
+ return img
251
+
252
+
253
+ class RandAugment:
254
+ def __init__(self, n, m):
255
+ self.n = n
256
+ self.m = m # [0, 10]
257
+ self.augment_list = augment_list()
258
+
259
+ def __call__(self, img, label):
260
+ pair = img, label
261
+ ops = random.choices(self.augment_list, k=self.n)
262
+ for op, minval, maxval in ops:
263
+ val = (float(self.m) / 10) * float(maxval - minval) + minval
264
+ pair = op(pair, val)
265
+ return pair
prismer_model.py CHANGED
@@ -7,14 +7,13 @@ import sys
7
  import cv2
8
  import torch
9
 
10
- from prismer.dataset import create_dataset, create_loader
11
- from prismer.model.prismer_caption import PrismerCaption
12
-
13
-
14
  repo_dir = pathlib.Path(__file__).parent
15
  submodule_dir = repo_dir / 'prismer'
16
  sys.path.insert(0, submodule_dir.as_posix())
17
 
 
 
 
18
 
19
  def download_models() -> None:
20
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
 
7
  import cv2
8
  import torch
9
 
 
 
 
 
10
  repo_dir = pathlib.Path(__file__).parent
11
  submodule_dir = repo_dir / 'prismer'
12
  sys.path.insert(0, submodule_dir.as_posix())
13
 
14
+ from dataset import create_dataset, create_loader
15
+ from model.prismer_caption import PrismerCaption
16
+
17
 
18
  def download_models() -> None:
19
  if not pathlib.Path('prismer/experts/expert_weights/').exists():