Spaces:
Sleeping
Sleeping
Delete lxmert/src/tasks
Browse files- lxmert/src/tasks/__init__.py +0 -0
- lxmert/src/tasks/gqa.py +0 -210
- lxmert/src/tasks/gqa_data.py +0 -194
- lxmert/src/tasks/gqa_model.py +0 -45
- lxmert/src/tasks/nlvr2.py +0 -182
- lxmert/src/tasks/nlvr2_data.py +0 -157
- lxmert/src/tasks/nlvr2_model.py +0 -55
- lxmert/src/tasks/vqa.py +0 -214
- lxmert/src/tasks/vqa_data.py +0 -188
- lxmert/src/tasks/vqa_model.py +0 -50
lxmert/src/tasks/__init__.py
DELETED
File without changes
|
lxmert/src/tasks/gqa.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import os
|
5 |
-
import collections
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from tqdm import tqdm
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.utils.data.dataloader import DataLoader
|
11 |
-
|
12 |
-
from param import args
|
13 |
-
from pretrain.qa_answer_table import load_lxmert_qa
|
14 |
-
from tasks.gqa_model import GQAModel
|
15 |
-
from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
|
16 |
-
|
17 |
-
|
18 |
-
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
|
19 |
-
|
20 |
-
|
21 |
-
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
|
22 |
-
dset = GQADataset(splits)
|
23 |
-
tset = GQATorchDataset(dset)
|
24 |
-
evaluator = GQAEvaluator(dset)
|
25 |
-
data_loader = DataLoader(
|
26 |
-
tset, batch_size=bs,
|
27 |
-
shuffle=shuffle, num_workers=args.num_workers,
|
28 |
-
drop_last=drop_last, pin_memory=True
|
29 |
-
)
|
30 |
-
|
31 |
-
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
|
32 |
-
|
33 |
-
|
34 |
-
class GQA:
|
35 |
-
def __init__(self):
|
36 |
-
self.train_tuple = get_tuple(
|
37 |
-
args.train, bs=args.batch_size, shuffle=True, drop_last=True
|
38 |
-
)
|
39 |
-
if args.valid != "":
|
40 |
-
valid_bsize = 2048 if args.multiGPU else 512
|
41 |
-
self.valid_tuple = get_tuple(
|
42 |
-
args.valid, bs=valid_bsize,
|
43 |
-
shuffle=False, drop_last=False
|
44 |
-
)
|
45 |
-
else:
|
46 |
-
self.valid_tuple = None
|
47 |
-
|
48 |
-
self.model = GQAModel(self.train_tuple.dataset.num_answers)
|
49 |
-
|
50 |
-
# Load pre-trained weights
|
51 |
-
if args.load_lxmert is not None:
|
52 |
-
self.model.lxrt_encoder.load(args.load_lxmert)
|
53 |
-
if args.load_lxmert_qa is not None:
|
54 |
-
load_lxmert_qa(args.load_lxmert_qa, self.model,
|
55 |
-
label2ans=self.train_tuple.dataset.label2ans)
|
56 |
-
|
57 |
-
# GPU options
|
58 |
-
self.model = self.model.cuda()
|
59 |
-
if args.multiGPU:
|
60 |
-
self.model.lxrt_encoder.multi_gpu()
|
61 |
-
|
62 |
-
# Losses and optimizer
|
63 |
-
self.bce_loss = nn.BCEWithLogitsLoss()
|
64 |
-
self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
|
65 |
-
if 'bert' in args.optim:
|
66 |
-
batch_per_epoch = len(self.train_tuple.loader)
|
67 |
-
t_total = int(batch_per_epoch * args.epochs)
|
68 |
-
print("Total Iters: %d" % t_total)
|
69 |
-
from lxrt.optimization import BertAdam
|
70 |
-
self.optim = BertAdam(list(self.model.parameters()),
|
71 |
-
lr=args.lr,
|
72 |
-
warmup=0.1,
|
73 |
-
t_total=t_total)
|
74 |
-
else:
|
75 |
-
self.optim = args.optimizer(list(self.model.parameters()), args.lr)
|
76 |
-
|
77 |
-
self.output = args.output
|
78 |
-
os.makedirs(self.output, exist_ok=True)
|
79 |
-
|
80 |
-
def train(self, train_tuple, eval_tuple):
|
81 |
-
dset, loader, evaluator = train_tuple
|
82 |
-
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
|
83 |
-
|
84 |
-
best_valid = 0.
|
85 |
-
for epoch in range(args.epochs):
|
86 |
-
quesid2ans = {}
|
87 |
-
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
|
88 |
-
|
89 |
-
self.model.train()
|
90 |
-
self.optim.zero_grad()
|
91 |
-
|
92 |
-
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
|
93 |
-
logit = self.model(feats, boxes, sent)
|
94 |
-
assert logit.dim() == target.dim() == 2
|
95 |
-
if args.mce_loss:
|
96 |
-
max_value, target = target.max(1)
|
97 |
-
loss = self.mce_loss(logit, target) * logit.size(1)
|
98 |
-
else:
|
99 |
-
loss = self.bce_loss(logit, target)
|
100 |
-
loss = loss * logit.size(1)
|
101 |
-
|
102 |
-
loss.backward()
|
103 |
-
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
104 |
-
self.optim.step()
|
105 |
-
|
106 |
-
score, label = logit.max(1)
|
107 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
108 |
-
ans = dset.label2ans[l]
|
109 |
-
quesid2ans[qid] = ans
|
110 |
-
|
111 |
-
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
|
112 |
-
|
113 |
-
if self.valid_tuple is not None: # Do Validation
|
114 |
-
valid_score = self.evaluate(eval_tuple)
|
115 |
-
if valid_score > best_valid:
|
116 |
-
best_valid = valid_score
|
117 |
-
self.save("BEST")
|
118 |
-
|
119 |
-
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
|
120 |
-
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
|
121 |
-
|
122 |
-
print(log_str, end='')
|
123 |
-
|
124 |
-
with open(self.output + "/log.log", 'a') as f:
|
125 |
-
f.write(log_str)
|
126 |
-
f.flush()
|
127 |
-
|
128 |
-
self.save("LAST")
|
129 |
-
|
130 |
-
def predict(self, eval_tuple: DataTuple, dump=None):
|
131 |
-
self.model.eval()
|
132 |
-
dset, loader, evaluator = eval_tuple
|
133 |
-
quesid2ans = {}
|
134 |
-
for i, datum_tuple in enumerate(loader):
|
135 |
-
ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
|
136 |
-
with torch.no_grad():
|
137 |
-
feats, boxes = feats.cuda(), boxes.cuda()
|
138 |
-
logit = self.model(feats, boxes, sent)
|
139 |
-
score, label = logit.max(1)
|
140 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
141 |
-
ans = dset.label2ans[l]
|
142 |
-
quesid2ans[qid] = ans
|
143 |
-
if dump is not None:
|
144 |
-
evaluator.dump_result(quesid2ans, dump)
|
145 |
-
return quesid2ans
|
146 |
-
|
147 |
-
def evaluate(self, eval_tuple: DataTuple, dump=None):
|
148 |
-
dset, loader, evaluator = eval_tuple
|
149 |
-
quesid2ans = self.predict(eval_tuple, dump)
|
150 |
-
return evaluator.evaluate(quesid2ans)
|
151 |
-
|
152 |
-
@staticmethod
|
153 |
-
def oracle_score(data_tuple):
|
154 |
-
dset, loader, evaluator = data_tuple
|
155 |
-
quesid2ans = {}
|
156 |
-
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
|
157 |
-
_, label = target.max(1)
|
158 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
159 |
-
ans = dset.label2ans[l]
|
160 |
-
quesid2ans[qid] = ans
|
161 |
-
return evaluator.evaluate(quesid2ans)
|
162 |
-
|
163 |
-
def save(self, name):
|
164 |
-
torch.save(self.model.state_dict(),
|
165 |
-
os.path.join(self.output, "%s.pth" % name))
|
166 |
-
|
167 |
-
def load(self, path):
|
168 |
-
print("Load model from %s" % path)
|
169 |
-
state_dict = torch.load("%s.pth" % path)
|
170 |
-
for key in list(state_dict.keys()):
|
171 |
-
if '.module' in key:
|
172 |
-
state_dict[key.replace('.module', '')] = state_dict.pop(key)
|
173 |
-
self.model.load_state_dict(state_dict, strict=False)
|
174 |
-
|
175 |
-
|
176 |
-
if __name__ == "__main__":
|
177 |
-
# Build Class
|
178 |
-
gqa = GQA()
|
179 |
-
|
180 |
-
# Load Model
|
181 |
-
if args.load is not None:
|
182 |
-
gqa.load(args.load)
|
183 |
-
|
184 |
-
# Test or Train
|
185 |
-
if args.test is not None:
|
186 |
-
args.fast = args.tiny = False # Always loading all data in test
|
187 |
-
if 'submit' in args.test:
|
188 |
-
gqa.predict(
|
189 |
-
get_tuple(args.test, bs=args.batch_size,
|
190 |
-
shuffle=False, drop_last=False),
|
191 |
-
dump=os.path.join(args.output, 'submit_predict.json')
|
192 |
-
)
|
193 |
-
if 'testdev' in args.test:
|
194 |
-
result = gqa.evaluate(
|
195 |
-
get_tuple('testdev', bs=args.batch_size,
|
196 |
-
shuffle=False, drop_last=False),
|
197 |
-
dump=os.path.join(args.output, 'testdev_predict.json')
|
198 |
-
)
|
199 |
-
print(result)
|
200 |
-
else:
|
201 |
-
# print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
|
202 |
-
print('Splits in Train data:', gqa.train_tuple.dataset.splits)
|
203 |
-
if gqa.valid_tuple is not None:
|
204 |
-
print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
|
205 |
-
print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
|
206 |
-
else:
|
207 |
-
print("DO NOT USE VALIDATION")
|
208 |
-
gqa.train(gqa.train_tuple, gqa.valid_tuple)
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/gqa_data.py
DELETED
@@ -1,194 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import json
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import Dataset
|
9 |
-
|
10 |
-
from param import args
|
11 |
-
from utils import load_obj_tsv
|
12 |
-
|
13 |
-
# Load part of the dataset for fast checking.
|
14 |
-
# Notice that here is the number of images instead of the number of data,
|
15 |
-
# which means all related data to the images would be used.
|
16 |
-
TINY_IMG_NUM = 512
|
17 |
-
FAST_IMG_NUM = 5000
|
18 |
-
|
19 |
-
|
20 |
-
class GQADataset:
|
21 |
-
"""
|
22 |
-
A GQA data example in json file:
|
23 |
-
{
|
24 |
-
"img_id": "2375429",
|
25 |
-
"label": {
|
26 |
-
"pipe": 1.0
|
27 |
-
},
|
28 |
-
"question_id": "07333408",
|
29 |
-
"sent": "What is on the white wall?"
|
30 |
-
}
|
31 |
-
"""
|
32 |
-
def __init__(self, splits: str):
|
33 |
-
self.name = splits
|
34 |
-
self.splits = splits.split(',')
|
35 |
-
|
36 |
-
# Loading datasets to data
|
37 |
-
self.data = []
|
38 |
-
for split in self.splits:
|
39 |
-
self.data.extend(json.load(open("data/gqa/%s.json" % split)))
|
40 |
-
print("Load %d data from split(s) %s." % (len(self.data), self.name))
|
41 |
-
|
42 |
-
# List to dict (for evaluation and others)
|
43 |
-
self.id2datum = {
|
44 |
-
datum['question_id']: datum
|
45 |
-
for datum in self.data
|
46 |
-
}
|
47 |
-
|
48 |
-
# Answers
|
49 |
-
self.ans2label = json.load(open("data/gqa/trainval_ans2label.json"))
|
50 |
-
self.label2ans = json.load(open("data/gqa/trainval_label2ans.json"))
|
51 |
-
assert len(self.ans2label) == len(self.label2ans)
|
52 |
-
for ans, label in self.ans2label.items():
|
53 |
-
assert self.label2ans[label] == ans
|
54 |
-
|
55 |
-
@property
|
56 |
-
def num_answers(self):
|
57 |
-
return len(self.ans2label)
|
58 |
-
|
59 |
-
def __len__(self):
|
60 |
-
return len(self.data)
|
61 |
-
|
62 |
-
|
63 |
-
class GQABufferLoader():
|
64 |
-
def __init__(self):
|
65 |
-
self.key2data = {}
|
66 |
-
|
67 |
-
def load_data(self, name, number):
|
68 |
-
if name == 'testdev':
|
69 |
-
path = "data/vg_gqa_imgfeat/gqa_testdev_obj36.tsv"
|
70 |
-
else:
|
71 |
-
path = "data/vg_gqa_imgfeat/vg_gqa_obj36.tsv"
|
72 |
-
key = "%s_%d" % (path, number)
|
73 |
-
if key not in self.key2data:
|
74 |
-
self.key2data[key] = load_obj_tsv(
|
75 |
-
path,
|
76 |
-
topk=number
|
77 |
-
)
|
78 |
-
return self.key2data[key]
|
79 |
-
|
80 |
-
|
81 |
-
gqa_buffer_loader = GQABufferLoader()
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
-
Example in obj tsv:
|
86 |
-
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
|
87 |
-
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
|
88 |
-
"""
|
89 |
-
class GQATorchDataset(Dataset):
|
90 |
-
def __init__(self, dataset: GQADataset):
|
91 |
-
super().__init__()
|
92 |
-
self.raw_dataset = dataset
|
93 |
-
|
94 |
-
if args.tiny:
|
95 |
-
topk = TINY_IMG_NUM
|
96 |
-
elif args.fast:
|
97 |
-
topk = FAST_IMG_NUM
|
98 |
-
else:
|
99 |
-
topk = -1
|
100 |
-
|
101 |
-
# Loading detection features to img_data
|
102 |
-
# Since images in train and valid both come from Visual Genome,
|
103 |
-
# buffer the image loading to save memory.
|
104 |
-
img_data = []
|
105 |
-
if 'testdev' in dataset.splits or 'testdev_all' in dataset.splits: # Always loading all the data in testdev
|
106 |
-
img_data.extend(gqa_buffer_loader.load_data('testdev', -1))
|
107 |
-
else:
|
108 |
-
img_data.extend(gqa_buffer_loader.load_data('train', topk))
|
109 |
-
self.imgid2img = {}
|
110 |
-
for img_datum in img_data:
|
111 |
-
self.imgid2img[img_datum['img_id']] = img_datum
|
112 |
-
|
113 |
-
# Only kept the data with loaded image features
|
114 |
-
self.data = []
|
115 |
-
for datum in self.raw_dataset.data:
|
116 |
-
if datum['img_id'] in self.imgid2img:
|
117 |
-
self.data.append(datum)
|
118 |
-
print("Use %d data in torch dataset" % (len(self.data)))
|
119 |
-
print()
|
120 |
-
|
121 |
-
def __len__(self):
|
122 |
-
return len(self.data)
|
123 |
-
|
124 |
-
def __getitem__(self, item: int):
|
125 |
-
datum = self.data[item]
|
126 |
-
|
127 |
-
img_id = datum['img_id']
|
128 |
-
ques_id = datum['question_id']
|
129 |
-
ques = datum['sent']
|
130 |
-
|
131 |
-
# Get image info
|
132 |
-
img_info = self.imgid2img[img_id]
|
133 |
-
obj_num = img_info['num_boxes']
|
134 |
-
boxes = img_info['boxes'].copy()
|
135 |
-
feats = img_info['features'].copy()
|
136 |
-
assert len(boxes) == len(feats) == obj_num
|
137 |
-
|
138 |
-
# Normalize the boxes (to 0 ~ 1)
|
139 |
-
img_h, img_w = img_info['img_h'], img_info['img_w']
|
140 |
-
boxes = boxes.copy()
|
141 |
-
boxes[:, (0, 2)] /= img_w
|
142 |
-
boxes[:, (1, 3)] /= img_h
|
143 |
-
np.testing.assert_array_less(boxes, 1+1e-5)
|
144 |
-
np.testing.assert_array_less(-boxes, 0+1e-5)
|
145 |
-
|
146 |
-
# Create target
|
147 |
-
if 'label' in datum:
|
148 |
-
label = datum['label']
|
149 |
-
target = torch.zeros(self.raw_dataset.num_answers)
|
150 |
-
for ans, score in label.items():
|
151 |
-
if ans in self.raw_dataset.ans2label:
|
152 |
-
target[self.raw_dataset.ans2label[ans]] = score
|
153 |
-
return ques_id, feats, boxes, ques, target
|
154 |
-
else:
|
155 |
-
return ques_id, feats, boxes, ques
|
156 |
-
|
157 |
-
|
158 |
-
class GQAEvaluator:
|
159 |
-
def __init__(self, dataset: GQADataset):
|
160 |
-
self.dataset = dataset
|
161 |
-
|
162 |
-
def evaluate(self, quesid2ans: dict):
|
163 |
-
score = 0.
|
164 |
-
for quesid, ans in quesid2ans.items():
|
165 |
-
datum = self.dataset.id2datum[quesid]
|
166 |
-
label = datum['label']
|
167 |
-
if ans in label:
|
168 |
-
score += label[ans]
|
169 |
-
return score / len(quesid2ans)
|
170 |
-
|
171 |
-
def dump_result(self, quesid2ans: dict, path):
|
172 |
-
"""
|
173 |
-
Dump the result to a GQA-challenge submittable json file.
|
174 |
-
GQA json file submission requirement:
|
175 |
-
results = [result]
|
176 |
-
result = {
|
177 |
-
"questionId": str, # Note: it's a actually an int number but the server requires an str.
|
178 |
-
"prediction": str
|
179 |
-
}
|
180 |
-
|
181 |
-
:param quesid2ans: A dict mapping question id to its predicted answer.
|
182 |
-
:param path: The file path to save the json file.
|
183 |
-
:return:
|
184 |
-
"""
|
185 |
-
with open(path, 'w') as f:
|
186 |
-
result = []
|
187 |
-
for ques_id, ans in quesid2ans.items():
|
188 |
-
result.append({
|
189 |
-
'questionId': ques_id,
|
190 |
-
'prediction': ans
|
191 |
-
})
|
192 |
-
json.dump(result, f, indent=4, sort_keys=True)
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/gqa_model.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
from param import args
|
7 |
-
from lxrt.entry import LXRTEncoder
|
8 |
-
from lxrt.modeling import BertLayerNorm, GeLU
|
9 |
-
|
10 |
-
# Max length including <bos> and <eos>
|
11 |
-
MAX_GQA_LENGTH = 20
|
12 |
-
|
13 |
-
|
14 |
-
class GQAModel(nn.Module):
|
15 |
-
def __init__(self, num_answers):
|
16 |
-
super().__init__()
|
17 |
-
self.lxrt_encoder = LXRTEncoder(
|
18 |
-
args,
|
19 |
-
max_seq_length=MAX_GQA_LENGTH
|
20 |
-
)
|
21 |
-
hid_dim = self.lxrt_encoder.dim
|
22 |
-
self.logit_fc = nn.Sequential(
|
23 |
-
nn.Linear(hid_dim, hid_dim * 2),
|
24 |
-
GeLU(),
|
25 |
-
BertLayerNorm(hid_dim * 2, eps=1e-12),
|
26 |
-
nn.Linear(hid_dim * 2, num_answers)
|
27 |
-
)
|
28 |
-
self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
|
29 |
-
|
30 |
-
def forward(self, feat, pos, sent):
|
31 |
-
"""
|
32 |
-
b -- batch_size, o -- object_number, f -- visual_feature_size
|
33 |
-
|
34 |
-
:param feat: (b, o, f)
|
35 |
-
:param pos: (b, o, 4)
|
36 |
-
:param sent: (b,) Type -- list of string
|
37 |
-
:param leng: (b,) Type -- int numpy array
|
38 |
-
:return: (b, num_answer) The logit of each answers.
|
39 |
-
"""
|
40 |
-
x = self.lxrt_encoder(sent, (feat, pos))
|
41 |
-
logit = self.logit_fc(x)
|
42 |
-
|
43 |
-
return logit
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/nlvr2.py
DELETED
@@ -1,182 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import os
|
5 |
-
import collections
|
6 |
-
|
7 |
-
from tqdm import tqdm
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.utils.data.dataloader import DataLoader
|
11 |
-
|
12 |
-
from param import args
|
13 |
-
from tasks.nlvr2_model import NLVR2Model
|
14 |
-
from tasks.nlvr2_data import NLVR2Dataset, NLVR2TorchDataset, NLVR2Evaluator
|
15 |
-
|
16 |
-
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
|
17 |
-
|
18 |
-
|
19 |
-
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
|
20 |
-
dset = NLVR2Dataset(splits)
|
21 |
-
tset = NLVR2TorchDataset(dset)
|
22 |
-
evaluator = NLVR2Evaluator(dset)
|
23 |
-
data_loader = DataLoader(
|
24 |
-
tset, batch_size=bs,
|
25 |
-
shuffle=shuffle, num_workers=args.num_workers,
|
26 |
-
drop_last=drop_last, pin_memory=True
|
27 |
-
)
|
28 |
-
|
29 |
-
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
|
30 |
-
|
31 |
-
|
32 |
-
class NLVR2:
|
33 |
-
def __init__(self):
|
34 |
-
self.train_tuple = get_tuple(
|
35 |
-
args.train, bs=args.batch_size, shuffle=True, drop_last=True
|
36 |
-
)
|
37 |
-
if args.valid != "":
|
38 |
-
valid_bsize = 2048 if args.multiGPU else 512
|
39 |
-
self.valid_tuple = get_tuple(
|
40 |
-
args.valid, bs=valid_bsize,
|
41 |
-
shuffle=False, drop_last=False
|
42 |
-
)
|
43 |
-
else:
|
44 |
-
self.valid_tuple = None
|
45 |
-
|
46 |
-
self.model = NLVR2Model()
|
47 |
-
|
48 |
-
# Load pre-trained weights
|
49 |
-
if args.load_lxmert is not None:
|
50 |
-
self.model.lxrt_encoder.load(args.load_lxmert)
|
51 |
-
|
52 |
-
# GPU options
|
53 |
-
if args.multiGPU:
|
54 |
-
self.model.lxrt_encoder.multi_gpu()
|
55 |
-
self.model = self.model.cuda()
|
56 |
-
|
57 |
-
# Losses and optimizer
|
58 |
-
self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
|
59 |
-
if 'bert' in args.optim:
|
60 |
-
batch_per_epoch = len(self.train_tuple.loader)
|
61 |
-
t_total = int(batch_per_epoch * args.epochs)
|
62 |
-
print("Total Iters: %d" % t_total)
|
63 |
-
from lxrt.optimization import BertAdam
|
64 |
-
self.optim = BertAdam(list(self.model.parameters()),
|
65 |
-
lr=args.lr,
|
66 |
-
warmup=0.1,
|
67 |
-
t_total=t_total)
|
68 |
-
else:
|
69 |
-
self.optim = args.optimizer(list(self.model.parameters()), args.lr)
|
70 |
-
|
71 |
-
self.output = args.output
|
72 |
-
os.makedirs(self.output, exist_ok=True)
|
73 |
-
|
74 |
-
def train(self, train_tuple, eval_tuple):
|
75 |
-
dset, loader, evaluator = train_tuple
|
76 |
-
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
|
77 |
-
|
78 |
-
best_valid = 0.
|
79 |
-
for epoch in range(args.epochs):
|
80 |
-
quesid2ans = {}
|
81 |
-
for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)):
|
82 |
-
self.model.train()
|
83 |
-
|
84 |
-
self.optim.zero_grad()
|
85 |
-
feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda()
|
86 |
-
logit = self.model(feats, boxes, sent)
|
87 |
-
|
88 |
-
loss = self.mce_loss(logit, label)
|
89 |
-
|
90 |
-
loss.backward()
|
91 |
-
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
92 |
-
self.optim.step()
|
93 |
-
|
94 |
-
score, predict = logit.max(1)
|
95 |
-
for qid, l in zip(ques_id, predict.cpu().numpy()):
|
96 |
-
quesid2ans[qid] = l
|
97 |
-
|
98 |
-
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
|
99 |
-
|
100 |
-
if self.valid_tuple is not None: # Do Validation
|
101 |
-
valid_score = self.evaluate(eval_tuple)
|
102 |
-
if valid_score > best_valid:
|
103 |
-
best_valid = valid_score
|
104 |
-
self.save("BEST")
|
105 |
-
|
106 |
-
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
|
107 |
-
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
|
108 |
-
|
109 |
-
print(log_str, end='')
|
110 |
-
|
111 |
-
with open(self.output + "/log.log", 'a') as f:
|
112 |
-
f.write(log_str)
|
113 |
-
f.flush()
|
114 |
-
|
115 |
-
self.save("LAST")
|
116 |
-
|
117 |
-
def predict(self, eval_tuple: DataTuple, dump=None):
|
118 |
-
self.model.eval()
|
119 |
-
dset, loader, evaluator = eval_tuple
|
120 |
-
quesid2ans = {}
|
121 |
-
for i, datum_tuple in enumerate(loader):
|
122 |
-
ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
|
123 |
-
with torch.no_grad():
|
124 |
-
feats, boxes = feats.cuda(), boxes.cuda()
|
125 |
-
logit = self.model(feats, boxes, sent)
|
126 |
-
score, predict = logit.max(1)
|
127 |
-
for qid, l in zip(ques_id, predict.cpu().numpy()):
|
128 |
-
quesid2ans[qid] = l
|
129 |
-
if dump is not None:
|
130 |
-
evaluator.dump_result(quesid2ans, dump)
|
131 |
-
return quesid2ans
|
132 |
-
|
133 |
-
def evaluate(self, eval_tuple: DataTuple, dump=None):
|
134 |
-
dset, loader, evaluator = eval_tuple
|
135 |
-
quesid2ans = self.predict(eval_tuple, dump)
|
136 |
-
return evaluator.evaluate(quesid2ans)
|
137 |
-
|
138 |
-
def save(self, name):
|
139 |
-
torch.save(self.model.state_dict(),
|
140 |
-
os.path.join(self.output, "%s.pth" % name))
|
141 |
-
|
142 |
-
def load(self, path):
|
143 |
-
print("Load model from %s" % path)
|
144 |
-
state_dict = torch.load("%s.pth" % path)
|
145 |
-
self.model.load_state_dict(state_dict)
|
146 |
-
|
147 |
-
|
148 |
-
if __name__ == "__main__":
|
149 |
-
# Build Class
|
150 |
-
nlvr2 = NLVR2()
|
151 |
-
|
152 |
-
# Load Model
|
153 |
-
if args.load is not None:
|
154 |
-
nlvr2.load(args.load)
|
155 |
-
|
156 |
-
# Test or Train
|
157 |
-
if args.test is not None:
|
158 |
-
args.fast = args.tiny = False # Always loading all data in test
|
159 |
-
if 'hidden' in args.test:
|
160 |
-
nlvr2.predict(
|
161 |
-
get_tuple(args.test, bs=args.batch_size,
|
162 |
-
shuffle=False, drop_last=False),
|
163 |
-
dump=os.path.join(args.output, 'hidden_predict.csv')
|
164 |
-
)
|
165 |
-
elif 'test' in args.test or 'valid' in args.test:
|
166 |
-
result = nlvr2.evaluate(
|
167 |
-
get_tuple(args.test, bs=args.batch_size,
|
168 |
-
shuffle=False, drop_last=False),
|
169 |
-
dump=os.path.join(args.output, '%s_predict.csv' % args.test)
|
170 |
-
)
|
171 |
-
print(result)
|
172 |
-
else:
|
173 |
-
assert False, "No such test option for %s" % args.test
|
174 |
-
else:
|
175 |
-
print('Splits in Train data:', nlvr2.train_tuple.dataset.splits)
|
176 |
-
if nlvr2.valid_tuple is not None:
|
177 |
-
print('Splits in Valid data:', nlvr2.valid_tuple.dataset.splits)
|
178 |
-
else:
|
179 |
-
print("DO NOT USE VALIDATION")
|
180 |
-
nlvr2.train(nlvr2.train_tuple, nlvr2.valid_tuple)
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/nlvr2_data.py
DELETED
@@ -1,157 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import json
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from torch.utils.data import Dataset
|
8 |
-
|
9 |
-
from param import args
|
10 |
-
from utils import load_obj_tsv
|
11 |
-
|
12 |
-
# Load part of the dataset for fast checking.
|
13 |
-
# Notice that here is the number of images instead of the number of data,
|
14 |
-
# which means all related data to the images would be used.
|
15 |
-
TINY_IMG_NUM = 512
|
16 |
-
FAST_IMG_NUM = 5000
|
17 |
-
|
18 |
-
|
19 |
-
class NLVR2Dataset:
|
20 |
-
"""
|
21 |
-
An NLVR2 data example in json file:
|
22 |
-
{
|
23 |
-
"identifier": "train-10171-0-0",
|
24 |
-
"img0": "train-10171-0-img0",
|
25 |
-
"img1": "train-10171-0-img1",
|
26 |
-
"label": 0,
|
27 |
-
"sent": "An image shows one leather pencil case, displayed open with writing implements tucked inside.
|
28 |
-
",
|
29 |
-
"uid": "nlvr2_train_0"
|
30 |
-
}
|
31 |
-
"""
|
32 |
-
def __init__(self, splits: str):
|
33 |
-
self.name = splits
|
34 |
-
self.splits = splits.split(',')
|
35 |
-
|
36 |
-
# Loading datasets to data
|
37 |
-
self.data = []
|
38 |
-
for split in self.splits:
|
39 |
-
self.data.extend(json.load(open("data/nlvr2/%s.json" % split)))
|
40 |
-
print("Load %d data from split(s) %s." % (len(self.data), self.name))
|
41 |
-
|
42 |
-
# List to dict (for evaluation and others)
|
43 |
-
self.id2datum = {
|
44 |
-
datum['uid']: datum
|
45 |
-
for datum in self.data
|
46 |
-
}
|
47 |
-
|
48 |
-
def __len__(self):
|
49 |
-
return len(self.data)
|
50 |
-
|
51 |
-
|
52 |
-
"""
|
53 |
-
An example in obj36 tsv:
|
54 |
-
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
|
55 |
-
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
|
56 |
-
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
|
57 |
-
"""
|
58 |
-
class NLVR2TorchDataset(Dataset):
|
59 |
-
def __init__(self, dataset: NLVR2Dataset):
|
60 |
-
super().__init__()
|
61 |
-
self.raw_dataset = dataset
|
62 |
-
|
63 |
-
if args.tiny:
|
64 |
-
topk = TINY_IMG_NUM
|
65 |
-
elif args.fast:
|
66 |
-
topk = FAST_IMG_NUM
|
67 |
-
else:
|
68 |
-
topk = -1
|
69 |
-
|
70 |
-
# Loading detection features to img_data
|
71 |
-
img_data = []
|
72 |
-
if 'train' in dataset.splits:
|
73 |
-
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/train_obj36.tsv', topk=topk))
|
74 |
-
if 'valid' in dataset.splits:
|
75 |
-
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/valid_obj36.tsv', topk=topk))
|
76 |
-
if 'test' in dataset.name:
|
77 |
-
img_data.extend(load_obj_tsv('data/nlvr2_imgfeat/test_obj36.tsv', topk=topk))
|
78 |
-
self.imgid2img = {}
|
79 |
-
for img_datum in img_data:
|
80 |
-
self.imgid2img[img_datum['img_id']] = img_datum
|
81 |
-
|
82 |
-
# Filter out the dataset
|
83 |
-
self.data = []
|
84 |
-
for datum in self.raw_dataset.data:
|
85 |
-
if datum['img0'] in self.imgid2img and datum['img1'] in self.imgid2img:
|
86 |
-
self.data.append(datum)
|
87 |
-
print("Use %d data in torch dataset" % (len(self.data)))
|
88 |
-
print()
|
89 |
-
|
90 |
-
def __len__(self):
|
91 |
-
return len(self.data)
|
92 |
-
|
93 |
-
def __getitem__(self, item: int):
|
94 |
-
datum = self.data[item]
|
95 |
-
|
96 |
-
ques_id = datum['uid']
|
97 |
-
ques = datum['sent']
|
98 |
-
|
99 |
-
# Get image info
|
100 |
-
boxes2 = []
|
101 |
-
feats2 = []
|
102 |
-
for key in ['img0', 'img1']:
|
103 |
-
img_id = datum[key]
|
104 |
-
img_info = self.imgid2img[img_id]
|
105 |
-
boxes = img_info['boxes'].copy()
|
106 |
-
feats = img_info['features'].copy()
|
107 |
-
assert len(boxes) == len(feats)
|
108 |
-
|
109 |
-
# Normalize the boxes (to 0 ~ 1)
|
110 |
-
img_h, img_w = img_info['img_h'], img_info['img_w']
|
111 |
-
boxes[..., (0, 2)] /= img_w
|
112 |
-
boxes[..., (1, 3)] /= img_h
|
113 |
-
np.testing.assert_array_less(boxes, 1+1e-5)
|
114 |
-
np.testing.assert_array_less(-boxes, 0+1e-5)
|
115 |
-
|
116 |
-
boxes2.append(boxes)
|
117 |
-
feats2.append(feats)
|
118 |
-
feats = np.stack(feats2)
|
119 |
-
boxes = np.stack(boxes2)
|
120 |
-
|
121 |
-
# Create target
|
122 |
-
if 'label' in datum:
|
123 |
-
label = datum['label']
|
124 |
-
return ques_id, feats, boxes, ques, label
|
125 |
-
else:
|
126 |
-
return ques_id, feats, boxes, ques
|
127 |
-
|
128 |
-
|
129 |
-
class NLVR2Evaluator:
|
130 |
-
def __init__(self, dataset: NLVR2Dataset):
|
131 |
-
self.dataset = dataset
|
132 |
-
|
133 |
-
def evaluate(self, quesid2ans: dict):
|
134 |
-
score = 0.
|
135 |
-
for quesid, ans in quesid2ans.items():
|
136 |
-
datum = self.dataset.id2datum[quesid]
|
137 |
-
label = datum['label']
|
138 |
-
if ans == label:
|
139 |
-
score += 1
|
140 |
-
return score / len(quesid2ans)
|
141 |
-
|
142 |
-
def dump_result(self, quesid2ans: dict, path):
|
143 |
-
"""
|
144 |
-
Dump result to a CSV file, which is compatible with NLVR2 evaluation system.
|
145 |
-
NLVR2 CSV file requirement:
|
146 |
-
Each line contains: identifier, answer
|
147 |
-
|
148 |
-
:param quesid2ans: nlvr2 uid to ans (either "True" or "False")
|
149 |
-
:param path: The desired path of saved file.
|
150 |
-
:return:
|
151 |
-
"""
|
152 |
-
with open(path, 'w') as f:
|
153 |
-
for uid, ans in quesid2ans.items():
|
154 |
-
idt = self.dataset.id2datum[uid]["identifier"]
|
155 |
-
ans = 'True' if ans == 1 else 'False'
|
156 |
-
f.write("%s,%s\n" % (idt, ans))
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/nlvr2_model.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from lxrt.modeling import GeLU, BertLayerNorm
|
6 |
-
from lxrt.entry import LXRTEncoder
|
7 |
-
from param import args
|
8 |
-
|
9 |
-
|
10 |
-
class NLVR2Model(nn.Module):
|
11 |
-
def __init__(self):
|
12 |
-
super().__init__()
|
13 |
-
self.lxrt_encoder = LXRTEncoder(
|
14 |
-
args,
|
15 |
-
max_seq_length=20
|
16 |
-
)
|
17 |
-
self.hid_dim = hid_dim = self.lxrt_encoder.dim
|
18 |
-
self.logit_fc = nn.Sequential(
|
19 |
-
nn.Linear(hid_dim * 2, hid_dim * 2),
|
20 |
-
GeLU(),
|
21 |
-
BertLayerNorm(hid_dim * 2, eps=1e-12),
|
22 |
-
nn.Linear(hid_dim * 2, 2)
|
23 |
-
)
|
24 |
-
self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
|
25 |
-
|
26 |
-
def forward(self, feat, pos, sent):
|
27 |
-
"""
|
28 |
-
:param feat: b, 2, o, f
|
29 |
-
:param pos: b, 2, o, 4
|
30 |
-
:param sent: b, (string)
|
31 |
-
:param leng: b, (numpy, int)
|
32 |
-
:return:
|
33 |
-
"""
|
34 |
-
# Pairing images and sentences:
|
35 |
-
# The input of NLVR2 is two images and one sentence. In batch level, they are saved as
|
36 |
-
# [ [img0_0, img0_1], [img1_0, img1_1], ...] and [sent0, sent1, ...]
|
37 |
-
# Here, we flat them to
|
38 |
-
# feat/pos = [ img0_0, img0_1, img1_0, img1_1, ...]
|
39 |
-
# sent = [ sent0, sent0, sent1, sent1, ...]
|
40 |
-
sent = sum(zip(sent, sent), ())
|
41 |
-
batch_size, img_num, obj_num, feat_size = feat.size()
|
42 |
-
assert img_num == 2 and obj_num == 36 and feat_size == 2048
|
43 |
-
feat = feat.view(batch_size * 2, obj_num, feat_size)
|
44 |
-
pos = pos.view(batch_size * 2, obj_num, 4)
|
45 |
-
|
46 |
-
# Extract feature --> Concat
|
47 |
-
x = self.lxrt_encoder(sent, (feat, pos))
|
48 |
-
x = x.view(-1, self.hid_dim*2)
|
49 |
-
|
50 |
-
# Compute logit of answers
|
51 |
-
logit = self.logit_fc(x)
|
52 |
-
|
53 |
-
return logit
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/vqa.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import os
|
5 |
-
import collections
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from torch.utils.data.dataloader import DataLoader
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from ..param import args
|
13 |
-
from ..pretrain.qa_answer_table import load_lxmert_qa
|
14 |
-
from .vqa_model import VQAModel
|
15 |
-
from .vqa_data import VQADataset, VQATorchDataset, VQAEvaluator
|
16 |
-
|
17 |
-
DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
|
18 |
-
|
19 |
-
|
20 |
-
def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
|
21 |
-
dset = VQADataset(splits)
|
22 |
-
tset = VQATorchDataset(dset)
|
23 |
-
evaluator = VQAEvaluator(dset)
|
24 |
-
data_loader = DataLoader(
|
25 |
-
tset, batch_size=bs,
|
26 |
-
shuffle=shuffle, num_workers=args.num_workers,
|
27 |
-
drop_last=drop_last, pin_memory=True
|
28 |
-
)
|
29 |
-
|
30 |
-
return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
|
31 |
-
|
32 |
-
|
33 |
-
class VQA:
|
34 |
-
def __init__(self):
|
35 |
-
# Datasets
|
36 |
-
self.train_tuple = get_data_tuple(
|
37 |
-
args.train, bs=args.batch_size, shuffle=True, drop_last=True
|
38 |
-
)
|
39 |
-
if args.valid != "":
|
40 |
-
self.valid_tuple = get_data_tuple(
|
41 |
-
args.valid, bs=1024,
|
42 |
-
shuffle=False, drop_last=False
|
43 |
-
)
|
44 |
-
else:
|
45 |
-
self.valid_tuple = None
|
46 |
-
|
47 |
-
# Model
|
48 |
-
self.model = VQAModel(self.train_tuple.dataset.num_answers)
|
49 |
-
|
50 |
-
# Load pre-trained weights
|
51 |
-
if args.load_lxmert is not None:
|
52 |
-
self.model.lxrt_encoder.load(args.load_lxmert)
|
53 |
-
if args.load_lxmert_qa is not None:
|
54 |
-
load_lxmert_qa(args.load_lxmert_qa, self.model,
|
55 |
-
label2ans=self.train_tuple.dataset.label2ans)
|
56 |
-
|
57 |
-
# GPU options
|
58 |
-
self.model = self.model.cuda()
|
59 |
-
if args.multiGPU:
|
60 |
-
self.model.lxrt_encoder.multi_gpu()
|
61 |
-
|
62 |
-
# Loss and Optimizer
|
63 |
-
self.bce_loss = nn.BCEWithLogitsLoss()
|
64 |
-
if 'bert' in args.optim:
|
65 |
-
batch_per_epoch = len(self.train_tuple.loader)
|
66 |
-
t_total = int(batch_per_epoch * args.epochs)
|
67 |
-
print("BertAdam Total Iters: %d" % t_total)
|
68 |
-
from ..lxrt.optimization import BertAdam
|
69 |
-
self.optim = BertAdam(list(self.model.parameters()),
|
70 |
-
lr=args.lr,
|
71 |
-
warmup=0.1,
|
72 |
-
t_total=t_total)
|
73 |
-
else:
|
74 |
-
self.optim = args.optimizer(self.model.parameters(), args.lr)
|
75 |
-
|
76 |
-
# Output Directory
|
77 |
-
self.output = args.output
|
78 |
-
os.makedirs(self.output, exist_ok=True)
|
79 |
-
|
80 |
-
def train(self, train_tuple, eval_tuple):
|
81 |
-
dset, loader, evaluator = train_tuple
|
82 |
-
iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
|
83 |
-
|
84 |
-
best_valid = 0.
|
85 |
-
for epoch in range(args.epochs):
|
86 |
-
quesid2ans = {}
|
87 |
-
for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
|
88 |
-
|
89 |
-
self.model.train()
|
90 |
-
self.optim.zero_grad()
|
91 |
-
|
92 |
-
feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
|
93 |
-
logit = self.model(feats, boxes, sent)
|
94 |
-
assert logit.dim() == target.dim() == 2
|
95 |
-
loss = self.bce_loss(logit, target)
|
96 |
-
loss = loss * logit.size(1)
|
97 |
-
|
98 |
-
loss.backward()
|
99 |
-
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
|
100 |
-
self.optim.step()
|
101 |
-
|
102 |
-
score, label = logit.max(1)
|
103 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
104 |
-
ans = dset.label2ans[l]
|
105 |
-
quesid2ans[qid.item()] = ans
|
106 |
-
|
107 |
-
log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
|
108 |
-
|
109 |
-
if self.valid_tuple is not None: # Do Validation
|
110 |
-
valid_score = self.evaluate(eval_tuple)
|
111 |
-
if valid_score > best_valid:
|
112 |
-
best_valid = valid_score
|
113 |
-
self.save("BEST")
|
114 |
-
|
115 |
-
log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
|
116 |
-
"Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
|
117 |
-
|
118 |
-
print(log_str, end='')
|
119 |
-
|
120 |
-
with open(self.output + "/log.log", 'a') as f:
|
121 |
-
f.write(log_str)
|
122 |
-
f.flush()
|
123 |
-
|
124 |
-
self.save("LAST")
|
125 |
-
|
126 |
-
def predict(self, eval_tuple: DataTuple, dump=None):
|
127 |
-
"""
|
128 |
-
Predict the answers to questions in a data split.
|
129 |
-
|
130 |
-
:param eval_tuple: The data tuple to be evaluated.
|
131 |
-
:param dump: The path of saved file to dump results.
|
132 |
-
:return: A dict of question_id to answer.
|
133 |
-
"""
|
134 |
-
self.model.eval()
|
135 |
-
dset, loader, evaluator = eval_tuple
|
136 |
-
quesid2ans = {}
|
137 |
-
for i, datum_tuple in enumerate(loader):
|
138 |
-
ques_id, feats, boxes, sent = datum_tuple[:4] # Avoid seeing ground truth
|
139 |
-
with torch.no_grad():
|
140 |
-
feats, boxes = feats.cuda(), boxes.cuda()
|
141 |
-
logit = self.model(feats, boxes, sent)
|
142 |
-
score, label = logit.max(1)
|
143 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
144 |
-
ans = dset.label2ans[l]
|
145 |
-
quesid2ans[qid.item()] = ans
|
146 |
-
if dump is not None:
|
147 |
-
evaluator.dump_result(quesid2ans, dump)
|
148 |
-
return quesid2ans
|
149 |
-
|
150 |
-
def evaluate(self, eval_tuple: DataTuple, dump=None):
|
151 |
-
"""Evaluate all data in data_tuple."""
|
152 |
-
quesid2ans = self.predict(eval_tuple, dump)
|
153 |
-
return eval_tuple.evaluator.evaluate(quesid2ans)
|
154 |
-
|
155 |
-
@staticmethod
|
156 |
-
def oracle_score(data_tuple):
|
157 |
-
dset, loader, evaluator = data_tuple
|
158 |
-
quesid2ans = {}
|
159 |
-
for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
|
160 |
-
_, label = target.max(1)
|
161 |
-
for qid, l in zip(ques_id, label.cpu().numpy()):
|
162 |
-
ans = dset.label2ans[l]
|
163 |
-
quesid2ans[qid.item()] = ans
|
164 |
-
return evaluator.evaluate(quesid2ans)
|
165 |
-
|
166 |
-
def save(self, name):
|
167 |
-
torch.save(self.model.state_dict(),
|
168 |
-
os.path.join(self.output, "%s.pth" % name))
|
169 |
-
|
170 |
-
def load(self, path):
|
171 |
-
print("Load model from %s" % path)
|
172 |
-
state_dict = torch.load("%s.pth" % path)
|
173 |
-
self.model.load_state_dict(state_dict)
|
174 |
-
|
175 |
-
|
176 |
-
if __name__ == "__main__":
|
177 |
-
# Build Class
|
178 |
-
vqa = VQA()
|
179 |
-
|
180 |
-
# Load VQA model weights
|
181 |
-
# Note: It is different from loading lxmert pre-trained weights.
|
182 |
-
if args.load is not None:
|
183 |
-
vqa.load(args.load)
|
184 |
-
|
185 |
-
# Test or Train
|
186 |
-
if args.test is not None:
|
187 |
-
args.fast = args.tiny = False # Always loading all data in test
|
188 |
-
if 'test' in args.test:
|
189 |
-
vqa.predict(
|
190 |
-
get_data_tuple(args.test, bs=950,
|
191 |
-
shuffle=False, drop_last=False),
|
192 |
-
dump=os.path.join(args.output, 'test_predict.json')
|
193 |
-
)
|
194 |
-
elif 'val' in args.test:
|
195 |
-
# Since part of valididation data are used in pre-training/fine-tuning,
|
196 |
-
# only validate on the minival set.
|
197 |
-
result = vqa.evaluate(
|
198 |
-
get_data_tuple('minival', bs=950,
|
199 |
-
shuffle=False, drop_last=False),
|
200 |
-
dump=os.path.join(args.output, 'minival_predict.json')
|
201 |
-
)
|
202 |
-
print(result)
|
203 |
-
else:
|
204 |
-
assert False, "No such test option for %s" % args.test
|
205 |
-
else:
|
206 |
-
print('Splits in Train data:', vqa.train_tuple.dataset.splits)
|
207 |
-
if vqa.valid_tuple is not None:
|
208 |
-
print('Splits in Valid data:', vqa.valid_tuple.dataset.splits)
|
209 |
-
print("Valid Oracle: %0.2f" % (vqa.oracle_score(vqa.valid_tuple) * 100))
|
210 |
-
else:
|
211 |
-
print("DO NOT USE VALIDATION")
|
212 |
-
vqa.train(vqa.train_tuple, vqa.valid_tuple)
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/vqa_data.py
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import json
|
5 |
-
import os
|
6 |
-
import pickle
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
from torch.utils.data import Dataset
|
11 |
-
|
12 |
-
from ..param import args
|
13 |
-
from ..utils import load_obj_tsv
|
14 |
-
|
15 |
-
# Load part of the dataset for fast checking.
|
16 |
-
# Notice that here is the number of images instead of the number of data,
|
17 |
-
# which means all related data to the images would be used.
|
18 |
-
TINY_IMG_NUM = 512
|
19 |
-
FAST_IMG_NUM = 5000
|
20 |
-
|
21 |
-
# The path to data and image features.
|
22 |
-
VQA_DATA_ROOT = 'data/vqa/'
|
23 |
-
MSCOCO_IMGFEAT_ROOT = 'data/mscoco_imgfeat/'
|
24 |
-
SPLIT2NAME = {
|
25 |
-
'train': 'train2014',
|
26 |
-
'valid': 'val2014',
|
27 |
-
'minival': 'val2014',
|
28 |
-
'nominival': 'val2014',
|
29 |
-
'test': 'test2015',
|
30 |
-
}
|
31 |
-
|
32 |
-
|
33 |
-
class VQADataset:
|
34 |
-
"""
|
35 |
-
A VQA data example in json file:
|
36 |
-
{
|
37 |
-
"answer_type": "other",
|
38 |
-
"img_id": "COCO_train2014_000000458752",
|
39 |
-
"label": {
|
40 |
-
"net": 1
|
41 |
-
},
|
42 |
-
"question_id": 458752000,
|
43 |
-
"question_type": "what is this",
|
44 |
-
"sent": "What is this photo taken looking through?"
|
45 |
-
}
|
46 |
-
"""
|
47 |
-
def __init__(self, splits: str):
|
48 |
-
self.name = splits
|
49 |
-
self.splits = splits.split(',')
|
50 |
-
|
51 |
-
# Loading datasets
|
52 |
-
self.data = []
|
53 |
-
for split in self.splits:
|
54 |
-
self.data.extend(json.load(open("data/vqa/%s.json" % split)))
|
55 |
-
print("Load %d data from split(s) %s." % (len(self.data), self.name))
|
56 |
-
|
57 |
-
# Convert list to dict (for evaluation)
|
58 |
-
self.id2datum = {
|
59 |
-
datum['question_id']: datum
|
60 |
-
for datum in self.data
|
61 |
-
}
|
62 |
-
|
63 |
-
# Answers
|
64 |
-
self.ans2label = json.load(open("data/vqa/trainval_ans2label.json"))
|
65 |
-
self.label2ans = json.load(open("data/vqa/trainval_label2ans.json"))
|
66 |
-
assert len(self.ans2label) == len(self.label2ans)
|
67 |
-
|
68 |
-
@property
|
69 |
-
def num_answers(self):
|
70 |
-
return len(self.ans2label)
|
71 |
-
|
72 |
-
def __len__(self):
|
73 |
-
return len(self.data)
|
74 |
-
|
75 |
-
|
76 |
-
"""
|
77 |
-
An example in obj36 tsv:
|
78 |
-
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
|
79 |
-
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
|
80 |
-
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
|
81 |
-
"""
|
82 |
-
class VQATorchDataset(Dataset):
|
83 |
-
def __init__(self, dataset: VQADataset):
|
84 |
-
super().__init__()
|
85 |
-
self.raw_dataset = dataset
|
86 |
-
|
87 |
-
if args.tiny:
|
88 |
-
topk = TINY_IMG_NUM
|
89 |
-
elif args.fast:
|
90 |
-
topk = FAST_IMG_NUM
|
91 |
-
else:
|
92 |
-
topk = None
|
93 |
-
|
94 |
-
# Loading detection features to img_data
|
95 |
-
img_data = []
|
96 |
-
for split in dataset.splits:
|
97 |
-
# Minival is 5K images in MS COCO, which is used in evaluating VQA/lxmert-pre-training.
|
98 |
-
# It is saved as the top 5K features in val2014_***.tsv
|
99 |
-
load_topk = 5000 if (split == 'minival' and topk is None) else topk
|
100 |
-
img_data.extend(load_obj_tsv(
|
101 |
-
os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
|
102 |
-
topk=load_topk))
|
103 |
-
|
104 |
-
# Convert img list to dict
|
105 |
-
self.imgid2img = {}
|
106 |
-
for img_datum in img_data:
|
107 |
-
self.imgid2img[img_datum['img_id']] = img_datum
|
108 |
-
|
109 |
-
# Only kept the data with loaded image features
|
110 |
-
self.data = []
|
111 |
-
for datum in self.raw_dataset.data:
|
112 |
-
if datum['img_id'] in self.imgid2img:
|
113 |
-
self.data.append(datum)
|
114 |
-
print("Use %d data in torch dataset" % (len(self.data)))
|
115 |
-
print()
|
116 |
-
|
117 |
-
def __len__(self):
|
118 |
-
return len(self.data)
|
119 |
-
|
120 |
-
def __getitem__(self, item: int):
|
121 |
-
datum = self.data[item]
|
122 |
-
|
123 |
-
img_id = datum['img_id']
|
124 |
-
ques_id = datum['question_id']
|
125 |
-
ques = datum['sent']
|
126 |
-
|
127 |
-
# Get image info
|
128 |
-
img_info = self.imgid2img[img_id]
|
129 |
-
obj_num = img_info['num_boxes']
|
130 |
-
feats = img_info['features'].copy()
|
131 |
-
boxes = img_info['boxes'].copy()
|
132 |
-
assert obj_num == len(boxes) == len(feats)
|
133 |
-
|
134 |
-
# Normalize the boxes (to 0 ~ 1)
|
135 |
-
img_h, img_w = img_info['img_h'], img_info['img_w']
|
136 |
-
boxes = boxes.copy()
|
137 |
-
boxes[:, (0, 2)] /= img_w
|
138 |
-
boxes[:, (1, 3)] /= img_h
|
139 |
-
np.testing.assert_array_less(boxes, 1+1e-5)
|
140 |
-
np.testing.assert_array_less(-boxes, 0+1e-5)
|
141 |
-
|
142 |
-
# Provide label (target)
|
143 |
-
if 'label' in datum:
|
144 |
-
label = datum['label']
|
145 |
-
target = torch.zeros(self.raw_dataset.num_answers)
|
146 |
-
for ans, score in label.items():
|
147 |
-
target[self.raw_dataset.ans2label[ans]] = score
|
148 |
-
return ques_id, feats, boxes, ques, target
|
149 |
-
else:
|
150 |
-
return ques_id, feats, boxes, ques
|
151 |
-
|
152 |
-
|
153 |
-
class VQAEvaluator:
|
154 |
-
def __init__(self, dataset: VQADataset):
|
155 |
-
self.dataset = dataset
|
156 |
-
|
157 |
-
def evaluate(self, quesid2ans: dict):
|
158 |
-
score = 0.
|
159 |
-
for quesid, ans in quesid2ans.items():
|
160 |
-
datum = self.dataset.id2datum[quesid]
|
161 |
-
label = datum['label']
|
162 |
-
if ans in label:
|
163 |
-
score += label[ans]
|
164 |
-
return score / len(quesid2ans)
|
165 |
-
|
166 |
-
def dump_result(self, quesid2ans: dict, path):
|
167 |
-
"""
|
168 |
-
Dump results to a json file, which could be submitted to the VQA online evaluation.
|
169 |
-
VQA json file submission requirement:
|
170 |
-
results = [result]
|
171 |
-
result = {
|
172 |
-
"question_id": int,
|
173 |
-
"answer": str
|
174 |
-
}
|
175 |
-
|
176 |
-
:param quesid2ans: dict of quesid --> ans
|
177 |
-
:param path: The desired path of saved file.
|
178 |
-
"""
|
179 |
-
with open(path, 'w') as f:
|
180 |
-
result = []
|
181 |
-
for ques_id, ans in quesid2ans.items():
|
182 |
-
result.append({
|
183 |
-
'question_id': ques_id,
|
184 |
-
'answer': ans
|
185 |
-
})
|
186 |
-
json.dump(result, f, indent=4, sort_keys=True)
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/tasks/vqa_model.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
from ..param import args
|
7 |
-
from ..lxrt.entry import LXRTEncoder
|
8 |
-
from ..lxrt.modeling import BertLayerNorm, GeLU
|
9 |
-
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
10 |
-
|
11 |
-
# Max length including <bos> and <eos>
|
12 |
-
MAX_VQA_LENGTH = 20
|
13 |
-
|
14 |
-
|
15 |
-
class VQAModel(nn.Module):
|
16 |
-
def __init__(self, num_answers):
|
17 |
-
super().__init__()
|
18 |
-
|
19 |
-
# # Build LXRT encoder
|
20 |
-
# self.lxrt_encoder = LXRTEncoder(
|
21 |
-
# args,
|
22 |
-
# max_seq_length=MAX_VQA_LENGTH
|
23 |
-
# )
|
24 |
-
# hid_dim = self.lxrt_encoder.dim
|
25 |
-
#
|
26 |
-
# # VQA Answer heads
|
27 |
-
# self.logit_fc = nn.Sequential(
|
28 |
-
# nn.Linear(hid_dim, hid_dim * 2),
|
29 |
-
# GeLU(),
|
30 |
-
# BertLayerNorm(hid_dim * 2, eps=1e-12),
|
31 |
-
# nn.Linear(hid_dim * 2, num_answers)
|
32 |
-
# )
|
33 |
-
# self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
|
34 |
-
|
35 |
-
self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased")
|
36 |
-
self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased")
|
37 |
-
|
38 |
-
def forward(self, feat, pos, sent):
|
39 |
-
"""
|
40 |
-
b -- batch_size, o -- object_number, f -- visual_feature_size
|
41 |
-
|
42 |
-
:param feat: (b, o, f)
|
43 |
-
:param pos: (b, o, 4)
|
44 |
-
:param sent: (b,) Type -- list of string
|
45 |
-
:param leng: (b,) Type -- int numpy array
|
46 |
-
:return: (b, num_answer) The logit of each answers.
|
47 |
-
"""
|
48 |
-
return self.model(sent, feat, pos)
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|