Spaces:
Sleeping
Sleeping
Delete lxmert/src/pretrain
Browse files
lxmert/src/pretrain/__init__.py
DELETED
File without changes
|
lxmert/src/pretrain/lxmert_data.py
DELETED
@@ -1,255 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
from collections import defaultdict
|
5 |
-
import json
|
6 |
-
import random
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
from torch.utils.data import Dataset
|
10 |
-
|
11 |
-
from param import args
|
12 |
-
from pretrain.qa_answer_table import AnswerTable
|
13 |
-
from utils import load_obj_tsv
|
14 |
-
|
15 |
-
TINY_IMG_NUM = 500
|
16 |
-
FAST_IMG_NUM = 5000
|
17 |
-
|
18 |
-
Split2ImgFeatPath = {
|
19 |
-
'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
|
20 |
-
'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
|
21 |
-
'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
|
22 |
-
'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
class InputExample(object):
|
27 |
-
"""A single training/test example for the language model."""
|
28 |
-
def __init__(self, uid, sent, visual_feats=None,
|
29 |
-
obj_labels=None, attr_labels=None,
|
30 |
-
is_matched=None, label=None):
|
31 |
-
self.uid = uid
|
32 |
-
self.sent = sent
|
33 |
-
self.visual_feats = visual_feats
|
34 |
-
self.obj_labels = obj_labels
|
35 |
-
self.attr_labels = attr_labels
|
36 |
-
self.is_matched = is_matched # whether the visual and obj matched
|
37 |
-
self.label = label
|
38 |
-
|
39 |
-
|
40 |
-
class LXMERTDataset:
|
41 |
-
def __init__(self, splits: str, qa_sets=None):
|
42 |
-
"""
|
43 |
-
:param splits: The data sources to be loaded
|
44 |
-
:param qa_sets: if None, no action
|
45 |
-
o.w., only takes the answers appearing in these dsets
|
46 |
-
and remove all unlabeled data (MSCOCO captions)
|
47 |
-
"""
|
48 |
-
self.name = splits
|
49 |
-
self.sources = splits.split(',')
|
50 |
-
|
51 |
-
# Loading datasets to data
|
52 |
-
self.data = []
|
53 |
-
for source in self.sources:
|
54 |
-
self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
|
55 |
-
print("Load %d data from %s" % (len(self.data), self.name))
|
56 |
-
|
57 |
-
# Create answer table according to the qa_sets
|
58 |
-
self.answer_table = AnswerTable(qa_sets)
|
59 |
-
print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
|
60 |
-
|
61 |
-
# Modify the answers
|
62 |
-
for datum in self.data:
|
63 |
-
labelf = datum['labelf']
|
64 |
-
for cat, labels in labelf.items():
|
65 |
-
for label in labels:
|
66 |
-
for ans in list(label.keys()):
|
67 |
-
new_ans = self.answer_table.convert_ans(ans)
|
68 |
-
if self.answer_table.used(new_ans):
|
69 |
-
if ans != new_ans:
|
70 |
-
label[new_ans] = label.pop(ans)
|
71 |
-
else:
|
72 |
-
label.pop(ans)
|
73 |
-
|
74 |
-
def __len__(self):
|
75 |
-
return len(self.data)
|
76 |
-
|
77 |
-
|
78 |
-
def make_uid(img_id, dset, sent_idx):
|
79 |
-
return "%s_%s_%03d" % (img_id, dset, sent_idx),
|
80 |
-
|
81 |
-
|
82 |
-
"""
|
83 |
-
Example in obj tsv:
|
84 |
-
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
|
85 |
-
"attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
|
86 |
-
"""
|
87 |
-
class LXMERTTorchDataset(Dataset):
|
88 |
-
def __init__(self, dataset: LXMERTDataset, topk=-1):
|
89 |
-
super().__init__()
|
90 |
-
self.raw_dataset = dataset
|
91 |
-
self.task_matched = args.task_matched
|
92 |
-
|
93 |
-
if args.tiny:
|
94 |
-
topk = TINY_IMG_NUM
|
95 |
-
elif args.fast:
|
96 |
-
topk = FAST_IMG_NUM
|
97 |
-
|
98 |
-
# Load the dataset
|
99 |
-
img_data = []
|
100 |
-
for source in self.raw_dataset.sources:
|
101 |
-
img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))
|
102 |
-
|
103 |
-
self.imgid2img = {}
|
104 |
-
for img_datum in img_data:
|
105 |
-
self.imgid2img[img_datum['img_id']] = img_datum
|
106 |
-
|
107 |
-
# Filter out the dataset
|
108 |
-
used_data = []
|
109 |
-
for datum in self.raw_dataset.data:
|
110 |
-
if datum['img_id'] in self.imgid2img:
|
111 |
-
used_data.append(datum)
|
112 |
-
|
113 |
-
# Flatten the dataset (into one sent + one image entries)
|
114 |
-
self.data = []
|
115 |
-
for datum in used_data:
|
116 |
-
sentf = datum['sentf']
|
117 |
-
for sents_cat, sents in sentf.items():
|
118 |
-
if sents_cat in datum['labelf']:
|
119 |
-
labels = datum['labelf'][sents_cat]
|
120 |
-
else:
|
121 |
-
labels = None
|
122 |
-
for sent_idx, sent in enumerate(sents):
|
123 |
-
new_datum = {
|
124 |
-
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
|
125 |
-
'img_id': datum['img_id'],
|
126 |
-
'sent': sent
|
127 |
-
}
|
128 |
-
if labels is not None:
|
129 |
-
new_datum['label'] = labels[sent_idx]
|
130 |
-
self.data.append(new_datum)
|
131 |
-
print("Use %d data in torch dataset" % (len(self.data)))
|
132 |
-
|
133 |
-
def __len__(self):
|
134 |
-
return len(self.data)
|
135 |
-
|
136 |
-
def random_feat(self):
|
137 |
-
"""Get a random obj feat from the dataset."""
|
138 |
-
datum = self.data[random.randint(0, len(self.data)-1)]
|
139 |
-
img_id = datum['img_id']
|
140 |
-
img_info = self.imgid2img[img_id]
|
141 |
-
feat = img_info['features'][random.randint(0, 35)]
|
142 |
-
return feat
|
143 |
-
|
144 |
-
def __getitem__(self, item: int):
|
145 |
-
datum = self.data[item]
|
146 |
-
|
147 |
-
uid = datum['uid']
|
148 |
-
img_id = datum['img_id']
|
149 |
-
|
150 |
-
# Get image info
|
151 |
-
img_info = self.imgid2img[img_id]
|
152 |
-
obj_num = img_info['num_boxes']
|
153 |
-
feats = img_info['features'].copy()
|
154 |
-
boxes = img_info['boxes'].copy()
|
155 |
-
obj_labels = img_info['objects_id'].copy()
|
156 |
-
obj_confs = img_info['objects_conf'].copy()
|
157 |
-
attr_labels = img_info['attrs_id'].copy()
|
158 |
-
attr_confs = img_info['attrs_conf'].copy()
|
159 |
-
assert obj_num == len(boxes) == len(feats)
|
160 |
-
|
161 |
-
# Normalize the boxes (to 0 ~ 1)
|
162 |
-
img_h, img_w = img_info['img_h'], img_info['img_w']
|
163 |
-
boxes = boxes.copy()
|
164 |
-
boxes[:, (0, 2)] /= img_w
|
165 |
-
boxes[:, (1, 3)] /= img_h
|
166 |
-
np.testing.assert_array_less(boxes, 1+1e-5)
|
167 |
-
np.testing.assert_array_less(-boxes, 0+1e-5)
|
168 |
-
|
169 |
-
# If calculating the matched loss, replace the sentence with an sentence
|
170 |
-
# corresponding to other image.
|
171 |
-
is_matched = 1
|
172 |
-
sent = datum['sent']
|
173 |
-
if self.task_matched:
|
174 |
-
if random.random() < 0.5:
|
175 |
-
is_matched = 0
|
176 |
-
other_datum = self.data[random.randint(0, len(self.data)-1)]
|
177 |
-
while other_datum['img_id'] == img_id:
|
178 |
-
other_datum = self.data[random.randint(0, len(self.data)-1)]
|
179 |
-
sent = other_datum['sent']
|
180 |
-
|
181 |
-
# Label, convert answer to id
|
182 |
-
if 'label' in datum:
|
183 |
-
label = datum['label'].copy()
|
184 |
-
for ans in list(label.keys()):
|
185 |
-
label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
|
186 |
-
else:
|
187 |
-
label = None
|
188 |
-
|
189 |
-
# Create target
|
190 |
-
example = InputExample(
|
191 |
-
uid, sent, (feats, boxes),
|
192 |
-
(obj_labels, obj_confs), (attr_labels, attr_confs),
|
193 |
-
is_matched, label
|
194 |
-
)
|
195 |
-
return example
|
196 |
-
|
197 |
-
|
198 |
-
class LXMERTEvaluator:
|
199 |
-
def __init__(self, dataset: LXMERTDataset):
|
200 |
-
self.raw_dataset = dataset
|
201 |
-
|
202 |
-
# Create QA Eval Data
|
203 |
-
self.data = []
|
204 |
-
for datum in self.raw_dataset.data:
|
205 |
-
sentf = datum['sentf']
|
206 |
-
for sents_cat, sents in sentf.items():
|
207 |
-
if sents_cat in datum['labelf']: # A labeled dataset
|
208 |
-
labels = datum['labelf'][sents_cat]
|
209 |
-
for sent_idx, sent in enumerate(sents):
|
210 |
-
new_datum = {
|
211 |
-
'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
|
212 |
-
'img_id': datum['img_id'],
|
213 |
-
'sent': sent,
|
214 |
-
'dset': sents_cat,
|
215 |
-
'label': labels[sent_idx]
|
216 |
-
}
|
217 |
-
self.data.append(new_datum)
|
218 |
-
|
219 |
-
# uid2datum
|
220 |
-
self.uid2datum = {}
|
221 |
-
for datum in self.data:
|
222 |
-
self.uid2datum[datum['uid']] = datum
|
223 |
-
|
224 |
-
def evaluate(self, uid2ans: dict, pprint=False):
|
225 |
-
score = 0.
|
226 |
-
cnt = 0
|
227 |
-
dset2score = defaultdict(lambda: 0.)
|
228 |
-
dset2cnt = defaultdict(lambda: 0)
|
229 |
-
for uid, ans in uid2ans.items():
|
230 |
-
if uid not in self.uid2datum: # Not a labeled data
|
231 |
-
continue
|
232 |
-
datum = self.uid2datum[uid]
|
233 |
-
label = datum['label']
|
234 |
-
dset = datum['dset']
|
235 |
-
if ans in label:
|
236 |
-
score += label[ans]
|
237 |
-
dset2score[dset] += label[ans]
|
238 |
-
cnt += 1
|
239 |
-
dset2cnt[dset] += 1
|
240 |
-
accu = score / cnt
|
241 |
-
dset2accu = {}
|
242 |
-
for dset in dset2cnt:
|
243 |
-
dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
|
244 |
-
|
245 |
-
if pprint:
|
246 |
-
accu_str = "Overall Accu %0.4f, " % (accu)
|
247 |
-
sorted_keys = sorted(dset2accu.keys())
|
248 |
-
for key in sorted_keys:
|
249 |
-
accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
|
250 |
-
print(accu_str)
|
251 |
-
|
252 |
-
return accu, dset2accu
|
253 |
-
|
254 |
-
def dump_result(self, uid2ans: dict, path):
|
255 |
-
raise NotImplemented
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/pretrain/lxmert_pretrain.py
DELETED
@@ -1,435 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import collections
|
5 |
-
import os
|
6 |
-
import random
|
7 |
-
|
8 |
-
from tqdm import tqdm
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
from torch.utils.data import DataLoader
|
13 |
-
|
14 |
-
from param import args
|
15 |
-
from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
|
16 |
-
from lxrt.entry import set_visual_config
|
17 |
-
from lxrt.tokenization import BertTokenizer
|
18 |
-
from lxrt.modeling import LXRTPretraining
|
19 |
-
|
20 |
-
DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
|
21 |
-
|
22 |
-
|
23 |
-
def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
|
24 |
-
# Decide which QA datasets would be used in pre-training.
|
25 |
-
# Options: vqa, gqa, visual7w
|
26 |
-
# Note: visual7w is a part of vgqa, we take the name here.
|
27 |
-
qa_sets = args.qa_sets
|
28 |
-
if qa_sets is not None:
|
29 |
-
qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
|
30 |
-
|
31 |
-
# Build dataset, data loader, and evaluator.
|
32 |
-
dset = LXMERTDataset(splits, qa_sets=qa_sets)
|
33 |
-
tset = LXMERTTorchDataset(dset, topk)
|
34 |
-
data_loader = DataLoader(
|
35 |
-
tset, batch_size=bs,
|
36 |
-
shuffle=shuffle, num_workers=args.num_workers,
|
37 |
-
collate_fn=lambda x: x,
|
38 |
-
drop_last=drop_last, pin_memory=True
|
39 |
-
)
|
40 |
-
evaluator = LXMERTEvaluator(dset)
|
41 |
-
print()
|
42 |
-
|
43 |
-
return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
|
44 |
-
|
45 |
-
|
46 |
-
train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
|
47 |
-
valid_batch_size = 2048 if args.multiGPU else 512
|
48 |
-
valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
|
49 |
-
|
50 |
-
|
51 |
-
class InputFeatures(object):
|
52 |
-
"""A single set of features of data."""
|
53 |
-
|
54 |
-
def __init__(self,
|
55 |
-
input_ids, input_mask, segment_ids, lm_label_ids,
|
56 |
-
visual_feats, obj_labels,
|
57 |
-
is_matched, ans):
|
58 |
-
self.input_ids = input_ids
|
59 |
-
self.input_mask = input_mask
|
60 |
-
self.segment_ids = segment_ids
|
61 |
-
self.lm_label_ids = lm_label_ids
|
62 |
-
|
63 |
-
self.visual_feats = visual_feats
|
64 |
-
self.obj_labels = obj_labels
|
65 |
-
|
66 |
-
self.is_matched = is_matched
|
67 |
-
|
68 |
-
self.ans = ans
|
69 |
-
|
70 |
-
|
71 |
-
def random_word(tokens, tokenizer):
|
72 |
-
"""
|
73 |
-
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
|
74 |
-
:param tokens: list of str, tokenized sentence.
|
75 |
-
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
|
76 |
-
:return: (list of str, list of int), masked tokens and related labels for LM prediction
|
77 |
-
"""
|
78 |
-
output_label = []
|
79 |
-
|
80 |
-
for i, token in enumerate(tokens):
|
81 |
-
prob = random.random()
|
82 |
-
# mask token with probability
|
83 |
-
ratio = args.word_mask_rate
|
84 |
-
if prob < ratio:
|
85 |
-
prob /= ratio
|
86 |
-
|
87 |
-
# 80% randomly change token to mask token
|
88 |
-
if prob < 0.8:
|
89 |
-
tokens[i] = "[MASK]"
|
90 |
-
|
91 |
-
# 10% randomly change token to random token
|
92 |
-
elif prob < 0.9:
|
93 |
-
tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
|
94 |
-
|
95 |
-
# -> rest 10% randomly keep current token
|
96 |
-
|
97 |
-
# append current token to output (we will predict these later)
|
98 |
-
try:
|
99 |
-
output_label.append(tokenizer.vocab[token])
|
100 |
-
except KeyError:
|
101 |
-
# For unknown words (should not occur with BPE vocab)
|
102 |
-
output_label.append(tokenizer.vocab["[UNK]"])
|
103 |
-
else:
|
104 |
-
# no masking token (will be ignored by loss function later)
|
105 |
-
output_label.append(-1)
|
106 |
-
|
107 |
-
return tokens, output_label
|
108 |
-
|
109 |
-
|
110 |
-
def random_feat(feats):
|
111 |
-
mask_feats = feats.copy()
|
112 |
-
feat_mask = np.zeros(len(feats), dtype=np.float32)
|
113 |
-
for i in range(len(feats)):
|
114 |
-
prob = random.random()
|
115 |
-
# mask token with probability
|
116 |
-
if prob < args.obj_mask_rate:
|
117 |
-
prob /= args.obj_mask_rate
|
118 |
-
|
119 |
-
# 80% randomly change token to zero feat
|
120 |
-
if prob < 0.8:
|
121 |
-
mask_feats[i, :] = 0.
|
122 |
-
|
123 |
-
# 10% randomly change token to random feat
|
124 |
-
elif prob < 0.9:
|
125 |
-
mask_feats[i, :] = train_tuple.torchdset.random_feat()
|
126 |
-
# -> rest 10% randomly keep current feat
|
127 |
-
|
128 |
-
# Need to predict this feat
|
129 |
-
feat_mask[i] = 1.
|
130 |
-
|
131 |
-
return mask_feats, feat_mask
|
132 |
-
|
133 |
-
|
134 |
-
def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
|
135 |
-
"""
|
136 |
-
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
|
137 |
-
IDs, LM labels, input_mask, CLS and SEP tokens etc.
|
138 |
-
:param example: InputExample, containing sentence input as strings and is_next label
|
139 |
-
:param max_seq_length: int, maximum length of sequence.
|
140 |
-
:param tokenizer: Tokenizer
|
141 |
-
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
|
142 |
-
"""
|
143 |
-
tokens = tokenizer.tokenize(example.sent.strip())
|
144 |
-
|
145 |
-
# Account for [CLS] and [SEP] with "- 2"
|
146 |
-
if len(tokens) > max_seq_length - 2:
|
147 |
-
tokens = tokens[:(max_seq_length - 2)]
|
148 |
-
|
149 |
-
# Ge random words
|
150 |
-
masked_tokens, masked_label = random_word(tokens, tokenizer)
|
151 |
-
|
152 |
-
# concatenate lm labels and account for CLS, SEP, SEP
|
153 |
-
masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
|
154 |
-
input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
|
155 |
-
|
156 |
-
# Mask & Segment Word
|
157 |
-
lm_label_ids = ([-1] + masked_label + [-1])
|
158 |
-
input_mask = [1] * len(input_ids)
|
159 |
-
segment_ids = [0] * len(input_ids)
|
160 |
-
|
161 |
-
# Zero-pad up to the sequence length.
|
162 |
-
while len(input_ids) < max_seq_length:
|
163 |
-
input_ids.append(0)
|
164 |
-
input_mask.append(0)
|
165 |
-
segment_ids.append(0)
|
166 |
-
lm_label_ids.append(-1)
|
167 |
-
|
168 |
-
assert len(input_ids) == max_seq_length
|
169 |
-
assert len(input_mask) == max_seq_length
|
170 |
-
assert len(segment_ids) == max_seq_length
|
171 |
-
assert len(lm_label_ids) == max_seq_length
|
172 |
-
|
173 |
-
feat, boxes = example.visual_feats
|
174 |
-
obj_labels, obj_confs = example.obj_labels
|
175 |
-
attr_labels, attr_confs = example.attr_labels
|
176 |
-
|
177 |
-
# Mask Image Features:
|
178 |
-
masked_feat, feat_mask = random_feat(feat)
|
179 |
-
|
180 |
-
# QA answer label
|
181 |
-
if example.label is None or len(example.label) == 0 or example.is_matched != 1:
|
182 |
-
# 1. No label 2. Label is pruned 3. unmatched visual + language pair
|
183 |
-
ans = -1
|
184 |
-
else:
|
185 |
-
keys, values = zip(*example.label.items())
|
186 |
-
if len(keys) == 1:
|
187 |
-
ans = keys[0]
|
188 |
-
else:
|
189 |
-
value_sum = sum(values)
|
190 |
-
prob = [value / value_sum for value in values]
|
191 |
-
choice = np.random.multinomial(1, prob).argmax()
|
192 |
-
ans = keys[choice]
|
193 |
-
|
194 |
-
features = InputFeatures(
|
195 |
-
input_ids=input_ids,
|
196 |
-
input_mask=input_mask,
|
197 |
-
segment_ids=segment_ids,
|
198 |
-
lm_label_ids=lm_label_ids,
|
199 |
-
visual_feats=(masked_feat, boxes),
|
200 |
-
obj_labels={
|
201 |
-
'obj': (obj_labels, obj_confs),
|
202 |
-
'attr': (attr_labels, attr_confs),
|
203 |
-
'feat': (feat, feat_mask),
|
204 |
-
},
|
205 |
-
is_matched=example.is_matched,
|
206 |
-
ans=ans,
|
207 |
-
)
|
208 |
-
return features
|
209 |
-
|
210 |
-
|
211 |
-
LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
|
212 |
-
|
213 |
-
|
214 |
-
class LXMERT:
|
215 |
-
def __init__(self, max_seq_length):
|
216 |
-
super().__init__()
|
217 |
-
self.max_seq_length = max_seq_length
|
218 |
-
|
219 |
-
self.tokenizer = BertTokenizer.from_pretrained(
|
220 |
-
"bert-base-uncased",
|
221 |
-
do_lower_case=True
|
222 |
-
)
|
223 |
-
|
224 |
-
# Build model
|
225 |
-
set_visual_config(args)
|
226 |
-
self.model = LXRTPretraining.from_pretrained(
|
227 |
-
"bert-base-uncased",
|
228 |
-
task_mask_lm=args.task_mask_lm,
|
229 |
-
task_obj_predict=args.task_obj_predict,
|
230 |
-
task_matched=args.task_matched,
|
231 |
-
task_qa=args.task_qa,
|
232 |
-
visual_losses=args.visual_losses,
|
233 |
-
num_answers=train_tuple.dataset.answer_table.num_answers
|
234 |
-
)
|
235 |
-
|
236 |
-
# Weight initialization and loading
|
237 |
-
if args.from_scratch:
|
238 |
-
print("Train from Scratch: re-initialize all BERT weights.")
|
239 |
-
self.model.apply(self.model.init_bert_weights)
|
240 |
-
if args.load is not None:
|
241 |
-
self.load(args.load)
|
242 |
-
if args.load_lxmert is not None:
|
243 |
-
# Load lxmert would not load the answer head.
|
244 |
-
self.load_lxmert(args.load_lxmert)
|
245 |
-
|
246 |
-
# GPU Options
|
247 |
-
self.model = self.model.cuda()
|
248 |
-
if args.multiGPU:
|
249 |
-
self.model = nn.DataParallel(self.model)
|
250 |
-
|
251 |
-
def forward(self, examples):
|
252 |
-
train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
|
253 |
-
for example in examples]
|
254 |
-
|
255 |
-
# language Inputs
|
256 |
-
input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
|
257 |
-
input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
|
258 |
-
segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
|
259 |
-
|
260 |
-
# Visual Inputs
|
261 |
-
feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
|
262 |
-
pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
|
263 |
-
|
264 |
-
# Language Prediction
|
265 |
-
lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
|
266 |
-
|
267 |
-
# Visual Prediction
|
268 |
-
obj_labels = {}
|
269 |
-
for key in ('obj', 'attr', 'feat'):
|
270 |
-
visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
|
271 |
-
visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
|
272 |
-
assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
|
273 |
-
obj_labels[key] = (visn_labels, visn_mask)
|
274 |
-
|
275 |
-
# Joint Prediction
|
276 |
-
matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
|
277 |
-
ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
|
278 |
-
|
279 |
-
"""
|
280 |
-
forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
281 |
-
visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
|
282 |
-
"""
|
283 |
-
loss, losses, ans_logit = self.model(
|
284 |
-
input_ids, segment_ids, input_mask, lm_labels,
|
285 |
-
feats, pos, obj_labels, matched_labels, ans
|
286 |
-
)
|
287 |
-
return loss, losses.detach().cpu(), ans_logit
|
288 |
-
|
289 |
-
def train_batch(self, optim, batch):
|
290 |
-
optim.zero_grad()
|
291 |
-
loss, losses, ans_logit = self.forward(batch)
|
292 |
-
if args.multiGPU:
|
293 |
-
loss = loss.mean()
|
294 |
-
losses = losses.mean(0)
|
295 |
-
loss.backward()
|
296 |
-
nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
|
297 |
-
optim.step()
|
298 |
-
|
299 |
-
return loss.item(), losses.cpu().numpy(), ans_logit
|
300 |
-
|
301 |
-
def valid_batch(self, batch):
|
302 |
-
with torch.no_grad():
|
303 |
-
loss, losses, ans_logit = self.forward(batch)
|
304 |
-
if args.multiGPU:
|
305 |
-
loss = loss.mean()
|
306 |
-
losses = losses.mean(0)
|
307 |
-
return loss.item(), losses.cpu().numpy(), ans_logit
|
308 |
-
|
309 |
-
def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
|
310 |
-
train_ld = train_tuple.loader
|
311 |
-
|
312 |
-
# Optimizer
|
313 |
-
from lxrt.optimization import BertAdam
|
314 |
-
batch_per_epoch = len(train_ld)
|
315 |
-
t_total = int(batch_per_epoch * args.epochs)
|
316 |
-
warmup_ratio = 0.05
|
317 |
-
warmup_iters = int(t_total * warmup_ratio)
|
318 |
-
print("Batch per epoch: %d" % batch_per_epoch)
|
319 |
-
print("Total Iters: %d" % t_total)
|
320 |
-
print("Warm up Iters: %d" % warmup_iters)
|
321 |
-
optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
|
322 |
-
|
323 |
-
# Train
|
324 |
-
best_eval_loss = 9595.
|
325 |
-
for epoch in range(args.epochs):
|
326 |
-
# Train
|
327 |
-
self.model.train()
|
328 |
-
total_loss = 0.
|
329 |
-
total_losses = 0.
|
330 |
-
uid2ans = {}
|
331 |
-
for batch in tqdm(train_ld, total=len(train_ld)):
|
332 |
-
loss, losses, logit = self.train_batch(optim, batch)
|
333 |
-
total_loss += loss
|
334 |
-
total_losses += losses
|
335 |
-
|
336 |
-
if args.task_qa:
|
337 |
-
score, label = logit.max(1)
|
338 |
-
for datum, l in zip(batch, label.cpu().numpy()):
|
339 |
-
uid = datum.uid
|
340 |
-
ans = train_tuple.dataset.answer_table.id2ans(l)
|
341 |
-
uid2ans[uid] = ans
|
342 |
-
|
343 |
-
print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
|
344 |
-
losses_str = "The losses are "
|
345 |
-
for name, loss in zip(LOSSES_NAME, total_losses):
|
346 |
-
losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
|
347 |
-
print(losses_str)
|
348 |
-
if args.task_qa:
|
349 |
-
train_tuple.evaluator.evaluate(uid2ans, pprint=True)
|
350 |
-
|
351 |
-
# Eval
|
352 |
-
avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
|
353 |
-
|
354 |
-
# Save
|
355 |
-
if avg_eval_loss < best_eval_loss:
|
356 |
-
best_eval_loss = avg_eval_loss
|
357 |
-
self.save("BEST_EVAL_LOSS")
|
358 |
-
self.save("Epoch%02d" % (epoch+1))
|
359 |
-
|
360 |
-
def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
|
361 |
-
self.model.eval()
|
362 |
-
eval_ld = eval_tuple.loader
|
363 |
-
total_loss = 0.
|
364 |
-
total_losses = 0.
|
365 |
-
uid2ans = {}
|
366 |
-
for i, batch in enumerate(eval_ld):
|
367 |
-
loss, losses, logit = self.valid_batch(batch)
|
368 |
-
total_loss += loss
|
369 |
-
total_losses += losses
|
370 |
-
if args.task_qa:
|
371 |
-
score, label = logit.max(1)
|
372 |
-
for datum, l in zip(batch, label.cpu().numpy()):
|
373 |
-
uid = datum.uid
|
374 |
-
ans = train_tuple.dataset.answer_table.id2ans(l)
|
375 |
-
uid2ans[uid] = ans
|
376 |
-
if i == iters:
|
377 |
-
break
|
378 |
-
|
379 |
-
print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
|
380 |
-
losses_str = "The losses are "
|
381 |
-
for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
|
382 |
-
losses_str += "%s: %0.4f " % (name, loss)
|
383 |
-
print(losses_str)
|
384 |
-
|
385 |
-
if args.task_qa:
|
386 |
-
eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
|
387 |
-
|
388 |
-
return total_loss / len(eval_ld)
|
389 |
-
|
390 |
-
def save(self, name):
|
391 |
-
torch.save(self.model.state_dict(),
|
392 |
-
os.path.join(args.output, "%s_LXRT.pth" % name))
|
393 |
-
|
394 |
-
def load(self, path):
|
395 |
-
print("Load BERT extractor from %s" % path)
|
396 |
-
state_dict = torch.load("%s_LXRT.pth" % path)
|
397 |
-
self.model.load_state_dict(state_dict)
|
398 |
-
|
399 |
-
def load_lxmert(self, path):
|
400 |
-
print("Load lxmert model from %s" % path)
|
401 |
-
state_dict = torch.load("%s_LXRT.pth" % path)
|
402 |
-
|
403 |
-
# Do not load any answer head
|
404 |
-
for key in list(state_dict.keys()):
|
405 |
-
if 'answer' in key:
|
406 |
-
state_dict.pop(key)
|
407 |
-
|
408 |
-
# Change Multi GPU to single GPU
|
409 |
-
new_state_dict = {}
|
410 |
-
for key, value in state_dict.items():
|
411 |
-
if key.startswith("module."):
|
412 |
-
new_state_dict[key[len("module."):]] = value
|
413 |
-
state_dict = new_state_dict
|
414 |
-
|
415 |
-
load_keys = set(state_dict.keys())
|
416 |
-
model_keys = set(self.model.state_dict().keys())
|
417 |
-
print()
|
418 |
-
print("Keys in loaded but not in model:")
|
419 |
-
for key in sorted(load_keys.difference(model_keys)):
|
420 |
-
print(key)
|
421 |
-
print()
|
422 |
-
print("Keys in model but not in loaded:")
|
423 |
-
for key in sorted(model_keys.difference(load_keys)):
|
424 |
-
print(key)
|
425 |
-
print()
|
426 |
-
|
427 |
-
self.model.load_state_dict(state_dict, strict=False)
|
428 |
-
|
429 |
-
|
430 |
-
if __name__ == "__main__":
|
431 |
-
|
432 |
-
lxmert = LXMERT(max_seq_length=20)
|
433 |
-
|
434 |
-
|
435 |
-
lxmert.train(train_tuple, valid_tuple)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lxmert/src/pretrain/qa_answer_table.py
DELETED
@@ -1,158 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# Copyleft 2019 project LXRT.
|
3 |
-
|
4 |
-
import json
|
5 |
-
import torch
|
6 |
-
|
7 |
-
|
8 |
-
class AnswerTable:
|
9 |
-
ANS_CONVERT = {
|
10 |
-
"a man": "man",
|
11 |
-
"the man": "man",
|
12 |
-
"a woman": "woman",
|
13 |
-
"the woman": "woman",
|
14 |
-
'one': '1',
|
15 |
-
'two': '2',
|
16 |
-
'three': '3',
|
17 |
-
'four': '4',
|
18 |
-
'five': '5',
|
19 |
-
'six': '6',
|
20 |
-
'seven': '7',
|
21 |
-
'eight': '8',
|
22 |
-
'nine': '9',
|
23 |
-
'ten': '10',
|
24 |
-
'grey': 'gray',
|
25 |
-
}
|
26 |
-
|
27 |
-
def __init__(self, dsets=None):
|
28 |
-
self.all_ans = json.load(open("data/lxmert/all_ans.json"))
|
29 |
-
if dsets is not None:
|
30 |
-
dsets = set(dsets)
|
31 |
-
# If the answer is used in the dsets
|
32 |
-
self.anss = [ans['ans'] for ans in self.all_ans if
|
33 |
-
len(set(ans['dsets']) & dsets) > 0]
|
34 |
-
else:
|
35 |
-
self.anss = [ans['ans'] for ans in self.all_ans]
|
36 |
-
self.ans_set = set(self.anss)
|
37 |
-
|
38 |
-
self._id2ans_map = self.anss
|
39 |
-
self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
|
40 |
-
|
41 |
-
assert len(self._id2ans_map) == len(self._ans2id_map)
|
42 |
-
for ans_id, ans in enumerate(self._id2ans_map):
|
43 |
-
assert self._ans2id_map[ans] == ans_id
|
44 |
-
|
45 |
-
def convert_ans(self, ans):
|
46 |
-
if len(ans) == 0:
|
47 |
-
return ""
|
48 |
-
ans = ans.lower()
|
49 |
-
if ans[-1] == '.':
|
50 |
-
ans = ans[:-1].strip()
|
51 |
-
if ans.startswith("a "):
|
52 |
-
ans = ans[2:].strip()
|
53 |
-
if ans.startswith("an "):
|
54 |
-
ans = ans[3:].strip()
|
55 |
-
if ans.startswith("the "):
|
56 |
-
ans = ans[4:].strip()
|
57 |
-
if ans in self.ANS_CONVERT:
|
58 |
-
ans = self.ANS_CONVERT[ans]
|
59 |
-
return ans
|
60 |
-
|
61 |
-
def ans2id(self, ans):
|
62 |
-
return self._ans2id_map[ans]
|
63 |
-
|
64 |
-
def id2ans(self, ans_id):
|
65 |
-
return self._id2ans_map[ans_id]
|
66 |
-
|
67 |
-
def ans2id_map(self):
|
68 |
-
return self._ans2id_map.copy()
|
69 |
-
|
70 |
-
def id2ans_map(self):
|
71 |
-
return self._id2ans_map.copy()
|
72 |
-
|
73 |
-
def used(self, ans):
|
74 |
-
return ans in self.ans_set
|
75 |
-
|
76 |
-
def all_answers(self):
|
77 |
-
return self.anss.copy()
|
78 |
-
|
79 |
-
@property
|
80 |
-
def num_answers(self):
|
81 |
-
return len(self.anss)
|
82 |
-
|
83 |
-
|
84 |
-
def load_lxmert_qa(path, model, label2ans):
|
85 |
-
"""
|
86 |
-
Load model weights from lxmert pre-training.
|
87 |
-
The answers in the fine-tuned QA task (indicated by label2ans)
|
88 |
-
would also be properly initialized with lxmert pre-trained
|
89 |
-
QA heads.
|
90 |
-
|
91 |
-
:param path: Path to lxmert snapshot.
|
92 |
-
:param model: LXRT model instance.
|
93 |
-
:param label2ans: The label2ans dict of fine-tuned QA datasets, like
|
94 |
-
{0: 'cat', 1: 'dog', ...}
|
95 |
-
:return:
|
96 |
-
"""
|
97 |
-
print("Load QA pre-trained lxmert from %s " % path)
|
98 |
-
loaded_state_dict = torch.load("%s_LXRT.pth" % path)
|
99 |
-
model_state_dict = model.state_dict()
|
100 |
-
|
101 |
-
# Handle Multi-GPU pre-training --> Single GPU fine-tuning
|
102 |
-
for key in list(loaded_state_dict.keys()):
|
103 |
-
loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
|
104 |
-
|
105 |
-
# Isolate bert model
|
106 |
-
bert_state_dict = {}
|
107 |
-
for key, value in loaded_state_dict.items():
|
108 |
-
if key.startswith('bert.'):
|
109 |
-
bert_state_dict[key] = value
|
110 |
-
|
111 |
-
# Isolate answer head
|
112 |
-
answer_state_dict = {}
|
113 |
-
for key, value in loaded_state_dict.items():
|
114 |
-
if key.startswith("answer_head."):
|
115 |
-
answer_state_dict[key.replace('answer_head.', '')] = value
|
116 |
-
|
117 |
-
# Do surgery on answer state dict
|
118 |
-
ans_weight = answer_state_dict['logit_fc.3.weight']
|
119 |
-
ans_bias = answer_state_dict['logit_fc.3.bias']
|
120 |
-
import copy
|
121 |
-
new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
|
122 |
-
new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
|
123 |
-
answer_table = AnswerTable()
|
124 |
-
loaded = 0
|
125 |
-
unload = 0
|
126 |
-
if type(label2ans) is list:
|
127 |
-
label2ans = {label: ans for label, ans in enumerate(label2ans)}
|
128 |
-
for label, ans in label2ans.items():
|
129 |
-
new_ans = answer_table.convert_ans(ans)
|
130 |
-
if answer_table.used(new_ans):
|
131 |
-
ans_id_9500 = answer_table.ans2id(new_ans)
|
132 |
-
new_answer_weight[label] = ans_weight[ans_id_9500]
|
133 |
-
new_answer_bias[label] = ans_bias[ans_id_9500]
|
134 |
-
loaded += 1
|
135 |
-
else:
|
136 |
-
new_answer_weight[label] = 0.
|
137 |
-
new_answer_bias[label] = 0.
|
138 |
-
unload += 1
|
139 |
-
print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
|
140 |
-
print()
|
141 |
-
answer_state_dict['logit_fc.3.weight'] = new_answer_weight
|
142 |
-
answer_state_dict['logit_fc.3.bias'] = new_answer_bias
|
143 |
-
|
144 |
-
# Load Bert Weights
|
145 |
-
bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
|
146 |
-
bert_loaded_keys = set(bert_state_dict.keys())
|
147 |
-
assert len(bert_model_keys - bert_loaded_keys) == 0
|
148 |
-
model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
|
149 |
-
|
150 |
-
# Load Answer Logic FC Weights
|
151 |
-
model_keys = set(model.state_dict().keys())
|
152 |
-
ans_loaded_keys = set(answer_state_dict.keys())
|
153 |
-
assert len(ans_loaded_keys - model_keys) == 0
|
154 |
-
|
155 |
-
model.load_state_dict(answer_state_dict, strict=False)
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|