WwYc commited on
Commit
4f00e9e
·
verified ·
1 Parent(s): dd0b9c4

Delete lxmert/src/pretrain

Browse files
lxmert/src/pretrain/__init__.py DELETED
File without changes
lxmert/src/pretrain/lxmert_data.py DELETED
@@ -1,255 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- from collections import defaultdict
5
- import json
6
- import random
7
-
8
- import numpy as np
9
- from torch.utils.data import Dataset
10
-
11
- from param import args
12
- from pretrain.qa_answer_table import AnswerTable
13
- from utils import load_obj_tsv
14
-
15
- TINY_IMG_NUM = 500
16
- FAST_IMG_NUM = 5000
17
-
18
- Split2ImgFeatPath = {
19
- 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
20
- 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
21
- 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
22
- 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
23
- }
24
-
25
-
26
- class InputExample(object):
27
- """A single training/test example for the language model."""
28
- def __init__(self, uid, sent, visual_feats=None,
29
- obj_labels=None, attr_labels=None,
30
- is_matched=None, label=None):
31
- self.uid = uid
32
- self.sent = sent
33
- self.visual_feats = visual_feats
34
- self.obj_labels = obj_labels
35
- self.attr_labels = attr_labels
36
- self.is_matched = is_matched # whether the visual and obj matched
37
- self.label = label
38
-
39
-
40
- class LXMERTDataset:
41
- def __init__(self, splits: str, qa_sets=None):
42
- """
43
- :param splits: The data sources to be loaded
44
- :param qa_sets: if None, no action
45
- o.w., only takes the answers appearing in these dsets
46
- and remove all unlabeled data (MSCOCO captions)
47
- """
48
- self.name = splits
49
- self.sources = splits.split(',')
50
-
51
- # Loading datasets to data
52
- self.data = []
53
- for source in self.sources:
54
- self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
55
- print("Load %d data from %s" % (len(self.data), self.name))
56
-
57
- # Create answer table according to the qa_sets
58
- self.answer_table = AnswerTable(qa_sets)
59
- print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
60
-
61
- # Modify the answers
62
- for datum in self.data:
63
- labelf = datum['labelf']
64
- for cat, labels in labelf.items():
65
- for label in labels:
66
- for ans in list(label.keys()):
67
- new_ans = self.answer_table.convert_ans(ans)
68
- if self.answer_table.used(new_ans):
69
- if ans != new_ans:
70
- label[new_ans] = label.pop(ans)
71
- else:
72
- label.pop(ans)
73
-
74
- def __len__(self):
75
- return len(self.data)
76
-
77
-
78
- def make_uid(img_id, dset, sent_idx):
79
- return "%s_%s_%03d" % (img_id, dset, sent_idx),
80
-
81
-
82
- """
83
- Example in obj tsv:
84
- FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
85
- "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
86
- """
87
- class LXMERTTorchDataset(Dataset):
88
- def __init__(self, dataset: LXMERTDataset, topk=-1):
89
- super().__init__()
90
- self.raw_dataset = dataset
91
- self.task_matched = args.task_matched
92
-
93
- if args.tiny:
94
- topk = TINY_IMG_NUM
95
- elif args.fast:
96
- topk = FAST_IMG_NUM
97
-
98
- # Load the dataset
99
- img_data = []
100
- for source in self.raw_dataset.sources:
101
- img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))
102
-
103
- self.imgid2img = {}
104
- for img_datum in img_data:
105
- self.imgid2img[img_datum['img_id']] = img_datum
106
-
107
- # Filter out the dataset
108
- used_data = []
109
- for datum in self.raw_dataset.data:
110
- if datum['img_id'] in self.imgid2img:
111
- used_data.append(datum)
112
-
113
- # Flatten the dataset (into one sent + one image entries)
114
- self.data = []
115
- for datum in used_data:
116
- sentf = datum['sentf']
117
- for sents_cat, sents in sentf.items():
118
- if sents_cat in datum['labelf']:
119
- labels = datum['labelf'][sents_cat]
120
- else:
121
- labels = None
122
- for sent_idx, sent in enumerate(sents):
123
- new_datum = {
124
- 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
125
- 'img_id': datum['img_id'],
126
- 'sent': sent
127
- }
128
- if labels is not None:
129
- new_datum['label'] = labels[sent_idx]
130
- self.data.append(new_datum)
131
- print("Use %d data in torch dataset" % (len(self.data)))
132
-
133
- def __len__(self):
134
- return len(self.data)
135
-
136
- def random_feat(self):
137
- """Get a random obj feat from the dataset."""
138
- datum = self.data[random.randint(0, len(self.data)-1)]
139
- img_id = datum['img_id']
140
- img_info = self.imgid2img[img_id]
141
- feat = img_info['features'][random.randint(0, 35)]
142
- return feat
143
-
144
- def __getitem__(self, item: int):
145
- datum = self.data[item]
146
-
147
- uid = datum['uid']
148
- img_id = datum['img_id']
149
-
150
- # Get image info
151
- img_info = self.imgid2img[img_id]
152
- obj_num = img_info['num_boxes']
153
- feats = img_info['features'].copy()
154
- boxes = img_info['boxes'].copy()
155
- obj_labels = img_info['objects_id'].copy()
156
- obj_confs = img_info['objects_conf'].copy()
157
- attr_labels = img_info['attrs_id'].copy()
158
- attr_confs = img_info['attrs_conf'].copy()
159
- assert obj_num == len(boxes) == len(feats)
160
-
161
- # Normalize the boxes (to 0 ~ 1)
162
- img_h, img_w = img_info['img_h'], img_info['img_w']
163
- boxes = boxes.copy()
164
- boxes[:, (0, 2)] /= img_w
165
- boxes[:, (1, 3)] /= img_h
166
- np.testing.assert_array_less(boxes, 1+1e-5)
167
- np.testing.assert_array_less(-boxes, 0+1e-5)
168
-
169
- # If calculating the matched loss, replace the sentence with an sentence
170
- # corresponding to other image.
171
- is_matched = 1
172
- sent = datum['sent']
173
- if self.task_matched:
174
- if random.random() < 0.5:
175
- is_matched = 0
176
- other_datum = self.data[random.randint(0, len(self.data)-1)]
177
- while other_datum['img_id'] == img_id:
178
- other_datum = self.data[random.randint(0, len(self.data)-1)]
179
- sent = other_datum['sent']
180
-
181
- # Label, convert answer to id
182
- if 'label' in datum:
183
- label = datum['label'].copy()
184
- for ans in list(label.keys()):
185
- label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
186
- else:
187
- label = None
188
-
189
- # Create target
190
- example = InputExample(
191
- uid, sent, (feats, boxes),
192
- (obj_labels, obj_confs), (attr_labels, attr_confs),
193
- is_matched, label
194
- )
195
- return example
196
-
197
-
198
- class LXMERTEvaluator:
199
- def __init__(self, dataset: LXMERTDataset):
200
- self.raw_dataset = dataset
201
-
202
- # Create QA Eval Data
203
- self.data = []
204
- for datum in self.raw_dataset.data:
205
- sentf = datum['sentf']
206
- for sents_cat, sents in sentf.items():
207
- if sents_cat in datum['labelf']: # A labeled dataset
208
- labels = datum['labelf'][sents_cat]
209
- for sent_idx, sent in enumerate(sents):
210
- new_datum = {
211
- 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
212
- 'img_id': datum['img_id'],
213
- 'sent': sent,
214
- 'dset': sents_cat,
215
- 'label': labels[sent_idx]
216
- }
217
- self.data.append(new_datum)
218
-
219
- # uid2datum
220
- self.uid2datum = {}
221
- for datum in self.data:
222
- self.uid2datum[datum['uid']] = datum
223
-
224
- def evaluate(self, uid2ans: dict, pprint=False):
225
- score = 0.
226
- cnt = 0
227
- dset2score = defaultdict(lambda: 0.)
228
- dset2cnt = defaultdict(lambda: 0)
229
- for uid, ans in uid2ans.items():
230
- if uid not in self.uid2datum: # Not a labeled data
231
- continue
232
- datum = self.uid2datum[uid]
233
- label = datum['label']
234
- dset = datum['dset']
235
- if ans in label:
236
- score += label[ans]
237
- dset2score[dset] += label[ans]
238
- cnt += 1
239
- dset2cnt[dset] += 1
240
- accu = score / cnt
241
- dset2accu = {}
242
- for dset in dset2cnt:
243
- dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
244
-
245
- if pprint:
246
- accu_str = "Overall Accu %0.4f, " % (accu)
247
- sorted_keys = sorted(dset2accu.keys())
248
- for key in sorted_keys:
249
- accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
250
- print(accu_str)
251
-
252
- return accu, dset2accu
253
-
254
- def dump_result(self, uid2ans: dict, path):
255
- raise NotImplemented
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/pretrain/lxmert_pretrain.py DELETED
@@ -1,435 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import collections
5
- import os
6
- import random
7
-
8
- from tqdm import tqdm
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- from torch.utils.data import DataLoader
13
-
14
- from param import args
15
- from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
16
- from lxrt.entry import set_visual_config
17
- from lxrt.tokenization import BertTokenizer
18
- from lxrt.modeling import LXRTPretraining
19
-
20
- DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
21
-
22
-
23
- def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
24
- # Decide which QA datasets would be used in pre-training.
25
- # Options: vqa, gqa, visual7w
26
- # Note: visual7w is a part of vgqa, we take the name here.
27
- qa_sets = args.qa_sets
28
- if qa_sets is not None:
29
- qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
30
-
31
- # Build dataset, data loader, and evaluator.
32
- dset = LXMERTDataset(splits, qa_sets=qa_sets)
33
- tset = LXMERTTorchDataset(dset, topk)
34
- data_loader = DataLoader(
35
- tset, batch_size=bs,
36
- shuffle=shuffle, num_workers=args.num_workers,
37
- collate_fn=lambda x: x,
38
- drop_last=drop_last, pin_memory=True
39
- )
40
- evaluator = LXMERTEvaluator(dset)
41
- print()
42
-
43
- return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
44
-
45
-
46
- train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
47
- valid_batch_size = 2048 if args.multiGPU else 512
48
- valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
49
-
50
-
51
- class InputFeatures(object):
52
- """A single set of features of data."""
53
-
54
- def __init__(self,
55
- input_ids, input_mask, segment_ids, lm_label_ids,
56
- visual_feats, obj_labels,
57
- is_matched, ans):
58
- self.input_ids = input_ids
59
- self.input_mask = input_mask
60
- self.segment_ids = segment_ids
61
- self.lm_label_ids = lm_label_ids
62
-
63
- self.visual_feats = visual_feats
64
- self.obj_labels = obj_labels
65
-
66
- self.is_matched = is_matched
67
-
68
- self.ans = ans
69
-
70
-
71
- def random_word(tokens, tokenizer):
72
- """
73
- Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
74
- :param tokens: list of str, tokenized sentence.
75
- :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
76
- :return: (list of str, list of int), masked tokens and related labels for LM prediction
77
- """
78
- output_label = []
79
-
80
- for i, token in enumerate(tokens):
81
- prob = random.random()
82
- # mask token with probability
83
- ratio = args.word_mask_rate
84
- if prob < ratio:
85
- prob /= ratio
86
-
87
- # 80% randomly change token to mask token
88
- if prob < 0.8:
89
- tokens[i] = "[MASK]"
90
-
91
- # 10% randomly change token to random token
92
- elif prob < 0.9:
93
- tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
94
-
95
- # -> rest 10% randomly keep current token
96
-
97
- # append current token to output (we will predict these later)
98
- try:
99
- output_label.append(tokenizer.vocab[token])
100
- except KeyError:
101
- # For unknown words (should not occur with BPE vocab)
102
- output_label.append(tokenizer.vocab["[UNK]"])
103
- else:
104
- # no masking token (will be ignored by loss function later)
105
- output_label.append(-1)
106
-
107
- return tokens, output_label
108
-
109
-
110
- def random_feat(feats):
111
- mask_feats = feats.copy()
112
- feat_mask = np.zeros(len(feats), dtype=np.float32)
113
- for i in range(len(feats)):
114
- prob = random.random()
115
- # mask token with probability
116
- if prob < args.obj_mask_rate:
117
- prob /= args.obj_mask_rate
118
-
119
- # 80% randomly change token to zero feat
120
- if prob < 0.8:
121
- mask_feats[i, :] = 0.
122
-
123
- # 10% randomly change token to random feat
124
- elif prob < 0.9:
125
- mask_feats[i, :] = train_tuple.torchdset.random_feat()
126
- # -> rest 10% randomly keep current feat
127
-
128
- # Need to predict this feat
129
- feat_mask[i] = 1.
130
-
131
- return mask_feats, feat_mask
132
-
133
-
134
- def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
135
- """
136
- Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
137
- IDs, LM labels, input_mask, CLS and SEP tokens etc.
138
- :param example: InputExample, containing sentence input as strings and is_next label
139
- :param max_seq_length: int, maximum length of sequence.
140
- :param tokenizer: Tokenizer
141
- :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
142
- """
143
- tokens = tokenizer.tokenize(example.sent.strip())
144
-
145
- # Account for [CLS] and [SEP] with "- 2"
146
- if len(tokens) > max_seq_length - 2:
147
- tokens = tokens[:(max_seq_length - 2)]
148
-
149
- # Ge random words
150
- masked_tokens, masked_label = random_word(tokens, tokenizer)
151
-
152
- # concatenate lm labels and account for CLS, SEP, SEP
153
- masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
154
- input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
155
-
156
- # Mask & Segment Word
157
- lm_label_ids = ([-1] + masked_label + [-1])
158
- input_mask = [1] * len(input_ids)
159
- segment_ids = [0] * len(input_ids)
160
-
161
- # Zero-pad up to the sequence length.
162
- while len(input_ids) < max_seq_length:
163
- input_ids.append(0)
164
- input_mask.append(0)
165
- segment_ids.append(0)
166
- lm_label_ids.append(-1)
167
-
168
- assert len(input_ids) == max_seq_length
169
- assert len(input_mask) == max_seq_length
170
- assert len(segment_ids) == max_seq_length
171
- assert len(lm_label_ids) == max_seq_length
172
-
173
- feat, boxes = example.visual_feats
174
- obj_labels, obj_confs = example.obj_labels
175
- attr_labels, attr_confs = example.attr_labels
176
-
177
- # Mask Image Features:
178
- masked_feat, feat_mask = random_feat(feat)
179
-
180
- # QA answer label
181
- if example.label is None or len(example.label) == 0 or example.is_matched != 1:
182
- # 1. No label 2. Label is pruned 3. unmatched visual + language pair
183
- ans = -1
184
- else:
185
- keys, values = zip(*example.label.items())
186
- if len(keys) == 1:
187
- ans = keys[0]
188
- else:
189
- value_sum = sum(values)
190
- prob = [value / value_sum for value in values]
191
- choice = np.random.multinomial(1, prob).argmax()
192
- ans = keys[choice]
193
-
194
- features = InputFeatures(
195
- input_ids=input_ids,
196
- input_mask=input_mask,
197
- segment_ids=segment_ids,
198
- lm_label_ids=lm_label_ids,
199
- visual_feats=(masked_feat, boxes),
200
- obj_labels={
201
- 'obj': (obj_labels, obj_confs),
202
- 'attr': (attr_labels, attr_confs),
203
- 'feat': (feat, feat_mask),
204
- },
205
- is_matched=example.is_matched,
206
- ans=ans,
207
- )
208
- return features
209
-
210
-
211
- LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
212
-
213
-
214
- class LXMERT:
215
- def __init__(self, max_seq_length):
216
- super().__init__()
217
- self.max_seq_length = max_seq_length
218
-
219
- self.tokenizer = BertTokenizer.from_pretrained(
220
- "bert-base-uncased",
221
- do_lower_case=True
222
- )
223
-
224
- # Build model
225
- set_visual_config(args)
226
- self.model = LXRTPretraining.from_pretrained(
227
- "bert-base-uncased",
228
- task_mask_lm=args.task_mask_lm,
229
- task_obj_predict=args.task_obj_predict,
230
- task_matched=args.task_matched,
231
- task_qa=args.task_qa,
232
- visual_losses=args.visual_losses,
233
- num_answers=train_tuple.dataset.answer_table.num_answers
234
- )
235
-
236
- # Weight initialization and loading
237
- if args.from_scratch:
238
- print("Train from Scratch: re-initialize all BERT weights.")
239
- self.model.apply(self.model.init_bert_weights)
240
- if args.load is not None:
241
- self.load(args.load)
242
- if args.load_lxmert is not None:
243
- # Load lxmert would not load the answer head.
244
- self.load_lxmert(args.load_lxmert)
245
-
246
- # GPU Options
247
- self.model = self.model.cuda()
248
- if args.multiGPU:
249
- self.model = nn.DataParallel(self.model)
250
-
251
- def forward(self, examples):
252
- train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
253
- for example in examples]
254
-
255
- # language Inputs
256
- input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
257
- input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
258
- segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
259
-
260
- # Visual Inputs
261
- feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
262
- pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
263
-
264
- # Language Prediction
265
- lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
266
-
267
- # Visual Prediction
268
- obj_labels = {}
269
- for key in ('obj', 'attr', 'feat'):
270
- visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
271
- visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
272
- assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
273
- obj_labels[key] = (visn_labels, visn_mask)
274
-
275
- # Joint Prediction
276
- matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
277
- ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
278
-
279
- """
280
- forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
281
- visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
282
- """
283
- loss, losses, ans_logit = self.model(
284
- input_ids, segment_ids, input_mask, lm_labels,
285
- feats, pos, obj_labels, matched_labels, ans
286
- )
287
- return loss, losses.detach().cpu(), ans_logit
288
-
289
- def train_batch(self, optim, batch):
290
- optim.zero_grad()
291
- loss, losses, ans_logit = self.forward(batch)
292
- if args.multiGPU:
293
- loss = loss.mean()
294
- losses = losses.mean(0)
295
- loss.backward()
296
- nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
297
- optim.step()
298
-
299
- return loss.item(), losses.cpu().numpy(), ans_logit
300
-
301
- def valid_batch(self, batch):
302
- with torch.no_grad():
303
- loss, losses, ans_logit = self.forward(batch)
304
- if args.multiGPU:
305
- loss = loss.mean()
306
- losses = losses.mean(0)
307
- return loss.item(), losses.cpu().numpy(), ans_logit
308
-
309
- def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
310
- train_ld = train_tuple.loader
311
-
312
- # Optimizer
313
- from lxrt.optimization import BertAdam
314
- batch_per_epoch = len(train_ld)
315
- t_total = int(batch_per_epoch * args.epochs)
316
- warmup_ratio = 0.05
317
- warmup_iters = int(t_total * warmup_ratio)
318
- print("Batch per epoch: %d" % batch_per_epoch)
319
- print("Total Iters: %d" % t_total)
320
- print("Warm up Iters: %d" % warmup_iters)
321
- optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
322
-
323
- # Train
324
- best_eval_loss = 9595.
325
- for epoch in range(args.epochs):
326
- # Train
327
- self.model.train()
328
- total_loss = 0.
329
- total_losses = 0.
330
- uid2ans = {}
331
- for batch in tqdm(train_ld, total=len(train_ld)):
332
- loss, losses, logit = self.train_batch(optim, batch)
333
- total_loss += loss
334
- total_losses += losses
335
-
336
- if args.task_qa:
337
- score, label = logit.max(1)
338
- for datum, l in zip(batch, label.cpu().numpy()):
339
- uid = datum.uid
340
- ans = train_tuple.dataset.answer_table.id2ans(l)
341
- uid2ans[uid] = ans
342
-
343
- print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
344
- losses_str = "The losses are "
345
- for name, loss in zip(LOSSES_NAME, total_losses):
346
- losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
347
- print(losses_str)
348
- if args.task_qa:
349
- train_tuple.evaluator.evaluate(uid2ans, pprint=True)
350
-
351
- # Eval
352
- avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
353
-
354
- # Save
355
- if avg_eval_loss < best_eval_loss:
356
- best_eval_loss = avg_eval_loss
357
- self.save("BEST_EVAL_LOSS")
358
- self.save("Epoch%02d" % (epoch+1))
359
-
360
- def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
361
- self.model.eval()
362
- eval_ld = eval_tuple.loader
363
- total_loss = 0.
364
- total_losses = 0.
365
- uid2ans = {}
366
- for i, batch in enumerate(eval_ld):
367
- loss, losses, logit = self.valid_batch(batch)
368
- total_loss += loss
369
- total_losses += losses
370
- if args.task_qa:
371
- score, label = logit.max(1)
372
- for datum, l in zip(batch, label.cpu().numpy()):
373
- uid = datum.uid
374
- ans = train_tuple.dataset.answer_table.id2ans(l)
375
- uid2ans[uid] = ans
376
- if i == iters:
377
- break
378
-
379
- print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
380
- losses_str = "The losses are "
381
- for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
382
- losses_str += "%s: %0.4f " % (name, loss)
383
- print(losses_str)
384
-
385
- if args.task_qa:
386
- eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
387
-
388
- return total_loss / len(eval_ld)
389
-
390
- def save(self, name):
391
- torch.save(self.model.state_dict(),
392
- os.path.join(args.output, "%s_LXRT.pth" % name))
393
-
394
- def load(self, path):
395
- print("Load BERT extractor from %s" % path)
396
- state_dict = torch.load("%s_LXRT.pth" % path)
397
- self.model.load_state_dict(state_dict)
398
-
399
- def load_lxmert(self, path):
400
- print("Load lxmert model from %s" % path)
401
- state_dict = torch.load("%s_LXRT.pth" % path)
402
-
403
- # Do not load any answer head
404
- for key in list(state_dict.keys()):
405
- if 'answer' in key:
406
- state_dict.pop(key)
407
-
408
- # Change Multi GPU to single GPU
409
- new_state_dict = {}
410
- for key, value in state_dict.items():
411
- if key.startswith("module."):
412
- new_state_dict[key[len("module."):]] = value
413
- state_dict = new_state_dict
414
-
415
- load_keys = set(state_dict.keys())
416
- model_keys = set(self.model.state_dict().keys())
417
- print()
418
- print("Keys in loaded but not in model:")
419
- for key in sorted(load_keys.difference(model_keys)):
420
- print(key)
421
- print()
422
- print("Keys in model but not in loaded:")
423
- for key in sorted(model_keys.difference(load_keys)):
424
- print(key)
425
- print()
426
-
427
- self.model.load_state_dict(state_dict, strict=False)
428
-
429
-
430
- if __name__ == "__main__":
431
-
432
- lxmert = LXMERT(max_seq_length=20)
433
-
434
-
435
- lxmert.train(train_tuple, valid_tuple)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/pretrain/qa_answer_table.py DELETED
@@ -1,158 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import json
5
- import torch
6
-
7
-
8
- class AnswerTable:
9
- ANS_CONVERT = {
10
- "a man": "man",
11
- "the man": "man",
12
- "a woman": "woman",
13
- "the woman": "woman",
14
- 'one': '1',
15
- 'two': '2',
16
- 'three': '3',
17
- 'four': '4',
18
- 'five': '5',
19
- 'six': '6',
20
- 'seven': '7',
21
- 'eight': '8',
22
- 'nine': '9',
23
- 'ten': '10',
24
- 'grey': 'gray',
25
- }
26
-
27
- def __init__(self, dsets=None):
28
- self.all_ans = json.load(open("data/lxmert/all_ans.json"))
29
- if dsets is not None:
30
- dsets = set(dsets)
31
- # If the answer is used in the dsets
32
- self.anss = [ans['ans'] for ans in self.all_ans if
33
- len(set(ans['dsets']) & dsets) > 0]
34
- else:
35
- self.anss = [ans['ans'] for ans in self.all_ans]
36
- self.ans_set = set(self.anss)
37
-
38
- self._id2ans_map = self.anss
39
- self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
40
-
41
- assert len(self._id2ans_map) == len(self._ans2id_map)
42
- for ans_id, ans in enumerate(self._id2ans_map):
43
- assert self._ans2id_map[ans] == ans_id
44
-
45
- def convert_ans(self, ans):
46
- if len(ans) == 0:
47
- return ""
48
- ans = ans.lower()
49
- if ans[-1] == '.':
50
- ans = ans[:-1].strip()
51
- if ans.startswith("a "):
52
- ans = ans[2:].strip()
53
- if ans.startswith("an "):
54
- ans = ans[3:].strip()
55
- if ans.startswith("the "):
56
- ans = ans[4:].strip()
57
- if ans in self.ANS_CONVERT:
58
- ans = self.ANS_CONVERT[ans]
59
- return ans
60
-
61
- def ans2id(self, ans):
62
- return self._ans2id_map[ans]
63
-
64
- def id2ans(self, ans_id):
65
- return self._id2ans_map[ans_id]
66
-
67
- def ans2id_map(self):
68
- return self._ans2id_map.copy()
69
-
70
- def id2ans_map(self):
71
- return self._id2ans_map.copy()
72
-
73
- def used(self, ans):
74
- return ans in self.ans_set
75
-
76
- def all_answers(self):
77
- return self.anss.copy()
78
-
79
- @property
80
- def num_answers(self):
81
- return len(self.anss)
82
-
83
-
84
- def load_lxmert_qa(path, model, label2ans):
85
- """
86
- Load model weights from lxmert pre-training.
87
- The answers in the fine-tuned QA task (indicated by label2ans)
88
- would also be properly initialized with lxmert pre-trained
89
- QA heads.
90
-
91
- :param path: Path to lxmert snapshot.
92
- :param model: LXRT model instance.
93
- :param label2ans: The label2ans dict of fine-tuned QA datasets, like
94
- {0: 'cat', 1: 'dog', ...}
95
- :return:
96
- """
97
- print("Load QA pre-trained lxmert from %s " % path)
98
- loaded_state_dict = torch.load("%s_LXRT.pth" % path)
99
- model_state_dict = model.state_dict()
100
-
101
- # Handle Multi-GPU pre-training --> Single GPU fine-tuning
102
- for key in list(loaded_state_dict.keys()):
103
- loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
104
-
105
- # Isolate bert model
106
- bert_state_dict = {}
107
- for key, value in loaded_state_dict.items():
108
- if key.startswith('bert.'):
109
- bert_state_dict[key] = value
110
-
111
- # Isolate answer head
112
- answer_state_dict = {}
113
- for key, value in loaded_state_dict.items():
114
- if key.startswith("answer_head."):
115
- answer_state_dict[key.replace('answer_head.', '')] = value
116
-
117
- # Do surgery on answer state dict
118
- ans_weight = answer_state_dict['logit_fc.3.weight']
119
- ans_bias = answer_state_dict['logit_fc.3.bias']
120
- import copy
121
- new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
122
- new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
123
- answer_table = AnswerTable()
124
- loaded = 0
125
- unload = 0
126
- if type(label2ans) is list:
127
- label2ans = {label: ans for label, ans in enumerate(label2ans)}
128
- for label, ans in label2ans.items():
129
- new_ans = answer_table.convert_ans(ans)
130
- if answer_table.used(new_ans):
131
- ans_id_9500 = answer_table.ans2id(new_ans)
132
- new_answer_weight[label] = ans_weight[ans_id_9500]
133
- new_answer_bias[label] = ans_bias[ans_id_9500]
134
- loaded += 1
135
- else:
136
- new_answer_weight[label] = 0.
137
- new_answer_bias[label] = 0.
138
- unload += 1
139
- print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
140
- print()
141
- answer_state_dict['logit_fc.3.weight'] = new_answer_weight
142
- answer_state_dict['logit_fc.3.bias'] = new_answer_bias
143
-
144
- # Load Bert Weights
145
- bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
146
- bert_loaded_keys = set(bert_state_dict.keys())
147
- assert len(bert_model_keys - bert_loaded_keys) == 0
148
- model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
149
-
150
- # Load Answer Logic FC Weights
151
- model_keys = set(model.state_dict().keys())
152
- ans_loaded_keys = set(answer_state_dict.keys())
153
- assert len(ans_loaded_keys - model_keys) == 0
154
-
155
- model.load_state_dict(answer_state_dict, strict=False)
156
-
157
-
158
-