Alberto Carmona commited on
Commit
35df8d2
1 Parent(s): 2773b59

Add required folders and files

Browse files
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+
21
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
+ # N=num_layers
23
+ # d_model=input_encoding_size
24
+ # d_ff=rnn_size
25
+
26
+ # will be ignored
27
+ num_layers: 6
28
+ input_encoding_size: 512
29
+ rnn_size: 2048
30
+
31
+ # Transformer config
32
+ N_enc: 6
33
+ N_dec: 6
34
+ d_model: 512
35
+ d_ff: 2048
36
+ num_att_heads: 8
37
+ dropout: 0.1
38
+
39
+
40
+ learning_rate_decay_start: 0
41
+ scheduled_sampling_start: -1
42
+ save_checkpoint_every: 3000
43
+ language_eval: 0
44
+ val_images_use: 5000
45
+ max_epochs: 15
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+ # _BASE_: transformer.yml
51
+ reduce_on_plateau: false
52
+ noamopt: false
53
+ learning_rate: 0.000005
54
+ learning_rate_decay_start: -1
55
+
56
+ self_critical_after: 15
57
+ max_epochs: 50
58
+
59
+ verbose: false
60
+ precision: 32
61
+
62
+ # use_clipscore: true
63
+ use_clipscore: false
64
+ clipscore_reward_weight: 2.0
configs/phase2/clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+ clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
21
+
22
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
23
+ # N=num_layers
24
+ # d_model=input_encoding_size
25
+ # d_ff=rnn_size
26
+
27
+ # will be ignored
28
+ num_layers: 6
29
+ input_encoding_size: 512
30
+ rnn_size: 2048
31
+
32
+ # Transformer config
33
+ N_enc: 6
34
+ N_dec: 6
35
+ d_model: 512
36
+ d_ff: 2048
37
+ num_att_heads: 8
38
+ dropout: 0.1
39
+
40
+
41
+ learning_rate_decay_start: 0
42
+ scheduled_sampling_start: -1
43
+ save_checkpoint_every: 3000
44
+ language_eval: 1
45
+ val_images_use: 5000
46
+ max_epochs: 15
47
+ train_sample_n: 5
48
+
49
+ REFORWARD: false
50
+
51
+ # _BASE_: transformer.yml
52
+ reduce_on_plateau: false
53
+ noamopt: false
54
+ learning_rate: 0.000005
55
+ learning_rate_decay_start: -1
56
+
57
+ self_critical_after: 15
58
+ max_epochs: 40
59
+
60
+ verbose: false
61
+ precision: 32
62
+
63
+ use_clipscore: true
64
+ clipscore_reward_weight: 2.0
configs/phase2/transformer.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_att_dir: data/cocotalk_att
8
+ seq_per_img: 5
9
+ batch_size: 10
10
+ learning_rate: 0.0005
11
+
12
+ checkpoint_path: ./save/trans_rn50_sc
13
+
14
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
+ # N=num_layers
16
+ # d_model=input_encoding_size
17
+ # d_ff=rnn_size
18
+
19
+ # will be ignored
20
+ num_layers: 6
21
+ input_encoding_size: 512
22
+ rnn_size: 2048
23
+
24
+ # Transformer config
25
+ N_enc: 6
26
+ N_dec: 6
27
+ d_model: 512
28
+ d_ff: 2048
29
+ num_att_heads: 8
30
+ dropout: 0.1
31
+
32
+
33
+ learning_rate_decay_start: 0
34
+ scheduled_sampling_start: -1
35
+ save_checkpoint_every: 3000
36
+ language_eval: 1
37
+ val_images_use: 5000
38
+ max_epochs: 15
39
+ train_sample_n: 5
40
+
41
+ REFORWARD: false
data/README.md ADDED
@@ -0,0 +1 @@
 
