WwYc commited on
Commit
9c3230e
1 Parent(s): 4f00e9e

Delete lxmert/src/tasks

Browse files
lxmert/src/tasks/__init__.py DELETED
File without changes
lxmert/src/tasks/gqa.py DELETED
@@ -1,210 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import os
5
- import collections
6
-
7
- import torch
8
- from tqdm import tqdm
9
- import torch.nn as nn
10
- from torch.utils.data.dataloader import DataLoader
11
-
12
- from param import args
13
- from pretrain.qa_answer_table import load_lxmert_qa
14
- from tasks.gqa_model import GQAModel
15
- from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
16
-
17
-
18
- DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
19
-
20
-
21
- def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
22
- dset = GQADataset(splits)
23
- tset = GQATorchDataset(dset)
24
- evaluator = GQAEvaluator(dset)
25
- data_loader = DataLoader(
26
- tset, batch_size=bs,
27
- shuffle=shuffle, num_workers=args.num_workers,
28
- drop_last=drop_last, pin_memory=True
29
- )
30
-
31
- return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
32
-
33
-
34
- class GQA:
35
- def __init__(self):
36
- self.train_tuple = get_tuple(
37
- args.train, bs=args.batch_size, shuffle=True, drop_last=True
38
- )
39
- if args.valid != "":
40
- valid_bsize = 2048 if args.multiGPU else 512
41
- self.valid_tuple = get_tuple(
42
- args.valid, bs=valid_bsize,
43
- shuffle=False, drop_last=False
44
- )
45
- else:
46
- self.valid_tuple = None
47
-
48
- self.model = GQAModel(self.train_tuple.dataset.num_answers)
49
-
50
- # Load pre-trained weights
51
- if args.load_lxmert is not None:
52
- self.model.lxrt_encoder.load(args.load_lxmert)
53
- if args.load_lxmert_qa is not None:
54
- load_lxmert_qa(args.load_lxmert_qa, self.model,
55
- label2ans=self.train_tuple.dataset.label2ans)
56
-
57
- # GPU options
58
- self.model = self.model.cuda()
59
- if args.multiGPU:
60
- self.model.lxrt_encoder.multi_gpu()
61
-
62
- # Losses and optimizer
63
- self.bce_loss = nn.BCEWithLogitsLoss()
64
- self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
65
- if 'bert' in args.optim:
66
- batch_per_epoch = len(self.train_tuple.loader)
67
- t_total = int(batch_per_epoch * args.epochs)
68
- print("Total Iters: %d" % t_total)
69
- from lxrt.optimization import BertAdam
70
- self.optim = BertAdam(list(self.model.parameters()),
71
- lr=args.lr,
72
- warmup=0.1,
73
- t_total=t_total)
74
- else:
75
- self.optim = args.optimizer(list(self.model.parameters()), args.lr)
76
-
77
- self.output = args.output
78
- os.makedirs(self.output, exist_ok=True)
79
-
80
- def train(self, train_tuple, eval_tuple):
81
- dset, loader, evaluator = train_tuple
82
- iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
83
-
84
- best_valid = 0.
85
- for epoch in range(args.epochs):
86
- quesid2ans = {}
87
- for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
88
-
89
- self.model.train()
90
- self.optim.zero_grad()
91
-
92
- feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
93
- logit = self.model(feats, boxes, sent)
94
- assert logit.dim() == target.dim() == 2
95
- if args.mce_loss:
96
- max_value, target = target.max(1)
97
- loss = self.mce_loss(logit, target) * logit.size(1)
98
- else:
99
- loss = self.bce_loss(logit, target)
100
- loss = loss * logit.size(1)
101
-
102
- loss.backward()
103
- nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
104
- self.optim.step()
105
-
106
- score, label = logit.max(1)
107
- for qid, l in zip(ques_id, label.cpu().numpy()):
108
- ans = dset.label2ans[l]
109
- quesid2ans[qid] = ans
110
-
111
- log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
112
-
113
- if self.valid_tuple is not None: # Do Validation
114
- valid_score = self.evaluate(eval_tuple)
115
- if valid_score > best_valid:
116
- best_valid = valid_score
117
- self.save("BEST")
118
-
119
- log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
120
- "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
121
-
122
- print(log_str, end='')
123
-
124
- with open(self.output + "/log.log", 'a') as f:
125
- f.write(log_str)
126
- f.flush()
127
-
128
- self.save("LAST")
129
-
130
- def predict(self, eval_tuple: DataTuple, dump=None):
131
- self.model.eval()
132
- dset, loader, evaluator = eval_tuple
133
- quesid2ans = {}
134
- for i, datum_tuple in enumerate(loader):
135
- ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
136
- with torch.no_grad():
137
- feats, boxes = feats.cuda(), boxes.cuda()
138
- logit = self.model(feats, boxes, sent)
139
- score, label = logit.max(1)
140
- for qid, l in zip(ques_id, label.cpu().numpy()):
141
- ans = dset.label2ans[l]
142
- quesid2ans[qid] = ans
143
- if dump is not None:
144
- evaluator.dump_result(quesid2ans, dump)
145
- return quesid2ans
146
-
147
- def evaluate(self, eval_tuple: DataTuple, dump=None):
148
- dset, loader, evaluator = eval_tuple
149
- quesid2ans = self.predict(eval_tuple, dump)
150
- return evaluator.evaluate(quesid2ans)
151
-
152
- @staticmethod
153
- def oracle_score(data_tuple):
154
- dset, loader, evaluator = data_tuple
155
- quesid2ans = {}
156
- for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
157
- _, label = target.max(1)
158
- for qid, l in zip(ques_id, label.cpu().numpy()):
159
- ans = dset.label2ans[l]
160
- quesid2ans[qid] = ans
161
- return evaluator.evaluate(quesid2ans)
162
-
163
- def save(self, name):
164
- torch.save(self.model.state_dict(),
165
- os.path.join(self.output, "%s.pth" % name))
166
-
167
- def load(self, path):
168
- print("Load model from %s" % path)
169
- state_dict = torch.load("%s.pth" % path)
170
- for key in list(state_dict.keys()):
171
- if '.module' in key:
172
- state_dict[key.replace('.module', '')] = state_dict.pop(key)
173
- self.model.load_state_dict(state_dict, strict=False)
174
-
175
-
176
- if __name__ == "__main__":
177
- # Build Class
178
- gqa = GQA()
179
-
180
- # Load Model
181
- if args.load is not None:
182
- gqa.load(args.load)
183
-
184
- # Test or Train
185
- if args.test is not None:
186
- args.fast = args.tiny = False # Always loading all data in test
187
- if 'submit' in args.test:
188
- gqa.predict(
189
- get_tuple(args.test, bs=args.batch_size,
190
- shuffle=False, drop_last=False),
191
- dump=os.path.join(args.output, 'submit_predict.json')
192
- )
193
- if 'testdev' in args.test:
194
- result = gqa.evaluate(
195
- get_tuple('testdev', bs=args.batch_size,
196
- shuffle=False, drop_last=False),
197
- dump=os.path.join(args.output, 'testdev_predict.json')
198
- )
199
- print(result)
200
- else:
201
- # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
202
- print('Splits in Train data:', gqa.train_tuple.dataset.splits)
203
- if gqa.valid_tuple is not None:
204
- print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
205
- print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
206
- else:
207
- print("DO NOT USE VALIDATION")
208
- gqa.train(gqa.train_tuple, gqa.valid_tuple)
209
-
210
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/gqa_data.py DELETED
@@ -1,194 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import json
5
-
6
- import numpy as np
7
- import torch
8
- from torch.utils.data import Dataset
9
-
10
- from param import args
11
- from utils import load_obj_tsv
12
-
13
- # Load part of the dataset for fast checking.
14
- # Notice that here is the number of images instead of the number of data,
15
- # which means all related data to the images would be used.
16
- TINY_IMG_NUM = 512
17
- FAST_IMG_NUM = 5000
18
-
19
-
20
- class GQADataset:
21
- """
22
- A GQA data example in json file:
23
- {
24
- "img_id": "2375429",
25
- "label": {
26
- "pipe": 1.0
27
- },
28
- "question_id": "07333408",
29
- "sent": "What is on the white wall?"
30
- }
31
- """
32
- def __init__(self, splits: str):
33
- self.name = splits
34
- self.splits = splits.split(',')
35
-
36
- # Loading datasets to data
37
- self.data = []
38
- for split in self.splits:
39
- self.data.extend(json.load(open("data/gqa/%s.json" % split)))
40
- print("Load %d data from split(s) %s." % (len(self.data), self.name))
41
-
42
- # List to dict (for evaluation and others)
43
- self.id2datum = {
44
- datum['question_id']: datum
45
- for datum in self.data
46
- }
47
-
48
- # Answers
49
- self.ans2label = json.load(open("data/gqa/trainval_ans2label.json"))
50
- self.label2ans = json.load(open("data/gqa/trainval_label2ans.json"))
51
- assert len(self.ans2label) == len(self.label2ans)
52
- for ans, label in self.ans2label.items():
53
- assert self.label2ans[label] == ans
54
-
55
- @property
56
- def num_answers(self):
57
- return len(self.ans2label)
58
-
59
- def __len__(self):
60
- return len(self.data)
61
-
62
-
63
- class GQABufferLoader():
64
- def __init__(self):
65
- self.key2data = {}
66
-
67
- def load_data(self, name, number):
68
- if name == 'testdev':
69
- path = "data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv"
70
- else:
71
- path = "data/vg_gqa_imgfeat/vg_gqa_obj36.tsv"
72
- key = "%s_%d" % (path, number)
73
- if key not in self.key2data:
74
- self.key2data[key] = load_obj_tsv(
75
- path,
76
- topk=number
77
- )
78
- return self.key2data[key]
79
-
80
-
81
- gqa_buffer_loader = GQABufferLoader()
82
-
83
-
84
- """
85
- Example in obj tsv:
86
- FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
87
- "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
88
- """
89
- class GQATorchDataset(Dataset):
90
- def __init__(self, dataset: GQADataset):
91
- super().__init__()
92
- self.raw_dataset = dataset
93
-
94
- if args.tiny:
95
- topk = TINY_IMG_NUM
96
- elif args.fast:
97
- topk = FAST_IMG_NUM
98
- else:
99
- topk = -1
100
-
101
- # Loading detection features to img_data
102
- # Since images in train and valid both come from Visual Genome,
103
- # buffer the image loading to save memory.
104
- img_data = []
105
- if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev
106
- img_data.extend(gqa_buffer_loader.load_data('testdev', -1))
107
- else:
108
- img_data.extend(gqa_buffer_loader.load_data('train', topk))
109
- self.imgid2img = {}
110
- for img_datum in img_data:
111
- self.imgid2img[img_datum['img_id']] = img_datum
112
-
113
- # Only kept the data with loaded image features
114
- self.data = []
115
- for datum in self.raw_dataset.data:
116
- if datum['img_id'] in self.imgid2img:
117
- self.data.append(datum)
118
- print("Use %d data in torch dataset" % (len(self.data)))
119
- print()
120
-
121
- def __len__(self):
122
- return len(self.data)
123
-
124
- def __getitem__(self, item: int):
125
- datum = self.data[item]
126
-
127
- img_id = datum['img_id']
128
- ques_id = datum['question_id']
129
- ques = datum['sent']
130
-
131
- # Get image info
132
- img_info = self.imgid2img[img_id]
133
- obj_num = img_info['num_boxes']
134
- boxes = img_info['boxes'].copy()
135
- feats = img_info['features'].copy()
136
- assert len(boxes) == len(feats) == obj_num
137
-
138
- # Normalize the boxes (to 0 ~ 1)
139
- img_h, img_w = img_info['img_h'], img_info['img_w']
140
- boxes = boxes.copy()
141
- boxes[:, (0, 2)] /= img_w
142
- boxes[:, (1, 3)] /= img_h
143
- np.testing.assert_array_less(boxes, 1+1e-5)
144
- np.testing.assert_array_less(-boxes, 0+1e-5)
145
-
146
- # Create target
147
- if 'label' in datum:
148
- label = datum['label']
149
- target = torch.zeros(self.raw_dataset.num_answers)
150
- for ans, score in label.items():
151
- if ans in self.raw_dataset.ans2label:
152
- target[self.raw_dataset.ans2label[ans]] = score
153
- return ques_id, feats, boxes, ques, target
154
- else:
155
- return ques_id, feats, boxes, ques
156
-
157
-
158
- class GQAEvaluator:
159
- def __init__(self, dataset: GQADataset):
160
- self.dataset = dataset
161
-
162
- def evaluate(self, quesid2ans: dict):
163
- score = 0.
164
- for quesid, ans in quesid2ans.items():
165
- datum = self.dataset.id2datum[quesid]
166
- label = datum['label']
167
- if ans in label:
168
- score += label[ans]
169
- return score / len(quesid2ans)
170
-
171
- def dump_result(self, quesid2ans: dict, path):
172
- """
173
- Dump the result to a GQA-challenge submittable json file.
174
- GQA json file submission requirement:
175
- results = [result]
176
- result = {
177
- "questionId": str, # Note: it's a actually an int number but the server requires an str.
178
- "prediction": str
179
- }
180
-
181
- :param quesid2ans: A dict mapping question id to its predicted answer.
182
- :param path: The file path to save the json file.
183
- :return:
184
- """
185
- with open(path, 'w') as f:
186
- result = []
187
- for ques_id, ans in quesid2ans.items():
188
- result.append({
189
- 'questionId': ques_id,
190
- 'prediction': ans
191
- })
192
- json.dump(result, f, indent=4, sort_keys=True)
193
-
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/gqa_model.py DELETED
@@ -1,45 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import torch.nn as nn
5
-
6
- from param import args
7
- from lxrt.entry import LXRTEncoder
8
- from lxrt.modeling import BertLayerNorm, GeLU
9
-
10
- # Max length including <bos> and <eos>
11
- MAX_GQA_LENGTH = 20
12
-
13
-
14
- class GQAModel(nn.Module):
15
- def __init__(self, num_answers):
16
- super().__init__()
17
- self.lxrt_encoder = LXRTEncoder(
18
- args,
19
- max_seq_length=MAX_GQA_LENGTH
20
- )
21
- hid_dim = self.lxrt_encoder.dim
22
- self.logit_fc = nn.Sequential(
23
- nn.Linear(hid_dim, hid_dim * 2),
24
- GeLU(),
25
- BertLayerNorm(hid_dim * 2, eps=1e-12),
26
- nn.Linear(hid_dim * 2, num_answers)
27
- )
28
- self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
29
-
30
- def forward(self, feat, pos, sent):
31
- """
32
- b -- batch_size, o -- object_number, f -- visual_feature_size
33
-
34
- :param feat: (b, o, f)
35
- :param pos: (b, o, 4)
36
- :param sent: (b,) Type -- list of string
37
- :param leng: (b,) Type -- int numpy array
38
- :return: (b, num_answer) The logit of each answers.
39
- """
40
- x = self.lxrt_encoder(sent, (feat, pos))
41
- logit = self.logit_fc(x)
42
-
43
- return logit
44
-
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/nlvr2.py DELETED
@@ -1,182 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import os
5
- import collections
6
-
7
- from tqdm import tqdm
8
- import torch
9
- import torch.nn as nn
10
- from torch.utils.data.dataloader import DataLoader
11
-
12
- from param import args
13
- from tasks.nlvr2_model import NLVR2Model
14
- from tasks.nlvr2_data import NLVR2Dataset, NLVR2TorchDataset, NLVR2Evaluator
15
-
16
- DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
17
-
18
-
19
- def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
20
- dset = NLVR2Dataset(splits)
21
- tset = NLVR2TorchDataset(dset)
22
- evaluator = NLVR2Evaluator(dset)
23
- data_loader = DataLoader(
24
- tset, batch_size=bs,
25
- shuffle=shuffle, num_workers=args.num_workers,
26
- drop_last=drop_last, pin_memory=True
27
- )
28
-
29
- return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
30
-
31
-
32
- class NLVR2:
33
- def __init__(self):
34
- self.train_tuple = get_tuple(
35
- args.train, bs=args.batch_size, shuffle=True, drop_last=True
36
- )
37
- if args.valid != "":
38
- valid_bsize = 2048 if args.multiGPU else 512
39
- self.valid_tuple = get_tuple(
40
- args.valid, bs=valid_bsize,
41
- shuffle=False, drop_last=False
42
- )
43
- else:
44
- self.valid_tuple = None
45
-
46
- self.model = NLVR2Model()
47
-
48
- # Load pre-trained weights
49
- if args.load_lxmert is not None:
50
- self.model.lxrt_encoder.load(args.load_lxmert)
51
-
52
- # GPU options
53
- if args.multiGPU:
54
- self.model.lxrt_encoder.multi_gpu()
55
- self.model = self.model.cuda()
56
-
57
- # Losses and optimizer
58
- self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
59
- if 'bert' in args.optim:
60
- batch_per_epoch = len(self.train_tuple.loader)
61
- t_total = int(batch_per_epoch * args.epochs)
62
- print("Total Iters: %d" % t_total)
63
- from lxrt.optimization import BertAdam
64
- self.optim = BertAdam(list(self.model.parameters()),
65
- lr=args.lr,
66
- warmup=0.1,
67
- t_total=t_total)
68
- else:
69
- self.optim = args.optimizer(list(self.model.parameters()), args.lr)
70
-
71
- self.output = args.output
72
- os.makedirs(self.output, exist_ok=True)
73
-
74
- def train(self, train_tuple, eval_tuple):
75
- dset, loader, evaluator = train_tuple
76
- iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
77
-
78
- best_valid = 0.
79
- for epoch in range(args.epochs):
80
- quesid2ans = {}
81
- for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)):
82
- self.model.train()
83
-
84
- self.optim.zero_grad()
85
- feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda()
86
- logit = self.model(feats, boxes, sent)
87
-
88
- loss = self.mce_loss(logit, label)
89
-
90
- loss.backward()
91
- nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
92
- self.optim.step()
93
-
94
- score, predict = logit.max(1)
95
- for qid, l in zip(ques_id, predict.cpu().numpy()):
96
- quesid2ans[qid] = l
97
-
98
- log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
99
-
100
- if self.valid_tuple is not None: # Do Validation
101
- valid_score = self.evaluate(eval_tuple)
102
- if valid_score > best_valid:
103
- best_valid = valid_score
104
- self.save("BEST")
105
-
106
- log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
107
- "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
108
-
109
- print(log_str, end='')
110
-
111
- with open(self.output + "/log.log", 'a') as f:
112
- f.write(log_str)
113
- f.flush()
114
-
115
- self.save("LAST")
116
-
117
- def predict(self, eval_tuple: DataTuple, dump=None):
118
- self.model.eval()
119
- dset, loader, evaluator = eval_tuple
120
- quesid2ans = {}
121
- for i, datum_tuple in enumerate(loader):
122
- ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
123
- with torch.no_grad():
124
- feats, boxes = feats.cuda(), boxes.cuda()
125
- logit = self.model(feats, boxes, sent)
126
- score, predict = logit.max(1)
127
- for qid, l in zip(ques_id, predict.cpu().numpy()):
128
- quesid2ans[qid] = l
129
- if dump is not None:
130
- evaluator.dump_result(quesid2ans, dump)
131
- return quesid2ans
132
-
133
- def evaluate(self, eval_tuple: DataTuple, dump=None):
134
- dset, loader, evaluator = eval_tuple
135
- quesid2ans = self.predict(eval_tuple, dump)
136
- return evaluator.evaluate(quesid2ans)
137
-
138
- def save(self, name):
139
- torch.save(self.model.state_dict(),
140
- os.path.join(self.output, "%s.pth" % name))
141
-
142
- def load(self, path):
143
- print("Load model from %s" % path)
144
- state_dict = torch.load("%s.pth" % path)
145
- self.model.load_state_dict(state_dict)
146
-
147
-
148
- if __name__ == "__main__":
149
- # Build Class
150
- nlvr2 = NLVR2()
151
-
152
- # Load Model
153
- if args.load is not None:
154
- nlvr2.load(args.load)
155
-
156
- # Test or Train
157
- if args.test is not None:
158
- args.fast = args.tiny = False # Always loading all data in test
159
- if 'hidden' in args.test:
160
- nlvr2.predict(
161
- get_tuple(args.test, bs=args.batch_size,
162
- shuffle=False, drop_last=False),
163
- dump=os.path.join(args.output, 'hidden_predict.csv')
164
- )
165
- elif 'test' in args.test or 'valid' in args.test:
166
- result = nlvr2.evaluate(
167
- get_tuple(args.test, bs=args.batch_size,
168
- shuffle=False, drop_last=False),
169
- dump=os.path.join(args.output, '%s_predict.csv' % args.test)
170
- )
171
- print(result)
172
- else:
173
- assert False, "No such test option for %s" % args.test
174
- else:
175
- print('Splits in Train data:', nlvr2.train_tuple.dataset.splits)
176
- if nlvr2.valid_tuple is not None:
177
- print('Splits in Valid data:', nlvr2.valid_tuple.dataset.splits)
178
- else:
179
- print("DO NOT USE VALIDATION")
180
- nlvr2.train(nlvr2.train_tuple, nlvr2.valid_tuple)
181
-
182
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/nlvr2_data.py DELETED
@@ -1,157 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import json
5
-
6
- import numpy as np
7
- from torch.utils.data import Dataset
8
-
9
- from param import args
10
- from utils import load_obj_tsv
11
-
12
- # Load part of the dataset for fast checking.
13
- # Notice that here is the number of images instead of the number of data,
14
- # which means all related data to the images would be used.
15
- TINY_IMG_NUM = 512
16
- FAST_IMG_NUM = 5000
17
-
18
-
19
- class NLVR2Dataset:
20
- """
21
- An NLVR2 data example in json file:
22
- {
23
- "identifier": "train-10171-0-0",
24
- "img0": "train-10171-0-img0",
25
- "img1": "train-10171-0-img1",
26
- "label": 0,
27
- "sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside.
28
- ",
29
- "uid": "nlvr2_train_0"
30
- }
31
- """
32
- def __init__(self, splits: str):
33
- self.name = splits
34
- self.splits = splits.split(',')
35
-
36
- # Loading datasets to data
37
- self.data = []
38
- for split in self.splits:
39
- self.data.extend(json.load(open("data/nlvr2/%s.json" % split)))
40
- print("Load %d data from split(s) %s." % (len(self.data), self.name))
41
-
42
- # List to dict (for evaluation and others)
43
- self.id2datum = {
44
- datum['uid']: datum
45
- for datum in self.data
46
- }
47
-
48
- def __len__(self):
49
- return len(self.data)
50
-
51
-
52
- """
53
- An example in obj36 tsv:
54
- FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
55
- "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
56
- FIELDNAMES would be keys in the dict returned by load_obj_tsv.
57
- """
58
- class NLVR2TorchDataset(Dataset):
59
- def __init__(self, dataset: NLVR2Dataset):
60
- super().__init__()
61
- self.raw_dataset = dataset
62
-
63
- if args.tiny:
64
- topk = TINY_IMG_NUM
65
- elif args.fast:
66
- topk = FAST_IMG_NUM
67
- else:
68
- topk = -1
69
-
70
- # Loading detection features to img_data
71
- img_data = []
72
- if 'train' in dataset.splits:
73
- img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk))
74
- if 'valid' in dataset.splits:
75
- img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk))
76
- if 'test' in dataset.name:
77
- img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk))
78
- self.imgid2img = {}
79
- for img_datum in img_data:
80
- self.imgid2img[img_datum['img_id']] = img_datum
81
-
82
- # Filter out the dataset
83
- self.data = []
84
- for datum in self.raw_dataset.data:
85
- if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img:
86
- self.data.append(datum)
87
- print("Use %d data in torch dataset" % (len(self.data)))
88
- print()
89
-
90
- def __len__(self):
91
- return len(self.data)
92
-
93
- def __getitem__(self, item: int):
94
- datum = self.data[item]
95
-
96
- ques_id = datum['uid']
97
- ques = datum['sent']
98
-
99
- # Get image info
100
- boxes2 = []
101
- feats2 = []
102
- for key in ['img0', 'img1']:
103
- img_id = datum[key]
104
- img_info = self.imgid2img[img_id]
105
- boxes = img_info['boxes'].copy()
106
- feats = img_info['features'].copy()
107
- assert len(boxes) == len(feats)
108
-
109
- # Normalize the boxes (to 0 ~ 1)
110
- img_h, img_w = img_info['img_h'], img_info['img_w']
111
- boxes[..., (0, 2)] /= img_w
112
- boxes[..., (1, 3)] /= img_h
113
- np.testing.assert_array_less(boxes, 1+1e-5)
114
- np.testing.assert_array_less(-boxes, 0+1e-5)
115
-
116
- boxes2.append(boxes)
117
- feats2.append(feats)
118
- feats = np.stack(feats2)
119
- boxes = np.stack(boxes2)
120
-
121
- # Create target
122
- if 'label' in datum:
123
- label = datum['label']
124
- return ques_id, feats, boxes, ques, label
125
- else:
126
- return ques_id, feats, boxes, ques
127
-
128
-
129
- class NLVR2Evaluator:
130
- def __init__(self, dataset: NLVR2Dataset):
131
- self.dataset = dataset
132
-
133
- def evaluate(self, quesid2ans: dict):
134
- score = 0.
135
- for quesid, ans in quesid2ans.items():
136
- datum = self.dataset.id2datum[quesid]
137
- label = datum['label']
138
- if ans == label:
139
- score += 1
140
- return score / len(quesid2ans)
141
-
142
- def dump_result(self, quesid2ans: dict, path):
143
- """
144
- Dump result to a CSV file, which is compatible with NLVR2 evaluation system.
145
- NLVR2 CSV file requirement:
146
- Each line contains: identifier, answer
147
-
148
- :param quesid2ans: nlvr2 uid to ans (either "True" or "False")
149
- :param path: The desired path of saved file.
150
- :return:
151
- """
152
- with open(path, 'w') as f:
153
- for uid, ans in quesid2ans.items():
154
- idt = self.dataset.id2datum[uid]["identifier"]
155
- ans = 'True' if ans == 1 else 'False'
156
- f.write("%s,%s\n" % (idt, ans))
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/nlvr2_model.py DELETED
@@ -1,55 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import torch.nn as nn
5
- from lxrt.modeling import GeLU, BertLayerNorm
6
- from lxrt.entry import LXRTEncoder
7
- from param import args
8
-
9
-
10
- class NLVR2Model(nn.Module):
11
- def __init__(self):
12
- super().__init__()
13
- self.lxrt_encoder = LXRTEncoder(
14
- args,
15
- max_seq_length=20
16
- )
17
- self.hid_dim = hid_dim = self.lxrt_encoder.dim
18
- self.logit_fc = nn.Sequential(
19
- nn.Linear(hid_dim * 2, hid_dim * 2),
20
- GeLU(),
21
- BertLayerNorm(hid_dim * 2, eps=1e-12),
22
- nn.Linear(hid_dim * 2, 2)
23
- )
24
- self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
25
-
26
- def forward(self, feat, pos, sent):
27
- """
28
- :param feat: b, 2, o, f
29
- :param pos: b, 2, o, 4
30
- :param sent: b, (string)
31
- :param leng: b, (numpy, int)
32
- :return:
33
- """
34
- # Pairing images and sentences:
35
- # The input of NLVR2 is two images and one sentence. In batch level, they are saved as
36
- # [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...]
37
- # Here, we flat them to
38
- # feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...]
39
- # sent = [ sent0, sent0, sent1, sent1, ...]
40
- sent = sum(zip(sent, sent), ())
41
- batch_size, img_num, obj_num, feat_size = feat.size()
42
- assert img_num == 2 and obj_num == 36 and feat_size == 2048
43
- feat = feat.view(batch_size * 2, obj_num, feat_size)
44
- pos = pos.view(batch_size * 2, obj_num, 4)
45
-
46
- # Extract feature --> Concat
47
- x = self.lxrt_encoder(sent, (feat, pos))
48
- x = x.view(-1, self.hid_dim*2)
49
-
50
- # Compute logit of answers
51
- logit = self.logit_fc(x)
52
-
53
- return logit
54
-
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/vqa.py DELETED
@@ -1,214 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import os
5
- import collections
6
-
7
- import torch
8
- import torch.nn as nn
9
- from torch.utils.data.dataloader import DataLoader
10
- from tqdm import tqdm
11
-
12
- from ..param import args
13
- from ..pretrain.qa_answer_table import load_lxmert_qa
14
- from .vqa_model import VQAModel
15
- from .vqa_data import VQADataset, VQATorchDataset, VQAEvaluator
16
-
17
- DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
18
-
19
-
20
- def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
21
- dset = VQADataset(splits)
22
- tset = VQATorchDataset(dset)
23
- evaluator = VQAEvaluator(dset)
24
- data_loader = DataLoader(
25
- tset, batch_size=bs,
26
- shuffle=shuffle, num_workers=args.num_workers,
27
- drop_last=drop_last, pin_memory=True
28
- )
29
-
30
- return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
31
-
32
-
33
- class VQA:
34
- def __init__(self):
35
- # Datasets
36
- self.train_tuple = get_data_tuple(
37
- args.train, bs=args.batch_size, shuffle=True, drop_last=True
38
- )
39
- if args.valid != "":
40
- self.valid_tuple = get_data_tuple(
41
- args.valid, bs=1024,
42
- shuffle=False, drop_last=False
43
- )
44
- else:
45
- self.valid_tuple = None
46
-
47
- # Model
48
- self.model = VQAModel(self.train_tuple.dataset.num_answers)
49
-
50
- # Load pre-trained weights
51
- if args.load_lxmert is not None:
52
- self.model.lxrt_encoder.load(args.load_lxmert)
53
- if args.load_lxmert_qa is not None:
54
- load_lxmert_qa(args.load_lxmert_qa, self.model,
55
- label2ans=self.train_tuple.dataset.label2ans)
56
-
57
- # GPU options
58
- self.model = self.model.cuda()
59
- if args.multiGPU:
60
- self.model.lxrt_encoder.multi_gpu()
61
-
62
- # Loss and Optimizer
63
- self.bce_loss = nn.BCEWithLogitsLoss()
64
- if 'bert' in args.optim:
65
- batch_per_epoch = len(self.train_tuple.loader)
66
- t_total = int(batch_per_epoch * args.epochs)
67
- print("BertAdam Total Iters: %d" % t_total)
68
- from ..lxrt.optimization import BertAdam
69
- self.optim = BertAdam(list(self.model.parameters()),
70
- lr=args.lr,
71
- warmup=0.1,
72
- t_total=t_total)
73
- else:
74
- self.optim = args.optimizer(self.model.parameters(), args.lr)
75
-
76
- # Output Directory
77
- self.output = args.output
78
- os.makedirs(self.output, exist_ok=True)
79
-
80
- def train(self, train_tuple, eval_tuple):
81
- dset, loader, evaluator = train_tuple
82
- iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
83
-
84
- best_valid = 0.
85
- for epoch in range(args.epochs):
86
- quesid2ans = {}
87
- for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
88
-
89
- self.model.train()
90
- self.optim.zero_grad()
91
-
92
- feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
93
- logit = self.model(feats, boxes, sent)
94
- assert logit.dim() == target.dim() == 2
95
- loss = self.bce_loss(logit, target)
96
- loss = loss * logit.size(1)
97
-
98
- loss.backward()
99
- nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
100
- self.optim.step()
101
-
102
- score, label = logit.max(1)
103
- for qid, l in zip(ques_id, label.cpu().numpy()):
104
- ans = dset.label2ans[l]
105
- quesid2ans[qid.item()] = ans
106
-
107
- log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
108
-
109
- if self.valid_tuple is not None: # Do Validation
110
- valid_score = self.evaluate(eval_tuple)
111
- if valid_score > best_valid:
112
- best_valid = valid_score
113
- self.save("BEST")
114
-
115
- log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
116
- "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
117
-
118
- print(log_str, end='')
119
-
120
- with open(self.output + "/log.log", 'a') as f:
121
- f.write(log_str)
122
- f.flush()
123
-
124
- self.save("LAST")
125
-
126
- def predict(self, eval_tuple: DataTuple, dump=None):
127
- """
128
- Predict the answers to questions in a data split.
129
-
130
- :param eval_tuple: The data tuple to be evaluated.
131
- :param dump: The path of saved file to dump results.
132
- :return: A dict of question_id to answer.
133
- """
134
- self.model.eval()
135
- dset, loader, evaluator = eval_tuple
136
- quesid2ans = {}
137
- for i, datum_tuple in enumerate(loader):
138
- ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth
139
- with torch.no_grad():
140
- feats, boxes = feats.cuda(), boxes.cuda()
141
- logit = self.model(feats, boxes, sent)
142
- score, label = logit.max(1)
143
- for qid, l in zip(ques_id, label.cpu().numpy()):
144
- ans = dset.label2ans[l]
145
- quesid2ans[qid.item()] = ans
146
- if dump is not None:
147
- evaluator.dump_result(quesid2ans, dump)
148
- return quesid2ans
149
-
150
- def evaluate(self, eval_tuple: DataTuple, dump=None):
151
- """Evaluate all data in data_tuple."""
152
- quesid2ans = self.predict(eval_tuple, dump)
153
- return eval_tuple.evaluator.evaluate(quesid2ans)
154
-
155
- @staticmethod
156
- def oracle_score(data_tuple):
157
- dset, loader, evaluator = data_tuple
158
- quesid2ans = {}
159
- for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
160
- _, label = target.max(1)
161
- for qid, l in zip(ques_id, label.cpu().numpy()):
162
- ans = dset.label2ans[l]
163
- quesid2ans[qid.item()] = ans
164
- return evaluator.evaluate(quesid2ans)
165
-
166
- def save(self, name):
167
- torch.save(self.model.state_dict(),
168
- os.path.join(self.output, "%s.pth" % name))
169
-
170
- def load(self, path):
171
- print("Load model from %s" % path)
172
- state_dict = torch.load("%s.pth" % path)
173
- self.model.load_state_dict(state_dict)
174
-
175
-
176
- if __name__ == "__main__":
177
- # Build Class
178
- vqa = VQA()
179
-
180
- # Load VQA model weights
181
- # Note: It is different from loading lxmert pre-trained weights.
182
- if args.load is not None:
183
- vqa.load(args.load)
184
-
185
- # Test or Train
186
- if args.test is not None:
187
- args.fast = args.tiny = False # Always loading all data in test
188
- if 'test' in args.test:
189
- vqa.predict(
190
- get_data_tuple(args.test, bs=950,
191
- shuffle=False, drop_last=False),
192
- dump=os.path.join(args.output, 'test_predict.json')
193
- )
194
- elif 'val' in args.test:
195
- # Since part of valididation data are used in pre-training/fine-tuning,
196
- # only validate on the minival set.
197
- result = vqa.evaluate(
198
- get_data_tuple('minival', bs=950,
199
- shuffle=False, drop_last=False),
200
- dump=os.path.join(args.output, 'minival_predict.json')
201
- )
202
- print(result)
203
- else:
204
- assert False, "No such test option for %s" % args.test
205
- else:
206
- print('Splits in Train data:', vqa.train_tuple.dataset.splits)
207
- if vqa.valid_tuple is not None:
208
- print('Splits in Valid data:', vqa.valid_tuple.dataset.splits)
209
- print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100))
210
- else:
211
- print("DO NOT USE VALIDATION")
212
- vqa.train(vqa.train_tuple, vqa.valid_tuple)
213
-
214
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/vqa_data.py DELETED
@@ -1,188 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import json
5
- import os
6
- import pickle
7
-
8
- import numpy as np
9
- import torch
10
- from torch.utils.data import Dataset
11
-
12
- from ..param import args
13
- from ..utils import load_obj_tsv
14
-
15
- # Load part of the dataset for fast checking.
16
- # Notice that here is the number of images instead of the number of data,
17
- # which means all related data to the images would be used.
18
- TINY_IMG_NUM = 512
19
- FAST_IMG_NUM = 5000
20
-
21
- # The path to data and image features.
22
- VQA_DATA_ROOT = 'data/vqa/'
23
- MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/'
24
- SPLIT2NAME = {
25
- 'train': 'train2014',
26
- 'valid': 'val2014',
27
- 'minival': 'val2014',
28
- 'nominival': 'val2014',
29
- 'test': 'test2015',
30
- }
31
-
32
-
33
- class VQADataset:
34
- """
35
- A VQA data example in json file:
36
- {
37
- "answer_type": "other",
38
- "img_id": "COCO_train2014_000000458752",
39
- "label": {
40
- "net": 1
41
- },
42
- "question_id": 458752000,
43
- "question_type": "what is this",
44
- "sent": "What is this photo taken looking through?"
45
- }
46
- """
47
- def __init__(self, splits: str):
48
- self.name = splits
49
- self.splits = splits.split(',')
50
-
51
- # Loading datasets
52
- self.data = []
53
- for split in self.splits:
54
- self.data.extend(json.load(open("data/vqa/%s.json" % split)))
55
- print("Load %d data from split(s) %s." % (len(self.data), self.name))
56
-
57
- # Convert list to dict (for evaluation)
58
- self.id2datum = {
59
- datum['question_id']: datum
60
- for datum in self.data
61
- }
62
-
63
- # Answers
64
- self.ans2label = json.load(open("data/vqa/trainval_ans2label.json"))
65
- self.label2ans = json.load(open("data/vqa/trainval_label2ans.json"))
66
- assert len(self.ans2label) == len(self.label2ans)
67
-
68
- @property
69
- def num_answers(self):
70
- return len(self.ans2label)
71
-
72
- def __len__(self):
73
- return len(self.data)
74
-
75
-
76
- """
77
- An example in obj36 tsv:
78
- FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
79
- "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
80
- FIELDNAMES would be keys in the dict returned by load_obj_tsv.
81
- """
82
- class VQATorchDataset(Dataset):
83
- def __init__(self, dataset: VQADataset):
84
- super().__init__()
85
- self.raw_dataset = dataset
86
-
87
- if args.tiny:
88
- topk = TINY_IMG_NUM
89
- elif args.fast:
90
- topk = FAST_IMG_NUM
91
- else:
92
- topk = None
93
-
94
- # Loading detection features to img_data
95
- img_data = []
96
- for split in dataset.splits:
97
- # Minival is 5K images in MS COCO, which is used in evaluating VQA/lxmert-pre-training.
98
- # It is saved as the top 5K features in val2014_***.tsv
99
- load_topk = 5000 if (split == 'minival' and topk is None) else topk
100
- img_data.extend(load_obj_tsv(
101
- os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
102
- topk=load_topk))
103
-
104
- # Convert img list to dict
105
- self.imgid2img = {}
106
- for img_datum in img_data:
107
- self.imgid2img[img_datum['img_id']] = img_datum
108
-
109
- # Only kept the data with loaded image features
110
- self.data = []
111
- for datum in self.raw_dataset.data:
112
- if datum['img_id'] in self.imgid2img:
113
- self.data.append(datum)
114
- print("Use %d data in torch dataset" % (len(self.data)))
115
- print()
116
-
117
- def __len__(self):
118
- return len(self.data)
119
-
120
- def __getitem__(self, item: int):
121
- datum = self.data[item]
122
-
123
- img_id = datum['img_id']
124
- ques_id = datum['question_id']
125
- ques = datum['sent']
126
-
127
- # Get image info
128
- img_info = self.imgid2img[img_id]
129
- obj_num = img_info['num_boxes']
130
- feats = img_info['features'].copy()
131
- boxes = img_info['boxes'].copy()
132
- assert obj_num == len(boxes) == len(feats)
133
-
134
- # Normalize the boxes (to 0 ~ 1)
135
- img_h, img_w = img_info['img_h'], img_info['img_w']
136
- boxes = boxes.copy()
137
- boxes[:, (0, 2)] /= img_w
138
- boxes[:, (1, 3)] /= img_h
139
- np.testing.assert_array_less(boxes, 1+1e-5)
140
- np.testing.assert_array_less(-boxes, 0+1e-5)
141
-
142
- # Provide label (target)
143
- if 'label' in datum:
144
- label = datum['label']
145
- target = torch.zeros(self.raw_dataset.num_answers)
146
- for ans, score in label.items():
147
- target[self.raw_dataset.ans2label[ans]] = score
148
- return ques_id, feats, boxes, ques, target
149
- else:
150
- return ques_id, feats, boxes, ques
151
-
152
-
153
- class VQAEvaluator:
154
- def __init__(self, dataset: VQADataset):
155
- self.dataset = dataset
156
-
157
- def evaluate(self, quesid2ans: dict):
158
- score = 0.
159
- for quesid, ans in quesid2ans.items():
160
- datum = self.dataset.id2datum[quesid]
161
- label = datum['label']
162
- if ans in label:
163
- score += label[ans]
164
- return score / len(quesid2ans)
165
-
166
- def dump_result(self, quesid2ans: dict, path):
167
- """
168
- Dump results to a json file, which could be submitted to the VQA online evaluation.
169
- VQA json file submission requirement:
170
- results = [result]
171
- result = {
172
- "question_id": int,
173
- "answer": str
174
- }
175
-
176
- :param quesid2ans: dict of quesid --> ans
177
- :param path: The desired path of saved file.
178
- """
179
- with open(path, 'w') as f:
180
- result = []
181
- for ques_id, ans in quesid2ans.items():
182
- result.append({
183
- 'question_id': ques_id,
184
- 'answer': ans
185
- })
186
- json.dump(result, f, indent=4, sort_keys=True)
187
-
188
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lxmert/src/tasks/vqa_model.py DELETED
@@ -1,50 +0,0 @@
1
- # coding=utf-8
2
- # Copyleft 2019 project LXRT.
3
-
4
- import torch.nn as nn
5
-
6
- from ..param import args
7
- from ..lxrt.entry import LXRTEncoder
8
- from ..lxrt.modeling import BertLayerNorm, GeLU
9
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
10
-
11
- # Max length including <bos> and <eos>
12
- MAX_VQA_LENGTH = 20
13
-
14
-
15
- class VQAModel(nn.Module):
16
- def __init__(self, num_answers):
17
- super().__init__()
18
-
19
- # # Build LXRT encoder
20
- # self.lxrt_encoder = LXRTEncoder(
21
- # args,
22
- # max_seq_length=MAX_VQA_LENGTH
23
- # )
24
- # hid_dim = self.lxrt_encoder.dim
25
- #
26
- # # VQA Answer heads
27
- # self.logit_fc = nn.Sequential(
28
- # nn.Linear(hid_dim, hid_dim * 2),
29
- # GeLU(),
30
- # BertLayerNorm(hid_dim * 2, eps=1e-12),
31
- # nn.Linear(hid_dim * 2, num_answers)
32
- # )
33
- # self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
34
-
35
- self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased")
36
- self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased")
37
-
38
- def forward(self, feat, pos, sent):
39
- """
40
- b -- batch_size, o -- object_number, f -- visual_feature_size
41
-
42
- :param feat: (b, o, f)
43
- :param pos: (b, o, 4)
44
- :param sent: (b,) Type -- list of string
45
- :param leng: (b,) Type -- int numpy array
46
- :return: (b, num_answer) The logit of each answers.
47
- """
48
- return self.model(sent, feat, pos)
49
-
50
-