1
+ directory to store preprocessed files
retrieval/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ # Finetuning CLIP reward model
2
+
3
+ ```bash
4
+ python train_pl.py --cfg clip_negative_text --id clip_negative_text
5
+ ```
retrieval/caption_data.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset, Sampler
2
+ from pathlib import Path
3
+ import json
4
+ from multiprocessing import Pool
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ import torchvision.transforms as T
12
+
13
+ from torch.utils.data.distributed import DistributedSampler
14
+
15
+ from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer
16
+
17
+ import text_utils
18
+
19
+ project_dir = Path(__file__).parent.resolve()
20
+ workspace_dir = project_dir.parent.parent
21
+ dataset_dir = workspace_dir.joinpath('datasets/').resolve()
22
+ # coco_dir = dataset_dir.joinpath('COCO')
23
+ # vg_dir = dataset_dir.joinpath('VG')
24
+ coco_img_dir = dataset_dir.joinpath('COCO/images/')
25
+ coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/')
26
+ # coco_feature_dir = coco_dir.joinpath('features')
27
+
28
+
29
+ class COCORetrievalDataset(Dataset):
30
+ def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'):
31
+ super().__init__()
32
+
33
+ self.topk = topk
34
+ self.verbose = verbose
35
+ self.args = args
36
+ self.rank = rank
37
+ self.mode = mode
38
+
39
+ # Loading datasets to data
40
+ self.source = split
41
+ if self.verbose:
42
+ print('Data source: ', self.source)
43
+
44
+ # if self.args.tokenizer is None:
45
+ # self.args.tokenizer = self.args.decoder_backbone
46
+
47
+ # if 'bert' in self.args.tokenizer:
48
+ # self.tokenizer = BertTokenizerFast.from_pretrained(
49
+ # self.args.tokenizer,
50
+ # # max_length=self.args.max_text_length,
51
+ # # do_lower_case=self.args.do_lower_case
52
+ # )
53
+ # elif 'clip' in self.args.tokenizer:
54
+ # self.tokenizer = CLIPTokenizer.from_pretrained(
55
+ # self.args.tokenizer,
56
+ # # max_length=self.args.max_text_length,
57
+ # # do_lower_case=self.args.do_lower_case
58
+ # )
59
+
60
+ self.tokenizer = CLIPTokenizer.from_pretrained(
61
+ self.args.tokenizer,
62
+ # max_length=self.args.max_text_length,
63
+ # do_lower_case=self.args.do_lower_case
64
+ )
65
+
66
+ with open(coco_data_dir.joinpath('cocotalk.json')) as f:
67
+ self.vocab = list(json.load(f)['ix_to_word'].values())
68
+ popped = self.vocab.pop(-1)
69
+ assert popped == 'UNK'
70
+ if self.verbose:
71
+ print('vocab size: ', len(self.vocab))
72
+
73
+
74
+ data_info_path = coco_data_dir.joinpath('dataset_coco.json')
75
+ with open(data_info_path) as f:
76
+ karpathy_data = json.load(f)
77
+
78
+ split_rename = {
79
+ 'train': 'train',
80
+ 'restval': 'train',
81
+ 'val': 'val',
82
+ 'test': 'test'
83
+ }
84
+
85
+ n_images = 0
86
+
87
+ data = []
88
+ # self.vocab = set()
89
+ for datum in karpathy_data['images']:
90
+ re_split = split_rename[datum['split']]
91
+
92
+ # if re_split == 'train':
93
+ # for d in datum['sentences']:
94
+ # self.vocab = self.vocab.union(set(d['tokens']))
95
+
96
+ if re_split != self.source.split('_')[-1]:
97
+ continue
98
+
99
+ if re_split == 'train':
100
+ # for d in datum['sentences']:
101
+ # img_id = datum['filename'].split('.')[0]
102
+ # new_datum = {
103
+ # 'filename': datum['filename'],
104
+ # 'img_id': img_id,
105
+ # 'sent': d['raw'].strip(),
106
+ # 'targets': [d['raw'].strip() for d in datum['sentences']],
107
+ # 'is_train': True,
108
+ # 'cocoid': datum['cocoid']
109
+ # }
110
+ # data.append(new_datum)
111
+ img_id = datum['filename'].split('.')[0]
112
+ new_datum = {
113
+ 'filename': datum['filename'],
114
+ 'img_id': img_id,
115
+ # 'sent': d['raw'],
116
+ # 'targets': [d['raw'].strip() for d in datum['sentences']],
117
+ 'targets': [" ".join(d['tokens']) for d in datum['sentences']],
118
+ 'is_train': True,
119
+ 'cocoid': datum['cocoid']
120
+ }
121
+ data.append(new_datum)
122
+
123
+ else:
124
+ img_id = datum['filename'].split('.')[0]
125
+ new_datum = {
126
+ 'filename': datum['filename'],
127
+ 'img_id': img_id,
128
+ # 'sent': d['raw'],
129
+ # 'targets': [d['raw'].strip() for d in datum['sentences']],
130
+ 'targets': [" ".join(d['tokens']) for d in datum['sentences']],
131
+ 'is_train': False,
132
+ 'cocoid': datum['cocoid']
133
+ }
134
+ data.append(new_datum)
135
+
136
+ n_images += 1
137
+
138
+ if self.verbose:
139
+ print(f"{self.source} has {n_images} images")
140
+ # print(f"Loaded {len(data)} data from", split)
141
+
142
+ self.n_gpus = torch.cuda.device_count()
143
+
144
+ if self.topk > 0:
145
+ data = data[:self.topk]
146
+ if self.verbose:
147
+ print(f"Use only {self.topk} data")
148
+
149
+ self.data = data
150
+
151
+ # if self.verbose:
152
+ # print("# all sentences:", len(self.data))
153
+
154
+ if self.args.load_feat:
155
+ # feat_dir = coco_dir.joinpath(''
156
+ # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False)
157
+ self.feat_loader = HybridLoader(
158
+ coco_data_dir.joinpath('cocotalk_clipscore_vis'),
159
+ ext='.npy', in_memory=False)
160
+ else:
161
+ if 'openai/clip' in self.args.encoder_backbone:
162
+ # from transformers import CLIPProcessor
163
+ # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32",
164
+ # size=args.image_size,
165
+ # do_resize=True,
166
+ # do_center_crop=False,
167
+ # )
168
+ # self.img_transform = lambda image: self.processor.feature_extractor(
169
+ # image,
170
+ # return_tensors='pt')['pixel_values'][0]
171
+
172
+ self.image_mean = [0.48145466, 0.4578275, 0.40821073]
173
+ self.image_std = [0.26862954, 0.26130258, 0.27577711]
174
+
175
+ # captioning
176
+ # self.img_transform = T.Compose([
177
+ # T.Resize((self.args.image_size, self.args.image_size))
178
+ # ])
179
+
180
+ # retrieval
181
+ self.img_transform = T.Compose([
182
+ T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC),
183
+ T.CenterCrop(self.args.image_size)
184
+ ])
185
+
186
+ self.img_tensor_transform = T.Compose([
187
+ # T.RandomCrop(224),
188
+ # T.RandomHorizontalFlip(p=0.3),
189
+ T.ConvertImageDtype(torch.float),
190
+ T.Normalize(self.image_mean, self.image_std)
191
+ ]
192
+ )
193
+ # elif 'google/vit' in self.args.encoder_backbone:
194
+ # self.image_mean = [0.5, 0.5, 0.5]
195
+ # self.image_std = [0.5, 0.5, 0.5]
196
+
197
+ # self.img_transform = T.Compose([
198
+ # # T.PILToTensor(),
199
+ # T.Resize((self.args.image_size, self.args.image_size))
200
+ # ])
201
+
202
+ # self.img_tensor_transform = T.Compose([
203
+ # # T.RandomCrop(224),
204
+ # # T.RandomHorizontalFlip(p=0.3),
205
+ # T.ConvertImageDtype(torch.float),
206
+ # T.Normalize(self.image_mean, self.image_std)
207
+ # ]
208
+ # )
209
+
210
+ def get_negative_text(self, text):
211
+ neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle'])
212
+
213
+ if neg_type == 'repeat':
214
+ text = text_utils.repeat(text)
215
+ elif neg_type == 'remove':
216
+ text = text_utils.remove(text)
217
+ elif neg_type == 'insert':
218
+ text = text_utils.insert(text, self.vocab)
219
+ elif neg_type == 'swap':
220
+ text = text_utils.swap(text, self.vocab)
221
+ elif neg_type == 'shuffle':
222
+ text = text_utils.shuffle(text)
223
+
224
+ return text, neg_type
225
+
226
+ def __len__(self):
227
+ return len(self.data)
228
+
229
+ def __getitem__(self, idx):
230
+ datum = self.data[idx]
231
+ return self.process_datum(datum)
232
+
233
+ def process_datum(self, datum):
234
+ out_dict = {}
235
+
236
+ ###### Image ######
237
+
238
+ if self.args.load_feat:
239
+ cocoid = datum['cocoid']
240
+ out_dict['cocoid'] = str(cocoid)
241
+ img_feat = self.feat_loader.get(str(cocoid))
242
+ out_dict['img_feat'] = torch.from_numpy(img_feat)
243
+
244
+ else:
245
+ img_id = datum['img_id']
246
+ out_dict['img_id'] = img_id
247
+
248
+ if 'train' in datum['filename']:
249
+ img_split = 'train2014'
250
+ elif 'val' in datum['filename']:
251
+ img_split = 'val2014'
252
+ img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg')
253
+ assert img_path.exists()
254
+ img_path = str(img_path)
255
+ out_dict['img_path'] = img_path
256
+
257
+ img_tensor = torchvision.io.read_image(img_path)
258
+ # out_dict['img_tensor'] = img
259
+
260
+ # img = Image.open(img_path).convert('RGB')
261
+ # img_tensor = torch.as_tensor(np.asarray(img))
262
+ out_dict['img_tensor'] = self.img_transform(img_tensor)
263
+ # self.img_transform(img_tensor)
264
+ # out_dict['img_tensor'] = self.img_transform(img)
265
+
266
+ ###### Text #####
267
+ # if datum['is_train']:
268
+ # sent = datum['sent'].strip()
269
+
270
+ sent = random.choice(datum['targets'])
271
+
272
+ # target_ids = self.tokenizer.encode(
273
+ # sent, max_length=self.args.gen_max_length, truncation=True)
274
+
275
+ # assert len(target_ids) <= self.args.gen_max_length, len(target_ids)
276
+ out_dict['sent'] = sent
277
+ # out_dict['target_ids'] = torch.LongTensor(target_ids)
278
+ # out_dict['target_length'] = len(target_ids)
279
+
280
+
281
+ # negative sample
282
+ neg_sent, neg_type = self.get_negative_text(sent)
283
+
284
+ # neg_target_ids = self.tokenizer.encode(
285
+ # neg_sent, max_length=self.args.gen_max_length, truncation=True)
286
+
287
+ # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids)
288
+ out_dict['neg_sent'] = neg_sent
289
+ out_dict['neg_type'] = neg_type
290
+ # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids)
291
+ # out_dict['neg_target_length'] = len(neg_target_ids)
292
+
293
+
294
+ if 'targets' in datum:
295
+ out_dict['targets'] = datum['targets']
296
+
297
+ return out_dict
298
+
299
+ def collate_fn(self, batch):
300
+ batch_entry = {}
301
+
302
+ B = len(batch)
303
+
304
+ # if 'target_ids' in batch[0]:
305
+ # T_W_L = max(entry['target_length'] for entry in batch)
306
+ # target_ids = torch.ones(
307
+ # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
308
+
309
+ # if 'target_ids' in batch[0]:
310
+ # T_W_L = max(entry['target_length'] for entry in batch)
311
+ # target_ids = torch.ones(
312
+ # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
313
+
314
+
315
+
316
+ targets = []
317
+ img_ids = []
318
+ img_paths = []
319
+
320
+ coco_ids = []
321
+
322
+ if self.args.load_feat:
323
+ img_feats = torch.zeros(B, 512, dtype=torch.float)
324
+ else:
325
+ # imgs = []
326
+ img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8)
327
+
328
+ for i, entry in enumerate(batch):
329
+
330
+ if self.args.load_feat:
331
+ coco_ids.append(entry['cocoid'])
332
+ img_feats[i] = entry['img_feat']
333
+
334
+ else:
335
+
336
+ img_ids.append(entry['img_id'])
337
+ img_paths.append(entry['img_path'])
338
+ img_tensor[i] = entry['img_tensor']
339
+
340
+ # if 'target_ids' in entry:
341
+ # target_ids[i, :entry['target_length']] = entry['target_ids']
342
+
343
+ if 'targets' in entry:
344
+ targets.append(entry['targets'])
345
+
346
+ if 'sent' in batch[0]:
347
+ # word_mask = target_ids != self.tokenizer.pad_token_id
348
+ # target_ids[~word_mask] = -100
349
+ # batch_entry['target_ids'] = target_ids
350
+
351
+ tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
352
+ neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt')
353
+ # sent, max_length=self.args.gen_max_length, truncation=True)
354
+
355
+ batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask)
356
+ batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask)
357
+
358
+
359
+ if self.args.load_feat:
360
+ batch_entry['coco_ids'] = coco_ids
361
+ batch_entry['img_feats'] = img_feats
362
+
363
+ else:
364
+
365
+ img_tensor = self.img_tensor_transform(img_tensor)
366
+
367
+ batch_entry['img_id'] = img_ids
368
+ batch_entry['img_paths'] = img_paths
369
+ batch_entry['img_tensor'] = img_tensor
370
+
371
+ batch_entry['targets'] = targets
372
+
373
+ # print('batch created')
374
+
375
+ # batch_entry['task'] = 'caption'
376
+
377
+ return batch_entry
378
+
379
+
380
+ # def get_loader(args, split='karpathy_train', mode='train',
381
+ # batch_size=32, workers=4, distributed=False, gpu=0,
382
+ # topk=-1):
383
+
384
+ # verbose = (gpu == 0)
385
+
386
+ # dataset = COCORetrievalDataset(
387
+ # split,
388
+ # rank=gpu,
389
+ # topk=topk,
390
+ # verbose=verbose,
391
+ # args=args,
392
+ # mode=mode)
393
+
394
+ # # if distributed:
395
+ # # sampler = DistributedSampler(dataset)
396
+ # # else:
397
+ # # sampler = None
398
+
399
+ # if mode == 'train':
400
+ # loader = DataLoader(
401
+ # dataset, batch_size=batch_size, shuffle=(sampler is None),
402
+ # num_workers=workers, pin_memory=True, sampler=sampler,
403
+ # collate_fn=dataset.collate_fn)
404
+ # else:
405
+ # loader = DataLoader(
406
+ # dataset,
407
+ # batch_size=batch_size, shuffle=False,
408
+ # num_workers=workers, pin_memory=True,
409
+ # sampler=sampler,
410
+ # collate_fn=dataset.collate_fn,
411
+ # drop_last=False)
412
+
413
+ # # if verbose:
414
+ # # loader.evaluator = COCOCaptionEvaluator()
415
+
416
+ # # loader.task = 'caption'
417
+
418
+ # return loader
419
+
420
+
421
+ # class COCOCaptionEvaluator:
422
+ # def __init__(self):
423
+ # import language_evaluation
424
+ # self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
425
+
426
+ # def evaluate(self, predicts, answers):
427
+
428
+ # results = self.evaluator.run_evaluation(predicts, answers)
429
+
430
+ # return results
431
+
432
+ import six
433
+ import os
434
+ import h5py
435
+
436
+ class HybridLoader:
437
+ """
438
+ If db_path is a director, then use normal file loading
439
+ If lmdb, then load from lmdb
440
+ The loading method depend on extention.
441
+
442
+ in_memory: if in_memory is True, we save all the features in memory
443
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
444
+ Should be useful for lmdb or h5.
445
+ (Copied this idea from vilbert)
446
+ """
447
+
448
+ def __init__(self, db_path, ext='.npy', in_memory=False):
449
+ self.db_path = db_path
450
+ self.ext = ext
451
+ if self.ext == '.npy':
452
+ self.loader = lambda x: np.load(six.BytesIO(x))
453
+ else:
454
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
455
+ # if db_path.endswith('.lmdb'):
456
+ # self.db_type = 'lmdb'
457
+ # self.lmdb = lmdbdict(db_path, unsafe=True)
458
+ # self.lmdb._key_dumps = DUMPS_FUNC['ascii']
459
+ # self.lmdb._value_loads = LOADS_FUNC['identity']
460
+ # elif db_path.endswith('.pth'): # Assume a key,value dictionary
461
+ # self.db_type = 'pth'
462
+ # self.feat_file = torch.load(db_path)
463
+ # self.loader = lambda x: x
464
+ # print('HybridLoader: ext is ignored')
465
+ # elif db_path.endswith('h5'):
466
+ # self.db_type = 'h5'
467
+ # self.loader = lambda x: np.array(x).astype('float32')
468
+ # else:
469
+ # self.db_type = 'dir'
470
+
471
+ self.in_memory = in_memory
472
+ if self.in_memory:
473
+ self.features = {}
474
+
475
+ def get(self, key):
476
+
477
+ # if self.in_memory and key in self.features:
478
+ # # We save f_input because we want to save the
479
+ # # compressed bytes to save memory
480
+ # f_input = self.features[key]
481
+ # elif self.db_type == 'lmdb':
482
+ # f_input = self.lmdb[key]
483
+ # elif self.db_type == 'pth':
484
+ # f_input = self.feat_file[key]
485
+ # elif self.db_type == 'h5':
486
+ # f_input = h5py.File(self.db_path, 'r')[key]
487
+ # else:
488
+ # f_input = open(os.path.join(
489
+ # self.db_path, key + self.ext), 'rb').read()
490
+
491
+ f_input = open(os.path.join(
492
+ self.db_path, key + self.ext), 'rb').read()
493
+
494
+ if self.in_memory and key not in self.features:
495
+ self.features[key] = f_input
496
+
497
+ # load image
498
+ feat = self.loader(f_input)
499
+
500
+ return feat
retrieval/clip_model.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPModel, CLIPTokenizer
2
+ import os
3
+ import json
4
+ import argparse
5
+ from random import shuffle, seed
6
+ import string
7
+ # non-standard dependencies:
8
+ import h5py
9
+ from six.moves import cPickle
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.models as models
13
+ import skimage.io
14
+
15
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
+ from PIL import Image
17
+ from torch import nn
18
+
19
+
20
+ class CLIPScore(nn.Module):
21
+ def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
22
+ super(CLIPScore, self).__init__()
23
+ # from transformers import CLIPModel, CLIPTokenizer
24
+ self.clip_model = CLIPModel.from_pretrained(
25
+ 'openai/clip-vit-base-patch32')
26
+ self.tokenizer = CLIPTokenizer.from_pretrained(
27
+ 'openai/clip-vit-base-patch32')
28
+
29
+ self.clip_model.eval()
30
+
31
+ self.clipscore_w = clipscore_w
32
+
33
+ self.image_transform = self._transform(image_size)
34
+
35
+ self.mode = mode
36
+ assert mode in ['clip_s', 'refclip_s']
37
+
38
+ self.use_grammar = use_grammar
39
+ self.joint_out = joint_out
40
+
41
+ if self.use_grammar and self.joint_out is False:
42
+ self.grammar_score_head = nn.Sequential(
43
+ nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
44
+ nn.ReLU(),
45
+ nn.Linear(self.clip_model.projection_dim, 2, bias=False)
46
+ )
47
+
48
+ def _transform(self, n_px):
49
+ return Compose([
50
+ Resize(n_px, interpolation=Image.BICUBIC),
51
+ CenterCrop(n_px),
52
+ lambda image: image.convert("RGB"),
53
+ ToTensor(),
54
+ Normalize((0.48145466, 0.4578275, 0.40821073),
55
+ (0.26862954, 0.26130258, 0.27577711)),
56
+ ])
57
+
58
+ def load_image(self, image_path):
59
+ image = Image.open(image_path)
60
+ return image
61
+
62
+ # @torch.no_grad()
63
+ def image_extract(self, image):
64
+ if isinstance(image, str):
65
+ image = self.load_image(image)
66
+ if not isinstance(image, torch.Tensor):
67
+ image = self.image_transform(image)
68
+
69
+ img_tensor = image.view(-1, 3, 224, 224)
70
+ device = next(self.clip_model.parameters()).device
71
+ img_tensor = img_tensor.to(device)
72
+
73
+ clip_model = self.clip_model
74
+
75
+ img_feat = clip_model.vision_model(img_tensor).pooler_output
76
+ img_feat = clip_model.visual_projection(img_feat)
77
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
78
+
79
+ return img_feat
80
+
81
+ # @torch.no_grad()
82
+ def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
83
+ if isinstance(text, str):
84
+ text_batch = [" ".join([prompt, text])]
85
+ elif isinstance(text, list):
86
+ text_batch = [" ".join([prompt, txt]) for txt in text]
87
+
88
+ if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
89
+ input_ids, attention_mask = text
90
+ else:
91
+ input_text = text_batch
92
+
93
+ tokenized = self.tokenizer(
94
+ input_text, return_tensors='pt', padding=True)
95
+
96
+ input_ids = tokenized.input_ids
97
+ attention_mask = tokenized.attention_mask
98
+
99
+ clip_model = self.clip_model
100
+ device = next(self.clip_model.parameters()).device
101
+ input_ids = input_ids.to(device)
102
+ attention_mask = attention_mask.to(device)
103
+
104
+ text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
105
+
106
+ if proj_norm:
107
+ text_feat = clip_model.text_projection(text_feat)
108
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
109
+
110
+ return text_feat
111
+
112
+ # @torch.no_grad()
113
+ def calc_clip_s(self, img_feat, text_feat):
114
+ return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
115
+
116
+ # @torch.no_grad()
117
+ def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
118
+
119
+ if clip_s is None:
120
+ clip_s = self.calc_clip_s(img_feat, text_feat)
121
+
122
+ B, dim = img_feat.size()
123
+
124
+ ref_text_feat = ref_text_feat.view(B, -1, dim)
125
+
126
+ K = ref_text_feat.size(1)
127
+
128
+ text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
129
+ assert ref_text_feat.size() == text_feat.size(
130
+ ), (ref_text_feat.size(), text_feat.size())
131
+
132
+ ref_score = self.calc_clip_s(text_feat, ref_text_feat)
133
+ if ref_text_mask is not None:
134
+ if not isinstance(ref_text_mask, torch.Tensor):
135
+ ref_text_mask = torch.tensor(
136
+ ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
137
+ ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
138
+
139
+ ref_score = ref_score.view(B, K).max(dim=1).values
140
+
141
+ assert clip_s.size() == (B,)
142
+ assert clip_s.size() == ref_score.size()
143
+
144
+ # harmonic mean
145
+ refclip_s = 2 / (1 / clip_s + 1 / ref_score)
146
+ return refclip_s
147
+
148
+ # # @torch.no_grad()
149
+ # def forward(self,
150
+ # images=None, text=None,
151
+ # img_feat=None, text_feat=None,
152
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
153
+ # prompt="A photo depicts",
154
+ # mode=None):
155
+ # if img_feat is None:
156
+ # img_feat = self.image_extract(images)
157
+ # img_feat = img_feat.view(-1, 512)
158
+
159
+ # if text_feat is None:
160
+ # text_feat = self.text_extract(text, prompt=prompt)
161
+ # text_feat = text_feat.view(-1, 512)
162
+
163
+ # if mode is None:
164
+ # mode = self.mode
165
+ # assert mode in ['clip_s', 'refclip_s']
166
+
167
+ # if mode == 'clip_s':
168
+ # clip_s = self.calc_clip_s(img_feat, text_feat)
169
+ # return clip_s
170
+ # elif mode == 'refclip_s':
171
+ # if ref_text_feat is None:
172
+ # ref_text_feat = self.text_extract(ref_text, prompt=prompt)
173
+ # ref_text_feat = ref_text_feat.view(-1, 512)
174
+
175
+ # refclip_s = self.calc_refclip_s(
176
+ # img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
177
+ # return refclip_s
178
+
179
+
180
+ def train_step(self,
181
+ images=None, text=None,
182
+ img_feat=None, text_feat=None,
183
+ neg_text=None, neg_text_feat=None,
184
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
185
+ prompt="A photo depicts",
186
+ # return_loss=True,
187
+ **kwargs):
188
+
189
+ if img_feat is None:
190
+ img_feat = self.image_extract(images)
191
+ img_feat = img_feat.view(-1, 512)
192
+
193
+ B = img_feat.size(0)
194
+
195
+ if self.joint_out:
196
+ pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512)
197
+ neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512)
198
+ neg_B = neg_text_feat.size(0)
199
+
200
+ # [B+neg_B, 512]
201
+ text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
202
+
203
+ text_cont_feat = self.clip_model.text_projection(text_feat)
204
+ text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
205
+
206
+ text_cont_feat = text_cont_feat.view(B+neg_B, 512)
207
+
208
+ logit_scale = self.clip_model.logit_scale.exp()
209
+
210
+ # [B+neg_B * B]
211
+ logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
212
+
213
+ # image-to-text label: positive text
214
+ caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean()
215
+
216
+ # calculate text-to-image only on positive text
217
+ image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean()
218
+
219
+ clip_loss = (caption_loss + image_loss) / 2.0
220
+
221
+ out = {
222
+ 'clip_loss': clip_loss,
223
+ 'img_feat': img_feat,
224
+ 'text_feat': text_cont_feat[:B].detach(),
225
+ # 'neg_text_feat': neg_text_feat,
226
+ }
227
+
228
+ return out
229
+
230
+
231
+ else:
232
+ if text_feat is None:
233
+ text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
234
+
235
+ text_cont_feat = self.clip_model.text_projection(text_feat)
236
+ text_cont_feat = text_cont_feat / \
237
+ text_cont_feat.norm(dim=-1, keepdim=True)
238
+
239
+ text_cont_feat = text_cont_feat.view(B, 512)
240
+
241
+
242
+ # cosine similarity as logits
243
+ logit_scale = self.clip_model.logit_scale.exp()
244
+ logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
245
+ # logits_per_image = logits_per_text.T
246
+
247
+ clip_loss = clip_loss_fn(logits_per_text)
248
+
249
+
250
+ # negative sampling
251
+ pos_text_feat = text_feat.view(B, 512)
252
+ neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
253
+
254
+ grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
255
+
256
+ # 2B, 1
257
+ grammar_text_logit = self.grammar_score_head(grammar_text_feat)
258
+ grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
259
+
260
+ grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
261
+
262
+ grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
263
+ grammar_pos_pred = grammar_pred[:B]
264
+ grammar_neg_pred = grammar_pred[B:]
265
+ # grammar_acc = (grammar_pred == grammar_labels).float().mean()
266
+
267
+ out = {
268
+ 'clip_loss': clip_loss,
269
+ 'grammar_loss': grammar_loss,
270
+ 'img_feat': img_feat,
271
+ 'text_feat': text_cont_feat,
272
+ 'neg_text_feat': neg_text_feat,
273
+ 'grammar_pos_pred': grammar_pos_pred,
274
+ 'grammar_neg_pred': grammar_neg_pred,
275
+ }
276
+
277
+ return out
278
+
279
+ def train_step_old(self,
280
+ images=None, text=None,
281
+ img_feat=None, text_feat=None,
282
+ neg_text=None, neg_text_feat=None,
283
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
284
+ prompt="A photo depicts",
285
+ # return_loss=True,
286
+ **kwargs):
287
+
288
+ if img_feat is None:
289
+ img_feat = self.image_extract(images)
290
+ img_feat = img_feat.view(-1, 512)
291
+
292
+ B = img_feat.size(0)
293
+
294
+
295
+
296
+ if text_feat is None:
297
+ text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
298
+
299
+ text_cont_feat = self.clip_model.text_projection(text_feat)
300
+ text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
301
+ text_cont_feat = text_cont_feat.view(B, 512)
302
+
303
+ # cosine similarity as logits
304
+ logit_scale = self.clip_model.logit_scale.exp()
305
+ logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
306
+ # logits_per_image = logits_per_text.T
307
+
308
+ clip_loss = clip_loss_fn(logits_per_text)
309
+
310
+
311
+ # negative sampling
312
+ pos_text_feat = text_feat.view(B, 512)
313
+ neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
314
+
315
+ grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
316
+
317
+ # 2B, 1
318
+ grammar_text_logit = self.grammar_score_head(grammar_text_feat)
319
+ grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
320
+
321
+ grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
322
+
323
+ grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
324
+ grammar_pos_pred = grammar_pred[:B]
325
+ grammar_neg_pred = grammar_pred[B:]
326
+ # grammar_acc = (grammar_pred == grammar_labels).float().mean()
327
+
328
+ out = {
329
+ 'clip_loss': clip_loss,
330
+ 'grammar_loss': grammar_loss,
331
+ 'img_feat': img_feat,
332
+ 'text_feat': text_cont_feat,
333
+ 'neg_text_feat': neg_text_feat,
334
+ 'grammar_pos_pred': grammar_pos_pred,
335
+ 'grammar_neg_pred': grammar_neg_pred,
336
+ }
337
+
338
+ return out
339
+
340
+ # contrastive loss function, adapted from
341
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
342
+ def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
343
+ neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
344
+ return -neg_ce.mean()
345
+
346
+
347
+ def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
348
+ caption_loss = contrastive_loss(similarity, dim=0)
349
+ image_loss = contrastive_loss(similarity, dim=1)
350
+ return (caption_loss + image_loss) / 2.0
retrieval/configs/clip_negative_text.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_dir: ./save/clip_negative_text/
2
+
3
+ losses_log_every: 25
4
+ precision: 32
5
+ load_feat: true
6
+ data_in_memory: false
7
+
8
+ batch_size: 1600
9
+ valid_batch_size: 200
10
+ clip_grad_norm: 0
11
+
12
+ epochs: 30
13
+ use_grammar: true
14
+ joint_out: false
retrieval/param.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ import pprint
8
+ import yaml
9
+
10
+
11
+ def str2bool(v):
12
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
13
+ return True
14
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15
+ return False
16
+ else:
17
+ raise argparse.ArgumentTypeError('Boolean value expected.')
18
+
19
+
20
+ def is_interactive():
21
+ import __main__ as main
22
+ return not hasattr(main, '__file__')
23
+
24
+
25
+ def get_optimizer(optim, verbose=False):
26
+ # Bind the optimizer
27
+ if optim == 'rms':
28
+ if verbose:
29
+ print("Optimizer: Using RMSProp")
30
+ optimizer = torch.optim.RMSprop
31
+ elif optim == 'adam':
32
+ if verbose:
33
+ print("Optimizer: Using Adam")
34
+ optimizer = torch.optim.Adam
35
+ elif optim == 'adamw':
36
+ if verbose:
37
+ print("Optimizer: Using AdamW")
38
+ # optimizer = torch.optim.AdamW
39
+ optimizer = 'adamw'
40
+ elif optim == 'adamax':
41
+ if verbose:
42
+ print("Optimizer: Using Adamax")
43
+ optimizer = torch.optim.Adamax
44
+ elif optim == 'sgd':
45
+ if verbose:
46
+ print("Optimizer: SGD")
47
+ optimizer = torch.optim.SGD
48
+ else:
49
+ assert False, "Please add your optimizer %s in the list." % optim
50
+
51
+ return optimizer
52
+
53
+
54
+ def parse_args(parse=True, **optional_kwargs):
55
+ parser = argparse.ArgumentParser()
56
+
57
+ parser.add_argument('--seed', type=int, default=9595, help='random seed')
58
+
59
+ # Data Splits
60
+ parser.add_argument("--train", default='karpathy_train')
61
+ parser.add_argument("--valid", default='karpathy_val')
62
+ parser.add_argument("--test", default='karpathy_test')
63
+ # parser.add_argument('--test_only', action='store_true')
64
+
65
+ # Quick experiments
66
+ parser.add_argument('--train_topk', type=int, default=-1)
67
+ parser.add_argument('--valid_topk', type=int, default=-1)
68
+
69
+ # Checkpoint
70
+ parser.add_argument('--output', type=str, default='snap/test')
71
+ parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).')
72
+ parser.add_argument('--from_scratch', action='store_true')
73
+
74
+ # CPU/GPU
75
+ parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
76
+ parser.add_argument('--fp16', action='store_true')
77
+ parser.add_argument("--distributed", action='store_true')
78
+ parser.add_argument("--num_workers", default=0, type=int)
79
+ parser.add_argument('--local_rank', type=int, default=-1)
80
+ # parser.add_argument('--rank', type=int, default=-1)
81
+
82
+ # Model Config
83
+ # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32')
84
+ # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased')
85
+ parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32')
86
+
87
+ # parser.add_argument('--position_embedding_type', type=str, default='absolute')
88
+
89
+ # parser.add_argument('--encoder_transform', action='store_true')
90
+
91
+ parser.add_argument('--max_text_length', type=int, default=40)
92
+
93
+ # parser.add_argument('--image_size', type=int, default=224)
94
+ # parser.add_argument('--patch_size', type=int, default=32)
95
+
96
+ # parser.add_argument('--decoder_num_layers', type=int, default=12)
97
+
98
+ # Training
99
+ parser.add_argument('--batch_size', type=int, default=256)
100
+ parser.add_argument('--valid_batch_size', type=int, default=None)
101
+
102
+ parser.add_argument('--optim', default='adamw')
103
+
104
+ parser.add_argument('--warmup_ratio', type=float, default=0.05)
105
+ parser.add_argument('--weight_decay', type=float, default=0.01)
106
+ parser.add_argument('--clip_grad_norm', type=float, default=-1.0)
107
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
108
+ parser.add_argument('--lr', type=float, default=1e-4)
109
+ parser.add_argument('--adam_eps', type=float, default=1e-6)
110
+ parser.add_argument('--adam_beta1', type=float, default=0.9)
111
+ parser.add_argument('--adam_beta2', type=float, default=0.999)
112
+
113
+ parser.add_argument('--epochs', type=int, default=20)
114
+ # parser.add_argument('--dropout', type=float, default=0.1)
115
+
116
+
117
+ # Inference
118
+ # parser.add_argument('--num_beams', type=int, default=1)
119
+ # parser.add_argument('--gen_max_length', type=int, default=20)
120
+
121
+ parser.add_argument('--start_from', type=str, default=None)
122
+
123
+ # Data
124
+ # parser.add_argument('--do_lower_case', type=str2bool, default=None)
125
+
126
+ # parser.add_argument('--prefix', type=str, default=None)
127
+
128
+
129
+ # COCO Caption
130
+ # parser.add_argument('--no_prefix', action='store_true')
131
+
132
+ parser.add_argument('--no_cls', action='store_true')
133
+
134
+ parser.add_argument('--cfg', type=str, default=None)
135
+ parser.add_argument('--id', type=str, default=None)
136
+
137
+ # Etc.
138
+ parser.add_argument('--comment', type=str, default='')
139
+ parser.add_argument("--dry", action='store_true')
140
+
141
+ # Parse the arguments.
142
+ if parse:
143
+ args = parser.parse_args()
144
+ # For interative engironmnet (ex. jupyter)
145
+ else:
146
+ args = parser.parse_known_args()[0]
147
+
148
+ loaded_kwargs = {}
149
+ if args.cfg is not None:
150
+ cfg_path = f'configs/{args.cfg}.yaml'
151
+ with open(cfg_path, 'r') as f:
152
+ loaded_kwargs = yaml.safe_load(f)
153
+
154
+ # Namespace => Dictionary
155
+ parsed_kwargs = vars(args)
156
+ parsed_kwargs.update(optional_kwargs)
157
+
158
+ kwargs = {}
159
+ kwargs.update(parsed_kwargs)
160
+ kwargs.update(loaded_kwargs)
161
+
162
+ args = Config(**kwargs)
163
+
164
+ # Bind optimizer class.
165
+ verbose = False
166
+ args.optimizer = get_optimizer(args.optim, verbose=verbose)
167
+
168
+ # Set seeds
169
+ torch.manual_seed(args.seed)
170
+ random.seed(args.seed)
171
+ np.random.seed(args.seed)
172
+
173
+ return args
174
+
175
+
176
+ class Config(object):
177
+ def __init__(self, **kwargs):
178
+ """Configuration Class: set kwargs as class attributes with setattr"""
179
+ for k, v in kwargs.items():
180
+ setattr(self, k, v)
181
+
182
+ @property
183
+ def config_str(self):
184
+ return pprint.pformat(self.__dict__)
185
+
186
+ def __repr__(self):
187
+ """Pretty-print configurations in alphabetical order"""
188
+ config_str = 'Configurations\n'
189
+ config_str += self.config_str
190
+ return config_str
191
+
192
+ # def update(self, **kwargs):
193
+ # for k, v in kwargs.items():
194
+ # setattr(self, k, v)
195
+
196
+ # def save(self, path):
197
+ # with open(path, 'w') as f:
198
+ # yaml.dump(self.__dict__, f, default_flow_style=False)
199
+
200
+ # @classmethod
201
+ # def load(cls, path):
202
+ # with open(path, 'r') as f:
203
+ # kwargs = yaml.load(f)
204
+
205
+ # return Config(**kwargs)
206
+
207
+
208
+ if __name__ == '__main__':
209
+ args = parse_args(True)
retrieval/pth_loader.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ verbose = True
21
+ # import torch
22
+ # if torch.cuda.current_device() in [0, -1]:
23
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
+ verbose = False
25
+
26
+ class HybridLoader:
27
+ """
28
+ If db_path is a director, then use normal file loading
29
+ If lmdb, then load from lmdb
30
+ The loading method depend on extention.
31
+
32
+ in_memory: if in_memory is True, we save all the features in memory
33
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
+ Should be useful for lmdb or h5.
35
+ (Copied this idea from vilbert)
36
+ """
37
+ def __init__(self, db_path, ext, in_memory=False):
38
+ self.db_path = db_path
39
+ self.ext = ext
40
+ if self.ext == '.npy':
41
+ self.loader = lambda x: np.load(six.BytesIO(x))
42
+ else:
43
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
+ if db_path.endswith('.lmdb'):
45
+ self.db_type = 'lmdb'
46
+ self.lmdb = lmdbdict(db_path, unsafe=True)
47
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
+ self.lmdb._value_loads = LOADS_FUNC['identity']
49
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
+ self.db_type = 'pth'
51
+ self.feat_file = torch.load(db_path)
52
+ self.loader = lambda x: x
53
+ print('HybridLoader: ext is ignored')
54
+ elif db_path.endswith('h5'):
55
+ self.db_type = 'h5'
56
+ self.loader = lambda x: np.array(x).astype('float32')
57
+ else:
58
+ self.db_type = 'dir'
59
+
60
+ self.in_memory = in_memory
61
+ if self.in_memory:
62
+ self.features = {}
63
+
64
+ def get(self, key):
65
+
66
+ if self.in_memory and key in self.features:
67
+ # We save f_input because we want to save the
68
+ # compressed bytes to save memory
69
+ f_input = self.features[key]
70
+ elif self.db_type == 'lmdb':
71
+ f_input = self.lmdb[key]
72
+ elif self.db_type == 'pth':
73
+ f_input = self.feat_file[key]
74
+ elif self.db_type == 'h5':
75
+ f_input = h5py.File(self.db_path, 'r')[key]
76
+ else:
77
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
+
79
+ if self.in_memory and key not in self.features:
80
+ self.features[key] = f_input
81
+
82
+ # load image
83
+ feat = self.loader(f_input)
84
+
85
+ return feat
86
+
87
+ class CaptionDataset(data.Dataset):
88
+
89
+ def get_vocab_size(self):
90
+ return self.vocab_size
91
+
92
+ def get_vocab(self):
93
+ return self.ix_to_word
94
+
95
+ def get_seq_length(self):
96
+ return self.seq_length
97
+
98
+ def __init__(self, opt):
99
+ self.opt = opt
100
+ self.seq_per_img = opt.seq_per_img
101
+
102
+ # feature related options
103
+ self.use_fc = getattr(opt, 'use_fc', True)
104
+ self.use_att = getattr(opt, 'use_att', True)
105
+ self.use_box = getattr(opt, 'use_box', 0)
106
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
+
109
+ # load the json file which contains additional information about the dataset
110
+ if verbose:
111
+ print('DataLoader loading json file: ', opt.input_json)
112
+ self.info = json.load(open(self.opt.input_json))
113
+ if 'ix_to_word' in self.info:
114
+ self.ix_to_word = self.info['ix_to_word']
115
+ self.vocab_size = len(self.ix_to_word)
116
+ if verbose:
117
+ print('vocab size is ', self.vocab_size)
118
+
119
+ # open the hdf5 file
120
+ if verbose:
121
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
+ """
123
+ Setting input_label_h5 to none is used when only doing generation.
124
+ For example, when you need to test on coco test set.
125
+ """
126
+ if self.opt.input_label_h5 != 'none':
127
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
+ # load in the sequence data
129
+ seq_size = self.h5_label_file['labels'].shape
130
+ self.label = self.h5_label_file['labels'][:]
131
+ self.seq_length = seq_size[1]
132
+ if verbose:
133
+ print('max sequence length in data is', self.seq_length)
134
+ # load the pointers in full to RAM (should be small enough)
135
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
+ else:
138
+ self.seq_length = 1
139
+
140
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
+
145
+ self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
+ if self.use_clipscore:
147
+ self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
+
149
+
150
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
+ if verbose:
152
+ print('read %d image features' %(self.num_images))
153
+
154
+ # separate out indexes for each of the provided splits
155
+ self.split_ix = {'train': [], 'val': [], 'test': []}
156
+ for ix in range(len(self.info['images'])):
157
+ img = self.info['images'][ix]
158
+ if not 'split' in img:
159
+ self.split_ix['train'].append(ix)
160
+ self.split_ix['val'].append(ix)
161
+ self.split_ix['test'].append(ix)
162
+ elif img['split'] == 'train':
163
+ self.split_ix['train'].append(ix)
164
+ elif img['split'] == 'val':
165
+ self.split_ix['val'].append(ix)
166
+ elif img['split'] == 'test':
167
+ self.split_ix['test'].append(ix)
168
+ elif opt.train_only == 0: # restval
169
+ self.split_ix['train'].append(ix)
170
+
171
+ if verbose:
172
+ print('assigned %d images to split train' %len(self.split_ix['train']))
173
+ print('assigned %d images to split val' %len(self.split_ix['val']))
174
+ print('assigned %d images to split test' %len(self.split_ix['test']))
175
+
176
+ def get_captions(self, ix, seq_per_img):
177
+ # fetch the sequence labels
178
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
+ ix2 = self.label_end_ix[ix] - 1
180
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
181
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
+
183
+ if ncap < seq_per_img:
184
+ # we need to subsample (with replacement)
185
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
+ for q in range(seq_per_img):
187
+ ixl = random.randint(ix1,ix2)
188
+ seq[q, :] = self.label[ixl, :self.seq_length]
189
+ else:
190
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
+
193
+ return seq
194
+
195
+ def collate_func(self, batch):
196
+ seq_per_img = self.seq_per_img
197
+
198
+ fc_batch = []
199
+ att_batch = []
200
+ label_batch = []
201
+
202
+ clip_vis_feat_batch = []
203
+
204
+ wrapped = False
205
+
206
+ infos = []
207
+ gts = []
208
+
209
+ for sample in batch:
210
+ # fetch image
211
+ if self.use_clipscore:
212
+ tmp_fc, tmp_att, tmp_seq, \
213
+ ix, tmp_clip_vis_feat = sample
214
+
215
+ clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
+ else:
217
+ tmp_fc, tmp_att, tmp_seq, \
218
+ ix = sample
219
+
220
+ fc_batch.append(tmp_fc)
221
+ att_batch.append(tmp_att)
222
+
223
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
+ if hasattr(self, 'h5_label_file'):
225
+ # if there is ground truth
226
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
+ label_batch.append(tmp_label)
228
+
229
+ # Used for reward evaluation
230
+ if hasattr(self, 'h5_label_file'):
231
+ # if there is ground truth
232
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
+ else:
234
+ gts.append([])
235
+
236
+ # record associated info as well
237
+ info_dict = {}
238
+ info_dict['ix'] = ix
239
+ info_dict['id'] = self.info['images'][ix]['id']
240
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
+ infos.append(info_dict)
242
+
243
+ # #sort by att_feat length
244
+ # fc_batch, att_batch, label_batch, gts, infos = \
245
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
+ if self.use_clipscore:
247
+ fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
+ else:
250
+ fc_batch, att_batch, label_batch, gts, infos = \
251
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
+ data = {}
253
+ data['fc_feats'] = np.stack(fc_batch)
254
+ # merge att_feats
255
+ max_att_len = max([_.shape[0] for _ in att_batch])
256
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
+ for i in range(len(att_batch)):
258
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
+ for i in range(len(att_batch)):
261
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
+ # set att_masks to None if attention features have same length
263
+ if data['att_masks'].sum() == data['att_masks'].size:
264
+ data['att_masks'] = None
265
+
266
+ if self.use_clipscore:
267
+ data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
+
269
+ data['labels'] = np.vstack(label_batch)
270
+ # generate mask
271
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
+ for ix, row in enumerate(mask_batch):
274
+ row[:nonzeros[ix]] = 1
275
+ data['masks'] = mask_batch
276
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
+
279
+ data['gts'] = gts # all ground truth captions of each images
280
+ data['infos'] = infos
281
+
282
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
+
284
+ return data
285
+
286
+ def __getitem__(self, ix):
287
+ """This function returns a tuple that is further passed to collate_fn
288
+ """
289
+ if self.use_att:
290
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
+ # Reshape to K x C
292
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
+ if self.norm_att_feat:
294
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
+ if self.use_box:
296
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
+ # devided by image width and height
298
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
+ if self.norm_box_feat:
302
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
+ att_feat = np.hstack([att_feat, box_feat])
304
+ # sort the features by the size of boxes
305
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
+ else:
307
+ att_feat = np.zeros((0,0), dtype='float32')
308
+ if self.use_fc:
309
+ try:
310
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
+ except:
312
+ # Use average of attention when there is no fc provided (For bottomup feature)
313
+ fc_feat = att_feat.mean(0)
314
+ else:
315
+ fc_feat = np.zeros((0), dtype='float32')
316
+ if hasattr(self, 'h5_label_file'):
317
+ seq = self.get_captions(ix, self.seq_per_img)
318
+ else:
319
+ seq = None
320
+
321
+ if self.use_clipscore:
322
+ clip_vis_feat = self.clipscore_loader.get(
323
+ str(self.info['images'][ix]['id']))
324
+
325
+ return (fc_feat,
326
+ att_feat, seq,
327
+ ix, clip_vis_feat)
328
+
329
+ return (fc_feat,
330
+ att_feat, seq,
331
+ ix)
332
+
333
+ def __len__(self):
334
+ return len(self.info['images'])
retrieval/text_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ def repeat(text, n_max_gram=3, n_max_repeat=3):
4
+ """repeat n-grams"""
5
+ tokens = text.split()
6
+
7
+ n_gram = random.randint(1, n_max_gram)
8
+
9
+ repeat_token_idx = random.randint(0, len(tokens) - n_gram)
10
+
11
+ repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram]
12
+
13
+ n_repeat = random.randint(1, n_max_repeat)
14
+ for _ in range(n_repeat):
15
+ insert_idx = random.randint(0, len(tokens))
16
+ tokens = tokens[:insert_idx] + \
17
+ repeated_tokens + tokens[insert_idx:]
18
+
19
+ new_text = " ".join(tokens)
20
+ return new_text
21
+
22
+ def remove(text, n_max_gram=3):
23
+ """remove n-grams"""
24
+ tokens = text.split()
25
+
26
+ n_gram = random.randint(1, n_max_gram)
27
+
28
+ remove_token_idx = random.randint(0, len(tokens) - n_gram)
29
+
30
+ tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:]
31
+
32
+ new_text = " ".join(tokens)
33
+ return new_text
34
+
35
+ def insert(text, vocab, n_max_tokens=3):
36
+ """Insert tokens"""
37
+ tokens = text.split()
38
+
39
+ n_insert_token = random.randint(1, n_max_tokens)
40
+
41
+ for _ in range(n_insert_token):
42
+ insert_token_idx = random.randint(0, len(tokens) - 1)
43
+ insert_token = random.choice(vocab)
44
+ tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:]
45
+
46
+ new_text = " ".join(tokens)
47
+ return new_text
48
+
49
+ def swap(text, vocab, n_max_tokens=3):
50
+ """Swap tokens"""
51
+ tokens = text.split()
52
+
53
+ n_swap_tokens = random.randint(1, n_max_tokens)
54
+
55
+ for _ in range(n_swap_tokens):
56
+ swap_token_idx = random.randint(0, len(tokens) - 1)
57
+
58
+ swap_token = random.choice(vocab)
59
+ while swap_token == tokens[swap_token_idx]:
60
+ swap_token = random.choice(vocab)
61
+
62
+ tokens[swap_token_idx] = swap_token
63
+
64
+ new_text = " ".join(tokens)
65
+ return new_text
66
+
67
+ def shuffle(text):
68
+ """shuffle tokens"""
69
+ tokens = text.split()
70
+
71
+ random.shuffle(tokens)
72
+
73
+ new_text = " ".join(tokens)
74
+ return new_text
retrieval/train_pl.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ast import parse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+
7
+ import numpy as np
8
+
9
+ import time
10
+ import os
11
+ from collections import defaultdict
12
+
13
+ # import captioning.utils.opts as opts
14
+ # import captioning.models as models
15
+ # from captioning.data.pth_loader import CaptionDataset
16
+ # import captioning.utils.eval_utils as eval_utils
17
+ # import captioning.utils.misc as utils
18
+ # from captioning.utils.rewards import init_scorer, get_self_critical_reward
19
+ # from captioning.modules.loss_wrapper import LossWrapper
20
+
21
+ from clip_model import CLIPScore
22
+ from caption_data import COCORetrievalDataset
23
+
24
+ import pytorch_lightning as pl
25
+
26
+ import detectron2.utils.comm as d2comm
27
+ from detectron2.utils.env import seed_all_rng
28
+ seed_all_rng(1234)
29
+
30
+
31
+ class LitModel(pl.LightningModule):
32
+ def __init__(self, opt):
33
+ super().__init__()
34
+ self.opt = opt
35
+ self.args = args
36
+ # Intilaize dataset
37
+ # self.dataset = CaptionDataset(opt)
38
+
39
+ # self.dataset =
40
+
41
+ # opt.vocab_size = self.dataset.vocab_size
42
+ # opt.seq_length = self.dataset.seq_length
43
+ # self.batch_size = opt.batch_size
44
+
45
+ # Build model
46
+ # opt.vocab = self.dataset.get_vocab()
47
+ # model = models.setup(opt)
48
+ # print(model)
49
+ # del opt.vocab
50
+
51
+ # wrapper with loss in it.
52
+ # lw_model = LossWrapper(model, opt)
53
+
54
+ self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out)
55
+ # self.lw_model = lw_model
56
+
57
+ for p in self.model.clip_model.vision_model.parameters():
58
+ p.requires_grad = False
59
+ for p in self.model.clip_model.visual_projection.parameters():
60
+ p.requires_grad = False
61
+
62
+ # self.struc_flag = None
63
+ # self.sc_flag = None
64
+
65
+
66
+ def forward(self, *args, **kwargs):
67
+ """
68
+ I hate this design. Never pretend it as a nn.Module
69
+ """
70
+ raise NotImplementedError
71
+
72
+ def train_dataloader(self):
73
+ # train_dataset = torch.utils.data.Subset(
74
+ # self.dataset,
75
+ # self.dataset.split_ix['train']
76
+ # )
77
+
78
+ # train_loader = torch.utils.data.DataLoader(
79
+ # dataset=train_dataset,
80
+ # batch_size=self.batch_size,
81
+ # shuffle=True,
82
+ # num_workers=4,
83
+ # collate_fn=self.dataset.collate_func
84
+ # )
85
+
86
+ train_dataset = COCORetrievalDataset(
87
+ split='karpathy_train', mode='train',
88
+ args=opt,
89
+ verbose=verbose
90
+ )
91
+
92
+ train_loader = torch.utils.data.DataLoader(
93
+ dataset=train_dataset,
94
+ batch_size=opt.batch_size,
95
+ shuffle=True,
96
+ num_workers=4,
97
+ collate_fn=train_dataset.collate_fn
98
+ )
99
+
100
+ return train_loader
101
+
102
+ def val_dataloader(self, split='karpathy_val'):
103
+ # val_dataset = torch.utils.data.Subset(
104
+ # self.dataset,
105
+ # self.dataset.split_ix[split]
106
+ # )
107
+ # val_loader = torch.utils.data.DataLoader(
108
+ # val_dataset,
109
+ # batch_size=self.batch_size,
110
+ # shuffle=False,
111
+ # num_workers=4,
112
+ # drop_last=False,
113
+ # collate_fn=self.dataset.collate_func
114
+ # )
115
+
116
+ val_dataset = COCORetrievalDataset(
117
+ split=split, mode='val',
118
+ args=opt,
119
+ verbose=verbose
120
+ )
121
+
122
+ val_loader = torch.utils.data.DataLoader(
123
+ dataset=val_dataset,
124
+ batch_size=opt.valid_batch_size,
125
+ shuffle=False,
126
+ num_workers=4,
127
+ drop_last=False,
128
+ collate_fn=val_dataset.collate_fn
129
+ )
130
+
131
+ return val_loader
132
+
133
+ def test_dataloader(self):
134
+
135
+ return self.val_dataloader('karpathy_test')
136
+
137
+ def training_step(self, data, batch_idx):
138
+
139
+
140
+ batch = data
141
+ self.model.train()
142
+
143
+ model_out = self.model.train_step(
144
+ img_feat=batch['img_feats'],
145
+ text=batch['text'],
146
+ neg_text=batch['neg_text'],
147
+ )
148
+
149
+ clip_loss = model_out['clip_loss']
150
+
151
+ if self.opt.joint_out:
152
+ loss = clip_loss
153
+ else:
154
+ grammar_loss = model_out['grammar_loss']
155
+ loss = clip_loss + grammar_loss
156
+
157
+
158
+ data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
159
+ data_time = torch.tensor(data_time)
160
+
161
+ # print('batch_idx', batch_idx)
162
+ # print('loss:', loss)
163
+
164
+ # logger_logs = model_out.copy()
165
+ logger_logs = {}
166
+
167
+ logger_logs['loss'] = loss.detach()
168
+
169
+ logger_logs['clip_loss'] = clip_loss.detach()
170
+
171
+ if not self.opt.joint_out:
172
+ logger_logs['grammar_loss'] = grammar_loss.detach()
173
+
174
+ logger_logs['data_time'] = data_time.detach()
175
+
176
+ # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
177
+ # Please use self.log(...) inside the lightningModule instead.
178
+
179
+ # # log on a step or aggregate epoch metric to the logger and/or progress bar
180
+ # # (inside LightningModule)
181
+ # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
182
+ # warnings.warn(*args, **kwargs)
183
+ # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
184
+ # Please use self.log(...) inside the lightningModule instead.
185
+
186
+ # output = {
187
+ # 'loss': loss,
188
+ # 'log': logger_logs,
189
+ # 'progress_bar': {'data_time': data_time}
190
+ # }
191
+
192
+ for k, v in logger_logs.items():
193
+ if k in ['data_time', 'clip_loss', 'grammar_loss']:
194
+ self.log('train/'+k, v, prog_bar=True)
195
+ else:
196
+ self.log('train/'+k, v)
197
+
198
+ # print('training step logged')
199
+
200
+ return loss
201
+
202
+ def validation_step(self, data, batch_idx):
203
+
204
+ batch = data
205
+ self.model.eval()
206
+
207
+ with torch.no_grad():
208
+ model_out = self.model.train_step(
209
+ img_feat=batch['img_feats'],
210
+ text=batch['text'],
211
+ neg_text=batch['neg_text'],
212
+ )
213
+
214
+ if self.opt.joint_out:
215
+ clip_loss = model_out['clip_loss']
216
+ loss = clip_loss
217
+
218
+ output = {
219
+ # 'val_loss': loss,
220
+ 'loss': loss.detach(),
221
+ 'clip_loss': clip_loss.detach(),
222
+ # 'grammar_loss': grammar_loss.detach(),
223
+
224
+ 'img_feat': model_out['img_feat'].detach(),
225
+ 'text_feat': model_out['text_feat'].detach(),
226
+ # 'neg_text_feat': model_out['neg_text_feat'].detach(),
227
+ # 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
228
+ # 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
229
+ # 'predictions': predictions,
230
+ # 'n_predictions': n_predictions,
231
+ }
232
+ else:
233
+ clip_loss = model_out['clip_loss']
234
+ grammar_loss = model_out['grammar_loss']
235
+ loss = clip_loss + grammar_loss
236
+
237
+ output = {
238
+ # 'val_loss': loss,
239
+ 'loss': loss.detach(),
240
+ 'clip_loss': clip_loss.detach(),
241
+ 'grammar_loss': grammar_loss.detach(),
242
+
243
+ 'img_feat': model_out['img_feat'].detach(),
244
+ 'text_feat': model_out['text_feat'].detach(),
245
+ # 'neg_text_feat': model_out['neg_text_feat'].detach(),
246
+ 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(),
247
+ 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(),
248
+ # 'predictions': predictions,
249
+ # 'n_predictions': n_predictions,
250
+ }
251
+ return output
252
+
253
+ def test_step(self, *args, **kwargs):
254
+ return self.validation_step(*args, **kwargs)
255
+
256
+ def validation_epoch_end(self, outputs, split='val'):
257
+ outputs = d2comm.gather(outputs)
258
+ # master node
259
+ if d2comm.is_main_process():
260
+ assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
261
+ outputs = sum(outputs, [])
262
+
263
+ out = {}
264
+
265
+ val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs)
266
+ val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs)
267
+ if not self.opt.joint_out:
268
+ val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs)
269
+
270
+ print('loss', val_loss_mean.item())
271
+ print('clip_loss', val_clip_loss_mean.item())
272
+ if not self.opt.joint_out:
273
+ print('grammar_loss', val_grammar_loss_mean.item())
274
+
275
+ logit_scale = self.model.clip_model.logit_scale.exp().cpu()
276
+
277
+ text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0)
278
+ img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0)
279
+
280
+ assert text_feats.size() == (5000, 512), text_feats.size()
281
+ assert img_feats.size() == (5000, 512), img_feats.size()
282
+
283
+ logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale
284
+ logits_per_image = logits_per_text.T
285
+
286
+ # text-to-image retrieval
287
+ print('Text-to-Image retrieval')
288
+ for k in [1, 5, 10]:
289
+ text_to_image_topk = logits_per_text.topk(k, dim=1).indices
290
+
291
+ n_text = len(text_to_image_topk)
292
+
293
+ labels = torch.arange(0, n_text).view(-1, 1)
294
+
295
+ n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum()
296
+
297
+ recall_k = n_retrieved / n_text * 100
298
+
299
+ out[f'text_to_image_recall_{k}'] = recall_k.item()
300
+
301
+ print(f'R@{k}: {recall_k.item():.2f}%')
302
+
303
+ # image-to-text retrieval
304
+ print('Image-to-Text retrieval')
305
+ for k in [1, 5, 10]:
306
+ image_to_text_topk = logits_per_image.topk(k, dim=1).indices
307
+
308
+ n_image = len(image_to_text_topk)
309
+
310
+ labels = torch.arange(0, n_image).view(-1, 1)
311
+
312
+ n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum()
313
+
314
+ recall_k = n_retrieved / n_image * 100
315
+
316
+ out[f'image_to_text_recall_{k}'] = recall_k.item()
317
+
318
+ print(f'R@{k}: {recall_k.item():.2f}%')
319
+
320
+ out.update({
321
+ 'loss': val_loss_mean.item(),
322
+ 'clip_loss': val_clip_loss_mean.item()
323
+ })
324
+
325
+ if not self.opt.joint_out:
326
+ # grammar scoring
327
+ grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0)
328
+ grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0)
329
+
330
+ TP = (grammar_pos_pred == 1).sum().item()
331
+ FP = (grammar_pos_pred == 0).sum().item()
332
+ FN = (grammar_neg_pred == 1).sum().item()
333
+ TN = (grammar_neg_pred == 0).sum().item()
334
+ print('Grammar check')
335
+ print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}')
336
+
337
+ precision = TP / (TP + FP) * 100
338
+ recall = TP / (TP + FN) * 100
339
+ accuracy = (TP + TN) / (TP + FP + FN + TN) * 100
340
+ f1 = 2 * precision * recall / (precision + recall)
341
+ print(f'Precision: {precision:.2f}%')
342
+ print(f'Recall: {recall:.2f}%')
343
+ print(f'Accuracy: {accuracy:.2f}%')
344
+ print(f'F1: {f1:.2f}%')
345
+ print('Total: {}'.format(len(grammar_pos_pred)))
346
+
347
+ out.update({
348
+ 'grammar_loss': val_grammar_loss_mean,
349
+
350
+ 'grammar_precision': precision,
351
+ 'grammar_recall': recall,
352
+ 'grammar_accuracy': accuracy,
353
+ 'grammar_f1': f1,
354
+
355
+ })
356
+
357
+ else:
358
+ out = {}
359
+
360
+ out = d2comm.all_gather(out)[0] # Only the one from master node
361
+ assert len(out) > 0 # make sure the head has index 0
362
+
363
+ # must all be tensors
364
+ out = {k: torch.tensor(v) if not torch.is_tensor(
365
+ v) else v for k, v in out.items()}
366
+
367
+ for k, v in out.items():
368
+ self.log(f'{split}/{k}', v)
369
+
370
+ def test_epoch_end(self, outputs):
371
+
372
+ self.validation_epoch_end(outputs, 'test')
373
+
374
+ def configure_optimizers(self):
375
+ # opt = self.opt
376
+ # model = self.model
377
+
378
+ # parameters = [p for p in model.parameters() if p.requires_grad]
379
+
380
+ # if opt.noamopt:
381
+ # # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
382
+ # optimizer = utils.get_std_opt(
383
+ # model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
384
+ # elif opt.reduce_on_plateau:
385
+ # # optimizer = utils.build_optimizer(model.parameters(), opt)
386
+ # optimizer = utils.build_optimizer(parameters, opt)
387
+ # optimizer = utils.ReduceLROnPlateau(optimizer,
388
+ # factor=opt.reduce_on_plateau_factor,
389
+ # patience=opt.reduce_on_plateau_patience)
390
+ # else:
391
+ # # optimizer = utils.build_optimizer(model.parameters(), opt)
392
+ # optimizer = utils.build_optimizer(parameters, opt)
393
+
394
+
395
+ # from transformers.optimization import AdamW, get_linear_schedule_with_warmup
396
+ # batch_per_epoch = len(self.train_loader)
397
+ # t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs
398
+ # warmup_ratio = self.args.warmup_ratio
399
+ # warmup_iters = int(t_total * warmup_ratio)
400
+ # if self.verbose:
401
+ # print("Batch per epoch: %d" % batch_per_epoch)
402
+ # print("Total Iters: %d" % t_total)
403
+ # print('Warmup ratio:', warmup_ratio)
404
+ # print("Warm up Iters: %d" % warmup_iters)
405
+
406
+ if self.args.optim == 'adamw':
407
+ no_decay = ["bias", "LayerNorm.weight"]
408
+ optimizer_grouped_parameters = [
409
+ {
410
+ "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
411
+ "weight_decay": self.args.weight_decay,
412
+ },
413
+ {
414
+ "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
415
+ "weight_decay": 0.0,
416
+ },
417
+ ]
418
+
419
+ for group in optimizer_grouped_parameters:
420
+ group['params'] = [p for p in group['params'] if p.requires_grad]
421
+
422
+ from transformers.optimization import AdamW
423
+ optim = AdamW(optimizer_grouped_parameters,
424
+ lr=self.args.lr, eps=self.args.adam_eps)
425
+ # lr_scheduler = get_linear_schedule_with_warmup(
426
+ # optim, warmup_iters, t_total)
427
+
428
+ # optimizers = []
429
+ optimizers = [optim]
430
+ lr_schedulers = []
431
+
432
+ return optimizers, lr_schedulers
433
+
434
+ def optimizer_step(self, epoch, batch_idx, optimizer,
435
+ optimizer_idx, *args, **kwargs):
436
+ # # warm up lr
437
+ # opt = self.opt
438
+ # iteration = self.trainer.global_step
439
+ # if opt.use_warmup and (iteration < opt.noamopt_warmup):
440
+ # opt.current_lr = opt.learning_rate * \
441
+ # (iteration+1) / opt.noamopt_warmup
442
+ # utils.set_lr(optimizer, opt.current_lr)
443
+
444
+ super().optimizer_step(epoch, batch_idx, optimizer,
445
+ optimizer_idx, *args, **kwargs)
446
+
447
+ # print('optimizer step')
448
+
449
+ def state_dict(self):
450
+ """
451
+ Save the model state dict as well as opt and vocab
452
+ """
453
+ state_dict = self.model.state_dict()
454
+ device = next(iter(state_dict.values())).device
455
+ assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
456
+ # state_dict.update({
457
+ # '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
458
+ # '_opt': utils.serialize_to_tensor(self.opt).to(device)
459
+ # })
460
+ return state_dict
461
+
462
+ def load_state_dict(self, state_dict=None, strict=True):
463
+ # if '_vocab' in state_dict:
464
+ # self.model.vocab = utils.deserialize(state_dict['_vocab'])
465
+ # del state_dict['_vocab']
466
+ # elif strict:
467
+ # raise KeyError
468
+ # if '_opt' in state_dict:
469
+ # saved_model_opt = utils.deserialize(state_dict['_opt'])
470
+ # del state_dict['_opt']
471
+ # opt = self.opt
472
+ # # Make sure the saved opt is compatible with the curren topt
473
+ # need_be_same = ["caption_model",
474
+ # "rnn_type", "rnn_size", "num_layers"]
475
+ # for checkme in need_be_same:
476
+ # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
477
+ # getattr(opt, checkme) in ['updown', 'topdown']:
478
+ # continue
479
+ # assert getattr(saved_model_opt, checkme) == getattr(
480
+ # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
481
+ # elif strict:
482
+ # raise KeyError
483
+ self.model.load_state_dict(state_dict, strict)
484
+
485
+
486
+ class OnEpochStartCallback(pl.Callback):
487
+
488
+ def on_epoch_start(self, trainer, pl_module):
489
+ # Update lr/training stage/scheduled sampling prob etc.
490
+ opt = pl_module.opt
491
+ model = pl_module.model
492
+ epoch = trainer.current_epoch
493
+ optimizer = trainer.optimizers[0]
494
+
495
+ # if not opt.noamopt and not opt.reduce_on_plateau:
496
+ # # Assign the learning rate
497
+ # if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
498
+ # frac = (
499
+ # epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
500
+ # decay_factor = opt.learning_rate_decay_rate ** frac
501
+ # opt.current_lr = opt.learning_rate * decay_factor
502
+ # else:
503
+ # opt.current_lr = opt.learning_rate
504
+ # utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
505
+ # # Assign the scheduled sampling prob
506
+ # if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
507
+ # frac = (
508
+ # epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
509
+ # opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
510
+ # frac, opt.scheduled_sampling_max_prob)
511
+ # model.ss_prob = opt.ss_prob
512
+
513
+ # # If start self critical training
514
+ # if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
515
+ # sc_flag = True
516
+ # init_scorer(opt.cached_tokens)
517
+ # else:
518
+ # sc_flag = False
519
+
520
+ # # If start structure loss training
521
+ # if opt.structure_after != -1 and epoch >= opt.structure_after:
522
+ # struc_flag = True
523
+ # init_scorer(opt.cached_tokens)
524
+ # else:
525
+ # struc_flag = False
526
+
527
+ # pl_module.struc_flag = struc_flag
528
+ # pl_module.sc_flag = sc_flag
529
+
530
+
531
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
532
+
533
+ def on_keyboard_interrupt(self, trainer, pl_module):
534
+ # Save model when keyboard interrupt
535
+ filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
536
+ self._save_model(filepath)
537
+
538
+ from param import parse_args
539
+ # opt = opts.parse_opt()
540
+ args = parse_args()
541
+ opt = args
542
+
543
+ checkpoint_callback = ModelCheckpoint(
544
+ filepath=opt.checkpoint_dir + '{epoch:02d}',
545
+ # dirpath=opt.checkpoint_path,
546
+ save_last=True,
547
+ save_top_k=1,
548
+ verbose=True,
549
+ # monitor='to_monitor',
550
+ # monitor='val/to_monitor',
551
+ # monitor='val/CIDEr',
552
+ monitor='val/loss',
553
+ mode='min',
554
+ # prefix=opt.id+'_',
555
+ prefix=opt.id,
556
+ # filename=f'{opt.id}_',
557
+ )
558
+
559
+ verbose = True
560
+ # import torch
561
+ # if torch.cuda.current_device() in [0, -1]:
562
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
563
+ verbose = False
564
+
565
+ # if verbose:
566
+ # print(opt)
567
+ # print("""
568
+ # val_image_use,
569
+ # save_checkpoint_very
570
+ # save_every_epoch,
571
+ # save_history-ckpt will be ignored.
572
+ # """)
573
+
574
+ # Lightning defines batch size as batch size per gpu
575
+ assert opt.batch_size % torch.cuda.device_count() == 0
576
+ opt.batch_size = opt.batch_size // torch.cuda.device_count()
577
+ opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count()
578
+
579
+ # If resume from last checkpoint
580
+ # if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
581
+ # resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
582
+ if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')):
583
+ resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
584
+ if verbose:
585
+ print('resume from', resume_from)
586
+ else:
587
+ resume_from = None
588
+
589
+ from pytorch_lightning.loggers import WandbLogger
590
+ wandb_logger = WandbLogger(
591
+ # project='CLIP-ViL-COCOCaption',
592
+ project='CLIP-Finetune-COCO',
593
+ name=opt.id,
594
+ )
595
+
596
+ if verbose:
597
+ wandb_logger.experiment.config.update(opt)
598
+ from pathlib import Path
599
+ import glob
600
+ import wandb
601
+ # src_dir = Path(__file__).resolve().parent.parent
602
+ glob_str = "*.py"
603
+ base_path = './'
604
+ wandb.save(glob_str=glob_str, base_path=base_path)
605
+
606
+ glob_str = "**/*.yaml"
607
+ base_path = './'
608
+ wandb.save(glob_str=glob_str, base_path=base_path)
609
+
610
+ # code = wandb.Artifact('project-source', type='code')
611
+ # for path in glob.glob('**/*.py', recursive=True):
612
+ # code.add_file(path, name='source/'+path)
613
+ # print(path)
614
+ # wandb.run.use_artifact(code)
615
+
616
+
617
+
618
+
619
+ lit = LitModel(opt)
620
+ # warning grad_clip_mode is ignored.
621
+ trainer = pl.Trainer(
622
+ callbacks=[
623
+ OnEpochStartCallback(),
624
+ # pl.callbacks.lr_logger.LearningRateLogger()
625
+ pl.callbacks.LearningRateMonitor()
626
+ ],
627
+ default_root_dir=opt.checkpoint_dir,
628
+ resume_from_checkpoint=resume_from,
629
+
630
+ distributed_backend='ddp',
631
+ gpus=torch.cuda.device_count(),
632
+
633
+ # gpus=1,
634
+
635
+ check_val_every_n_epoch=1,
636
+ # max_epochs=opt.max_epochs,
637
+ max_epochs=opt.epochs,
638
+ # gradient_clip_val=opt.grad_clip_value,
639
+ gradient_clip_val=opt.clip_grad_norm,
640
+
641
+ checkpoint_callback=checkpoint_callback,
642
+ log_gpu_memory='min_max',
643
+ # log_save_interval=opt.losses_log_every,
644
+ log_every_n_steps=opt.losses_log_every,
645
+ profiler=True,
646
+ # profiler='simple',
647
+ # row_log_interval=10, # what is it?
648
+ flush_logs_every_n_steps=10,
649
+ num_sanity_val_steps=0,
650
+ # val_check_interval=0.01,
651
+ # limit_train_batches=500,
652
+ # progress_bar_refresh_rate=0,
653
+ # fast_dev_run=True,
654
+ precision=opt.precision,
655
+ logger=wandb_logger
656
+ )
657
+
658
+ if os.getenv('EVALUATE', '0') == '1':
659
+ trainer.test(lit)
660
+ else:
661
+ trainer.fit(lit)
save/README.md ADDED
@@ -0,0 +1 @@
 
1
+ Directory for checkpoints