WwYc commited on
Commit
08d7644
1 Parent(s): a40480d

Upload 61 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. lxmert/.gitignore +3 -0
  2. lxmert/.gitmodules +3 -0
  3. lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  4. lxmert/LICENSE +21 -0
  5. lxmert/__init__.py +0 -0
  6. lxmert/__pycache__/__init__.cpython-38.pyc +0 -0
  7. lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg +0 -0
  8. lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg +0 -0
  9. lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg +0 -0
  10. lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg +0 -0
  11. lxmert/experiments/paper/new.jpg +0 -0
  12. lxmert/perturbation.py +254 -0
  13. lxmert/requirements.txt +107 -0
  14. lxmert/run/README.md +49 -0
  15. lxmert/run/gqa_finetune.bash +17 -0
  16. lxmert/run/gqa_test.bash +15 -0
  17. lxmert/run/lxmert_pretrain.bash +21 -0
  18. lxmert/run/nlvr2_finetune.bash +18 -0
  19. lxmert/run/nlvr2_test.bash +14 -0
  20. lxmert/run/vqa_finetune.bash +17 -0
  21. lxmert/run/vqa_test.bash +16 -0
  22. lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb +81 -0
  23. lxmert/src/ExplanationGenerator.py +665 -0
  24. lxmert/src/__init__.py +0 -0
  25. lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc +0 -0
  26. lxmert/src/__pycache__/__init__.cpython-38.pyc +0 -0
  27. lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc +0 -0
  28. lxmert/src/__pycache__/layers.cpython-38.pyc +0 -0
  29. lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc +0 -0
  30. lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc +0 -0
  31. lxmert/src/__pycache__/processing_image.cpython-38.pyc +0 -0
  32. lxmert/src/__pycache__/vqa_utils.cpython-38.pyc +0 -0
  33. lxmert/src/huggingface_lxmert.py +1472 -0
  34. lxmert/src/layers.py +292 -0
  35. lxmert/src/lxmert_lrp.py +1693 -0
  36. lxmert/src/lxrt/__init__.py +0 -0
  37. lxmert/src/lxrt/entry.py +156 -0
  38. lxmert/src/lxrt/file_utils.py +247 -0
  39. lxmert/src/lxrt/modeling.py +1018 -0
  40. lxmert/src/lxrt/optimization.py +180 -0
  41. lxmert/src/lxrt/tokenization.py +388 -0
  42. lxmert/src/modeling_frcnn.py +1922 -0
  43. lxmert/src/param.py +126 -0
  44. lxmert/src/pretrain/__init__.py +0 -0
  45. lxmert/src/pretrain/lxmert_data.py +255 -0
  46. lxmert/src/pretrain/lxmert_pretrain.py +435 -0
  47. lxmert/src/pretrain/qa_answer_table.py +158 -0
  48. lxmert/src/processing_image.py +147 -0
  49. lxmert/src/tasks/__init__.py +0 -0
  50. lxmert/src/tasks/gqa.py +210 -0
lxmert/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.caffemodel
2
+ *.tsv
3
+ /snap
lxmert/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "data/nlvr2/nlvr"]
2
+ path = data/nlvr2/nlvr
3
+ url = https://github.com/lil-lab/nlvr.git
lxmert/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
lxmert/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Hao Tan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
lxmert/__init__.py ADDED
File without changes
lxmert/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (144 Bytes). View file
 
lxmert/experiments/paper/COCO_val2014_000000127510/COCO_val2014_000000127510.jpg ADDED
lxmert/experiments/paper/COCO_val2014_000000185590/COCO_val2014_000000185590.jpg ADDED
lxmert/experiments/paper/COCO_val2014_000000200717/COCO_val2014_000000200717.jpg ADDED
lxmert/experiments/paper/COCO_val2014_000000324266/COCO_val2014_000000324266.jpg ADDED
lxmert/experiments/paper/new.jpg ADDED
lxmert/perturbation.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxmert.lxmert.src.tasks import vqa_data
2
+ from lxmert.lxmert.src.modeling_frcnn import GeneralizedRCNN
3
+ import lxmert.lxmert.src.vqa_utils as utils
4
+ from lxmert.lxmert.src.processing_image import Preprocess
5
+ from transformers import LxmertTokenizer
6
+ from lxmert.lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
7
+ from lxmert.lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
8
+ from tqdm import tqdm
9
+ from lxmert.lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
10
+ import random
11
+ from lxmert.lxmert.src.param import args
12
+
13
+ OBJ_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt"
14
+ ATTR_URL = "https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt"
15
+ VQA_URL = "https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json"
16
+
17
+ class ModelPert:
18
+ def __init__(self, COCO_val_path, use_lrp=False):
19
+ self.COCO_VAL_PATH = COCO_val_path
20
+ self.vqa_answers = utils.get_data(VQA_URL)
21
+
22
+ # load models and model components
23
+ self.frcnn_cfg = utils.Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
24
+ self.frcnn_cfg.MODEL.DEVICE = "cuda"
25
+
26
+ self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
27
+
28
+ self.image_preprocess = Preprocess(self.frcnn_cfg)
29
+
30
+ self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased")
31
+
32
+ if use_lrp:
33
+ self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
34
+ else:
35
+ self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased").to("cuda")
36
+
37
+ self.lxmert_vqa.eval()
38
+ self.model = self.lxmert_vqa
39
+
40
+ self.vqa_dataset = vqa_data.VQADataset(splits="valid")
41
+
42
+ self.pert_steps = [0, 0.25, 0.5, 0.75, 0.8, 0.85, 0.9, 0.95, 1]
43
+ self.pert_acc = [0] * len(self.pert_steps)
44
+
45
+ def forward(self, item):
46
+ image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
47
+ self.image_file_path = image_file_path
48
+ self.image_id = item['img_id']
49
+ # run frcnn
50
+ images, sizes, scales_yx = self.image_preprocess(image_file_path)
51
+ output_dict = self.frcnn(
52
+ images,
53
+ sizes,
54
+ scales_yx=scales_yx,
55
+ padding="max_detections",
56
+ max_detections= self.frcnn_cfg.max_detections,
57
+ return_tensors="pt"
58
+ )
59
+ inputs = self.lxmert_tokenizer(
60
+ item['sent'],
61
+ truncation=True,
62
+ return_token_type_ids=True,
63
+ return_attention_mask=True,
64
+ add_special_tokens=True,
65
+ return_tensors="pt"
66
+ )
67
+ self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
68
+ self.text_len = len(self.question_tokens)
69
+ # Very important that the boxes are normalized
70
+ normalized_boxes = output_dict.get("normalized_boxes")
71
+ features = output_dict.get("roi_features")
72
+ self.image_boxes_len = features.shape[1]
73
+ self.bboxes = output_dict.get("boxes")
74
+ self.output = self.lxmert_vqa(
75
+ input_ids=inputs.input_ids.to("cuda"),
76
+ attention_mask=inputs.attention_mask.to("cuda"),
77
+ visual_feats=features.to("cuda"),
78
+ visual_pos=normalized_boxes.to("cuda"),
79
+ token_type_ids=inputs.token_type_ids.to("cuda"),
80
+ return_dict=True,
81
+ output_attentions=False,
82
+ )
83
+ return self.output
84
+
85
+ def perturbation_image(self, item, cam_image, cam_text, is_positive_pert=False):
86
+ if is_positive_pert:
87
+ cam_image = cam_image * (-1)
88
+ image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
89
+ # run frcnn
90
+ images, sizes, scales_yx = self.image_preprocess(image_file_path)
91
+ output_dict = self.frcnn(
92
+ images,
93
+ sizes,
94
+ scales_yx=scales_yx,
95
+ padding="max_detections",
96
+ max_detections=self.frcnn_cfg.max_detections,
97
+ return_tensors="pt"
98
+ )
99
+ inputs = self.lxmert_tokenizer(
100
+ item['sent'],
101
+ truncation=True,
102
+ return_token_type_ids=True,
103
+ return_attention_mask=True,
104
+ add_special_tokens=True,
105
+ return_tensors="pt"
106
+ )
107
+ # Very important that the boxes are normalized
108
+ normalized_boxes = output_dict.get("normalized_boxes")
109
+ features = output_dict.get("roi_features")
110
+ for step_idx, step in enumerate(self.pert_steps):
111
+ # find top step boxes
112
+ curr_num_boxes = int((1 - step) * self.image_boxes_len)
113
+ _, top_bboxes_indices = cam_image.topk(k=curr_num_boxes, dim=-1)
114
+ top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
115
+
116
+ curr_features = features[:, top_bboxes_indices, :]
117
+ curr_pos = normalized_boxes[:, top_bboxes_indices, :]
118
+
119
+ output = self.lxmert_vqa(
120
+ input_ids=inputs.input_ids.to("cuda"),
121
+ attention_mask=inputs.attention_mask.to("cuda"),
122
+ visual_feats=curr_features.to("cuda"),
123
+ visual_pos=curr_pos.to("cuda"),
124
+ token_type_ids=inputs.token_type_ids.to("cuda"),
125
+ return_dict=True,
126
+ output_attentions=False,
127
+ )
128
+
129
+ answer = self.vqa_answers[output.question_answering_score.argmax()]
130
+ accuracy = item["label"].get(answer, 0)
131
+ self.pert_acc[step_idx] += accuracy
132
+
133
+ return self.pert_acc
134
+
135
+ def perturbation_text(self, item, cam_image, cam_text, is_positive_pert=False):
136
+ if is_positive_pert:
137
+ cam_text = cam_text * (-1)
138
+ image_file_path = self.COCO_VAL_PATH + item['img_id'] + '.jpg'
139
+ # run frcnn
140
+ images, sizes, scales_yx = self.image_preprocess(image_file_path)
141
+ output_dict = self.frcnn(
142
+ images,
143
+ sizes,
144
+ scales_yx=scales_yx,
145
+ padding="max_detections",
146
+ max_detections=self.frcnn_cfg.max_detections,
147
+ return_tensors="pt"
148
+ )
149
+ inputs = self.lxmert_tokenizer(
150
+ item['sent'],
151
+ truncation=True,
152
+ return_token_type_ids=True,
153
+ return_attention_mask=True,
154
+ add_special_tokens=True,
155
+ return_tensors="pt"
156
+ )
157
+ # Very important that the boxes are normalized
158
+ normalized_boxes = output_dict.get("normalized_boxes")
159
+ features = output_dict.get("roi_features")
160
+ for step_idx, step in enumerate(self.pert_steps):
161
+ # we must keep the [CLS] token in order to have the classification
162
+ # we also keep the [SEP] token
163
+ cam_pure_text = cam_text[1:-1]
164
+ text_len = cam_pure_text.shape[0]
165
+ # find top step tokens, without the [CLS] token and the [SEP] token
166
+ curr_num_tokens = int((1 - step) * text_len)
167
+ _, top_bboxes_indices = cam_pure_text.topk(k=curr_num_tokens, dim=-1)
168
+ top_bboxes_indices = top_bboxes_indices.cpu().data.numpy()
169
+
170
+ # add back [CLS], [SEP] tokens
171
+ top_bboxes_indices = [0, cam_text.shape[0] - 1] +\
172
+ [top_bboxes_indices[i] + 1 for i in range(len(top_bboxes_indices))]
173
+ # text tokens must be sorted for positional embedding to work
174
+ top_bboxes_indices = sorted(top_bboxes_indices)
175
+
176
+ curr_input_ids = inputs.input_ids[:, top_bboxes_indices]
177
+ curr_attention_mask = inputs.attention_mask[:, top_bboxes_indices]
178
+ curr_token_ids = inputs.token_type_ids[:, top_bboxes_indices]
179
+
180
+ output = self.lxmert_vqa(
181
+ input_ids=curr_input_ids.to("cuda"),
182
+ attention_mask=curr_attention_mask.to("cuda"),
183
+ visual_feats=features.to("cuda"),
184
+ visual_pos=normalized_boxes.to("cuda"),
185
+ token_type_ids=curr_token_ids.to("cuda"),
186
+ return_dict=True,
187
+ output_attentions=False,
188
+ )
189
+
190
+ answer = self.vqa_answers[output.question_answering_score.argmax()]
191
+ accuracy = item["label"].get(answer, 0)
192
+ self.pert_acc[step_idx] += accuracy
193
+
194
+ return self.pert_acc
195
+
196
+ def main(args):
197
+ model_pert = ModelPert(args.COCO_path, use_lrp=True)
198
+ ours = GeneratorOurs(model_pert)
199
+ baselines = GeneratorBaselines(model_pert)
200
+ oursNoAggAblation = GeneratorOursAblationNoAggregation(model_pert)
201
+ vqa_dataset = vqa_data.VQADataset(splits="valid")
202
+ vqa_answers = utils.get_data(VQA_URL)
203
+ method_name = args.method
204
+
205
+ items = vqa_dataset.data
206
+ random.seed(1234)
207
+ r = list(range(len(items)))
208
+ random.shuffle(r)
209
+ pert_samples_indices = r[:args.num_samples]
210
+ iterator = tqdm([vqa_dataset.data[i] for i in pert_samples_indices])
211
+
212
+ test_type = "positive" if args.is_positive_pert else "negative"
213
+ modality = "text" if args.is_text_pert else "image"
214
+ print("runnig {0} pert test for {1} modality with method {2}".format(test_type, modality, args.method))
215
+
216
+ for index, item in enumerate(iterator):
217
+ if method_name == 'transformer_att':
218
+ R_t_t, R_t_i = baselines.generate_transformer_attr(item)
219
+ elif method_name == 'attn_gradcam':
220
+ R_t_t, R_t_i = baselines.generate_attn_gradcam(item)
221
+ elif method_name == 'partial_lrp':
222
+ R_t_t, R_t_i = baselines.generate_partial_lrp(item)
223
+ elif method_name == 'raw_attn':
224
+ R_t_t, R_t_i = baselines.generate_raw_attn(item)
225
+ elif method_name == 'rollout':
226
+ R_t_t, R_t_i = baselines.generate_rollout(item)
227
+ elif method_name == "ours_with_lrp_no_normalization":
228
+ R_t_t, R_t_i = ours.generate_ours(item, normalize_self_attention=False)
229
+ elif method_name == "ours_no_lrp":
230
+ R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False)
231
+ elif method_name == "ours_no_lrp_no_norm":
232
+ R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, normalize_self_attention=False)
233
+ elif method_name == "ours_with_lrp":
234
+ R_t_t, R_t_i = ours.generate_ours(item, use_lrp=True)
235
+ elif method_name == "ablation_no_self_in_10":
236
+ R_t_t, R_t_i = ours.generate_ours(item, use_lrp=False, apply_self_in_rule_10=False)
237
+ elif method_name == "ablation_no_aggregation":
238
+ R_t_t, R_t_i = oursNoAggAblation.generate_ours_no_agg(item, use_lrp=False, normalize_self_attention=False)
239
+ else:
240
+ print("Please enter a valid method name")
241
+ return
242
+ cam_image = R_t_i[0]
243
+ cam_text = R_t_t[0]
244
+ cam_image = (cam_image - cam_image.min()) / (cam_image.max() - cam_image.min())
245
+ cam_text = (cam_text - cam_text.min()) / (cam_text.max() - cam_text.min())
246
+ if args.is_text_pert:
247
+ curr_pert_result = model_pert.perturbation_text(item, cam_image, cam_text, args.is_positive_pert)
248
+ else:
249
+ curr_pert_result = model_pert.perturbation_image(item, cam_image, cam_text, args.is_positive_pert)
250
+ curr_pert_result = [round(res / (index+1) * 100, 2) for res in curr_pert_result]
251
+ iterator.set_description("Acc: {}".format(curr_pert_result))
252
+
253
+ if __name__ == "__main__":
254
+ main(args)
lxmert/requirements.txt ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argon2-cffi==20.1.0
2
+ async-generator==1.10
3
+ attrs==20.3.0
4
+ backcall==0.2.0
5
+ bleach==3.3.0
6
+ certifi==2020.12.5
7
+ cffi==1.14.5
8
+ chardet==3.0.4
9
+ click==7.1.2
10
+ cycler==0.10.0
11
+ Cython==0.29.22
12
+ dataclasses==0.6
13
+ decorator==4.4.2
14
+ defusedxml==0.6.0
15
+ demjson==2.2.4
16
+ editdistance==0.5.3
17
+ einops==0.3.0
18
+ entrypoints==0.3
19
+ fasttext==0.9.1
20
+ filelock==3.0.12
21
+ future==0.18.2
22
+ gitdb==4.0.5
23
+ GitPython==3.1.0
24
+ idna==2.10
25
+ imageio==2.9.0
26
+ importlib-metadata==3.4.0
27
+ ipykernel==5.4.3
28
+ ipython==7.20.0
29
+ ipython-genutils==0.2.0
30
+ ipywidgets==7.6.3
31
+ jedi==0.18.0
32
+ Jinja2==2.11.3
33
+ joblib==0.17.0
34
+ jsonschema==3.2.0
35
+ jupyter-client==6.1.11
36
+ jupyter-console==6.2.0
37
+ jupyter-core==4.7.1
38
+ jupyterlab-pygments==0.1.2
39
+ jupyterlab-widgets==1.0.0
40
+ kiwisolver==1.3.1
41
+ lmdb==0.98
42
+ MarkupSafe==1.1.1
43
+ matplotlib==3.3.4
44
+ mistune==0.8.4
45
+ nbclient==0.5.2
46
+ nbconvert==6.0.7
47
+ nbformat==5.1.2
48
+ nest-asyncio==1.5.1
49
+ networkx==2.4
50
+ nltk==3.4.5
51
+ notebook==6.2.0
52
+ numpy==1.19.2
53
+ omegaconf==2.0.1rc4
54
+ opencv-python==4.5.1.48
55
+ packaging==20.9
56
+ pandocfilters==1.4.3
57
+ parso==0.8.1
58
+ pexpect==4.8.0
59
+ pickleshare==0.7.5
60
+ Pillow==8.1.2
61
+ prometheus-client==0.9.0
62
+ prompt-toolkit==3.0.16
63
+ protobuf==3.15.6
64
+ ptyprocess==0.7.0
65
+ pybind11==2.6.2
66
+ pycocotools==2.0.2
67
+ pycparser==2.20
68
+ pyparsing==2.4.7
69
+ pyrsistent==0.17.3
70
+ python-dateutil==2.8.1
71
+ PyWavelets==1.1.1
72
+ PyYAML==5.4.1
73
+ pyzmq==22.0.3
74
+ qtconsole==5.0.2
75
+ QtPy==1.9.0
76
+ regex==2020.11.13
77
+ requests==2.23.0
78
+ sacremoses==0.0.43
79
+ scikit-image==0.17.2
80
+ scikit-learn==0.23.2
81
+ scipy==1.6.1
82
+ Send2Trash==1.5.0
83
+ sentencepiece==0.1.91
84
+ six==1.15.0
85
+ sklearn==0.0
86
+ smmap==3.0.5
87
+ termcolor==1.1.0
88
+ terminado==0.9.2
89
+ testpath==0.4.4
90
+ threadpoolctl==2.1.0
91
+ tifffile==2021.2.1
92
+ tokenizers==0.9.3
93
+ torch==1.7.1
94
+ torchtext==0.5.0
95
+ torchvision==0.8.2
96
+ tornado==6.1
97
+ tqdm==4.51.0
98
+ traitlets==5.0.5
99
+ transformers==3.5.1
100
+ typing-extensions==3.7.4.3
101
+ urllib3==1.25.11
102
+ utils==1.0.1
103
+ wcwidth==0.2.5
104
+ webencodings==0.5.1
105
+ wget==3.2
106
+ widgetsnbextension==3.5.1
107
+ zipp==3.4.0
lxmert/run/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Script Arguments
2
+
3
+ ```
4
+ Data Splits:
5
+ --train [str,str,...]: use the splits (separated by comma) in training.
6
+ --valid [str,str,...]: use the splits (separated by comma) in validation.
7
+ --test [str,str,...]: use the splits (separated by comma) in testing.
8
+ Model Architecture:
9
+ --llayers [int]: number of layers in language encoder.
10
+ --xlayers [int]: number of layers in cross-modality encoder.
11
+ --rlayers [int]: number of layers in object relationship encoder.
12
+ Load Weights:
13
+ --load [str='path/to/saved_model']: load fine-tuned model path/to/saved_model.pth.
14
+ --loadLXMERT [str='path/to/saved_model']: load pre-trained model without answer heads from path/to/saved_model_LXRT.pth.
15
+ --loadLXMERTQA [str='path/to/saved_model']: load pre-trained model with answer head path/to/saved_model_LXRT.pth.
16
+ --fromScratch: If none of the above loading parameters are set, the default mode would
17
+ load the pre-trained BERT weights.
18
+ As we promised to EMNLP reviewers, the language encoder would be re-initialized with this one-line argument to test the performance without BERT weights.
19
+ Training Hyper Parameters:
20
+ --batchSize [int]: batch size.
21
+ --optim [str]: optimizers.
22
+ --lr [float]: peak learning rate.
23
+ --epochs [int]: training epochs.
24
+ Debugging:
25
+ --tiny: Load 512 images for each data split. (Note: number of images might be changed due to dataset specification)
26
+ --fast: Load 5000 images for each data split. (Note: number of images might be changed due to dataset specification)
27
+ ```
28
+
29
+ # Pre-training-Specific Arguments
30
+ ```
31
+ Pre-training Tasks:
32
+ --taskMaskLM: use the masked language model task.
33
+ --taskObjPredict: use the masked object prediction task.
34
+ --taskMatched: use the cross-modality matched task.
35
+ --taskQA: use the image QA task.
36
+ Visual Pre-training Losses (Tasks):
37
+ --visualLosses [str,str,...]: The sub-tasks in pre-training visual modality. Each one is from 'obj,attr,feat'.
38
+ obj: detected-object-label classification.
39
+ attr: detected-object-attribute classification.
40
+ feat: RoI-feature regression.
41
+ Mask Rate in Pre-training:
42
+ --wordMaskRate [float]: The prob of masking a word.
43
+ --objMaskRate [float]: The prob of masking an object.
44
+ Initialization:
45
+ --fromScratch: The default mode would load the pre-trained BERT weights into the model.
46
+ As we promised to EMNLP reviewers, this option would re-initialize the language encoder.
47
+ ```
48
+
49
+
lxmert/run/gqa_finetune.bash ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/gqa; make backup.
5
+ output=snap/gqa/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/gqa.py \
13
+ --train train,valid --valid testdev \
14
+ --llayers 9 --xlayers 5 --rlayers 5 \
15
+ --loadLXMERTQA snap/pretrained/model \
16
+ --batchSize 32 --optim bert --lr 1e-5 --epochs 4 \
17
+ --tqdm --output $output ${@:3}
lxmert/run/gqa_test.bash ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/gqa; make backup.
5
+ output=snap/gqa/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/gqa.py \
13
+ --tiny --train train --valid "" \
14
+ --llayers 9 --xlayers 5 --rlayers 5 \
15
+ --tqdm --output $output ${@:3}
lxmert/run/lxmert_pretrain.bash ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of experiment
2
+ name=lxmert
3
+
4
+ # Create dirs and make backup
5
+ output=snap/pretrain/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # Pre-training
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/pretrain/lxmert_pretrain.py \
13
+ --taskMaskLM --taskObjPredict --taskMatched --taskQA \
14
+ --visualLosses obj,attr,feat \
15
+ --wordMaskRate 0.15 --objMaskRate 0.15 \
16
+ --train mscoco_train,mscoco_nominival,vgnococo --valid mscoco_minival \
17
+ --llayers 9 --xlayers 5 --rlayers 5 \
18
+ --fromScratch \
19
+ --batchSize 256 --optim bert --lr 1e-4 --epochs 20 \
20
+ --tqdm --output $output ${@:2}
21
+
lxmert/run/nlvr2_finetune.bash ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/nlvr2; Make backup.
5
+ output=snap/nlvr2/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See run/Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/nlvr2.py \
13
+ --train train --valid valid \
14
+ --llayers 9 --xlayers 5 --rlayers 5 \
15
+ --loadLXMERT snap/pretrained/model \
16
+ --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
17
+ --tqdm --output $output ${@:3}
18
+
lxmert/run/nlvr2_test.bash ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/nlvr2; make backup.
5
+ output=snap/nlvr2/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/nlvr2.py \
13
+ --tiny --llayers 9 --xlayers 5 --rlayers 5 \
14
+ --tqdm --output $output ${@:3}
lxmert/run/vqa_finetune.bash ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/vqa; make backup.
5
+ output=snap/vqa/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/vqa.py \
13
+ --train train,nominival --valid minival \
14
+ --llayers 9 --xlayers 5 --rlayers 5 \
15
+ --loadLXMERTQA snap/pretrained/model \
16
+ --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
17
+ --tqdm --output $output ${@:3}
lxmert/run/vqa_test.bash ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The name of this experiment.
2
+ name=$2
3
+
4
+ # Save logs and models under snap/vqa; make backup.
5
+ output=snap/vqa/$name
6
+ mkdir -p $output/src
7
+ cp -r src/* $output/src/
8
+ cp $0 $output/run.bash
9
+
10
+ # See Readme.md for option details.
11
+ CUDA_VISIBLE_DEVICES=$1 PYTHONPATH=$PYTHONPATH:./src \
12
+ python src/tasks/vqa.py \
13
+ --tiny --train train --valid "" \
14
+ --llayers 9 --xlayers 5 --rlayers 5 \
15
+ --batchSize 32 --optim bert --lr 5e-5 --epochs 4 \
16
+ --tqdm --output $output ${@:3}
lxmert/src/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "id": "loose-wrong",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "ename": "ModuleNotFoundError",
11
+ "evalue": "No module named 'src'",
12
+ "output_type": "error",
13
+ "traceback": [
14
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
15
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
16
+ "\u001b[0;32m<ipython-input-7-b03239bcd702>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mlxmert_lrp\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLxmertForQuestionAnswering\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mLxmertForQuestionAnsweringLRP\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtasks\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mvqa_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodeling_frcnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGeneralizedRCNN\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvqa_utils\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mutils\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocessing_image\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPreprocess\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
17
+ "\u001b[0;32m/media/data2/hila_chefer/lxmert/lxmert/src/lxmert_lrp.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSmoothL1Loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m from transformers.file_utils import (\n\u001b[1;32m 29\u001b[0m \u001b[0mModelOutput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
18
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'src'"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP\n",
24
+ "from src.tasks import vqa_data\n",
25
+ "from src.modeling_frcnn import GeneralizedRCNN\n",
26
+ "import src.vqa_utils as utils\n",
27
+ "from src.processing_image import Preprocess\n",
28
+ "from transformers import LxmertTokenizer\n",
29
+ "from src.huggingface_lxmert import LxmertForQuestionAnswering\n",
30
+ "\n",
31
+ "from tqdm import tqdm\n",
32
+ "from src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines\n",
33
+ "import random\n",
34
+ "import cv2\n",
35
+ "\n",
36
+ "COCO_VAL_PATH = '/media/data2/hila_chefer/env_MMF/datasets/coco/subset_val/images/val2014/'\n",
37
+ "\n",
38
+ "OBJ_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/objects_vocab.txt\"\n",
39
+ "ATTR_URL = \"https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/genome/1600-400-20/attributes_vocab.txt\"\n",
40
+ "VQA_URL = \"https://raw.githubusercontent.com/airsplay/lxmert/master/data/vqa/trainval_label2ans.json\""
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "emerging-trace",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "royal-small",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": []
58
+ }
59
+ ],
60
+ "metadata": {
61
+ "kernelspec": {
62
+ "display_name": "Python 3",
63
+ "language": "python",
64
+ "name": "python3"
65
+ },
66
+ "language_info": {
67
+ "codemirror_mode": {
68
+ "name": "ipython",
69
+ "version": 3
70
+ },
71
+ "file_extension": ".py",
72
+ "mimetype": "text/x-python",
73
+ "name": "python",
74
+ "nbconvert_exporter": "python",
75
+ "pygments_lexer": "ipython3",
76
+ "version": "3.7.9"
77
+ }
78
+ },
79
+ "nbformat": 4,
80
+ "nbformat_minor": 5
81
+ }
lxmert/src/ExplanationGenerator.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import copy
4
+
5
+
6
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
7
+ # adding residual consideration
8
+ num_tokens = all_layer_matrices[0].shape[1]
9
+ eye = torch.eye(num_tokens).to(all_layer_matrices[0].device)
10
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
11
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
12
+ for i in range(len(all_layer_matrices))]
13
+ joint_attention = matrices_aug[start_layer]
14
+ for i in range(start_layer + 1, len(matrices_aug)):
15
+ joint_attention = matrices_aug[i].matmul(joint_attention)
16
+ return joint_attention
17
+
18
+
19
+ # rule 5 from paper
20
+ def avg_heads(cam, grad):
21
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
22
+ grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
23
+ cam = grad * cam
24
+ cam = cam.clamp(min=0).mean(dim=0)
25
+ return cam
26
+
27
+
28
+ # rules 6 + 7 from paper
29
+ def apply_self_attention_rules(R_ss, R_sq, cam_ss):
30
+ R_sq_addition = torch.matmul(cam_ss, R_sq)
31
+ R_ss_addition = torch.matmul(cam_ss, R_ss)
32
+ return R_ss_addition, R_sq_addition
33
+
34
+
35
+ # rules 10 + 11 from paper
36
+ def apply_mm_attention_rules(R_ss, R_qq, R_qs, cam_sq, apply_normalization=True, apply_self_in_rule_10=True):
37
+ R_ss_normalized = R_ss
38
+ R_qq_normalized = R_qq
39
+ if apply_normalization:
40
+ R_ss_normalized = handle_residual(R_ss)
41
+ R_qq_normalized = handle_residual(R_qq)
42
+ R_sq_addition = torch.matmul(R_ss_normalized.t(), torch.matmul(cam_sq, R_qq_normalized))
43
+ if not apply_self_in_rule_10:
44
+ R_sq_addition = cam_sq
45
+ R_ss_addition = torch.matmul(cam_sq, R_qs)
46
+ return R_sq_addition, R_ss_addition
47
+
48
+
49
+ # normalization- eq. 8+9
50
+ def handle_residual(orig_self_attention):
51
+ self_attention = orig_self_attention.clone()
52
+ diag_idx = range(self_attention.shape[-1])
53
+ # computing R hat
54
+ self_attention -= torch.eye(self_attention.shape[-1]).to(self_attention.device)
55
+ assert self_attention[diag_idx, diag_idx].min() >= 0
56
+ # normalizing R hat
57
+ self_attention = self_attention / self_attention.sum(dim=-1, keepdim=True)
58
+ self_attention += torch.eye(self_attention.shape[-1]).to(self_attention.device)
59
+ return self_attention
60
+
61
+
62
+ class GeneratorOurs:
63
+ def __init__(self, model_usage, save_visualization=False):
64
+ self.model_usage = model_usage
65
+ self.save_visualization = save_visualization
66
+
67
+ def handle_self_attention_lang(self, blocks):
68
+ for blk in blocks:
69
+ grad = blk.attention.self.get_attn_gradients().detach()
70
+ if self.use_lrp:
71
+ cam = blk.attention.self.get_attn_cam().detach()
72
+ else:
73
+ cam = blk.attention.self.get_attn().detach()
74
+ cam = avg_heads(cam, grad)
75
+ R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
76
+ self.R_t_t += R_t_t_add
77
+ self.R_t_i += R_t_i_add
78
+
79
+ def handle_self_attention_image(self, blocks):
80
+ for blk in blocks:
81
+ grad = blk.attention.self.get_attn_gradients().detach()
82
+ if self.use_lrp:
83
+ cam = blk.attention.self.get_attn_cam().detach()
84
+ else:
85
+ cam = blk.attention.self.get_attn().detach()
86
+ cam = avg_heads(cam, grad)
87
+ R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
88
+ self.R_i_i += R_i_i_add
89
+ self.R_i_t += R_i_t_add
90
+
91
+ def handle_co_attn_self_lang(self, block):
92
+ grad = block.lang_self_att.self.get_attn_gradients().detach()
93
+ if self.use_lrp:
94
+ cam = block.lang_self_att.self.get_attn_cam().detach()
95
+ else:
96
+ cam = block.lang_self_att.self.get_attn().detach()
97
+ cam = avg_heads(cam, grad)
98
+ R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
99
+ self.R_t_t += R_t_t_add
100
+ self.R_t_i += R_t_i_add
101
+
102
+ def handle_co_attn_self_image(self, block):
103
+ grad = block.visn_self_att.self.get_attn_gradients().detach()
104
+ if self.use_lrp:
105
+ cam = block.visn_self_att.self.get_attn_cam().detach()
106
+ else:
107
+ cam = block.visn_self_att.self.get_attn().detach()
108
+ cam = avg_heads(cam, grad)
109
+ R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
110
+ self.R_i_i += R_i_i_add
111
+ self.R_i_t += R_i_t_add
112
+
113
+ def handle_co_attn_lang(self, block):
114
+ if self.use_lrp:
115
+ cam_t_i = block.visual_attention.att.get_attn_cam().detach()
116
+ else:
117
+ cam_t_i = block.visual_attention.att.get_attn().detach()
118
+ grad_t_i = block.visual_attention.att.get_attn_gradients().detach()
119
+ cam_t_i = avg_heads(cam_t_i, grad_t_i)
120
+ R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i,
121
+ apply_normalization=self.normalize_self_attention,
122
+ apply_self_in_rule_10=self.apply_self_in_rule_10)
123
+ return R_t_i_addition, R_t_t_addition
124
+
125
+ def handle_co_attn_image(self, block):
126
+ if self.use_lrp:
127
+ cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach()
128
+ else:
129
+ cam_i_t = block.visual_attention_copy.att.get_attn().detach()
130
+ grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach()
131
+ cam_i_t = avg_heads(cam_i_t, grad_i_t)
132
+ R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t,
133
+ apply_normalization=self.normalize_self_attention,
134
+ apply_self_in_rule_10=self.apply_self_in_rule_10)
135
+ return R_i_t_addition, R_i_i_addition
136
+
137
+ def generate_ours(self, input, index=None, use_lrp=True, normalize_self_attention=True, apply_self_in_rule_10=True,
138
+ method_name="ours"):
139
+ self.use_lrp = use_lrp
140
+ self.normalize_self_attention = normalize_self_attention
141
+ self.apply_self_in_rule_10 = apply_self_in_rule_10
142
+ kwargs = {"alpha": 1}
143
+ output = self.model_usage.forward(input).question_answering_score
144
+ model = self.model_usage.model
145
+
146
+ # initialize relevancy matrices
147
+ text_tokens = self.model_usage.text_len
148
+ image_bboxes = self.model_usage.image_boxes_len
149
+
150
+ # text self attention matrix
151
+ self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
152
+ # image self attention matrix
153
+ self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
154
+ # impact of images on text
155
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
156
+ # impact of text on images
157
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
158
+
159
+ if index is None:
160
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
161
+
162
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
163
+ one_hot[0, index] = 1
164
+ one_hot_vector = one_hot
165
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
166
+ one_hot = torch.sum(one_hot.cuda() * output)
167
+
168
+ model.zero_grad()
169
+ one_hot.backward(retain_graph=True)
170
+ if self.use_lrp:
171
+ model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
172
+
173
+ # language self attention
174
+ blocks = model.lxmert.encoder.layer
175
+ self.handle_self_attention_lang(blocks)
176
+
177
+ # image self attention
178
+ blocks = model.lxmert.encoder.r_layers
179
+ self.handle_self_attention_image(blocks)
180
+
181
+ # cross attn layers
182
+ blocks = model.lxmert.encoder.x_layers
183
+ for i, blk in enumerate(blocks):
184
+ # in the last cross attention module, only the text cross modal
185
+ # attention has an impact on the CLS token, since it's the first
186
+ # token in the language tokens
187
+ if i == len(blocks) - 1:
188
+ break
189
+ # cross attn- first for language then for image
190
+ R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
191
+ R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk)
192
+
193
+ self.R_t_i += R_t_i_addition
194
+ self.R_t_t += R_t_t_addition
195
+ self.R_i_t += R_i_t_addition
196
+ self.R_i_i += R_i_i_addition
197
+
198
+ # language self attention
199
+ self.handle_co_attn_self_lang(blk)
200
+
201
+ # image self attention
202
+ self.handle_co_attn_self_image(blk)
203
+
204
+ # take care of last cross attention layer- only text
205
+ blk = model.lxmert.encoder.x_layers[-1]
206
+ # cross attn- first for language then for image
207
+ R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
208
+ self.R_t_i += R_t_i_addition
209
+ self.R_t_t += R_t_t_addition
210
+
211
+ # language self attention
212
+ self.handle_co_attn_self_lang(blk)
213
+
214
+ # disregard the [CLS] token itself
215
+ self.R_t_t[0, 0] = 0
216
+ return self.R_t_t, self.R_t_i
217
+
218
+
219
+ class GeneratorOursAblationNoAggregation:
220
+ def __init__(self, model_usage, save_visualization=False):
221
+ self.model_usage = model_usage
222
+ self.save_visualization = save_visualization
223
+
224
+ def handle_self_attention_lang(self, blocks):
225
+ for blk in blocks:
226
+ grad = blk.attention.self.get_attn_gradients().detach()
227
+ if self.use_lrp:
228
+ cam = blk.attention.self.get_attn_cam().detach()
229
+ else:
230
+ cam = blk.attention.self.get_attn().detach()
231
+ cam = avg_heads(cam, grad)
232
+ R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
233
+ self.R_t_t = R_t_t_add
234
+ self.R_t_i = R_t_i_add
235
+
236
+ def handle_self_attention_image(self, blocks):
237
+ for blk in blocks:
238
+ grad = blk.attention.self.get_attn_gradients().detach()
239
+ if self.use_lrp:
240
+ cam = blk.attention.self.get_attn_cam().detach()
241
+ else:
242
+ cam = blk.attention.self.get_attn().detach()
243
+ cam = avg_heads(cam, grad)
244
+ R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
245
+ self.R_i_i = R_i_i_add
246
+ self.R_i_t = R_i_t_add
247
+
248
+ def handle_co_attn_self_lang(self, block):
249
+ grad = block.lang_self_att.self.get_attn_gradients().detach()
250
+ if self.use_lrp:
251
+ cam = block.lang_self_att.self.get_attn_cam().detach()
252
+ else:
253
+ cam = block.lang_self_att.self.get_attn().detach()
254
+ cam = avg_heads(cam, grad)
255
+ R_t_t_add, R_t_i_add = apply_self_attention_rules(self.R_t_t, self.R_t_i, cam)
256
+ self.R_t_t = R_t_t_add
257
+ self.R_t_i = R_t_i_add
258
+
259
+ def handle_co_attn_self_image(self, block):
260
+ grad = block.visn_self_att.self.get_attn_gradients().detach()
261
+ if self.use_lrp:
262
+ cam = block.visn_self_att.self.get_attn_cam().detach()
263
+ else:
264
+ cam = block.visn_self_att.self.get_attn().detach()
265
+ cam = avg_heads(cam, grad)
266
+ R_i_i_add, R_i_t_add = apply_self_attention_rules(self.R_i_i, self.R_i_t, cam)
267
+ self.R_i_i = R_i_i_add
268
+ self.R_i_t = R_i_t_add
269
+
270
+ def handle_co_attn_lang(self, block):
271
+ if self.use_lrp:
272
+ cam_t_i = block.visual_attention.att.get_attn_cam().detach()
273
+ else:
274
+ cam_t_i = block.visual_attention.att.get_attn().detach()
275
+ grad_t_i = block.visual_attention.att.get_attn_gradients().detach()
276
+ cam_t_i = avg_heads(cam_t_i, grad_t_i)
277
+ R_t_i_addition, R_t_t_addition = apply_mm_attention_rules(self.R_t_t, self.R_i_i, self.R_i_t, cam_t_i,
278
+ apply_normalization=self.normalize_self_attention)
279
+ return R_t_i_addition, R_t_t_addition
280
+
281
+ def handle_co_attn_image(self, block):
282
+ if self.use_lrp:
283
+ cam_i_t = block.visual_attention_copy.att.get_attn_cam().detach()
284
+ else:
285
+ cam_i_t = block.visual_attention_copy.att.get_attn().detach()
286
+ grad_i_t = block.visual_attention_copy.att.get_attn_gradients().detach()
287
+ cam_i_t = avg_heads(cam_i_t, grad_i_t)
288
+ R_i_t_addition, R_i_i_addition = apply_mm_attention_rules(self.R_i_i, self.R_t_t, self.R_t_i, cam_i_t,
289
+ apply_normalization=self.normalize_self_attention)
290
+ return R_i_t_addition, R_i_i_addition
291
+
292
+ def generate_ours_no_agg(self, input, index=None, use_lrp=False, normalize_self_attention=True,
293
+ method_name="ours_no_agg"):
294
+ self.use_lrp = use_lrp
295
+ self.normalize_self_attention = normalize_self_attention
296
+ kwargs = {"alpha": 1}
297
+ output = self.model_usage.forward(input).question_answering_score
298
+ model = self.model_usage.model
299
+
300
+ # initialize relevancy matrices
301
+ text_tokens = self.model_usage.text_len
302
+ image_bboxes = self.model_usage.image_boxes_len
303
+
304
+ # text self attention matrix
305
+ self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
306
+ # image self attention matrix
307
+ self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
308
+ # impact of images on text
309
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
310
+ # impact of text on images
311
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
312
+
313
+ if index is None:
314
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
315
+
316
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
317
+ one_hot[0, index] = 1
318
+ one_hot_vector = one_hot
319
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
320
+ one_hot = torch.sum(one_hot.cuda() * output)
321
+
322
+ model.zero_grad()
323
+ one_hot.backward(retain_graph=True)
324
+ if self.use_lrp:
325
+ model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
326
+
327
+ # language self attention
328
+ blocks = model.lxmert.encoder.layer
329
+ self.handle_self_attention_lang(blocks)
330
+
331
+ # image self attention
332
+ blocks = model.lxmert.encoder.r_layers
333
+ self.handle_self_attention_image(blocks)
334
+
335
+ # cross attn layers
336
+ blocks = model.lxmert.encoder.x_layers
337
+ for i, blk in enumerate(blocks):
338
+ # in the last cross attention module, only the text cross modal
339
+ # attention has an impact on the CLS token, since it's the first
340
+ # token in the language tokens
341
+ if i == len(blocks) - 1:
342
+ break
343
+ # cross attn- first for language then for image
344
+ R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
345
+ R_i_t_addition, R_i_i_addition = self.handle_co_attn_image(blk)
346
+
347
+ self.R_t_i = R_t_i_addition
348
+ self.R_t_t = R_t_t_addition
349
+ self.R_i_t = R_i_t_addition
350
+ self.R_i_i = R_i_i_addition
351
+
352
+ # language self attention
353
+ self.handle_co_attn_self_lang(blk)
354
+
355
+ # image self attention
356
+ self.handle_co_attn_self_image(blk)
357
+
358
+ # take care of last cross attention layer- only text
359
+ blk = model.lxmert.encoder.x_layers[-1]
360
+ # cross attn- first for language then for image
361
+ R_t_i_addition, R_t_t_addition = self.handle_co_attn_lang(blk)
362
+ self.R_t_i = R_t_i_addition
363
+ self.R_t_t = R_t_t_addition
364
+
365
+ # language self attention
366
+ self.handle_co_attn_self_lang(blk)
367
+
368
+ # disregard the [CLS] token itself
369
+ self.R_t_t[0, 0] = 0
370
+ return self.R_t_t, self.R_t_i
371
+
372
+
373
+ class GeneratorBaselines:
374
+ def __init__(self, model_usage, save_visualization=False):
375
+ self.model_usage = model_usage
376
+ self.save_visualization = save_visualization
377
+
378
+ def generate_transformer_attr(self, input, index=None, method_name="transformer_attr"):
379
+ kwargs = {"alpha": 1}
380
+ output = self.model_usage.forward(input).question_answering_score
381
+ model = self.model_usage.model
382
+
383
+ # initialize relevancy matrices
384
+ text_tokens = self.model_usage.text_len
385
+ image_bboxes = self.model_usage.image_boxes_len
386
+
387
+ # text self attention matrix
388
+ self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
389
+ # image self attention matrix
390
+ self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
391
+ # impact of images on text
392
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
393
+ # impact of text on images
394
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
395
+
396
+ if index == None:
397
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
398
+
399
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
400
+ one_hot[0, index] = 1
401
+ one_hot_vector = one_hot
402
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
403
+ one_hot = torch.sum(one_hot.cuda() * output)
404
+
405
+ model.zero_grad()
406
+ one_hot.backward(retain_graph=True)
407
+ model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
408
+
409
+ # language self attention
410
+ blocks = model.lxmert.encoder.layer
411
+ for blk in blocks:
412
+ grad = blk.attention.self.get_attn_gradients().detach()
413
+ cam = blk.attention.self.get_attn_cam().detach()
414
+ cam = avg_heads(cam, grad)
415
+ self.R_t_t += torch.matmul(cam, self.R_t_t)
416
+
417
+ # image self attention
418
+ blocks = model.lxmert.encoder.r_layers
419
+ for blk in blocks:
420
+ grad = blk.attention.self.get_attn_gradients().detach()
421
+ cam = blk.attention.self.get_attn_cam().detach()
422
+ cam = avg_heads(cam, grad)
423
+ self.R_i_i += torch.matmul(cam, self.R_i_i)
424
+
425
+ # cross attn layers
426
+ blocks = model.lxmert.encoder.x_layers
427
+ for i, blk in enumerate(blocks):
428
+ # in the last cross attention module, only the text cross modal
429
+ # attention has an impact on the CLS token, since it's the first
430
+ # token in the language tokens
431
+ if i == len(blocks) - 1:
432
+ break
433
+
434
+ # language self attention
435
+ grad = blk.lang_self_att.self.get_attn_gradients().detach()
436
+ cam = blk.lang_self_att.self.get_attn_cam().detach()
437
+ cam = avg_heads(cam, grad)
438
+ self.R_t_t += torch.matmul(cam, self.R_t_t)
439
+
440
+ # image self attention
441
+ grad = blk.visn_self_att.self.get_attn_gradients().detach()
442
+ cam = blk.visn_self_att.self.get_attn_cam().detach()
443
+ cam = avg_heads(cam, grad)
444
+ self.R_i_i += torch.matmul(cam, self.R_i_i)
445
+
446
+ # take care of last cross attention layer- only text
447
+ blk = model.lxmert.encoder.x_layers[-1]
448
+ # cross attn cam will be the one used for the R_t_i matrix
449
+ cam_t_i = blk.visual_attention.att.get_attn_cam().detach()
450
+ grad_t_i = blk.visual_attention.att.get_attn_gradients().detach()
451
+ cam_t_i = avg_heads(cam_t_i, grad_t_i)
452
+ # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
453
+ self.R_t_i = cam_t_i
454
+
455
+ # language self attention
456
+ grad = blk.lang_self_att.self.get_attn_gradients().detach()
457
+ cam = blk.lang_self_att.self.get_attn_cam().detach()
458
+ cam = avg_heads(cam, grad)
459
+ self.R_t_t += torch.matmul(cam, self.R_t_t)
460
+
461
+ self.R_t_t[0, 0] = 0
462
+ return self.R_t_t, self.R_t_i
463
+
464
+ def generate_partial_lrp(self, input, index=None, method_name="partial_lrp"):
465
+ kwargs = {"alpha": 1}
466
+ output = self.model_usage.forward(input).question_answering_score
467
+ model = self.model_usage.model
468
+
469
+ # initialize relevancy matrices
470
+ text_tokens = self.model_usage.text_len
471
+ image_bboxes = self.model_usage.image_boxes_len
472
+
473
+ # text self attention matrix
474
+ self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device)
475
+ # image self attention matrix
476
+ self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device)
477
+ # impact of images on text
478
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
479
+ # impact of text on images
480
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
481
+
482
+ if index == None:
483
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
484
+
485
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
486
+ one_hot[0, index] = 1
487
+ one_hot_vector = one_hot
488
+ model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)
489
+
490
+ # last cross attention + self- attention layer
491
+ blk = model.lxmert.encoder.x_layers[-1]
492
+ # cross attn cam will be the one used for the R_t_i matrix
493
+ cam_t_i = blk.visual_attention.att.get_attn_cam().detach()
494
+ cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
495
+ self.R_t_i = cam_t_i
496
+
497
+ # language self attention
498
+ cam = blk.lang_self_att.self.get_attn_cam().detach()
499
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
500
+ self.R_t_t = cam
501
+
502
+ # normalize to get non-negative cams
503
+ self.R_t_t = (self.R_t_t - self.R_t_t.min()) / (self.R_t_t.max() - self.R_t_t.min())
504
+ self.R_t_i = (self.R_t_i - self.R_t_i.min()) / (self.R_t_i.max() - self.R_t_i.min())
505
+ # disregard the [CLS] token itself
506
+ self.R_t_t[0, 0] = 0
507
+ return self.R_t_t, self.R_t_i
508
+
509
+ def generate_raw_attn(self, input, method_name="raw_attention"):
510
+ output = self.model_usage.forward(input).question_answering_score
511
+ model = self.model_usage.model
512
+
513
+ # initialize relevancy matrices
514
+ text_tokens = self.model_usage.text_len
515
+ image_bboxes = self.model_usage.image_boxes_len
516
+
517
+ # text self attention matrix
518
+ self.R_t_t = torch.zeros(text_tokens, text_tokens).to(model.device)
519
+ # image self attention matrix
520
+ self.R_i_i = torch.zeros(image_bboxes, image_bboxes).to(model.device)
521
+ # impact of images on text
522
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
523
+ # impact of text on images
524
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
525
+
526
+ # last cross attention + self- attention layer
527
+ blk = model.lxmert.encoder.x_layers[-1]
528
+ # cross attn cam will be the one used for the R_t_i matrix
529
+ cam_t_i = blk.visual_attention.att.get_attn().detach()
530
+ cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
531
+ # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
532
+ self.R_t_i = cam_t_i
533
+
534
+ # language self attention
535
+ cam = blk.lang_self_att.self.get_attn().detach()
536
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
537
+ self.R_t_t = cam
538
+
539
+ # disregard the [CLS] token itself
540
+ self.R_t_t[0, 0] = 0
541
+ return self.R_t_t, self.R_t_i
542
+
543
+ def gradcam(self, cam, grad):
544
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
545
+ grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
546
+ grad = grad.mean(dim=[1, 2], keepdim=True)
547
+ cam = (cam * grad).mean(0).clamp(min=0)
548
+ return cam
549
+
550
+ def generate_attn_gradcam(self, input, index=None, method_name="gradcam"):
551
+ output = self.model_usage.forward(input).question_answering_score
552
+ model = self.model_usage.model
553
+
554
+ # initialize relevancy matrices
555
+ text_tokens = self.model_usage.text_len
556
+ image_bboxes = self.model_usage.image_boxes_len
557
+
558
+ # text self attention matrix
559
+ self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
560
+ # image self attention matrix
561
+ self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
562
+ # impact of images on text
563
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
564
+ # impact of text on images
565
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
566
+
567
+ if index == None:
568
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
569
+
570
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
571
+ one_hot[0, index] = 1
572
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
573
+ one_hot = torch.sum(one_hot.cuda() * output)
574
+
575
+ model.zero_grad()
576
+ one_hot.backward(retain_graph=True)
577
+
578
+ # last cross attention + self- attention layer
579
+ blk = model.lxmert.encoder.x_layers[-1]
580
+ # cross attn cam will be the one used for the R_t_i matrix
581
+ grad_t_i = blk.visual_attention.att.get_attn_gradients().detach()
582
+ cam_t_i = blk.visual_attention.att.get_attn().detach()
583
+ cam_t_i = self.gradcam(cam_t_i, grad_t_i)
584
+ # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
585
+ self.R_t_i = cam_t_i
586
+
587
+ # language self attention
588
+ grad = blk.lang_self_att.self.get_attn_gradients().detach()
589
+ cam = blk.lang_self_att.self.get_attn().detach()
590
+ self.R_t_t = self.gradcam(cam, grad)
591
+
592
+ # disregard the [CLS] token itself
593
+ self.R_t_t[0, 0] = 0
594
+ return self.R_t_t, self.R_t_i
595
+
596
+ def generate_rollout(self, input, method_name="rollout"):
597
+ output = self.model_usage.forward(input).question_answering_score
598
+ model = self.model_usage.model
599
+
600
+ # initialize relevancy matrices
601
+ text_tokens = self.model_usage.text_len
602
+ image_bboxes = self.model_usage.image_boxes_len
603
+
604
+ # text self attention matrix
605
+ self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
606
+ # image self attention matrix
607
+ self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
608
+ # impact of images on text
609
+ self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
610
+ # impact of text on images
611
+ self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)
612
+
613
+ cams_text = []
614
+ cams_image = []
615
+ # language self attention
616
+ blocks = model.lxmert.encoder.layer
617
+ for blk in blocks:
618
+ cam = blk.attention.self.get_attn().detach()
619
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
620
+ cams_text.append(cam)
621
+
622
+ # image self attention
623
+ blocks = model.lxmert.encoder.r_layers
624
+ for blk in blocks:
625
+ cam = blk.attention.self.get_attn().detach()
626
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
627
+ cams_image.append(cam)
628
+
629
+ # cross attn layers
630
+ blocks = model.lxmert.encoder.x_layers
631
+ for i, blk in enumerate(blocks):
632
+ # in the last cross attention module, only the text cross modal
633
+ # attention has an impact on the CLS token, since it's the first
634
+ # token in the language tokens
635
+ if i == len(blocks) - 1:
636
+ break
637
+
638
+ # language self attention
639
+ cam = blk.lang_self_att.self.get_attn().detach()
640
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
641
+ cams_text.append(cam)
642
+
643
+ # image self attention
644
+ cam = blk.visn_self_att.self.get_attn().detach()
645
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
646
+ cams_image.append(cam)
647
+
648
+ # take care of last cross attention layer- only text
649
+ blk = model.lxmert.encoder.x_layers[-1]
650
+ # cross attn cam will be the one used for the R_t_i matrix
651
+ cam_t_i = blk.visual_attention.att.get_attn().detach()
652
+ cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2], cam_t_i.shape[-1]).mean(dim=0)
653
+ self.R_t_t = compute_rollout_attention(copy.deepcopy(cams_text))
654
+ self.R_i_i = compute_rollout_attention(cams_image)
655
+ self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
656
+ # language self attention
657
+ cam = blk.lang_self_att.self.get_attn().detach()
658
+ cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
659
+ cams_text.append(cam)
660
+
661
+ self.R_t_t = compute_rollout_attention(cams_text)
662
+
663
+ # disregard the [CLS] token itself
664
+ self.R_t_t[0, 0] = 0
665
+ return self.R_t_t, self.R_t_i
lxmert/src/__init__.py ADDED
File without changes
lxmert/src/__pycache__/ExplanationGenerator.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
lxmert/src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (148 Bytes). View file
 
lxmert/src/__pycache__/huggingface_lxmert.cpython-38.pyc ADDED
Binary file (47.6 kB). View file
 
lxmert/src/__pycache__/layers.cpython-38.pyc ADDED
Binary file (10.5 kB). View file
 
lxmert/src/__pycache__/lxmert_lrp.cpython-38.pyc ADDED
Binary file (53.6 kB). View file
 
lxmert/src/__pycache__/modeling_frcnn.cpython-38.pyc ADDED
Binary file (56.8 kB). View file
 
lxmert/src/__pycache__/processing_image.cpython-38.pyc ADDED
Binary file (5.73 kB). View file
 
lxmert/src/__pycache__/vqa_utils.cpython-38.pyc ADDED
Binary file (14.4 kB). View file
 
lxmert/src/huggingface_lxmert.py ADDED
@@ -0,0 +1,1472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch lxmert model. """
16
+
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss, SmoothL1Loss
27
+
28
+ from transformers.activations import ACT2FN, gelu
29
+ from transformers.file_utils import (
30
+ ModelOutput,
31
+ add_code_sample_docstrings,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ replace_return_docstrings,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import logging
38
+ from transformers.configuration_lxmert import LxmertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "LxmertConfig"
44
+ _TOKENIZER_FOR_DOC = "LxmertTokenizer"
45
+
46
+ LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
+ "unc-nlp/lxmert-base-uncased",
48
+ ]
49
+
50
+
51
+ class GeLU(nn.Module):
52
+ def __init__(self):
53
+ super().__init__()
54
+
55
+ def forward(self, x):
56
+ return gelu(x)
57
+
58
+
59
+ @dataclass
60
+ class LxmertModelOutput(ModelOutput):
61
+ """
62
+ Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
63
+ visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
64
+ encoder")
65
+
66
+
67
+ Args:
68
+ language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
69
+ Sequence of hidden-states at the output of the last layer of the language encoder.
70
+ vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
71
+ Sequence of hidden-states at the output of the last layer of the visual encoder.
72
+ pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
73
+ Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
74
+ by a Linear layer and a Tanh activation function. The Linear
75
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
76
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
77
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
78
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
79
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
80
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
81
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
82
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
83
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
84
+ weighted average in the self-attention heads.
85
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
86
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
87
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
88
+ weighted average in the self-attention heads.
89
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
90
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
91
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
92
+ weighted average in the self-attention heads.
93
+ """
94
+
95
+ language_output: Optional[torch.FloatTensor] = None
96
+ vision_output: Optional[torch.FloatTensor] = None
97
+ pooled_output: Optional[torch.FloatTensor] = None
98
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
99
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
100
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
101
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
102
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
103
+
104
+
105
+ @dataclass
106
+ class LxmertForQuestionAnsweringOutput(ModelOutput):
107
+ """
108
+ Output type of :class:`~transformers.LxmertForQuestionAnswering`.
109
+
110
+ Args:
111
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
112
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
113
+ (classification) loss.k.
114
+ question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`):
115
+ Prediction scores of question answering objective (classification).
116
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
117
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
118
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
119
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
120
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
121
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
122
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
123
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
124
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
125
+ weighted average in the self-attention heads.
126
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
127
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
128
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
129
+ weighted average in the self-attention heads.
130
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
131
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
132
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
133
+ weighted average in the self-attention heads.
134
+ """
135
+
136
+ loss: Optional[torch.FloatTensor] = None
137
+ question_answering_score: Optional[torch.FloatTensor] = None
138
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
139
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
140
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
141
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
142
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
143
+
144
+
145
+ @dataclass
146
+ class LxmertForPreTrainingOutput(ModelOutput):
147
+ """
148
+ Output type of :class:`~transformers.LxmertForPreTraining`.
149
+
150
+ Args:
151
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
152
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
153
+ (classification) loss.
154
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
155
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
156
+ cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
157
+ Prediction scores of the textual matching objective (classification) head (scores of True/False
158
+ continuation before SoftMax).
159
+ question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`):
160
+ Prediction scores of question answering objective (classification).
161
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
162
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
163
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
164
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
165
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
166
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
167
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
168
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
169
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
170
+ weighted average in the self-attention heads.
171
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
172
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
173
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
174
+ weighted average in the self-attention heads.
175
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
176
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
177
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
178
+ weighted average in the self-attention heads.
179
+
180
+ """
181
+
182
+ loss: [torch.FloatTensor] = None
183
+ prediction_logits: Optional[torch.FloatTensor] = None
184
+ cross_relationship_score: Optional[torch.FloatTensor] = None
185
+ question_answering_score: Optional[torch.FloatTensor] = None
186
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
187
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
188
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
189
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
190
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
191
+
192
+
193
+ def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
194
+ """Load tf checkpoints in a pytorch model."""
195
+ try:
196
+ import re
197
+
198
+ import numpy as np
199
+ import tensorflow as tf
200
+ except ImportError:
201
+ logger.error(
202
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
203
+ "https://www.tensorflow.org/install/ for installation instructions."
204
+ )
205
+ raise
206
+ tf_path = os.path.abspath(tf_checkpoint_path)
207
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
208
+ # Load weights from TF model
209
+ init_vars = tf.train.list_variables(tf_path)
210
+ names = []
211
+ arrays = []
212
+ for name, shape in init_vars:
213
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
214
+ array = tf.train.load_variable(tf_path, name)
215
+ names.append(name)
216
+ arrays.append(array)
217
+
218
+ for name, array in zip(names, arrays):
219
+ name = name.split("/")
220
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
221
+ # which are not required for using pretrained model
222
+ if any(
223
+ n
224
+ in [
225
+ "adam_v",
226
+ "adam_m",
227
+ "AdamWeightDecayOptimizer",
228
+ "AdamWeightDecayOptimizer_1",
229
+ "global_step",
230
+ ]
231
+ for n in name
232
+ ):
233
+ logger.info("Skipping {}".format("/".join(name)))
234
+ continue
235
+ pointer = model
236
+ for m_name in name:
237
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
238
+ scope_names = re.split(r"_(\d+)", m_name)
239
+ else:
240
+ scope_names = [m_name]
241
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
242
+ pointer = getattr(pointer, "weight")
243
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
244
+ pointer = getattr(pointer, "bias")
245
+ elif scope_names[0] == "output_weights":
246
+ pointer = getattr(pointer, "weight")
247
+ elif scope_names[0] == "squad":
248
+ pointer = getattr(pointer, "classifier")
249
+ else:
250
+ try:
251
+ pointer = getattr(pointer, scope_names[0])
252
+ except AttributeError:
253
+ logger.info("Skipping {}".format("/".join(name)))
254
+ continue
255
+ if len(scope_names) >= 2:
256
+ num = int(scope_names[1])
257
+ pointer = pointer[num]
258
+ if m_name[-11:] == "_embeddings":
259
+ pointer = getattr(pointer, "weight")
260
+ elif m_name == "kernel":
261
+ array = np.transpose(array)
262
+ try:
263
+ assert pointer.shape == array.shape
264
+ except AssertionError as e:
265
+ e.args += (pointer.shape, array.shape)
266
+ raise
267
+ logger.info("Initialize PyTorch weight {}".format(name))
268
+ pointer.data = torch.from_numpy(array)
269
+ return model
270
+
271
+
272
+ class LxmertEmbeddings(nn.Module):
273
+ """Construct the embeddings from word, position and token_type embeddings."""
274
+
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
278
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
279
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
280
+
281
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
282
+ # any TensorFlow checkpoint file
283
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
284
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
285
+
286
+ def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
287
+ if input_ids is not None:
288
+ input_shape = input_ids.size()
289
+ device = input_ids.device
290
+ else:
291
+ input_shape = inputs_embeds.size()[:-1]
292
+ device = inputs_embeds.device
293
+ seq_length = input_shape[1]
294
+
295
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
296
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
297
+
298
+ if token_type_ids is None:
299
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
300
+
301
+ if inputs_embeds is None:
302
+ inputs_embeds = self.word_embeddings(input_ids)
303
+ position_embeddings = self.position_embeddings(position_ids)
304
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
305
+
306
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
307
+ embeddings = self.LayerNorm(embeddings)
308
+ embeddings = self.dropout(embeddings)
309
+ return embeddings
310
+
311
+
312
+ class LxmertAttention(nn.Module):
313
+ def __init__(self, config, ctx_dim=None, save_cams=False):
314
+ super().__init__()
315
+ if config.hidden_size % config.num_attention_heads != 0:
316
+ raise ValueError(
317
+ "The hidden size (%d) is not a multiple of the number of attention "
318
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
319
+ )
320
+ self.num_attention_heads = config.num_attention_heads
321
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
322
+ self.head_size = self.num_attention_heads * self.attention_head_size
323
+
324
+ # visual_dim = 2048
325
+ if ctx_dim is None:
326
+ ctx_dim = config.hidden_size
327
+ self.query = nn.Linear(config.hidden_size, self.head_size)
328
+ self.key = nn.Linear(ctx_dim, self.head_size)
329
+ self.value = nn.Linear(ctx_dim, self.head_size)
330
+
331
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
332
+
333
+ self.save_cams = save_cams
334
+ self.attn = None
335
+ self.attn_gradients = None
336
+
337
+ def get_attn(self):
338
+ ret = self.attn
339
+ self.attn = None
340
+ return ret
341
+
342
+ def save_attn(self, attn):
343
+ if self.attn is not None:
344
+ self.attn = [self.attn, attn]
345
+ else:
346
+ self.attn = attn
347
+
348
+ def save_attn_gradients(self, attn_gradients):
349
+ if self.attn_gradients is not None:
350
+ self.attn_gradients = [self.attn_gradients, attn_gradients]
351
+ else:
352
+ self.attn_gradients = attn_gradients
353
+
354
+ def get_attn_gradients(self):
355
+ ret = self.attn_gradients
356
+ self.attn_gradients = None
357
+ return ret
358
+
359
+ def reset(self):
360
+ self.attn = None
361
+ self.attn_gradients = None
362
+
363
+ def transpose_for_scores(self, x):
364
+ new_x_shape = x.size()[:-1] + (
365
+ self.num_attention_heads,
366
+ self.attention_head_size,
367
+ )
368
+ x = x.view(*new_x_shape)
369
+ return x.permute(0, 2, 1, 3)
370
+
371
+ def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
372
+ mixed_query_layer = self.query(hidden_states)
373
+ mixed_key_layer = self.key(context)
374
+ mixed_value_layer = self.value(context)
375
+
376
+ query_layer = self.transpose_for_scores(mixed_query_layer)
377
+ key_layer = self.transpose_for_scores(mixed_key_layer)
378
+ value_layer = self.transpose_for_scores(mixed_value_layer)
379
+
380
+ # Take the dot product between "query" and "key" to get the raw attention scores.
381
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
382
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
383
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
384
+ if attention_mask is not None:
385
+ attention_scores = attention_scores + attention_mask
386
+
387
+ # Normalize the attention scores to probabilities.
388
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
389
+
390
+ # if self.save_cams:
391
+ self.save_attn(attention_probs)
392
+ attention_probs.register_hook(self.save_attn_gradients)
393
+
394
+ # This is actually dropping out entire tokens to attend to, which might
395
+ # seem a bit unusual, but is taken from the original Transformer paper.
396
+ attention_probs = self.dropout(attention_probs)
397
+
398
+ context_layer = torch.matmul(attention_probs, value_layer)
399
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
400
+ new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
401
+ context_layer = context_layer.view(*new_context_layer_shape)
402
+
403
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
404
+ return outputs
405
+
406
+
407
+ class LxmertAttentionOutput(nn.Module):
408
+ def __init__(self, config):
409
+ super().__init__()
410
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
411
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
412
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
413
+
414
+ def forward(self, hidden_states, input_tensor):
415
+ hidden_states = self.dense(hidden_states)
416
+ hidden_states = self.dropout(hidden_states)
417
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
418
+ return hidden_states
419
+
420
+
421
+ class LxmertCrossAttentionLayer(nn.Module):
422
+ def __init__(self, config, save_cams=False):
423
+ super().__init__()
424
+ self.att = LxmertAttention(config, save_cams=save_cams)
425
+ self.output = LxmertAttentionOutput(config)
426
+
427
+ def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
428
+ output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
429
+ if output_attentions:
430
+ attention_probs = output[1]
431
+ attention_output = self.output(output[0], input_tensor)
432
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
433
+ return outputs
434
+
435
+
436
+ class LxmertSelfAttentionLayer(nn.Module):
437
+ def __init__(self, config, save_cams=False):
438
+ super().__init__()
439
+ self.self = LxmertAttention(config, save_cams=save_cams)
440
+ self.output = LxmertAttentionOutput(config)
441
+
442
+ def forward(self, input_tensor, attention_mask, output_attentions=False):
443
+ # Self attention attends to itself, thus keys and queries are the same (input_tensor).
444
+ output = self.self(
445
+ input_tensor,
446
+ input_tensor,
447
+ attention_mask,
448
+ output_attentions=output_attentions,
449
+ )
450
+ if output_attentions:
451
+ attention_probs = output[1]
452
+ attention_output = self.output(output[0], input_tensor)
453
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
454
+ return outputs
455
+
456
+
457
+ class LxmertIntermediate(nn.Module):
458
+ def __init__(self, config):
459
+ super().__init__()
460
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
461
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
462
+
463
+ def forward(self, hidden_states):
464
+ hidden_states = self.dense(hidden_states)
465
+ hidden_states = self.intermediate_act_fn(hidden_states)
466
+ return hidden_states
467
+
468
+
469
+ class LxmertOutput(nn.Module):
470
+ def __init__(self, config):
471
+ super().__init__()
472
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
473
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
474
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
475
+
476
+ def forward(self, hidden_states, input_tensor):
477
+ hidden_states = self.dense(hidden_states)
478
+ hidden_states = self.dropout(hidden_states)
479
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
480
+ return hidden_states
481
+
482
+
483
+ class LxmertLayer(nn.Module):
484
+ def __init__(self, config, save_cams=False):
485
+ super().__init__()
486
+ self.attention = LxmertSelfAttentionLayer(config, save_cams=save_cams)
487
+ self.intermediate = LxmertIntermediate(config)
488
+ self.output = LxmertOutput(config)
489
+
490
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
491
+ outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
492
+ attention_output = outputs[0]
493
+ intermediate_output = self.intermediate(attention_output)
494
+ layer_output = self.output(intermediate_output, attention_output)
495
+ outputs = (layer_output,) + outputs[1:] # add attentions if we output them
496
+ return outputs
497
+
498
+
499
+ class LxmertXLayer(nn.Module):
500
+ def __init__(self, config, save_cams=False):
501
+ super().__init__()
502
+ # The cross-attention Layer
503
+ self.visual_attention = LxmertCrossAttentionLayer(config, save_cams=save_cams)
504
+
505
+ # Self-attention Layers
506
+ self.lang_self_att = LxmertSelfAttentionLayer(config)
507
+ self.visn_self_att = LxmertSelfAttentionLayer(config)
508
+
509
+ # Intermediate and Output Layers (FFNs)
510
+ self.lang_inter = LxmertIntermediate(config)
511
+ self.lang_output = LxmertOutput(config)
512
+ self.visn_inter = LxmertIntermediate(config)
513
+ self.visn_output = LxmertOutput(config)
514
+
515
+ def cross_att(
516
+ self,
517
+ lang_input,
518
+ lang_attention_mask,
519
+ visual_input,
520
+ visual_attention_mask,
521
+ output_x_attentions=False,
522
+ ):
523
+ # Cross Attention
524
+ lang_att_output = self.visual_attention(
525
+ lang_input,
526
+ visual_input,
527
+ ctx_att_mask=visual_attention_mask,
528
+ output_attentions=output_x_attentions,
529
+ )
530
+ visual_att_output = self.visual_attention(
531
+ visual_input,
532
+ lang_input,
533
+ ctx_att_mask=lang_attention_mask,
534
+ output_attentions=False,
535
+ )
536
+ return lang_att_output, visual_att_output
537
+
538
+ def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
539
+ # Self Attention
540
+ lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
541
+ visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
542
+ return lang_att_output[0], visual_att_output[0]
543
+
544
+ def output_fc(self, lang_input, visual_input):
545
+ # FC layers
546
+ lang_inter_output = self.lang_inter(lang_input)
547
+ visual_inter_output = self.visn_inter(visual_input)
548
+
549
+ # Layer output
550
+ lang_output = self.lang_output(lang_inter_output, lang_input)
551
+ visual_output = self.visn_output(visual_inter_output, visual_input)
552
+
553
+ return lang_output, visual_output
554
+
555
+ def forward(
556
+ self,
557
+ lang_feats,
558
+ lang_attention_mask,
559
+ visual_feats,
560
+ visual_attention_mask,
561
+ output_attentions=False,
562
+ ):
563
+
564
+ lang_att_output, visual_att_output = self.cross_att(
565
+ lang_input=lang_feats,
566
+ lang_attention_mask=lang_attention_mask,
567
+ visual_input=visual_feats,
568
+ visual_attention_mask=visual_attention_mask,
569
+ output_x_attentions=output_attentions,
570
+ )
571
+ attention_probs = lang_att_output[1:]
572
+ lang_att_output, visual_att_output = self.self_att(
573
+ lang_att_output[0],
574
+ lang_attention_mask,
575
+ visual_att_output[0],
576
+ visual_attention_mask,
577
+ )
578
+
579
+ lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
580
+ return (
581
+ (
582
+ lang_output,
583
+ visual_output,
584
+ attention_probs[0],
585
+ )
586
+ if output_attentions
587
+ else (lang_output, visual_output)
588
+ )
589
+
590
+
591
+ class LxmertVisualFeatureEncoder(nn.Module):
592
+ def __init__(self, config):
593
+ super().__init__()
594
+ feat_dim = config.visual_feat_dim
595
+ pos_dim = config.visual_pos_dim
596
+
597
+ # Object feature encoding
598
+ self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
599
+ self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
600
+
601
+ # Box position encoding
602
+ self.box_fc = nn.Linear(pos_dim, config.hidden_size)
603
+ self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
604
+
605
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
606
+
607
+ def forward(self, visual_feats, visual_pos):
608
+ x = self.visn_fc(visual_feats)
609
+ x = self.visn_layer_norm(x)
610
+ y = self.box_fc(visual_pos)
611
+ y = self.box_layer_norm(y)
612
+ output = (x + y) / 2
613
+
614
+ output = self.dropout(output)
615
+ return output
616
+
617
+
618
+ class LxmertEncoder(nn.Module):
619
+ def __init__(self, config, save_cams=False):
620
+ super().__init__()
621
+
622
+ # Obj-level image embedding layer
623
+ self.visn_fc = LxmertVisualFeatureEncoder(config)
624
+ self.config = config
625
+
626
+ # Number of layers
627
+ self.num_l_layers = config.l_layers
628
+ self.num_x_layers = config.x_layers
629
+ self.num_r_layers = config.r_layers
630
+
631
+ # Layers
632
+ # Using self.layer instead of self.l_layer to support loading BERT weights.
633
+ self.layer = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_l_layers)])
634
+ self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
635
+ self.r_layers = nn.ModuleList([LxmertLayer(config, save_cams=save_cams) for _ in range(self.num_r_layers)])
636
+
637
+ def forward(
638
+ self,
639
+ lang_feats,
640
+ lang_attention_mask,
641
+ visual_feats,
642
+ visual_pos,
643
+ visual_attention_mask=None,
644
+ output_attentions=None,
645
+ ):
646
+
647
+ vision_hidden_states = ()
648
+ language_hidden_states = ()
649
+ vision_attentions = () if output_attentions or self.config.output_attentions else None
650
+ language_attentions = () if output_attentions or self.config.output_attentions else None
651
+ cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
652
+
653
+ visual_feats = self.visn_fc(visual_feats, visual_pos)
654
+
655
+ # Run language layers
656
+ for layer_module in self.layer:
657
+ l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
658
+ lang_feats = l_outputs[0]
659
+ language_hidden_states = language_hidden_states + (lang_feats,)
660
+ if language_attentions is not None:
661
+ language_attentions = language_attentions + (l_outputs[1],)
662
+
663
+ # Run relational layers
664
+ for layer_module in self.r_layers:
665
+ v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
666
+ visual_feats = v_outputs[0]
667
+ vision_hidden_states = vision_hidden_states + (visual_feats,)
668
+ if vision_attentions is not None:
669
+ vision_attentions = vision_attentions + (v_outputs[1],)
670
+
671
+ # Run cross-modality layers
672
+ for layer_module in self.x_layers:
673
+ x_outputs = layer_module(
674
+ lang_feats,
675
+ lang_attention_mask,
676
+ visual_feats,
677
+ visual_attention_mask,
678
+ output_attentions=output_attentions,
679
+ )
680
+ lang_feats, visual_feats = x_outputs[:2]
681
+ vision_hidden_states = vision_hidden_states + (visual_feats,)
682
+ language_hidden_states = language_hidden_states + (lang_feats,)
683
+ if cross_encoder_attentions is not None:
684
+ cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
685
+ visual_encoder_outputs = (
686
+ vision_hidden_states,
687
+ vision_attentions if output_attentions else None,
688
+ )
689
+ lang_encoder_outputs = (
690
+ language_hidden_states,
691
+ language_attentions if output_attentions else None,
692
+ )
693
+ return (
694
+ visual_encoder_outputs,
695
+ lang_encoder_outputs,
696
+ cross_encoder_attentions if output_attentions else None,
697
+ )
698
+
699
+
700
+ class LxmertPooler(nn.Module):
701
+ def __init__(self, config):
702
+ super(LxmertPooler, self).__init__()
703
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
704
+ self.activation = nn.Tanh()
705
+
706
+ def forward(self, hidden_states):
707
+ # We "pool" the model by simply taking the hidden state corresponding
708
+ # to the first token.
709
+ first_token_tensor = hidden_states[:, 0]
710
+ pooled_output = self.dense(first_token_tensor)
711
+ pooled_output = self.activation(pooled_output)
712
+ return pooled_output
713
+
714
+
715
+ class LxmertPredictionHeadTransform(nn.Module):
716
+ def __init__(self, config):
717
+ super(LxmertPredictionHeadTransform, self).__init__()
718
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
719
+ self.transform_act_fn = ACT2FN[config.hidden_act]
720
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
721
+
722
+ def forward(self, hidden_states):
723
+ hidden_states = self.dense(hidden_states)
724
+ hidden_states = self.transform_act_fn(hidden_states)
725
+ hidden_states = self.LayerNorm(hidden_states)
726
+ return hidden_states
727
+
728
+
729
+ class LxmertLMPredictionHead(nn.Module):
730
+ def __init__(self, config, lxmert_model_embedding_weights):
731
+ super(LxmertLMPredictionHead, self).__init__()
732
+ self.transform = LxmertPredictionHeadTransform(config)
733
+
734
+ # The output weights are the same as the input embeddings, but there is
735
+ # an output-only bias for each token.
736
+ self.decoder = nn.Linear(
737
+ lxmert_model_embedding_weights.size(1),
738
+ lxmert_model_embedding_weights.size(0),
739
+ bias=False,
740
+ )
741
+ self.decoder.weight = lxmert_model_embedding_weights
742
+ self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))
743
+
744
+ def forward(self, hidden_states):
745
+ hidden_states = self.transform(hidden_states)
746
+ hidden_states = self.decoder(hidden_states) + self.bias
747
+ return hidden_states
748
+
749
+
750
+ class LxmertVisualAnswerHead(nn.Module):
751
+ def __init__(self, config, num_labels):
752
+ super().__init__()
753
+ hid_dim = config.hidden_size
754
+ self.logit_fc = nn.Sequential(
755
+ nn.Linear(hid_dim, hid_dim * 2),
756
+ GeLU(),
757
+ nn.LayerNorm(hid_dim * 2, eps=1e-12),
758
+ nn.Linear(hid_dim * 2, num_labels),
759
+ )
760
+
761
+ def forward(self, hidden_states):
762
+ return self.logit_fc(hidden_states)
763
+
764
+
765
+ class LxmertVisualObjHead(nn.Module):
766
+ def __init__(self, config):
767
+ super().__init__()
768
+ self.transform = LxmertPredictionHeadTransform(config)
769
+ # Decide the use of visual losses
770
+ visual_losses = {}
771
+ if config.visual_obj_loss:
772
+ visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
773
+ if config.visual_attr_loss:
774
+ visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
775
+ if config.visual_obj_loss:
776
+ visual_losses["feat"] = {
777
+ "shape": (-1, config.visual_feat_dim),
778
+ "num": config.visual_feat_dim,
779
+ }
780
+ self.visual_losses = visual_losses
781
+
782
+ # The output weights are the same as the input embeddings, but there is
783
+ # an output-only bias for each token.
784
+ self.decoder_dict = nn.ModuleDict(
785
+ {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
786
+ )
787
+
788
+ def forward(self, hidden_states):
789
+ hidden_states = self.transform(hidden_states)
790
+ output = {}
791
+ for key in self.visual_losses:
792
+ output[key] = self.decoder_dict[key](hidden_states)
793
+ return output
794
+
795
+
796
+ class LxmertPreTrainingHeads(nn.Module):
797
+ def __init__(self, config, lxmert_model_embedding_weights):
798
+ super(LxmertPreTrainingHeads, self).__init__()
799
+ self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights)
800
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
801
+
802
+ def forward(self, sequence_output, pooled_output):
803
+ prediction_scores = self.predictions(sequence_output)
804
+ seq_relationship_score = self.seq_relationship(pooled_output)
805
+ return prediction_scores, seq_relationship_score
806
+
807
+
808
+ class LxmertPreTrainedModel(PreTrainedModel):
809
+ """
810
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
811
+ models.
812
+ """
813
+
814
+ config_class = LxmertConfig
815
+ load_tf_weights = load_tf_weights_in_lxmert
816
+ base_model_prefix = "lxmert"
817
+
818
+ def _init_weights(self, module):
819
+ """ Initialize the weights """
820
+ if isinstance(module, (nn.Linear, nn.Embedding)):
821
+ # Slightly different from the TF version which uses truncated_normal for initialization
822
+ # cf https://github.com/pytorch/pytorch/pull/5617
823
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
824
+ elif isinstance(module, nn.LayerNorm):
825
+ module.bias.data.zero_()
826
+ module.weight.data.fill_(1.0)
827
+ if isinstance(module, nn.Linear) and module.bias is not None:
828
+ module.bias.data.zero_()
829
+
830
+
831
+ LXMERT_START_DOCSTRING = r"""
832
+
833
+ The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers
834
+ <https://arxiv.org/abs/1908.07490>`__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model,
835
+ pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome,
836
+ using a combination of masked language modeling, region of interest feature regression, cross entropy loss for
837
+ question answering attribute prediction, and object tag prediction.
838
+
839
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
840
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
841
+ pruning heads etc.)
842
+
843
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
844
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
845
+ general usage and behavior.
846
+
847
+ Parameters:
848
+ config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model.
849
+ Initializing with a config file does not load the weights associated with the model, only the
850
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
851
+ weights.
852
+ """
853
+
854
+ LXMERT_INPUTS_DOCSTRING = r"""
855
+
856
+ Args:
857
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
858
+ Indices of input sequence tokens in the vocabulary.
859
+
860
+ Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See
861
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
862
+ details.
863
+
864
+ `What are input IDs? <../glossary.html#input-ids>`__
865
+ visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝):
866
+ This input represents visual features. They ROI pooled object features from bounding boxes using a
867
+ faster-RCNN model)
868
+
869
+ These are currently not provided by the transformers library.
870
+ visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝):
871
+ This input represents spacial features corresponding to their relative (via index) visual features. The
872
+ pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to
873
+ 1.
874
+
875
+ These are currently not provided by the transformers library.
876
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
877
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
878
+
879
+ - 1 for tokens that are **not masked**,
880
+ - 0 for tokens that are **masked**.
881
+
882
+ `What are attention masks? <../glossary.html#attention-mask>`__
883
+ visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
884
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
885
+
886
+ - 1 for tokens that are **not masked**,
887
+ - 0 for tokens that are **masked**.
888
+
889
+ `What are attention masks? <../glossary.html#attention-mask>`__
890
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
891
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
892
+ 1]``:
893
+
894
+ - 0 corresponds to a `sentence A` token,
895
+ - 1 corresponds to a `sentence B` token.
896
+
897
+ `What are token type IDs? <../glossary.html#token-type-ids>`__
898
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
899
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
900
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
901
+ vectors than the model's internal embedding lookup matrix.
902
+ output_attentions (:obj:`bool`, `optional`):
903
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
904
+ tensors for more detail.
905
+ output_hidden_states (:obj:`bool`, `optional`):
906
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
907
+ more detail.
908
+ return_dict (:obj:`bool`, `optional`):
909
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
910
+ """
911
+
912
+
913
+ @add_start_docstrings(
914
+ "The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.",
915
+ LXMERT_START_DOCSTRING,
916
+ )
917
+ class LxmertModel(LxmertPreTrainedModel):
918
+ def __init__(self, config, save_cams=False):
919
+ super().__init__(config)
920
+ self.embeddings = LxmertEmbeddings(config)
921
+ self.encoder = LxmertEncoder(config, save_cams=save_cams)
922
+ self.pooler = LxmertPooler(config)
923
+ self.init_weights()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.embeddings.word_embeddings
927
+
928
+ def set_input_embeddings(self, new_embeddings):
929
+ self.embeddings.word_embeddings = new_embeddings
930
+
931
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
932
+ @add_code_sample_docstrings(
933
+ tokenizer_class=_TOKENIZER_FOR_DOC,
934
+ checkpoint="unc-nlp/lxmert-base-uncased",
935
+ output_type=LxmertModelOutput,
936
+ config_class=_CONFIG_FOR_DOC,
937
+ )
938
+ def forward(
939
+ self,
940
+ input_ids=None,
941
+ visual_feats=None,
942
+ visual_pos=None,
943
+ attention_mask=None,
944
+ visual_attention_mask=None,
945
+ token_type_ids=None,
946
+ inputs_embeds=None,
947
+ output_attentions=None,
948
+ output_hidden_states=None,
949
+ return_dict=None,
950
+ ):
951
+
952
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
953
+ output_hidden_states = (
954
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
955
+ )
956
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
957
+
958
+ if input_ids is not None and inputs_embeds is not None:
959
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
960
+ elif input_ids is not None:
961
+ input_shape = input_ids.size()
962
+ elif inputs_embeds is not None:
963
+ input_shape = inputs_embeds.size()[:-1]
964
+ else:
965
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
966
+
967
+ assert visual_feats is not None, "`visual_feats` cannot be `None`"
968
+ assert visual_pos is not None, "`visual_pos` cannot be `None`"
969
+
970
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
971
+
972
+ if attention_mask is None:
973
+ attention_mask = torch.ones(input_shape, device=device)
974
+ if token_type_ids is None:
975
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
976
+
977
+ # We create a 3D attention mask from a 2D tensor mask.
978
+ # Sizes are [batch_size, 1, 1, to_seq_length]
979
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
980
+ # this attention mask is more simple than the triangular masking of causal attention
981
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
982
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
983
+
984
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
985
+ # masked positions, this operation will create a tensor which is 0.0 for
986
+ # positions we want to attend and -10000.0 for masked positions.
987
+ # Since we are adding it to the raw scores before the softmax, this is
988
+ # effectively the same as removing these entirely.
989
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
990
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
991
+
992
+ # Process the visual attention mask
993
+ if visual_attention_mask is not None:
994
+ extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
995
+ extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
996
+ extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
997
+ else:
998
+ extended_visual_attention_mask = None
999
+
1000
+ # Positional Word Embeddings
1001
+ embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
1002
+
1003
+ # Run Lxmert encoder
1004
+ encoder_outputs = self.encoder(
1005
+ embedding_output,
1006
+ extended_attention_mask,
1007
+ visual_feats=visual_feats,
1008
+ visual_pos=visual_pos,
1009
+ visual_attention_mask=extended_visual_attention_mask,
1010
+ output_attentions=output_attentions,
1011
+ )
1012
+
1013
+ visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
1014
+ vision_hidden_states = visual_encoder_outputs[0]
1015
+ language_hidden_states = lang_encoder_outputs[0]
1016
+
1017
+ all_attentions = ()
1018
+ if output_attentions:
1019
+ language_attentions = lang_encoder_outputs[1]
1020
+ vision_attentions = visual_encoder_outputs[1]
1021
+ cross_encoder_attentions = encoder_outputs[2]
1022
+ all_attentions = (
1023
+ language_attentions,
1024
+ vision_attentions,
1025
+ cross_encoder_attentions,
1026
+ )
1027
+
1028
+ hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
1029
+
1030
+ visual_output = vision_hidden_states[-1]
1031
+ lang_output = language_hidden_states[-1]
1032
+ pooled_output = self.pooler(lang_output)
1033
+
1034
+ if not return_dict:
1035
+ return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
1036
+
1037
+ return LxmertModelOutput(
1038
+ pooled_output=pooled_output,
1039
+ language_output=lang_output,
1040
+ vision_output=visual_output,
1041
+ language_hidden_states=language_hidden_states if output_hidden_states else None,
1042
+ vision_hidden_states=vision_hidden_states if output_hidden_states else None,
1043
+ language_attentions=language_attentions if output_attentions else None,
1044
+ vision_attentions=vision_attentions if output_attentions else None,
1045
+ cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
1046
+ )
1047
+
1048
+
1049
+ @add_start_docstrings(
1050
+ """Lxmert Model with a specified pretraining head on top. """,
1051
+ LXMERT_START_DOCSTRING,
1052
+ )
1053
+ class LxmertForPreTraining(LxmertPreTrainedModel):
1054
+ def __init__(self, config, save_cams=False):
1055
+ super().__init__(config)
1056
+ # Configuration
1057
+ self.config = config
1058
+ self.num_qa_labels = config.num_qa_labels
1059
+ self.visual_loss_normalizer = config.visual_loss_normalizer
1060
+
1061
+ # Use of pretraining tasks
1062
+ self.task_mask_lm = config.task_mask_lm
1063
+ self.task_obj_predict = config.task_obj_predict
1064
+ self.task_matched = config.task_matched
1065
+ self.task_qa = config.task_qa
1066
+
1067
+ # Lxmert backbone
1068
+ self.lxmert = LxmertModel(config, save_cams=save_cams)
1069
+
1070
+ # Pre-training heads
1071
+ self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight)
1072
+ if self.task_obj_predict:
1073
+ self.obj_predict_head = LxmertVisualObjHead(config)
1074
+ if self.task_qa:
1075
+ self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
1076
+
1077
+ # Weight initialization
1078
+ self.init_weights()
1079
+
1080
+ # Loss functions
1081
+ self.loss_fcts = {
1082
+ "l2": SmoothL1Loss(reduction="none"),
1083
+ "visual_ce": CrossEntropyLoss(reduction="none"),
1084
+ "ce": CrossEntropyLoss(),
1085
+ }
1086
+
1087
+ visual_losses = {}
1088
+ if config.visual_obj_loss:
1089
+ visual_losses["obj"] = {
1090
+ "shape": (-1,),
1091
+ "num": config.num_object_labels,
1092
+ "loss": "visual_ce",
1093
+ }
1094
+ if config.visual_attr_loss:
1095
+ visual_losses["attr"] = {
1096
+ "shape": (-1,),
1097
+ "num": config.num_attr_labels,
1098
+ "loss": "visual_ce",
1099
+ }
1100
+ if config.visual_obj_loss:
1101
+ visual_losses["feat"] = {
1102
+ "shape": (-1, config.visual_feat_dim),
1103
+ "num": config.visual_feat_dim,
1104
+ "loss": "l2",
1105
+ }
1106
+ self.visual_losses = visual_losses
1107
+
1108
+ def resize_num_qa_labels(self, num_labels):
1109
+ """
1110
+ Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
1111
+ will add newly initialized weights. Reducing the size will remove weights from the end
1112
+
1113
+ Args:
1114
+ num_labels (:obj:`int`, `optional`):
1115
+ New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
1116
+ weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
1117
+ just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
1118
+ anything.
1119
+
1120
+ Return:
1121
+ :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
1122
+ """
1123
+
1124
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1125
+ if num_labels is None or cur_qa_logit_layer is None:
1126
+ return
1127
+ new_qa_logit_layer = self._resize_qa_labels(num_labels)
1128
+ self.config.num_qa_labels = num_labels
1129
+ self.num_qa_labels = num_labels
1130
+
1131
+ return new_qa_logit_layer
1132
+
1133
+ def _resize_qa_labels(self, num_labels):
1134
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1135
+ new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
1136
+ self._set_qa_logit_layer(new_qa_logit_layer)
1137
+ return self.get_qa_logit_layer()
1138
+
1139
+ def get_qa_logit_layer(self) -> nn.Module:
1140
+ """
1141
+ Returns the the linear layer that produces question answering logits.
1142
+
1143
+ Returns:
1144
+ :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if
1145
+ lxmert does not have a visual answering head.
1146
+ """
1147
+ if hasattr(self, "answer_head"):
1148
+ return self.answer_head.logit_fc[-1]
1149
+
1150
+ def _set_qa_logit_layer(self, qa_logit_layer):
1151
+ self.answer_head.logit_fc[-1] = qa_logit_layer
1152
+
1153
+ def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
1154
+
1155
+ if num_labels is None:
1156
+ return cur_qa_logit_layer
1157
+
1158
+ cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
1159
+ if cur_qa_labels == num_labels:
1160
+ return cur_qa_logit_layer
1161
+
1162
+ # Build new linear output
1163
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1164
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
1165
+ else:
1166
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
1167
+
1168
+ new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
1169
+
1170
+ # initialize all new labels
1171
+ self._init_weights(new_qa_logit_layer)
1172
+
1173
+ # Copy labels from the previous weights
1174
+ num_labels_to_copy = min(cur_qa_labels, num_labels)
1175
+ new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
1176
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1177
+ new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
1178
+
1179
+ return new_qa_logit_layer
1180
+
1181
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1182
+ @replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1183
+ def forward(
1184
+ self,
1185
+ input_ids=None,
1186
+ visual_feats=None,
1187
+ visual_pos=None,
1188
+ attention_mask=None,
1189
+ visual_attention_mask=None,
1190
+ token_type_ids=None,
1191
+ inputs_embeds=None,
1192
+ labels=None,
1193
+ obj_labels=None,
1194
+ matched_label=None,
1195
+ ans=None,
1196
+ output_attentions=None,
1197
+ output_hidden_states=None,
1198
+ return_dict=None,
1199
+ **kwargs,
1200
+ ):
1201
+ r"""
1202
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
1203
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1204
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1205
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1206
+ obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`):
1207
+ each key is named after each one of the visual losses and each element of the tuple is of the shape
1208
+ ``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id
1209
+ and the label score respectively
1210
+ matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1211
+ Labels for computing the whether or not the text input matches the image (classification) loss. Input
1212
+ should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1213
+
1214
+ - 0 indicates that the sentence does not match the image,
1215
+ - 1 indicates that the sentence does match the image.
1216
+ ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
1217
+ a one hot representation hof the correct answer `optional`
1218
+
1219
+ Returns:
1220
+ """
1221
+
1222
+ if "masked_lm_labels" in kwargs:
1223
+ warnings.warn(
1224
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1225
+ FutureWarning,
1226
+ )
1227
+ labels = kwargs.pop("masked_lm_labels")
1228
+
1229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
+
1231
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1232
+ lxmert_output = self.lxmert(
1233
+ input_ids=input_ids,
1234
+ visual_feats=visual_feats,
1235
+ visual_pos=visual_pos,
1236
+ token_type_ids=token_type_ids,
1237
+ attention_mask=attention_mask,
1238
+ visual_attention_mask=visual_attention_mask,
1239
+ inputs_embeds=inputs_embeds,
1240
+ output_hidden_states=output_hidden_states,
1241
+ output_attentions=output_attentions,
1242
+ return_dict=return_dict,
1243
+ )
1244
+
1245
+ lang_output, visual_output, pooled_output = (
1246
+ lxmert_output[0],
1247
+ lxmert_output[1],
1248
+ lxmert_output[2],
1249
+ )
1250
+ lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
1251
+ if self.task_qa:
1252
+ answer_score = self.answer_head(pooled_output)
1253
+ else:
1254
+ answer_score = pooled_output[0][0]
1255
+
1256
+ total_loss = (
1257
+ None
1258
+ if (labels is None and matched_label is None and obj_labels is None and ans is None)
1259
+ else torch.tensor(0.0, device=device)
1260
+ )
1261
+ if labels is not None and self.task_mask_lm:
1262
+ masked_lm_loss = self.loss_fcts["ce"](
1263
+ lang_prediction_scores.view(-1, self.config.vocab_size),
1264
+ labels.view(-1),
1265
+ )
1266
+ total_loss += masked_lm_loss
1267
+ if matched_label is not None and self.task_matched:
1268
+ matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
1269
+ total_loss += matched_loss
1270
+ if obj_labels is not None and self.task_obj_predict:
1271
+ total_visual_loss = torch.tensor(0.0, device=input_ids.device)
1272
+ visual_prediction_scores_dict = self.obj_predict_head(visual_output)
1273
+ for key, key_info in self.visual_losses.items():
1274
+ label, mask_conf = obj_labels[key]
1275
+ output_dim = key_info["num"]
1276
+ loss_fct_name = key_info["loss"]
1277
+ label_shape = key_info["shape"]
1278
+ weight = self.visual_loss_normalizer
1279
+ visual_loss_fct = self.loss_fcts[loss_fct_name]
1280
+ visual_prediction_scores = visual_prediction_scores_dict[key]
1281
+ visual_loss = visual_loss_fct(
1282
+ visual_prediction_scores.view(-1, output_dim),
1283
+ label.view(*label_shape),
1284
+ )
1285
+ if visual_loss.dim() > 1: # Regression Losses
1286
+ visual_loss = visual_loss.mean(1)
1287
+ visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
1288
+ total_visual_loss += visual_loss
1289
+ total_loss += total_visual_loss
1290
+ if ans is not None and self.task_qa:
1291
+ answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
1292
+ total_loss += answer_loss
1293
+
1294
+ if not return_dict:
1295
+ output = (
1296
+ lang_prediction_scores,
1297
+ cross_relationship_score,
1298
+ answer_score,
1299
+ ) + lxmert_output[3:]
1300
+ return ((total_loss,) + output) if total_loss is not None else output
1301
+
1302
+ return LxmertForPreTrainingOutput(
1303
+ loss=total_loss,
1304
+ prediction_logits=lang_prediction_scores,
1305
+ cross_relationship_score=cross_relationship_score,
1306
+ question_answering_score=answer_score,
1307
+ language_hidden_states=lxmert_output.language_hidden_states,
1308
+ vision_hidden_states=lxmert_output.vision_hidden_states,
1309
+ language_attentions=lxmert_output.language_attentions,
1310
+ vision_attentions=lxmert_output.vision_attentions,
1311
+ cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
1312
+ )
1313
+
1314
+
1315
+ @add_start_docstrings(
1316
+ """Lxmert Model with a visual-answering head on top for downstream QA tasks""",
1317
+ LXMERT_START_DOCSTRING,
1318
+ )
1319
+ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
1320
+ def __init__(self, config):
1321
+ super().__init__(config)
1322
+ # Configuration
1323
+ self.config = config
1324
+ self.num_qa_labels = config.num_qa_labels
1325
+ self.visual_loss_normalizer = config.visual_loss_normalizer
1326
+
1327
+ # Lxmert backbone
1328
+ self.lxmert = LxmertModel(config)
1329
+
1330
+ self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
1331
+
1332
+ # Weight initialization
1333
+ self.init_weights()
1334
+
1335
+ # Loss function
1336
+ self.loss = CrossEntropyLoss()
1337
+
1338
+ def resize_num_qa_labels(self, num_labels):
1339
+ """
1340
+ Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
1341
+ will add newly initialized weights. Reducing the size will remove weights from the end
1342
+
1343
+ Args:
1344
+ num_labels (:obj:`int`, `optional`):
1345
+ New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
1346
+ weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
1347
+ just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
1348
+ anything.
1349
+
1350
+ Return:
1351
+ :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
1352
+ """
1353
+
1354
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1355
+ if num_labels is None or cur_qa_logit_layer is None:
1356
+ return
1357
+ new_qa_logit_layer = self._resize_qa_labels(num_labels)
1358
+ self.config.num_qa_labels = num_labels
1359
+ self.num_qa_labels = num_labels
1360
+
1361
+ return new_qa_logit_layer
1362
+
1363
+ def _resize_qa_labels(self, num_labels):
1364
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1365
+ new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
1366
+ self._set_qa_logit_layer(new_qa_logit_layer)
1367
+ return self.get_qa_logit_layer()
1368
+
1369
+ def get_qa_logit_layer(self) -> nn.Module:
1370
+ """
1371
+ Returns the the linear layer that produces question answering logits
1372
+
1373
+ Returns:
1374
+ :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A
1375
+ NoneType object if Lxmert does not have the visual answering head.
1376
+ """
1377
+
1378
+ if hasattr(self, "answer_head"):
1379
+ return self.answer_head.logit_fc[-1]
1380
+
1381
+ def _set_qa_logit_layer(self, qa_logit_layer):
1382
+ self.answer_head.logit_fc[-1] = qa_logit_layer
1383
+
1384
+ def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
1385
+
1386
+ if num_labels is None:
1387
+ return cur_qa_logit_layer
1388
+
1389
+ cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
1390
+ if cur_qa_labels == num_labels:
1391
+ return cur_qa_logit_layer
1392
+
1393
+ # Build new linear output
1394
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1395
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
1396
+ else:
1397
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
1398
+
1399
+ new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
1400
+
1401
+ # initialize all new labels
1402
+ self._init_weights(new_qa_logit_layer)
1403
+
1404
+ # Copy labels from the previous weights
1405
+ num_labels_to_copy = min(cur_qa_labels, num_labels)
1406
+ new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
1407
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1408
+ new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
1409
+
1410
+ return new_qa_logit_layer
1411
+
1412
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1413
+ @add_code_sample_docstrings(
1414
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1415
+ checkpoint="unc-nlp/lxmert-base-uncased",
1416
+ output_type=LxmertForQuestionAnsweringOutput,
1417
+ config_class=_CONFIG_FOR_DOC,
1418
+ )
1419
+ def forward(
1420
+ self,
1421
+ input_ids=None,
1422
+ visual_feats=None,
1423
+ visual_pos=None,
1424
+ attention_mask=None,
1425
+ visual_attention_mask=None,
1426
+ token_type_ids=None,
1427
+ inputs_embeds=None,
1428
+ labels=None,
1429
+ output_attentions=None,
1430
+ output_hidden_states=None,
1431
+ return_dict=None,
1432
+ ):
1433
+ r"""
1434
+ labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
1435
+ A one-hot representation of the correct answer
1436
+
1437
+ Returns:
1438
+ """
1439
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1440
+
1441
+ lxmert_output = self.lxmert(
1442
+ input_ids=input_ids,
1443
+ visual_feats=visual_feats,
1444
+ visual_pos=visual_pos,
1445
+ token_type_ids=token_type_ids,
1446
+ attention_mask=attention_mask,
1447
+ visual_attention_mask=visual_attention_mask,
1448
+ inputs_embeds=inputs_embeds,
1449
+ output_hidden_states=output_hidden_states,
1450
+ output_attentions=output_attentions,
1451
+ return_dict=return_dict,
1452
+ )
1453
+
1454
+ pooled_output = lxmert_output[2]
1455
+ answer_score = self.answer_head(pooled_output)
1456
+ loss = None
1457
+ if labels is not None:
1458
+ loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
1459
+
1460
+ if not return_dict:
1461
+ output = (answer_score,) + lxmert_output[3:]
1462
+ return (loss,) + output if loss is not None else output
1463
+
1464
+ return LxmertForQuestionAnsweringOutput(
1465
+ loss=loss,
1466
+ question_answering_score=answer_score,
1467
+ language_hidden_states=lxmert_output.language_hidden_states,
1468
+ vision_hidden_states=lxmert_output.vision_hidden_states,
1469
+ language_attentions=lxmert_output.language_attentions,
1470
+ vision_attentions=lxmert_output.vision_attentions,
1471
+ cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
1472
+ )
lxmert/src/layers.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye', 'Tanh', 'MatMul', 'Mul']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+
49
+ class RelPropSimple(RelProp):
50
+ def relprop(self, R, alpha):
51
+ Z = self.forward(self.X)
52
+ S = safe_divide(R, Z)
53
+ C = self.gradprop(Z, self.X, S)
54
+
55
+ if torch.is_tensor(self.X) == False:
56
+ outputs = []
57
+ outputs.append(self.X[0] * C[0])
58
+ outputs.append(self.X[1] * C[1])
59
+ else:
60
+ outputs = self.X * (C[0])
61
+ return outputs
62
+
63
+ class AddEye(RelPropSimple):
64
+ # input of shape B, C, seq_len, seq_len
65
+ def forward(self, input):
66
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
+
68
+ class ReLU(nn.ReLU, RelProp):
69
+ pass
70
+
71
+ class GELU(nn.GELU, RelProp):
72
+ pass
73
+
74
+ class Softmax(nn.Softmax, RelProp):
75
+ pass
76
+
77
+ class Mul(RelPropSimple):
78
+ def forward(self, inputs):
79
+ return torch.mul(*inputs)
80
+
81
+ class Tanh(nn.Tanh, RelProp):
82
+ pass
83
+ class LayerNorm(nn.LayerNorm, RelProp):
84
+ pass
85
+
86
+ class Dropout(nn.Dropout, RelProp):
87
+ pass
88
+
89
+ class MatMul(RelPropSimple):
90
+ def forward(self, inputs):
91
+ return torch.matmul(*inputs)
92
+
93
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
94
+ pass
95
+
96
+ class LayerNorm(nn.LayerNorm, RelProp):
97
+ pass
98
+
99
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
100
+ pass
101
+
102
+
103
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
104
+ pass
105
+
106
+
107
+ class Add(RelPropSimple):
108
+ def forward(self, inputs):
109
+ return torch.add(*inputs)
110
+
111
+ def relprop(self, R, alpha):
112
+ Z = self.forward(self.X)
113
+ S = safe_divide(R, Z)
114
+ C = self.gradprop(Z, self.X, S)
115
+
116
+ a = self.X[0] * C[0]
117
+ b = self.X[1] * C[1]
118
+
119
+ a_sum = a.sum()
120
+ b_sum = b.sum()
121
+
122
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
123
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
124
+
125
+ a = a * safe_divide(a_fact, a.sum())
126
+ b = b * safe_divide(b_fact, b.sum())
127
+
128
+ outputs = [a, b]
129
+
130
+ return outputs
131
+
132
+ class einsum(RelPropSimple):
133
+ def __init__(self, equation):
134
+ super().__init__()
135
+ self.equation = equation
136
+ def forward(self, *operands):
137
+ return torch.einsum(self.equation, *operands)
138
+
139
+ class IndexSelect(RelProp):
140
+ def forward(self, inputs, dim, indices):
141
+ self.__setattr__('dim', dim)
142
+ self.__setattr__('indices', indices)
143
+
144
+ return torch.index_select(inputs, dim, indices)
145
+
146
+ def relprop(self, R, alpha):
147
+ Z = self.forward(self.X, self.dim, self.indices)
148
+ S = safe_divide(R, Z)
149
+ C = self.gradprop(Z, self.X, S)
150
+
151
+ if torch.is_tensor(self.X) == False:
152
+ outputs = []
153
+ outputs.append(self.X[0] * C[0])
154
+ outputs.append(self.X[1] * C[1])
155
+ else:
156
+ outputs = self.X * (C[0])
157
+ return outputs
158
+
159
+
160
+
161
+ class Clone(RelProp):
162
+ def forward(self, input, num):
163
+ self.__setattr__('num', num)
164
+ outputs = []
165
+ for _ in range(num):
166
+ outputs.append(input)
167
+
168
+ return outputs
169
+
170
+ def relprop(self, R, alpha):
171
+ Z = []
172
+ for _ in range(self.num):
173
+ Z.append(self.X)
174
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
175
+ C = self.gradprop(Z, self.X, S)[0]
176
+
177
+ R = self.X * C
178
+
179
+ return R
180
+
181
+
182
+ class Cat(RelProp):
183
+ def forward(self, inputs, dim):
184
+ self.__setattr__('dim', dim)
185
+ return torch.cat(inputs, dim)
186
+
187
+ def relprop(self, R, alpha):
188
+ Z = self.forward(self.X, self.dim)
189
+ S = safe_divide(R, Z)
190
+ C = self.gradprop(Z, self.X, S)
191
+
192
+ outputs = []
193
+ for x, c in zip(self.X, C):
194
+ outputs.append(x * c)
195
+
196
+ return outputs
197
+
198
+
199
+ class Sequential(nn.Sequential):
200
+ def relprop(self, R, alpha):
201
+ for m in reversed(self._modules.values()):
202
+ R = m.relprop(R, alpha)
203
+ return R
204
+
205
+
206
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
207
+ def relprop(self, R, alpha):
208
+ X = self.X
209
+ beta = 1 - alpha
210
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
211
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
212
+ Z = X * weight + 1e-9
213
+ S = R / Z
214
+ Ca = S * weight
215
+ R = self.X * (Ca)
216
+ return R
217
+
218
+
219
+ class Linear(nn.Linear, RelProp):
220
+ def relprop(self, R, alpha):
221
+ beta = alpha - 1
222
+ pw = torch.clamp(self.weight, min=0)
223
+ nw = torch.clamp(self.weight, max=0)
224
+ px = torch.clamp(self.X, min=0)
225
+ nx = torch.clamp(self.X, max=0)
226
+
227
+ def f(w1, w2, x1, x2):
228
+ Z1 = F.linear(x1, w1)
229
+ Z2 = F.linear(x2, w2)
230
+ S1 = safe_divide(R, Z1 + Z2)
231
+ S2 = safe_divide(R, Z1 + Z2)
232
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
233
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
234
+
235
+ return C1 + C2
236
+
237
+ activator_relevances = f(pw, nw, px, nx)
238
+ inhibitor_relevances = f(nw, pw, px, nx)
239
+
240
+ R = alpha * activator_relevances - beta * inhibitor_relevances
241
+
242
+ return R
243
+
244
+
245
+ class Conv2d(nn.Conv2d, RelProp):
246
+ def gradprop2(self, DY, weight):
247
+ Z = self.forward(self.X)
248
+
249
+ output_padding = self.X.size()[2] - (
250
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
251
+
252
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
253
+
254
+ def relprop(self, R, alpha):
255
+ if self.X.shape[1] == 3:
256
+ pw = torch.clamp(self.weight, min=0)
257
+ nw = torch.clamp(self.weight, max=0)
258
+ X = self.X
259
+ L = self.X * 0 + \
260
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
261
+ keepdim=True)[0]
262
+ H = self.X * 0 + \
263
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
264
+ keepdim=True)[0]
265
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
266
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
267
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
268
+
269
+ S = R / Za
270
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
271
+ R = C
272
+ else:
273
+ beta = alpha - 1
274
+ pw = torch.clamp(self.weight, min=0)
275
+ nw = torch.clamp(self.weight, max=0)
276
+ px = torch.clamp(self.X, min=0)
277
+ nx = torch.clamp(self.X, max=0)
278
+
279
+ def f(w1, w2, x1, x2):
280
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
281
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
282
+ S1 = safe_divide(R, Z1)
283
+ S2 = safe_divide(R, Z2)
284
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
285
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
286
+ return C1 + C2
287
+
288
+ activator_relevances = f(pw, nw, px, nx)
289
+ inhibitor_relevances = f(nw, pw, px, nx)
290
+
291
+ R = alpha * activator_relevances - beta * inhibitor_relevances
292
+ return R
lxmert/src/lxmert_lrp.py ADDED
@@ -0,0 +1,1693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch lxmert model. """
16
+
17
+ import math
18
+ import os
19
+ import warnings
20
+ import copy
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss, SmoothL1Loss
27
+ from lxmert.lxmert.src.layers import *
28
+ from transformers.file_utils import (
29
+ ModelOutput,
30
+ add_code_sample_docstrings,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging
37
+ from transformers.configuration_lxmert import LxmertConfig
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CONFIG_FOR_DOC = "LxmertConfig"
42
+ _TOKENIZER_FOR_DOC = "LxmertTokenizer"
43
+
44
+ LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ "unc-nlp/lxmert-base-uncased",
46
+ ]
47
+
48
+ ACT2FN = {
49
+ "relu": ReLU,
50
+ "tanh": Tanh,
51
+ "gelu": GELU,
52
+ }
53
+
54
+
55
+ @dataclass
56
+ class LxmertModelOutput(ModelOutput):
57
+ """
58
+ Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
59
+ visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
60
+ encoder")
61
+
62
+
63
+ Args:
64
+ language_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
65
+ Sequence of hidden-states at the output of the last layer of the language encoder.
66
+ vision_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
67
+ Sequence of hidden-states at the output of the last layer of the visual encoder.
68
+ pooled_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
69
+ Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
70
+ by a Linear layer and a Tanh activation function. The Linear
71
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
72
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
73
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
74
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
75
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
76
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
77
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
78
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
79
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
80
+ weighted average in the self-attention heads.
81
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
82
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
83
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
84
+ weighted average in the self-attention heads.
85
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
86
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
87
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
88
+ weighted average in the self-attention heads.
89
+ """
90
+
91
+ language_output: Optional[torch.FloatTensor] = None
92
+ vision_output: Optional[torch.FloatTensor] = None
93
+ pooled_output: Optional[torch.FloatTensor] = None
94
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
95
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
96
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
97
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
98
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
99
+
100
+
101
+ @dataclass
102
+ class LxmertForQuestionAnsweringOutput(ModelOutput):
103
+ """
104
+ Output type of :class:`~transformers.LxmertForQuestionAnswering`.
105
+
106
+ Args:
107
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
108
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
109
+ (classification) loss.k.
110
+ question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`, `optional`):
111
+ Prediction scores of question answering objective (classification).
112
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
113
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
114
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
115
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
116
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
117
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
118
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
119
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
120
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
121
+ weighted average in the self-attention heads.
122
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
123
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
124
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
125
+ weighted average in the self-attention heads.
126
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
127
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
128
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
129
+ weighted average in the self-attention heads.
130
+ """
131
+
132
+ loss: Optional[torch.FloatTensor] = None
133
+ question_answering_score: Optional[torch.FloatTensor] = None
134
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
135
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
136
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
137
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
138
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
139
+
140
+
141
+ @dataclass
142
+ class LxmertForPreTrainingOutput(ModelOutput):
143
+ """
144
+ Output type of :class:`~transformers.LxmertForPreTraining`.
145
+
146
+ Args:
147
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
148
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
149
+ (classification) loss.
150
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
151
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
152
+ cross_relationship_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
153
+ Prediction scores of the textual matching objective (classification) head (scores of True/False
154
+ continuation before SoftMax).
155
+ question_answering_score: (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_qa_answers)`):
156
+ Prediction scores of question answering objective (classification).
157
+ language_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
158
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
159
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
160
+ vision_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
161
+ Tuple of :obj:`torch.FloatTensor` (one for input features + one for the output of each cross-modality
162
+ layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
163
+ language_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
164
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
165
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
166
+ weighted average in the self-attention heads.
167
+ vision_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
168
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
169
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
170
+ weighted average in the self-attention heads.
171
+ cross_encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
172
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
173
+ sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
174
+ weighted average in the self-attention heads.
175
+
176
+ """
177
+
178
+ loss: [torch.FloatTensor] = None
179
+ prediction_logits: Optional[torch.FloatTensor] = None
180
+ cross_relationship_score: Optional[torch.FloatTensor] = None
181
+ question_answering_score: Optional[torch.FloatTensor] = None
182
+ language_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
183
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
184
+ language_attentions: Optional[Tuple[torch.FloatTensor]] = None
185
+ vision_attentions: Optional[Tuple[torch.FloatTensor]] = None
186
+ cross_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
187
+
188
+
189
+ def load_tf_weights_in_lxmert(model, config, tf_checkpoint_path):
190
+ """Load tf checkpoints in a pytorch model."""
191
+ try:
192
+ import re
193
+
194
+ import numpy as np
195
+ import tensorflow as tf
196
+ except ImportError:
197
+ logger.error(
198
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
199
+ "https://www.tensorflow.org/install/ for installation instructions."
200
+ )
201
+ raise
202
+ tf_path = os.path.abspath(tf_checkpoint_path)
203
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
204
+ # Load weights from TF model
205
+ init_vars = tf.train.list_variables(tf_path)
206
+ names = []
207
+ arrays = []
208
+ for name, shape in init_vars:
209
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
210
+ array = tf.train.load_variable(tf_path, name)
211
+ names.append(name)
212
+ arrays.append(array)
213
+
214
+ for name, array in zip(names, arrays):
215
+ name = name.split("/")
216
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
217
+ # which are not required for using pretrained model
218
+ if any(
219
+ n
220
+ in [
221
+ "adam_v",
222
+ "adam_m",
223
+ "AdamWeightDecayOptimizer",
224
+ "AdamWeightDecayOptimizer_1",
225
+ "global_step",
226
+ ]
227
+ for n in name
228
+ ):
229
+ logger.info("Skipping {}".format("/".join(name)))
230
+ continue
231
+ pointer = model
232
+ for m_name in name:
233
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
234
+ scope_names = re.split(r"_(\d+)", m_name)
235
+ else:
236
+ scope_names = [m_name]
237
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
238
+ pointer = getattr(pointer, "weight")
239
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
240
+ pointer = getattr(pointer, "bias")
241
+ elif scope_names[0] == "output_weights":
242
+ pointer = getattr(pointer, "weight")
243
+ elif scope_names[0] == "squad":
244
+ pointer = getattr(pointer, "classifier")
245
+ else:
246
+ try:
247
+ pointer = getattr(pointer, scope_names[0])
248
+ except AttributeError:
249
+ logger.info("Skipping {}".format("/".join(name)))
250
+ continue
251
+ if len(scope_names) >= 2:
252
+ num = int(scope_names[1])
253
+ pointer = pointer[num]
254
+ if m_name[-11:] == "_embeddings":
255
+ pointer = getattr(pointer, "weight")
256
+ elif m_name == "kernel":
257
+ array = np.transpose(array)
258
+ try:
259
+ assert pointer.shape == array.shape
260
+ except AssertionError as e:
261
+ e.args += (pointer.shape, array.shape)
262
+ raise
263
+ logger.info("Initialize PyTorch weight {}".format(name))
264
+ pointer.data = torch.from_numpy(array)
265
+ return model
266
+
267
+
268
+ class LxmertEmbeddings(nn.Module):
269
+ """Construct the embeddings from word, position and token_type embeddings."""
270
+
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
274
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
275
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
276
+
277
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
278
+ # any TensorFlow checkpoint file
279
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
280
+ self.dropout = Dropout(config.hidden_dropout_prob)
281
+
282
+ self.add1 = Add()
283
+ self.add2 = Add()
284
+
285
+ def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
286
+ if input_ids is not None:
287
+ input_shape = input_ids.size()
288
+ device = input_ids.device
289
+ else:
290
+ input_shape = inputs_embeds.size()[:-1]
291
+ device = inputs_embeds.device
292
+ seq_length = input_shape[1]
293
+
294
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
295
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
296
+
297
+ if token_type_ids is None:
298
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
299
+
300
+ if inputs_embeds is None:
301
+ inputs_embeds = self.word_embeddings(input_ids)
302
+ position_embeddings = self.position_embeddings(position_ids)
303
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
304
+
305
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
306
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
307
+ embeddings = self.add2([embeddings, inputs_embeds])
308
+ embeddings = self.LayerNorm(embeddings)
309
+ embeddings = self.dropout(embeddings)
310
+ return embeddings
311
+
312
+ def relprop(self, cam, **kwargs):
313
+ cam = self.dropout.relprop(cam, **kwargs)
314
+ cam = self.LayerNorm.relprop(cam, **kwargs)
315
+
316
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
317
+ (cam) = self.add2.relprop(cam, **kwargs)
318
+
319
+ return cam
320
+
321
+
322
+ class LxmertAttention(nn.Module):
323
+ def __init__(self, config, ctx_dim=None):
324
+ super().__init__()
325
+ if config.hidden_size % config.num_attention_heads != 0:
326
+ raise ValueError(
327
+ "The hidden size (%d) is not a multiple of the number of attention "
328
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
329
+ )
330
+ self.num_attention_heads = config.num_attention_heads
331
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
332
+ self.head_size = self.num_attention_heads * self.attention_head_size
333
+
334
+ # visual_dim = 2048
335
+ if ctx_dim is None:
336
+ ctx_dim = config.hidden_size
337
+ self.query = Linear(config.hidden_size, self.head_size)
338
+ self.key = Linear(ctx_dim, self.head_size)
339
+ self.value = Linear(ctx_dim, self.head_size)
340
+
341
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
342
+
343
+ self.matmul1 = MatMul()
344
+ self.matmul2 = MatMul()
345
+ self.softmax = Softmax(dim=-1)
346
+ self.add = Add()
347
+ self.mul = Mul()
348
+ self.head_mask = None
349
+ self.attention_mask = None
350
+ self.clone = Clone()
351
+
352
+ self.attn = None
353
+ self.attn_gradients = None
354
+ self.attn_cam = None
355
+
356
+ def get_attn(self):
357
+ return self.attn
358
+
359
+ def save_attn(self, attn):
360
+ self.attn = attn
361
+
362
+ def get_attn_cam(self):
363
+ return self.attn_cam
364
+
365
+ def save_attn_cam(self, attn_cam):
366
+ self.attn_cam = attn_cam
367
+
368
+ def save_attn_gradients(self, attn_gradients):
369
+ self.attn_gradients = attn_gradients
370
+
371
+ def get_attn_gradients(self):
372
+ return self.attn_gradients
373
+
374
+ def transpose_for_scores(self, x):
375
+ new_x_shape = x.size()[:-1] + (
376
+ self.num_attention_heads,
377
+ self.attention_head_size,
378
+ )
379
+ x = x.view(*new_x_shape)
380
+ return x.permute(0, 2, 1, 3)
381
+
382
+ def transpose_for_scores_relprop(self, x):
383
+ return x.permute(0, 2, 1, 3).flatten(2)
384
+
385
+ def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
386
+ key, value = self.clone(context, 2)
387
+ mixed_query_layer = self.query(hidden_states)
388
+ # mixed_key_layer = self.key(context)
389
+ # mixed_value_layer = self.value(context)
390
+ mixed_key_layer = self.key(key)
391
+ mixed_value_layer = self.value(value)
392
+
393
+ query_layer = self.transpose_for_scores(mixed_query_layer)
394
+ key_layer = self.transpose_for_scores(mixed_key_layer)
395
+ value_layer = self.transpose_for_scores(mixed_value_layer)
396
+
397
+ # Take the dot product between "query" and "key" to get the raw attention scores.
398
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
399
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
400
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
401
+ if attention_mask is not None:
402
+ attention_scores = self.add([attention_scores, attention_mask])
403
+
404
+ # Normalize the attention scores to probabilities.
405
+ attention_probs = self.softmax(attention_scores)
406
+
407
+ self.save_attn(attention_probs)
408
+ attention_probs.register_hook(self.save_attn_gradients)
409
+
410
+ # This is actually dropping out entire tokens to attend to, which might
411
+ # seem a bit unusual, but is taken from the original Transformer paper.
412
+ attention_probs = self.dropout(attention_probs)
413
+
414
+ context_layer = self.matmul2([attention_probs, value_layer])
415
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
416
+ new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
417
+ context_layer = context_layer.view(*new_context_layer_shape)
418
+
419
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
420
+ return outputs
421
+
422
+ def relprop(self, cam, **kwargs):
423
+ # Assume output_attentions == False
424
+ cam = self.transpose_for_scores(cam)
425
+
426
+ # [attention_probs, value_layer]
427
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
428
+ cam1 /= 2
429
+ cam2 /= 2
430
+
431
+ self.save_attn_cam(cam1)
432
+
433
+ cam1 = self.dropout.relprop(cam1, **kwargs)
434
+
435
+ cam1 = self.softmax.relprop(cam1, **kwargs)
436
+
437
+ if self.attention_mask is not None:
438
+ # [attention_scores, attention_mask]
439
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
440
+
441
+ # [query_layer, key_layer.transpose(-1, -2)]
442
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
443
+ cam1_1 /= 2
444
+ cam1_2 /= 2
445
+
446
+ # query
447
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
448
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
449
+
450
+ # key
451
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
452
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
453
+
454
+ # value
455
+ cam2 = self.transpose_for_scores_relprop(cam2)
456
+ cam2 = self.value.relprop(cam2, **kwargs)
457
+
458
+ cam = self.clone.relprop((cam1_2, cam2), **kwargs)
459
+
460
+ # returning two cams- one for the hidden state and one for the context
461
+ return (cam1_1, cam)
462
+
463
+
464
+ class LxmertAttentionOutput(nn.Module):
465
+ def __init__(self, config):
466
+ super().__init__()
467
+ self.dense = Linear(config.hidden_size, config.hidden_size)
468
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
469
+ self.dropout = Dropout(config.hidden_dropout_prob)
470
+ self.add = Add()
471
+
472
+ def forward(self, hidden_states, input_tensor):
473
+ hidden_states = self.dense(hidden_states)
474
+ hidden_states = self.dropout(hidden_states)
475
+ add = self.add([hidden_states, input_tensor])
476
+ hidden_states = self.LayerNorm(add)
477
+ return hidden_states
478
+
479
+ def relprop(self, cam, **kwargs):
480
+ cam = self.LayerNorm.relprop(cam, **kwargs)
481
+ # [hidden_states, input_tensor]
482
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
483
+ cam1 = self.dropout.relprop(cam1, **kwargs)
484
+ cam1 = self.dense.relprop(cam1, **kwargs)
485
+
486
+ return (cam1, cam2)
487
+
488
+
489
+ class LxmertCrossAttentionLayer(nn.Module):
490
+ def __init__(self, config):
491
+ super().__init__()
492
+ self.att = LxmertAttention(config)
493
+ self.output = LxmertAttentionOutput(config)
494
+ self.clone = Clone()
495
+
496
+ def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
497
+ inp1, inp2 = self.clone(input_tensor, 2)
498
+ output = self.att(inp1, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
499
+ if output_attentions:
500
+ attention_probs = output[1]
501
+ attention_output = self.output(output[0], inp2)
502
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
503
+ return outputs
504
+
505
+ def relprop(self, cam, **kwargs):
506
+ cam_output, cam_inp2 = self.output.relprop(cam, **kwargs)
507
+ cam_inp1, cam_ctx = self.att.relprop(cam_output, **kwargs)
508
+ cam_inp = self.clone.relprop((cam_inp1, cam_inp2), **kwargs)
509
+
510
+ return (cam_inp, cam_ctx)
511
+
512
+
513
+ class LxmertSelfAttentionLayer(nn.Module):
514
+ def __init__(self, config):
515
+ super().__init__()
516
+ self.self = LxmertAttention(config)
517
+ self.output = LxmertAttentionOutput(config)
518
+ self.clone = Clone()
519
+
520
+ def forward(self, input_tensor, attention_mask, output_attentions=False):
521
+ inp1, inp2, inp3 = self.clone(input_tensor, 3)
522
+ # Self attention attends to itself, thus keys and queries are the same (input_tensor).
523
+ output = self.self(
524
+ inp1,
525
+ inp2,
526
+ attention_mask,
527
+ output_attentions=output_attentions,
528
+ )
529
+ if output_attentions:
530
+ attention_probs = output[1]
531
+ attention_output = self.output(output[0], inp3)
532
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
533
+ return outputs
534
+
535
+ def relprop(self, cam, **kwargs):
536
+ cam_output, cam_inp3 = self.output.relprop(cam, **kwargs)
537
+ cam_inp1, cam_inp2 = self.self.relprop(cam_output, **kwargs)
538
+ cam_inp = self.clone.relprop((cam_inp1, cam_inp2, cam_inp3), **kwargs)
539
+
540
+ return cam_inp
541
+
542
+
543
+ class LxmertIntermediate(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
547
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
548
+
549
+ def forward(self, hidden_states):
550
+ hidden_states = self.dense(hidden_states)
551
+ hidden_states = self.intermediate_act_fn(hidden_states)
552
+ return hidden_states
553
+
554
+ def relprop(self, cam, **kwargs):
555
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs)
556
+ cam = self.dense.relprop(cam, **kwargs)
557
+ return cam
558
+
559
+
560
+ class LxmertOutput(nn.Module):
561
+ def __init__(self, config):
562
+ super().__init__()
563
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
564
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
565
+ self.dropout = Dropout(config.hidden_dropout_prob)
566
+ self.add = Add()
567
+
568
+ def forward(self, hidden_states, input_tensor):
569
+ hidden_states = self.dense(hidden_states)
570
+ hidden_states = self.dropout(hidden_states)
571
+ add = self.add([hidden_states, input_tensor])
572
+ hidden_states = self.LayerNorm(add)
573
+ return hidden_states
574
+
575
+ def relprop(self, cam, **kwargs):
576
+ cam = self.LayerNorm.relprop(cam, **kwargs)
577
+ # [hidden_states, input_tensor]
578
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
579
+ cam1 = self.dropout.relprop(cam1, **kwargs)
580
+ cam1 = self.dense.relprop(cam1, **kwargs)
581
+ return (cam1, cam2)
582
+
583
+
584
+ class LxmertLayer(nn.Module):
585
+ def __init__(self, config):
586
+ super().__init__()
587
+ self.attention = LxmertSelfAttentionLayer(config)
588
+ self.intermediate = LxmertIntermediate(config)
589
+ self.output = LxmertOutput(config)
590
+ self.clone = Clone()
591
+
592
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
593
+ outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
594
+ attention_output = outputs[0]
595
+ ao1, ao2 = self.clone(attention_output, 2)
596
+ intermediate_output = self.intermediate(ao1)
597
+ layer_output = self.output(intermediate_output, ao2)
598
+ outputs = (layer_output,) + outputs[1:] # add attentions if we output them
599
+ return outputs
600
+
601
+ def relprop(self, cam, **kwargs):
602
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
603
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
604
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
605
+ cam = self.attention.relprop(cam, **kwargs)
606
+ return cam
607
+
608
+
609
+ class LxmertXLayer(nn.Module):
610
+ def __init__(self, config):
611
+ super().__init__()
612
+ # The cross-attention Layer
613
+ self.visual_attention = LxmertCrossAttentionLayer(config)
614
+
615
+ # Self-attention Layers
616
+ self.lang_self_att = LxmertSelfAttentionLayer(config)
617
+ self.visn_self_att = LxmertSelfAttentionLayer(config)
618
+
619
+ # Intermediate and Output Layers (FFNs)
620
+ self.lang_inter = LxmertIntermediate(config)
621
+ self.lang_output = LxmertOutput(config)
622
+ self.visn_inter = LxmertIntermediate(config)
623
+ self.visn_output = LxmertOutput(config)
624
+
625
+ self.clone1 = Clone()
626
+ self.clone2 = Clone()
627
+ self.clone3 = Clone()
628
+ self.clone4 = Clone()
629
+
630
+ def cross_att(
631
+ self,
632
+ lang_input,
633
+ lang_attention_mask,
634
+ visual_input,
635
+ visual_attention_mask,
636
+ output_x_attentions=False,
637
+ ):
638
+ lang_input1, lang_input2 = self.clone1(lang_input, 2)
639
+ visual_input1, visual_input2 = self.clone2(visual_input, 2)
640
+ if not hasattr(self, 'visual_attention_copy'):
641
+ self.visual_attention_copy = copy.deepcopy(self.visual_attention)
642
+ # Cross Attention
643
+ lang_att_output = self.visual_attention(
644
+ lang_input1,
645
+ visual_input1,
646
+ ctx_att_mask=visual_attention_mask,
647
+ output_attentions=output_x_attentions,
648
+ )
649
+ visual_att_output = self.visual_attention_copy(
650
+ visual_input2,
651
+ lang_input2,
652
+ ctx_att_mask=lang_attention_mask,
653
+ output_attentions=False,
654
+ )
655
+ return lang_att_output, visual_att_output
656
+
657
+ def relprop_cross(self, cam, **kwargs):
658
+ cam_lang, cam_vis = cam
659
+ cam_vis2, cam_lang2 = self.visual_attention_copy.relprop(cam_vis, **kwargs)
660
+ cam_lang1, cam_vis1 = self.visual_attention.relprop(cam_lang, **kwargs)
661
+ cam_vis = self.clone2.relprop((cam_vis1, cam_vis2), **kwargs)
662
+ cam_lang = self.clone1.relprop((cam_lang1, cam_lang2), **kwargs)
663
+ return cam_lang, cam_vis
664
+
665
+
666
+ def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
667
+ # Self Attention
668
+ lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
669
+ visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
670
+ return lang_att_output[0], visual_att_output[0]
671
+
672
+ def relprop_self(self, cam, **kwargs):
673
+ cam_lang, cam_vis = cam
674
+ cam_vis = self.visn_self_att.relprop(cam_vis, **kwargs)
675
+ cam_lang = self.lang_self_att.relprop(cam_lang, **kwargs)
676
+ return cam_lang, cam_vis
677
+
678
+ def output_fc(self, lang_input, visual_input):
679
+ lang_input1, lang_input2 = self.clone3(lang_input, 2)
680
+ visual_input1, visual_input2 = self.clone4(visual_input, 2)
681
+ # FC layers
682
+ lang_inter_output = self.lang_inter(lang_input1)
683
+ visual_inter_output = self.visn_inter(visual_input1)
684
+
685
+ # Layer output
686
+ lang_output = self.lang_output(lang_inter_output, lang_input2)
687
+ visual_output = self.visn_output(visual_inter_output, visual_input2)
688
+
689
+ return lang_output, visual_output
690
+
691
+ def relprop_output(self, cam, **kwargs):
692
+ cam_lang, cam_vis = cam
693
+ cam_vis_inter, cam_vis2 = self.visn_output.relprop(cam_vis, **kwargs)
694
+ cam_lang_inter, cam_lang2 = self.lang_output.relprop(cam_lang, **kwargs)
695
+ cam_vis1 = self.visn_inter.relprop(cam_vis_inter, **kwargs)
696
+ cam_lang1 = self.lang_inter.relprop(cam_lang_inter, **kwargs)
697
+ cam_vis = self.clone4.relprop((cam_vis1, cam_vis2), **kwargs)
698
+ cam_lang = self.clone3.relprop((cam_lang1, cam_lang2), **kwargs)
699
+ return cam_lang, cam_vis
700
+
701
+ def forward(
702
+ self,
703
+ lang_feats,
704
+ lang_attention_mask,
705
+ visual_feats,
706
+ visual_attention_mask,
707
+ output_attentions=False,
708
+ ):
709
+ lang_att_output, visual_att_output = self.cross_att(
710
+ lang_input=lang_feats,
711
+ lang_attention_mask=lang_attention_mask,
712
+ visual_input=visual_feats,
713
+ visual_attention_mask=visual_attention_mask,
714
+ output_x_attentions=output_attentions,
715
+ )
716
+ attention_probs = lang_att_output[1:]
717
+ lang_att_output, visual_att_output = self.self_att(
718
+ lang_att_output[0],
719
+ lang_attention_mask,
720
+ visual_att_output[0],
721
+ visual_attention_mask,
722
+ )
723
+
724
+ lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
725
+ return (
726
+ (
727
+ lang_output,
728
+ visual_output,
729
+ attention_probs[0],
730
+ )
731
+ if output_attentions
732
+ else (lang_output, visual_output)
733
+ )
734
+
735
+ def relprop(self, cam, **kwargs):
736
+ cam_lang, cam_vis = cam
737
+ cam_lang, cam_vis = self.relprop_output((cam_lang, cam_vis), **kwargs)
738
+ cam_lang, cam_vis = self.relprop_self((cam_lang, cam_vis), **kwargs)
739
+ cam_lang, cam_vis = self.relprop_cross((cam_lang, cam_vis), **kwargs)
740
+ return cam_lang, cam_vis
741
+
742
+ class LxmertVisualFeatureEncoder(nn.Module):
743
+ def __init__(self, config):
744
+ super().__init__()
745
+ feat_dim = config.visual_feat_dim
746
+ pos_dim = config.visual_pos_dim
747
+
748
+ # Object feature encoding
749
+ self.visn_fc = Linear(feat_dim, config.hidden_size)
750
+ self.visn_layer_norm = LayerNorm(config.hidden_size, eps=1e-12)
751
+
752
+ # Box position encoding
753
+ self.box_fc = Linear(pos_dim, config.hidden_size)
754
+ self.box_layer_norm = LayerNorm(config.hidden_size, eps=1e-12)
755
+
756
+ self.dropout = Dropout(config.hidden_dropout_prob)
757
+
758
+ def forward(self, visual_feats, visual_pos):
759
+ x = self.visn_fc(visual_feats)
760
+ x = self.visn_layer_norm(x)
761
+ y = self.box_fc(visual_pos)
762
+ y = self.box_layer_norm(y)
763
+ output = (x + y) / 2
764
+
765
+ output = self.dropout(output)
766
+ return output
767
+
768
+ def relprop(self, cam, **kwargs):
769
+ cam = self.dropout.relprop(cam, **kwargs)
770
+ cam = self.visn_layer_norm.relprop(cam, **kwargs)
771
+ cam = self.visn_fc.relprop(cam, **kwargs)
772
+ return cam
773
+
774
+ class LxmertEncoder(nn.Module):
775
+ def __init__(self, config):
776
+ super().__init__()
777
+
778
+ # Obj-level image embedding layer
779
+ self.visn_fc = LxmertVisualFeatureEncoder(config)
780
+ self.config = config
781
+
782
+ # Number of layers
783
+ self.num_l_layers = config.l_layers
784
+ self.num_x_layers = config.x_layers
785
+ self.num_r_layers = config.r_layers
786
+
787
+ # Layers
788
+ # Using self.layer instead of self.l_layer to support loading BERT weights.
789
+ self.layer = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_l_layers)])
790
+ self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
791
+ self.r_layers = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_r_layers)])
792
+
793
+ def forward(
794
+ self,
795
+ lang_feats,
796
+ lang_attention_mask,
797
+ visual_feats,
798
+ visual_pos,
799
+ visual_attention_mask=None,
800
+ output_attentions=None,
801
+ ):
802
+
803
+ vision_hidden_states = ()
804
+ language_hidden_states = ()
805
+ vision_attentions = () if output_attentions or self.config.output_attentions else None
806
+ language_attentions = () if output_attentions or self.config.output_attentions else None
807
+ cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
808
+
809
+ visual_feats = self.visn_fc(visual_feats, visual_pos)
810
+
811
+ # Run language layers
812
+ for layer_module in self.layer:
813
+ l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
814
+ lang_feats = l_outputs[0]
815
+ language_hidden_states = language_hidden_states + (lang_feats,)
816
+ if language_attentions is not None:
817
+ language_attentions = language_attentions + (l_outputs[1],)
818
+
819
+ # Run relational layers
820
+ for layer_module in self.r_layers:
821
+ v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
822
+ visual_feats = v_outputs[0]
823
+ vision_hidden_states = vision_hidden_states + (visual_feats,)
824
+ if vision_attentions is not None:
825
+ vision_attentions = vision_attentions + (v_outputs[1],)
826
+
827
+ # Run cross-modality layers
828
+ for layer_module in self.x_layers:
829
+ x_outputs = layer_module(
830
+ lang_feats,
831
+ lang_attention_mask,
832
+ visual_feats,
833
+ visual_attention_mask,
834
+ output_attentions=output_attentions,
835
+ )
836
+ lang_feats, visual_feats = x_outputs[:2]
837
+ vision_hidden_states = vision_hidden_states + (visual_feats,)
838
+ language_hidden_states = language_hidden_states + (lang_feats,)
839
+ if cross_encoder_attentions is not None:
840
+ cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
841
+ visual_encoder_outputs = (
842
+ vision_hidden_states,
843
+ vision_attentions if output_attentions else None,
844
+ )
845
+ lang_encoder_outputs = (
846
+ language_hidden_states,
847
+ language_attentions if output_attentions else None,
848
+ )
849
+ return (
850
+ visual_encoder_outputs,
851
+ lang_encoder_outputs,
852
+ cross_encoder_attentions if output_attentions else None,
853
+ )
854
+
855
+ def relprop(self, cam, **kwargs):
856
+ cam_lang, cam_vis = cam
857
+ for layer_module in reversed(self.x_layers):
858
+ cam_lang, cam_vis = layer_module.relprop((cam_lang, cam_vis), **kwargs)
859
+
860
+ for layer_module in reversed(self.r_layers):
861
+ cam_vis = layer_module.relprop(cam_vis, **kwargs)
862
+
863
+ for layer_module in reversed(self.layer):
864
+ cam_lang = layer_module.relprop(cam_lang, **kwargs)
865
+ return cam_lang, cam_vis
866
+
867
+
868
+ class LxmertPooler(nn.Module):
869
+ def __init__(self, config):
870
+ super(LxmertPooler, self).__init__()
871
+ self.dense = Linear(config.hidden_size, config.hidden_size)
872
+ self.activation = Tanh()
873
+
874
+ self.pool = IndexSelect()
875
+
876
+ def forward(self, hidden_states):
877
+ # We "pool" the model by simply taking the hidden state corresponding
878
+ # to the first token.
879
+ # first_token_tensor = hidden_states[:, 0]
880
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
881
+ first_token_tensor = first_token_tensor.squeeze(1)
882
+ pooled_output = self.dense(first_token_tensor)
883
+ pooled_output = self.activation(pooled_output)
884
+ return pooled_output
885
+
886
+ def relprop(self, cam, **kwargs):
887
+ cam = self.activation.relprop(cam, **kwargs)
888
+ cam = self.dense.relprop(cam, **kwargs)
889
+ cam = cam.unsqueeze(1)
890
+ cam = self.pool.relprop(cam, **kwargs)
891
+
892
+ return cam
893
+
894
+
895
+ class LxmertPredictionHeadTransform(nn.Module):
896
+ def __init__(self, config):
897
+ super(LxmertPredictionHeadTransform, self).__init__()
898
+ self.dense = Linear(config.hidden_size, config.hidden_size)
899
+ self.transform_act_fn = ACT2FN[config.hidden_act]
900
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
901
+
902
+ def forward(self, hidden_states):
903
+ hidden_states = self.dense(hidden_states)
904
+ hidden_states = self.transform_act_fn(hidden_states)
905
+ hidden_states = self.LayerNorm(hidden_states)
906
+ return hidden_states
907
+
908
+ def relprop(self, cam, **kwargs):
909
+ cam = self.LayerNorm.relprop(cam, **kwargs)
910
+ cam = self.transform_act_fn.relprop(cam, **kwargs)
911
+ cam = self.dense.relprop(cam, **kwargs)
912
+ return cam
913
+
914
+
915
+ class LxmertLMPredictionHead(nn.Module):
916
+ def __init__(self, config, lxmert_model_embedding_weights):
917
+ super(LxmertLMPredictionHead, self).__init__()
918
+ self.transform = LxmertPredictionHeadTransform(config)
919
+
920
+ # The output weights are the same as the input embeddings, but there is
921
+ # an output-only bias for each token.
922
+ self.decoder = Linear(
923
+ lxmert_model_embedding_weights.size(1),
924
+ lxmert_model_embedding_weights.size(0),
925
+ bias=False,
926
+ )
927
+ self.decoder.weight = lxmert_model_embedding_weights
928
+ self.bias = nn.Parameter(torch.zeros(lxmert_model_embedding_weights.size(0)))
929
+
930
+ def forward(self, hidden_states):
931
+ hidden_states = self.transform(hidden_states)
932
+ hidden_states = self.decoder(hidden_states) + self.bias
933
+ return hidden_states
934
+
935
+ def relprop(self, cam, **kwargs):
936
+ cam = self.decoder.relprop(cam, **kwargs)
937
+ cam = self.transform.relprop(cam, **kwargs)
938
+ return cam
939
+
940
+
941
+ class LxmertVisualAnswerHead(nn.Module):
942
+ def __init__(self, config, num_labels):
943
+ super().__init__()
944
+ hid_dim = config.hidden_size
945
+ self.logit_fc = nn.Sequential(
946
+ Linear(hid_dim, hid_dim * 2),
947
+ GELU(),
948
+ LayerNorm(hid_dim * 2, eps=1e-12),
949
+ Linear(hid_dim * 2, num_labels),
950
+ )
951
+
952
+ def forward(self, hidden_states):
953
+ return self.logit_fc(hidden_states)
954
+
955
+ def relprop(self, cam, **kwargs):
956
+ for m in reversed(self.logit_fc._modules.values()):
957
+ cam = m.relprop(cam, **kwargs)
958
+ return cam
959
+
960
+
961
+ class LxmertVisualObjHead(nn.Module):
962
+ def __init__(self, config):
963
+ super().__init__()
964
+ self.transform = LxmertPredictionHeadTransform(config)
965
+ # Decide the use of visual losses
966
+ visual_losses = {}
967
+ if config.visual_obj_loss:
968
+ visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
969
+ if config.visual_attr_loss:
970
+ visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
971
+ if config.visual_obj_loss:
972
+ visual_losses["feat"] = {
973
+ "shape": (-1, config.visual_feat_dim),
974
+ "num": config.visual_feat_dim,
975
+ }
976
+ self.visual_losses = visual_losses
977
+
978
+ # The output weights are the same as the input embeddings, but there is
979
+ # an output-only bias for each token.
980
+ self.decoder_dict = nn.ModuleDict(
981
+ {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
982
+ )
983
+
984
+ def forward(self, hidden_states):
985
+ hidden_states = self.transform(hidden_states)
986
+ output = {}
987
+ for key in self.visual_losses:
988
+ output[key] = self.decoder_dict[key](hidden_states)
989
+ return output
990
+
991
+ def relprop(self, cam, **kwargs):
992
+ return self.transform.relprop(cam, **kwargs)
993
+
994
+
995
+ class LxmertPreTrainingHeads(nn.Module):
996
+ def __init__(self, config, lxmert_model_embedding_weights):
997
+ super(LxmertPreTrainingHeads, self).__init__()
998
+ self.predictions = LxmertLMPredictionHead(config, lxmert_model_embedding_weights)
999
+ self.seq_relationship = Linear(config.hidden_size, 2)
1000
+
1001
+ def forward(self, sequence_output, pooled_output):
1002
+ prediction_scores = self.predictions(sequence_output)
1003
+ seq_relationship_score = self.seq_relationship(pooled_output)
1004
+ return prediction_scores, seq_relationship_score
1005
+
1006
+ def relprop(self, cam, **kwargs):
1007
+ cam_seq, cam_pooled = cam
1008
+ cam_pooled = self.seq_relationship.relprop(cam_pooled, **kwargs)
1009
+ cam_seq = self.predictions.relprop(cam_seq, **kwargs)
1010
+ return cam_seq, cam_pooled
1011
+
1012
+
1013
+ class LxmertPreTrainedModel(PreTrainedModel):
1014
+ """
1015
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1016
+ models.
1017
+ """
1018
+
1019
+ config_class = LxmertConfig
1020
+ load_tf_weights = load_tf_weights_in_lxmert
1021
+ base_model_prefix = "lxmert"
1022
+
1023
+ def _init_weights(self, module):
1024
+ """ Initialize the weights """
1025
+ if isinstance(module, (nn.Linear, nn.Embedding)):
1026
+ # Slightly different from the TF version which uses truncated_normal for initialization
1027
+ # cf https://github.com/pytorch/pytorch/pull/5617
1028
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1029
+ elif isinstance(module, nn.LayerNorm):
1030
+ module.bias.data.zero_()
1031
+ module.weight.data.fill_(1.0)
1032
+ if isinstance(module, nn.Linear) and module.bias is not None:
1033
+ module.bias.data.zero_()
1034
+
1035
+
1036
+ LXMERT_START_DOCSTRING = r"""
1037
+
1038
+ The lxmert model was proposed in `lxmert: Learning Cross-Modality Encoder Representations from Transformers
1039
+ <https://arxiv.org/abs/1908.07490>`__ by Hao Tan and Mohit Bansal. It's a vision and language transformer model,
1040
+ pretrained on a variety of multi-modal datasets comprising of GQA, VQAv2.0, MCSCOCO captions, and Visual genome,
1041
+ using a combination of masked language modeling, region of interest feature regression, cross entropy loss for
1042
+ question answering attribute prediction, and object tag prediction.
1043
+
1044
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
1045
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
1046
+ pruning heads etc.)
1047
+
1048
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
1049
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
1050
+ general usage and behavior.
1051
+
1052
+ Parameters:
1053
+ config (:class:`~transformers.LxmertConfig`): Model configuration class with all the parameters of the model.
1054
+ Initializing with a config file does not load the weights associated with the model, only the
1055
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
1056
+ weights.
1057
+ """
1058
+
1059
+ LXMERT_INPUTS_DOCSTRING = r"""
1060
+
1061
+ Args:
1062
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
1063
+ Indices of input sequence tokens in the vocabulary.
1064
+
1065
+ Indices can be obtained using :class:`~transformers.LxmertTokenizer`. See
1066
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
1067
+ details.
1068
+
1069
+ `What are input IDs? <../glossary.html#input-ids>`__
1070
+ visual_feats: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_feat_dim)՝):
1071
+ This input represents visual features. They ROI pooled object features from bounding boxes using a
1072
+ faster-RCNN model)
1073
+
1074
+ These are currently not provided by the transformers library.
1075
+ visual_pos: (:obj:`torch.FloatTensor` of shape :obj:՝(batch_size, num_visual_features, visual_pos_dim)՝):
1076
+ This input represents spacial features corresponding to their relative (via index) visual features. The
1077
+ pre-trained lxmert model expects these spacial features to be normalized bounding boxes on a scale of 0 to
1078
+ 1.
1079
+
1080
+ These are currently not provided by the transformers library.
1081
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
1082
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
1083
+
1084
+ - 1 for tokens that are **not masked**,
1085
+ - 0 for tokens that are **masked**.
1086
+
1087
+ `What are attention masks? <../glossary.html#attention-mask>`__
1088
+ visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
1089
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
1090
+
1091
+ - 1 for tokens that are **not masked**,
1092
+ - 0 for tokens that are **masked**.
1093
+
1094
+ `What are attention masks? <../glossary.html#attention-mask>`__
1095
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
1096
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1097
+ 1]``:
1098
+
1099
+ - 0 corresponds to a `sentence A` token,
1100
+ - 1 corresponds to a `sentence B` token.
1101
+
1102
+ `What are token type IDs? <../glossary.html#token-type-ids>`__
1103
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
1104
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
1105
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
1106
+ vectors than the model's internal embedding lookup matrix.
1107
+ output_attentions (:obj:`bool`, `optional`):
1108
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
1109
+ tensors for more detail.
1110
+ output_hidden_states (:obj:`bool`, `optional`):
1111
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
1112
+ more detail.
1113
+ return_dict (:obj:`bool`, `optional`):
1114
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
1115
+ """
1116
+
1117
+
1118
+ @add_start_docstrings(
1119
+ "The bare Lxmert Model transformer outputting raw hidden-states without any specific head on top.",
1120
+ LXMERT_START_DOCSTRING,
1121
+ )
1122
+ class LxmertModel(LxmertPreTrainedModel):
1123
+ def __init__(self, config):
1124
+ super().__init__(config)
1125
+ self.embeddings = LxmertEmbeddings(config)
1126
+ self.encoder = LxmertEncoder(config)
1127
+ self.pooler = LxmertPooler(config)
1128
+ self.init_weights()
1129
+
1130
+ def get_input_embeddings(self):
1131
+ return self.embeddings.word_embeddings
1132
+
1133
+ def set_input_embeddings(self, new_embeddings):
1134
+ self.embeddings.word_embeddings = new_embeddings
1135
+
1136
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1137
+ @add_code_sample_docstrings(
1138
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1139
+ checkpoint="unc-nlp/lxmert-base-uncased",
1140
+ output_type=LxmertModelOutput,
1141
+ config_class=_CONFIG_FOR_DOC,
1142
+ )
1143
+ def forward(
1144
+ self,
1145
+ input_ids=None,
1146
+ visual_feats=None,
1147
+ visual_pos=None,
1148
+ attention_mask=None,
1149
+ visual_attention_mask=None,
1150
+ token_type_ids=None,
1151
+ inputs_embeds=None,
1152
+ output_attentions=None,
1153
+ output_hidden_states=None,
1154
+ return_dict=None,
1155
+ ):
1156
+
1157
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1158
+ output_hidden_states = (
1159
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1160
+ )
1161
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1162
+
1163
+ if input_ids is not None and inputs_embeds is not None:
1164
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1165
+ elif input_ids is not None:
1166
+ input_shape = input_ids.size()
1167
+ elif inputs_embeds is not None:
1168
+ input_shape = inputs_embeds.size()[:-1]
1169
+ else:
1170
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1171
+
1172
+ assert visual_feats is not None, "`visual_feats` cannot be `None`"
1173
+ assert visual_pos is not None, "`visual_pos` cannot be `None`"
1174
+
1175
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1176
+
1177
+ if attention_mask is None:
1178
+ attention_mask = torch.ones(input_shape, device=device)
1179
+ if token_type_ids is None:
1180
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1181
+
1182
+ # We create a 3D attention mask from a 2D tensor mask.
1183
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1184
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1185
+ # this attention mask is more simple than the triangular masking of causal attention
1186
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1187
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1188
+
1189
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1190
+ # masked positions, this operation will create a tensor which is 0.0 for
1191
+ # positions we want to attend and -10000.0 for masked positions.
1192
+ # Since we are adding it to the raw scores before the softmax, this is
1193
+ # effectively the same as removing these entirely.
1194
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
1195
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1196
+
1197
+ # Process the visual attention mask
1198
+ if visual_attention_mask is not None:
1199
+ extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
1200
+ extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
1201
+ extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
1202
+ else:
1203
+ extended_visual_attention_mask = None
1204
+
1205
+ # Positional Word Embeddings
1206
+ embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
1207
+
1208
+ # Run Lxmert encoder
1209
+ encoder_outputs = self.encoder(
1210
+ embedding_output,
1211
+ extended_attention_mask,
1212
+ visual_feats=visual_feats,
1213
+ visual_pos=visual_pos,
1214
+ visual_attention_mask=extended_visual_attention_mask,
1215
+ output_attentions=output_attentions,
1216
+ )
1217
+
1218
+ visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
1219
+ vision_hidden_states = visual_encoder_outputs[0]
1220
+ language_hidden_states = lang_encoder_outputs[0]
1221
+
1222
+ all_attentions = ()
1223
+ if output_attentions:
1224
+ language_attentions = lang_encoder_outputs[1]
1225
+ vision_attentions = visual_encoder_outputs[1]
1226
+ cross_encoder_attentions = encoder_outputs[2]
1227
+ all_attentions = (
1228
+ language_attentions,
1229
+ vision_attentions,
1230
+ cross_encoder_attentions,
1231
+ )
1232
+
1233
+ hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
1234
+
1235
+ visual_output = vision_hidden_states[-1]
1236
+ lang_output = language_hidden_states[-1]
1237
+ pooled_output = self.pooler(lang_output)
1238
+
1239
+ if not return_dict:
1240
+ return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
1241
+
1242
+ return LxmertModelOutput(
1243
+ pooled_output=pooled_output,
1244
+ language_output=lang_output,
1245
+ vision_output=visual_output,
1246
+ language_hidden_states=language_hidden_states if output_hidden_states else None,
1247
+ vision_hidden_states=vision_hidden_states if output_hidden_states else None,
1248
+ language_attentions=language_attentions if output_attentions else None,
1249
+ vision_attentions=vision_attentions if output_attentions else None,
1250
+ cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
1251
+ )
1252
+
1253
+ def relprop(self, cam, **kwargs):
1254
+ cam_lang, cam_vis = cam
1255
+ cam_lang = self.pooler.relprop(cam_lang, **kwargs)
1256
+ cam_lang, cam_vis = self.encoder.relprop((cam_lang, cam_vis), **kwargs)
1257
+ return cam_lang, cam_vis
1258
+
1259
+
1260
+
1261
+ @add_start_docstrings(
1262
+ """Lxmert Model with a specified pretraining head on top. """,
1263
+ LXMERT_START_DOCSTRING,
1264
+ )
1265
+ class LxmertForPreTraining(LxmertPreTrainedModel):
1266
+ def __init__(self, config):
1267
+ super().__init__(config)
1268
+ # Configuration
1269
+ self.config = config
1270
+ self.num_qa_labels = config.num_qa_labels
1271
+ self.visual_loss_normalizer = config.visual_loss_normalizer
1272
+
1273
+ # Use of pretraining tasks
1274
+ self.task_mask_lm = config.task_mask_lm
1275
+ self.task_obj_predict = config.task_obj_predict
1276
+ self.task_matched = config.task_matched
1277
+ self.task_qa = config.task_qa
1278
+
1279
+ # Lxmert backbone
1280
+ self.lxmert = LxmertModel(config)
1281
+
1282
+ # Pre-training heads
1283
+ self.cls = LxmertPreTrainingHeads(config, self.lxmert.embeddings.word_embeddings.weight)
1284
+ if self.task_obj_predict:
1285
+ self.obj_predict_head = LxmertVisualObjHead(config)
1286
+ if self.task_qa:
1287
+ self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
1288
+
1289
+ # Weight initialization
1290
+ self.init_weights()
1291
+
1292
+ # Loss functions
1293
+ self.loss_fcts = {
1294
+ "l2": SmoothL1Loss(reduction="none"),
1295
+ "visual_ce": CrossEntropyLoss(reduction="none"),
1296
+ "ce": CrossEntropyLoss(),
1297
+ }
1298
+
1299
+ visual_losses = {}
1300
+ if config.visual_obj_loss:
1301
+ visual_losses["obj"] = {
1302
+ "shape": (-1,),
1303
+ "num": config.num_object_labels,
1304
+ "loss": "visual_ce",
1305
+ }
1306
+ if config.visual_attr_loss:
1307
+ visual_losses["attr"] = {
1308
+ "shape": (-1,),
1309
+ "num": config.num_attr_labels,
1310
+ "loss": "visual_ce",
1311
+ }
1312
+ if config.visual_obj_loss:
1313
+ visual_losses["feat"] = {
1314
+ "shape": (-1, config.visual_feat_dim),
1315
+ "num": config.visual_feat_dim,
1316
+ "loss": "l2",
1317
+ }
1318
+ self.visual_losses = visual_losses
1319
+
1320
+ def resize_num_qa_labels(self, num_labels):
1321
+ """
1322
+ Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
1323
+ will add newly initialized weights. Reducing the size will remove weights from the end
1324
+
1325
+ Args:
1326
+ num_labels (:obj:`int`, `optional`):
1327
+ New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
1328
+ weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
1329
+ just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
1330
+ anything.
1331
+
1332
+ Return:
1333
+ :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
1334
+ """
1335
+
1336
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1337
+ if num_labels is None or cur_qa_logit_layer is None:
1338
+ return
1339
+ new_qa_logit_layer = self._resize_qa_labels(num_labels)
1340
+ self.config.num_qa_labels = num_labels
1341
+ self.num_qa_labels = num_labels
1342
+
1343
+ return new_qa_logit_layer
1344
+
1345
+ def _resize_qa_labels(self, num_labels):
1346
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1347
+ new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
1348
+ self._set_qa_logit_layer(new_qa_logit_layer)
1349
+ return self.get_qa_logit_layer()
1350
+
1351
+ def get_qa_logit_layer(self) -> nn.Module:
1352
+ """
1353
+ Returns the the linear layer that produces question answering logits.
1354
+
1355
+ Returns:
1356
+ :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states or :obj:`None` if
1357
+ lxmert does not have a visual answering head.
1358
+ """
1359
+ if hasattr(self, "answer_head"):
1360
+ return self.answer_head.logit_fc[-1]
1361
+
1362
+ def _set_qa_logit_layer(self, qa_logit_layer):
1363
+ self.answer_head.logit_fc[-1] = qa_logit_layer
1364
+
1365
+ def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
1366
+
1367
+ if num_labels is None:
1368
+ return cur_qa_logit_layer
1369
+
1370
+ cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
1371
+ if cur_qa_labels == num_labels:
1372
+ return cur_qa_logit_layer
1373
+
1374
+ # Build new linear output
1375
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1376
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
1377
+ else:
1378
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
1379
+
1380
+ new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
1381
+
1382
+ # initialize all new labels
1383
+ self._init_weights(new_qa_logit_layer)
1384
+
1385
+ # Copy labels from the previous weights
1386
+ num_labels_to_copy = min(cur_qa_labels, num_labels)
1387
+ new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
1388
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1389
+ new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
1390
+
1391
+ return new_qa_logit_layer
1392
+
1393
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1394
+ @replace_return_docstrings(output_type=LxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1395
+ def forward(
1396
+ self,
1397
+ input_ids=None,
1398
+ visual_feats=None,
1399
+ visual_pos=None,
1400
+ attention_mask=None,
1401
+ visual_attention_mask=None,
1402
+ token_type_ids=None,
1403
+ inputs_embeds=None,
1404
+ labels=None,
1405
+ obj_labels=None,
1406
+ matched_label=None,
1407
+ ans=None,
1408
+ output_attentions=None,
1409
+ output_hidden_states=None,
1410
+ return_dict=None,
1411
+ **kwargs,
1412
+ ):
1413
+ r"""
1414
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
1415
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1416
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1417
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1418
+ obj_labels: (``Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]``, `optional`):
1419
+ each key is named after each one of the visual losses and each element of the tuple is of the shape
1420
+ ``(batch_size, num_features)`` and ``(batch_size, num_features, visual_feature_dim)`` for each the label id
1421
+ and the label score respectively
1422
+ matched_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1423
+ Labels for computing the whether or not the text input matches the image (classification) loss. Input
1424
+ should be a sequence pair (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1425
+
1426
+ - 0 indicates that the sentence does not match the image,
1427
+ - 1 indicates that the sentence does match the image.
1428
+ ans: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
1429
+ a one hot representation hof the correct answer `optional`
1430
+
1431
+ Returns:
1432
+ """
1433
+
1434
+ if "masked_lm_labels" in kwargs:
1435
+ warnings.warn(
1436
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1437
+ FutureWarning,
1438
+ )
1439
+ labels = kwargs.pop("masked_lm_labels")
1440
+
1441
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1442
+
1443
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1444
+ lxmert_output = self.lxmert(
1445
+ input_ids=input_ids,
1446
+ visual_feats=visual_feats,
1447
+ visual_pos=visual_pos,
1448
+ token_type_ids=token_type_ids,
1449
+ attention_mask=attention_mask,
1450
+ visual_attention_mask=visual_attention_mask,
1451
+ inputs_embeds=inputs_embeds,
1452
+ output_hidden_states=output_hidden_states,
1453
+ output_attentions=output_attentions,
1454
+ return_dict=return_dict,
1455
+ )
1456
+
1457
+ lang_output, visual_output, pooled_output = (
1458
+ lxmert_output[0],
1459
+ lxmert_output[1],
1460
+ lxmert_output[2],
1461
+ )
1462
+ lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
1463
+ if self.task_qa:
1464
+ answer_score = self.answer_head(pooled_output)
1465
+ else:
1466
+ answer_score = pooled_output[0][0]
1467
+
1468
+ total_loss = (
1469
+ None
1470
+ if (labels is None and matched_label is None and obj_labels is None and ans is None)
1471
+ else torch.tensor(0.0, device=device)
1472
+ )
1473
+ if labels is not None and self.task_mask_lm:
1474
+ masked_lm_loss = self.loss_fcts["ce"](
1475
+ lang_prediction_scores.view(-1, self.config.vocab_size),
1476
+ labels.view(-1),
1477
+ )
1478
+ total_loss += masked_lm_loss
1479
+ if matched_label is not None and self.task_matched:
1480
+ matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
1481
+ total_loss += matched_loss
1482
+ if obj_labels is not None and self.task_obj_predict:
1483
+ total_visual_loss = torch.tensor(0.0, device=input_ids.device)
1484
+ visual_prediction_scores_dict = self.obj_predict_head(visual_output)
1485
+ for key, key_info in self.visual_losses.items():
1486
+ label, mask_conf = obj_labels[key]
1487
+ output_dim = key_info["num"]
1488
+ loss_fct_name = key_info["loss"]
1489
+ label_shape = key_info["shape"]
1490
+ weight = self.visual_loss_normalizer
1491
+ visual_loss_fct = self.loss_fcts[loss_fct_name]
1492
+ visual_prediction_scores = visual_prediction_scores_dict[key]
1493
+ visual_loss = visual_loss_fct(
1494
+ visual_prediction_scores.view(-1, output_dim),
1495
+ label.view(*label_shape),
1496
+ )
1497
+ if visual_loss.dim() > 1: # Regression Losses
1498
+ visual_loss = visual_loss.mean(1)
1499
+ visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
1500
+ total_visual_loss += visual_loss
1501
+ total_loss += total_visual_loss
1502
+ if ans is not None and self.task_qa:
1503
+ answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
1504
+ total_loss += answer_loss
1505
+
1506
+ if not return_dict:
1507
+ output = (
1508
+ lang_prediction_scores,
1509
+ cross_relationship_score,
1510
+ answer_score,
1511
+ ) + lxmert_output[3:]
1512
+ return ((total_loss,) + output) if total_loss is not None else output
1513
+
1514
+ return LxmertForPreTrainingOutput(
1515
+ loss=total_loss,
1516
+ prediction_logits=lang_prediction_scores,
1517
+ cross_relationship_score=cross_relationship_score,
1518
+ question_answering_score=answer_score,
1519
+ language_hidden_states=lxmert_output.language_hidden_states,
1520
+ vision_hidden_states=lxmert_output.vision_hidden_states,
1521
+ language_attentions=lxmert_output.language_attentions,
1522
+ vision_attentions=lxmert_output.vision_attentions,
1523
+ cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
1524
+ )
1525
+
1526
+
1527
+
1528
+ @add_start_docstrings(
1529
+ """Lxmert Model with a visual-answering head on top for downstream QA tasks""",
1530
+ LXMERT_START_DOCSTRING,
1531
+ )
1532
+ class LxmertForQuestionAnswering(LxmertPreTrainedModel):
1533
+ def __init__(self, config):
1534
+ super().__init__(config)
1535
+ # Configuration
1536
+ self.config = config
1537
+ self.num_qa_labels = config.num_qa_labels
1538
+ self.visual_loss_normalizer = config.visual_loss_normalizer
1539
+
1540
+ # Lxmert backbone
1541
+ self.lxmert = LxmertModel(config)
1542
+
1543
+ self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
1544
+
1545
+ # Weight initialization
1546
+ self.init_weights()
1547
+
1548
+ # Loss function
1549
+ self.loss = CrossEntropyLoss()
1550
+
1551
+ def resize_num_qa_labels(self, num_labels):
1552
+ """
1553
+ Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
1554
+ will add newly initialized weights. Reducing the size will remove weights from the end
1555
+
1556
+ Args:
1557
+ num_labels (:obj:`int`, `optional`):
1558
+ New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
1559
+ weights at the end. Reducing the size will remove weights from the end. If not provided or :obj:`None`,
1560
+ just returns a pointer to the qa labels :obj:`torch.nn.Linear`` module of the model without doing
1561
+ anything.
1562
+
1563
+ Return:
1564
+ :obj:`torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
1565
+ """
1566
+
1567
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1568
+ if num_labels is None or cur_qa_logit_layer is None:
1569
+ return
1570
+ new_qa_logit_layer = self._resize_qa_labels(num_labels)
1571
+ self.config.num_qa_labels = num_labels
1572
+ self.num_qa_labels = num_labels
1573
+
1574
+ return new_qa_logit_layer
1575
+
1576
+ def _resize_qa_labels(self, num_labels):
1577
+ cur_qa_logit_layer = self.get_qa_logit_layer()
1578
+ new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
1579
+ self._set_qa_logit_layer(new_qa_logit_layer)
1580
+ return self.get_qa_logit_layer()
1581
+
1582
+ def get_qa_logit_layer(self) -> nn.Module:
1583
+ """
1584
+ Returns the the linear layer that produces question answering logits
1585
+
1586
+ Returns:
1587
+ :obj:`nn.Module`: A torch module mapping the question answering prediction hidden states. :obj:`None`: A
1588
+ NoneType object if Lxmert does not have the visual answering head.
1589
+ """
1590
+
1591
+ if hasattr(self, "answer_head"):
1592
+ return self.answer_head.logit_fc[-1]
1593
+
1594
+ def _set_qa_logit_layer(self, qa_logit_layer):
1595
+ self.answer_head.logit_fc[-1] = qa_logit_layer
1596
+
1597
+ def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
1598
+
1599
+ if num_labels is None:
1600
+ return cur_qa_logit_layer
1601
+
1602
+ cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
1603
+ if cur_qa_labels == num_labels:
1604
+ return cur_qa_logit_layer
1605
+
1606
+ # Build new linear output
1607
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1608
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
1609
+ else:
1610
+ new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
1611
+
1612
+ new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
1613
+
1614
+ # initialize all new labels
1615
+ self._init_weights(new_qa_logit_layer)
1616
+
1617
+ # Copy labels from the previous weights
1618
+ num_labels_to_copy = min(cur_qa_labels, num_labels)
1619
+ new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
1620
+ if getattr(cur_qa_logit_layer, "bias", None) is not None:
1621
+ new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
1622
+
1623
+ return new_qa_logit_layer
1624
+
1625
+ @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1626
+ @add_code_sample_docstrings(
1627
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1628
+ checkpoint="unc-nlp/lxmert-base-uncased",
1629
+ output_type=LxmertForQuestionAnsweringOutput,
1630
+ config_class=_CONFIG_FOR_DOC,
1631
+ )
1632
+ def forward(
1633
+ self,
1634
+ input_ids=None,
1635
+ visual_feats=None,
1636
+ visual_pos=None,
1637
+ attention_mask=None,
1638
+ visual_attention_mask=None,
1639
+ token_type_ids=None,
1640
+ inputs_embeds=None,
1641
+ labels=None,
1642
+ output_attentions=None,
1643
+ output_hidden_states=None,
1644
+ return_dict=None,
1645
+ ):
1646
+ r"""
1647
+ labels: (``Torch.Tensor`` of shape ``(batch_size)``, `optional`):
1648
+ A one-hot representation of the correct answer
1649
+
1650
+ Returns:
1651
+ """
1652
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1653
+
1654
+ lxmert_output = self.lxmert(
1655
+ input_ids=input_ids,
1656
+ visual_feats=visual_feats,
1657
+ visual_pos=visual_pos,
1658
+ token_type_ids=token_type_ids,
1659
+ attention_mask=attention_mask,
1660
+ visual_attention_mask=visual_attention_mask,
1661
+ inputs_embeds=inputs_embeds,
1662
+ output_hidden_states=output_hidden_states,
1663
+ output_attentions=output_attentions,
1664
+ return_dict=return_dict,
1665
+ )
1666
+
1667
+ pooled_output = lxmert_output[2]
1668
+ answer_score = self.answer_head(pooled_output)
1669
+ loss = None
1670
+ if labels is not None:
1671
+ loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
1672
+
1673
+ if not return_dict:
1674
+ output = (answer_score,) + lxmert_output[3:]
1675
+ return (loss,) + output if loss is not None else output
1676
+
1677
+ self.vis_shape = lxmert_output.vision_output.shape
1678
+
1679
+ return LxmertForQuestionAnsweringOutput(
1680
+ loss=loss,
1681
+ question_answering_score=answer_score,
1682
+ language_hidden_states=lxmert_output.language_hidden_states,
1683
+ vision_hidden_states=lxmert_output.vision_hidden_states,
1684
+ language_attentions=lxmert_output.language_attentions,
1685
+ vision_attentions=lxmert_output.vision_attentions,
1686
+ cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
1687
+ )
1688
+
1689
+ def relprop(self, cam, **kwargs):
1690
+ cam_lang = self.answer_head.relprop(cam, **kwargs)
1691
+ cam_vis = torch.zeros(self.vis_shape).to(cam_lang.device)
1692
+ cam_lang, cam_vis = self.lxmert.relprop((cam_lang, cam_vis), **kwargs)
1693
+ return cam_lang, cam_vis
lxmert/src/lxrt/__init__.py ADDED
File without changes
lxmert/src/lxrt/entry.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 project LXRT.
3
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from ..lxrt.tokenization import BertTokenizer
24
+ from ..lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG
25
+
26
+
27
+ class InputFeatures(object):
28
+ """A single set of features of data."""
29
+
30
+ def __init__(self, input_ids, input_mask, segment_ids):
31
+ self.input_ids = input_ids
32
+ self.input_mask = input_mask
33
+ self.segment_ids = segment_ids
34
+
35
+
36
+ def convert_sents_to_features(sents, max_seq_length, tokenizer):
37
+ """Loads a data file into a list of `InputBatch`s."""
38
+
39
+ features = []
40
+ for (i, sent) in enumerate(sents):
41
+ tokens_a = tokenizer.tokenize(sent.strip())
42
+
43
+ # Account for [CLS] and [SEP] with "- 2"
44
+ if len(tokens_a) > max_seq_length - 2:
45
+ tokens_a = tokens_a[:(max_seq_length - 2)]
46
+
47
+ # Keep segment id which allows loading BERT-weights.
48
+ tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
49
+ segment_ids = [0] * len(tokens)
50
+
51
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
52
+
53
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
54
+ # tokens are attended to.
55
+ input_mask = [1] * len(input_ids)
56
+
57
+ # Zero-pad up to the sequence length.
58
+ padding = [0] * (max_seq_length - len(input_ids))
59
+ input_ids += padding
60
+ input_mask += padding
61
+ segment_ids += padding
62
+
63
+ assert len(input_ids) == max_seq_length
64
+ assert len(input_mask) == max_seq_length
65
+ assert len(segment_ids) == max_seq_length
66
+
67
+ features.append(
68
+ InputFeatures(input_ids=input_ids,
69
+ input_mask=input_mask,
70
+ segment_ids=segment_ids))
71
+ return features
72
+
73
+
74
+ def set_visual_config(args):
75
+ VISUAL_CONFIG.l_layers = args.llayers
76
+ VISUAL_CONFIG.x_layers = args.xlayers
77
+ VISUAL_CONFIG.r_layers = args.rlayers
78
+
79
+
80
+ class LXRTEncoder(nn.Module):
81
+ def __init__(self, args, max_seq_length, mode='x'):
82
+ super().__init__()
83
+ self.max_seq_length = max_seq_length
84
+ set_visual_config(args)
85
+
86
+ # Using the bert tokenizer
87
+ self.tokenizer = BertTokenizer.from_pretrained(
88
+ "bert-base-uncased",
89
+ do_lower_case=True
90
+ )
91
+
92
+ # Build LXRT Model
93
+ self.model = VisualBertForLXRFeature.from_pretrained(
94
+ "bert-base-uncased",
95
+ mode=mode
96
+ )
97
+
98
+ if args.from_scratch:
99
+ print("initializing all the weights")
100
+ self.model.apply(self.model.init_bert_weights)
101
+
102
+ def multi_gpu(self):
103
+ self.model = nn.DataParallel(self.model)
104
+
105
+ @property
106
+ def dim(self):
107
+ return 768
108
+
109
+ def forward(self, sents, feats, visual_attention_mask=None):
110
+ train_features = convert_sents_to_features(
111
+ sents, self.max_seq_length, self.tokenizer)
112
+
113
+ input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
114
+ input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
115
+ segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
116
+
117
+ output = self.model(input_ids, segment_ids, input_mask,
118
+ visual_feats=feats,
119
+ visual_attention_mask=visual_attention_mask)
120
+ return output
121
+
122
+ def save(self, path):
123
+ torch.save(self.model.state_dict(),
124
+ os.path.join("%s_LXRT.pth" % path))
125
+
126
+ def load(self, path):
127
+ # Load state_dict from snapshot file
128
+ print("Load lxmert pre-trained model from %s" % path)
129
+ state_dict = torch.load("%s_LXRT.pth" % path)
130
+ new_state_dict = {}
131
+ for key, value in state_dict.items():
132
+ if key.startswith("module."):
133
+ new_state_dict[key[len("module."):]] = value
134
+ else:
135
+ new_state_dict[key] = value
136
+ state_dict = new_state_dict
137
+
138
+ # Print out the differences of pre-trained and model weights.
139
+ load_keys = set(state_dict.keys())
140
+ model_keys = set(self.model.state_dict().keys())
141
+ print()
142
+ print("Weights in loaded but not in model:")
143
+ for key in sorted(load_keys.difference(model_keys)):
144
+ print(key)
145
+ print()
146
+ print("Weights in model but not in loaded:")
147
+ for key in sorted(model_keys.difference(load_keys)):
148
+ print(key)
149
+ print()
150
+
151
+ # Load weights to model
152
+ self.model.load_state_dict(state_dict, strict=False)
153
+
154
+
155
+
156
+
lxmert/src/lxrt/file_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ import json
7
+ import logging
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ from functools import wraps
12
+ from hashlib import sha256
13
+ import sys
14
+ from io import open
15
+
16
+ import boto3
17
+ import requests
18
+ from botocore.exceptions import ClientError
19
+ from tqdm import tqdm
20
+
21
+ try:
22
+ from urllib.parse import urlparse
23
+ except ImportError:
24
+ from urlparse import urlparse
25
+
26
+ try:
27
+ from pathlib import Path
28
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
29
+ Path.home() / '.pytorch_pretrained_bert'))
30
+ except (AttributeError, ImportError):
31
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
32
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
33
+
34
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def url_to_filename(url, etag=None):
38
+ """
39
+ Convert `url` into a hashed filename in a repeatable way.
40
+ If `etag` is specified, append its hash to the url's, delimited
41
+ by a period.
42
+ """
43
+ url_bytes = url.encode('utf-8')
44
+ url_hash = sha256(url_bytes)
45
+ filename = url_hash.hexdigest()
46
+
47
+ if etag:
48
+ etag_bytes = etag.encode('utf-8')
49
+ etag_hash = sha256(etag_bytes)
50
+ filename += '.' + etag_hash.hexdigest()
51
+
52
+ return filename
53
+
54
+
55
+ def filename_to_url(filename, cache_dir=None):
56
+ """
57
+ Return the url and etag (which may be ``None``) stored for `filename`.
58
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
59
+ """
60
+ if cache_dir is None:
61
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
62
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
63
+ cache_dir = str(cache_dir)
64
+
65
+ cache_path = os.path.join(cache_dir, filename)
66
+ if not os.path.exists(cache_path):
67
+ raise EnvironmentError("file {} not found".format(cache_path))
68
+
69
+ meta_path = cache_path + '.json'
70
+ if not os.path.exists(meta_path):
71
+ raise EnvironmentError("file {} not found".format(meta_path))
72
+
73
+ with open(meta_path, encoding="utf-8") as meta_file:
74
+ metadata = json.load(meta_file)
75
+ url = metadata['url']
76
+ etag = metadata['etag']
77
+
78
+ return url, etag
79
+
80
+
81
+ def cached_path(url_or_filename, cache_dir=None):
82
+ """
83
+ Given something that might be a URL (or might be a local path),
84
+ determine which. If it's a URL, download the file and cache it, and
85
+ return the path to the cached file. If it's already a local path,
86
+ make sure the file exists and then return the path.
87
+ """
88
+ if cache_dir is None:
89
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
90
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
91
+ url_or_filename = str(url_or_filename)
92
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
93
+ cache_dir = str(cache_dir)
94
+
95
+ parsed = urlparse(url_or_filename)
96
+
97
+ if parsed.scheme in ('http', 'https', 's3'):
98
+ # URL, so get it from the cache (downloading if necessary)
99
+ return get_from_cache(url_or_filename, cache_dir)
100
+ elif os.path.exists(url_or_filename):
101
+ # File, and it exists.
102
+ return url_or_filename
103
+ elif parsed.scheme == '':
104
+ # File, but it doesn't exist.
105
+ raise EnvironmentError("file {} not found".format(url_or_filename))
106
+ else:
107
+ # Something unknown
108
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
109
+
110
+
111
+ def split_s3_path(url):
112
+ """Split a full s3 path into the bucket name and path."""
113
+ parsed = urlparse(url)
114
+ if not parsed.netloc or not parsed.path:
115
+ raise ValueError("bad s3 path {}".format(url))
116
+ bucket_name = parsed.netloc
117
+ s3_path = parsed.path
118
+ # Remove '/' at beginning of path.
119
+ if s3_path.startswith("/"):
120
+ s3_path = s3_path[1:]
121
+ return bucket_name, s3_path
122
+
123
+
124
+ def s3_request(func):
125
+ """
126
+ Wrapper function for s3 requests in order to create more helpful error
127
+ messages.
128
+ """
129
+
130
+ @wraps(func)
131
+ def wrapper(url, *args, **kwargs):
132
+ try:
133
+ return func(url, *args, **kwargs)
134
+ except ClientError as exc:
135
+ if int(exc.response["Error"]["Code"]) == 404:
136
+ raise EnvironmentError("file {} not found".format(url))
137
+ else:
138
+ raise
139
+
140
+ return wrapper
141
+
142
+
143
+ @s3_request
144
+ def s3_etag(url):
145
+ """Check ETag on S3 object."""
146
+ s3_resource = boto3.resource("s3")
147
+ bucket_name, s3_path = split_s3_path(url)
148
+ s3_object = s3_resource.Object(bucket_name, s3_path)
149
+ return s3_object.e_tag
150
+
151
+
152
+ @s3_request
153
+ def s3_get(url, temp_file):
154
+ """Pull a file directly from S3."""
155
+ s3_resource = boto3.resource("s3")
156
+ bucket_name, s3_path = split_s3_path(url)
157
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
158
+
159
+
160
+ def http_get(url, temp_file):
161
+ req = requests.get(url, stream=True)
162
+ content_length = req.headers.get('Content-Length')
163
+ total = int(content_length) if content_length is not None else None
164
+ progress = tqdm(unit="B", total=total)
165
+ for chunk in req.iter_content(chunk_size=1024):
166
+ if chunk: # filter out keep-alive new chunks
167
+ progress.update(len(chunk))
168
+ temp_file.write(chunk)
169
+ progress.close()
170
+
171
+
172
+ def get_from_cache(url, cache_dir=None):
173
+ """
174
+ Given a URL, look for the corresponding dataset in the local cache.
175
+ If it's not there, download it. Then return the path to the cached file.
176
+ """
177
+ if cache_dir is None:
178
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
179
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
180
+ cache_dir = str(cache_dir)
181
+
182
+ if not os.path.exists(cache_dir):
183
+ os.makedirs(cache_dir)
184
+
185
+ # Get eTag to add to filename, if it exists.
186
+ if url.startswith("s3://"):
187
+ etag = s3_etag(url)
188
+ else:
189
+ response = requests.head(url, allow_redirects=True)
190
+ if response.status_code != 200:
191
+ raise IOError("HEAD request failed for url {} with status code {}"
192
+ .format(url, response.status_code))
193
+ etag = response.headers.get("ETag")
194
+
195
+ filename = url_to_filename(url, etag)
196
+
197
+ # get cache path to put the file
198
+ cache_path = os.path.join(cache_dir, filename)
199
+
200
+ if not os.path.exists(cache_path):
201
+ # Download to temporary file, then copy to cache dir once finished.
202
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
203
+ with tempfile.NamedTemporaryFile() as temp_file:
204
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
205
+
206
+ # GET file object
207
+ if url.startswith("s3://"):
208
+ s3_get(url, temp_file)
209
+ else:
210
+ http_get(url, temp_file)
211
+
212
+ # we are copying the file before closing it, so flush to avoid truncation
213
+ temp_file.flush()
214
+ # shutil.copyfileobj() starts at the current position, so go to the start
215
+ temp_file.seek(0)
216
+
217
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
218
+ with open(cache_path, 'wb') as cache_file:
219
+ shutil.copyfileobj(temp_file, cache_file)
220
+
221
+ logger.info("creating metadata file for %s", cache_path)
222
+ meta = {'url': url, 'etag': etag}
223
+ meta_path = cache_path + '.json'
224
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
225
+ json.dump(meta, meta_file)
226
+
227
+ logger.info("removing temp file %s", temp_file.name)
228
+
229
+ return cache_path
230
+
231
+
232
+ def read_set_from_file(filename):
233
+ '''
234
+ Extract a de-duped collection (set) of text from a file.
235
+ Expected file format is one item per line.
236
+ '''
237
+ collection = set()
238
+ with open(filename, 'r', encoding='utf-8') as file_:
239
+ for line in file_:
240
+ collection.add(line.rstrip())
241
+ return collection
242
+
243
+
244
+ def get_file_extension(path, dot=True, lower=True):
245
+ ext = os.path.splitext(path)[1]
246
+ ext = ext if dot else ext[1:]
247
+ return ext.lower() if lower else ext
lxmert/src/lxrt/modeling.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 project LXRT.
3
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """PyTorch LXRT model."""
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import shutil
25
+ import tarfile
26
+ import tempfile
27
+ import sys
28
+ from io import open
29
+
30
+ import torch
31
+ from torch import nn
32
+ from torch.nn import CrossEntropyLoss, SmoothL1Loss
33
+
34
+ from .file_utils import cached_path
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
39
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
40
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
41
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
42
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
43
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
44
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
45
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
46
+ }
47
+ CONFIG_NAME = 'bert_config.json'
48
+ WEIGHTS_NAME = 'pytorch_model.bin'
49
+ TF_WEIGHTS_NAME = 'model.ckpt'
50
+
51
+ def load_tf_weights_in_bert(model, tf_checkpoint_path):
52
+ """ Load tf checkpoints in a pytorch model
53
+ """
54
+ try:
55
+ import re
56
+ import numpy as np
57
+ import tensorflow as tf
58
+ except Importtokenization:
59
+ print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
60
+ "https://www.tensorflow.org/install/ for installation instructions.")
61
+ raise
62
+ tf_path = os.path.abspath(tf_checkpoint_path)
63
+ print("Converting TensorFlow checkpoint from {}".format(tf_path))
64
+ # Load weights from TF model
65
+ init_vars = tf.train.list_variables(tf_path)
66
+ names = []
67
+ arrays = []
68
+ for name, shape in init_vars:
69
+ print("Loading TF weight {} with shape {}".format(name, shape))
70
+ array = tf.train.load_variable(tf_path, name)
71
+ names.append(name)
72
+ arrays.append(array)
73
+
74
+ for name, array in zip(names, arrays):
75
+ name = name.split('/')
76
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
77
+ # which are not required for using pretrained model
78
+ if any(n in ["adam_v", "adam_m"] for n in name):
79
+ print("Skipping {}".format("/".join(name)))
80
+ continue
81
+ pointer = model
82
+ for m_name in name:
83
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
84
+ l = re.split(r'_(\d+)', m_name)
85
+ else:
86
+ l = [m_name]
87
+ if l[0] == 'kernel' or l[0] == 'gamma':
88
+ pointer = getattr(pointer, 'weight')
89
+ elif l[0] == 'output_bias' or l[0] == 'beta':
90
+ pointer = getattr(pointer, 'bias')
91
+ elif l[0] == 'output_weights':
92
+ pointer = getattr(pointer, 'weight')
93
+ else:
94
+ pointer = getattr(pointer, l[0])
95
+ if len(l) >= 2:
96
+ num = int(l[1])
97
+ pointer = pointer[num]
98
+ if m_name[-11:] == '_embeddings':
99
+ pointer = getattr(pointer, 'weight')
100
+ elif m_name == 'kernel':
101
+ array = np.transpose(array)
102
+ try:
103
+ assert pointer.shape == array.shape
104
+ except AssertionError as e:
105
+ e.args += (pointer.shape, array.shape)
106
+ raise
107
+ print("Initialize PyTorch weight {}".format(name))
108
+ pointer.data = torch.from_numpy(array)
109
+ return model
110
+
111
+
112
+ def gelu(x):
113
+ """Implementation of the gelu activation function.
114
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
115
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
116
+ Also see https://arxiv.org/abs/1606.08415
117
+ """
118
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
119
+
120
+
121
+ class GeLU(nn.Module):
122
+ """Implementation of the gelu activation function.
123
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
124
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
125
+ Also see https://arxiv.org/abs/1606.08415
126
+ """
127
+ def __init__(self):
128
+ super().__init__()
129
+
130
+ def forward(self, x):
131
+ return gelu(x)
132
+
133
+
134
+ def swish(x):
135
+ return x * torch.sigmoid(x)
136
+
137
+
138
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
139
+
140
+
141
+ class VisualConfig(object):
142
+ VISUAL_LOSSES = ['obj', 'attr', 'feat']
143
+ def __init__(self,
144
+ l_layers=12,
145
+ x_layers=5,
146
+ r_layers=0):
147
+ self.l_layers = l_layers
148
+ self.x_layers = x_layers
149
+ self.r_layers = r_layers
150
+
151
+ self.visual_feat_dim = 2048
152
+ self.visual_pos_dim = 4
153
+
154
+ self.obj_id_num = 1600
155
+ self.attr_id_num = 400
156
+
157
+ self.visual_losses = self.VISUAL_LOSSES
158
+ self.visual_loss_config = {
159
+ 'obj': (self.obj_id_num, 'ce', (-1,), 1/0.15),
160
+ 'attr': (self.attr_id_num, 'ce', (-1,), 1/0.15),
161
+ 'feat': (2048, 'l2', (-1, 2048), 1/0.15),
162
+ }
163
+
164
+ def set_visual_dims(self, feat_dim, pos_dim):
165
+ self.visual_feat_dim = feat_dim
166
+ self.visual_pos_dim = pos_dim
167
+
168
+
169
+ VISUAL_CONFIG = VisualConfig()
170
+
171
+
172
+ class BertConfig(object):
173
+ """Configuration class to store the configuration of a `BertModel`.
174
+ """
175
+ def __init__(self,
176
+ vocab_size_or_config_json_file,
177
+ hidden_size=768,
178
+ num_hidden_layers=12,
179
+ num_attention_heads=12,
180
+ intermediate_size=3072,
181
+ hidden_act="gelu",
182
+ hidden_dropout_prob=0.1,
183
+ attention_probs_dropout_prob=0.1,
184
+ max_position_embeddings=512,
185
+ type_vocab_size=2,
186
+ initializer_range=0.02):
187
+ """Constructs BertConfig.
188
+
189
+ Args:
190
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
191
+ hidden_size: Size of the encoder layers and the pooler layer.
192
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
193
+ num_attention_heads: Number of attention heads for each attention layer in
194
+ the Transformer encoder.
195
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
196
+ layer in the Transformer encoder.
197
+ hidden_act: The non-linear activation function (function or string) in the
198
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
199
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
200
+ layers in the embeddings, encoder, and pooler.
201
+ attention_probs_dropout_prob: The dropout ratio for the attention
202
+ probabilities.
203
+ max_position_embeddings: The maximum sequence length that this model might
204
+ ever be used with. Typically set this to something large just in case
205
+ (e.g., 512 or 1024 or 2048).
206
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
207
+ `BertModel`.
208
+ initializer_range: The sttdev of the truncated_normal_initializer for
209
+ initializing all weight matrices.
210
+ """
211
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
212
+ and isinstance(vocab_size_or_config_json_file, unicode)):
213
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
214
+ json_config = json.loads(reader.read())
215
+ for key, value in json_config.items():
216
+ self.__dict__[key] = value
217
+ elif isinstance(vocab_size_or_config_json_file, int):
218
+ self.vocab_size = vocab_size_or_config_json_file
219
+ self.hidden_size = hidden_size
220
+ self.num_hidden_layers = num_hidden_layers
221
+ self.num_attention_heads = num_attention_heads
222
+ self.hidden_act = hidden_act
223
+ self.intermediate_size = intermediate_size
224
+ self.hidden_dropout_prob = hidden_dropout_prob
225
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
226
+ self.max_position_embeddings = max_position_embeddings
227
+ self.type_vocab_size = type_vocab_size
228
+ self.initializer_range = initializer_range
229
+ else:
230
+ raise ValueError("First argument must be either a vocabulary size (int)"
231
+ "or the path to a pretrained model config file (str)")
232
+
233
+ @classmethod
234
+ def from_dict(cls, json_object):
235
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
236
+ config = BertConfig(vocab_size_or_config_json_file=-1)
237
+ for key, value in json_object.items():
238
+ config.__dict__[key] = value
239
+ return config
240
+
241
+ @classmethod
242
+ def from_json_file(cls, json_file):
243
+ """Constructs a `BertConfig` from a json file of parameters."""
244
+ with open(json_file, "r", encoding='utf-8') as reader:
245
+ text = reader.read()
246
+ return cls.from_dict(json.loads(text))
247
+
248
+ def __repr__(self):
249
+ return str(self.to_json_string())
250
+
251
+ def to_dict(self):
252
+ """Serializes this instance to a Python dictionary."""
253
+ output = copy.deepcopy(self.__dict__)
254
+ return output
255
+
256
+ def to_json_string(self):
257
+ """Serializes this instance to a JSON string."""
258
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
259
+
260
+
261
+ BertLayerNorm = torch.nn.LayerNorm
262
+
263
+
264
+ class BertEmbeddings(nn.Module):
265
+ """Construct the embeddings from word, position and token_type embeddings.
266
+ """
267
+ def __init__(self, config):
268
+ super(BertEmbeddings, self).__init__()
269
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
270
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
271
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
272
+
273
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
274
+ # any TensorFlow checkpoint file
275
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
276
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
277
+
278
+ def forward(self, input_ids, token_type_ids=None):
279
+ seq_length = input_ids.size(1)
280
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
281
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
282
+ if token_type_ids is None:
283
+ token_type_ids = torch.zeros_like(input_ids)
284
+
285
+ words_embeddings = self.word_embeddings(input_ids)
286
+ position_embeddings = self.position_embeddings(position_ids)
287
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
288
+
289
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
290
+ embeddings = self.LayerNorm(embeddings)
291
+ embeddings = self.dropout(embeddings)
292
+ return embeddings
293
+
294
+
295
+ class BertAttention(nn.Module):
296
+ def __init__(self, config, ctx_dim=None):
297
+ super().__init__()
298
+ if config.hidden_size % config.num_attention_heads != 0:
299
+ raise ValueError(
300
+ "The hidden size (%d) is not a multiple of the number of attention "
301
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
302
+ self.num_attention_heads = config.num_attention_heads
303
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
304
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
305
+
306
+ # visual_dim = 2048
307
+ if ctx_dim is None:
308
+ ctx_dim =config.hidden_size
309
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
310
+ self.key = nn.Linear(ctx_dim, self.all_head_size)
311
+ self.value = nn.Linear(ctx_dim, self.all_head_size)
312
+
313
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
314
+
315
+ def transpose_for_scores(self, x):
316
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
317
+ x = x.view(*new_x_shape)
318
+ return x.permute(0, 2, 1, 3)
319
+
320
+ def forward(self, hidden_states, context, attention_mask=None):
321
+ mixed_query_layer = self.query(hidden_states)
322
+ mixed_key_layer = self.key(context)
323
+ mixed_value_layer = self.value(context)
324
+
325
+ query_layer = self.transpose_for_scores(mixed_query_layer)
326
+ key_layer = self.transpose_for_scores(mixed_key_layer)
327
+ value_layer = self.transpose_for_scores(mixed_value_layer)
328
+
329
+ # Take the dot product between "query" and "key" to get the raw attention scores.
330
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
331
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
332
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
333
+ if attention_mask is not None:
334
+ attention_scores = attention_scores + attention_mask
335
+
336
+ # Normalize the attention scores to probabilities.
337
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
338
+
339
+ # This is actually dropping out entire tokens to attend to, which might
340
+ # seem a bit unusual, but is taken from the original Transformer paper.
341
+ attention_probs = self.dropout(attention_probs)
342
+
343
+ context_layer = torch.matmul(attention_probs, value_layer)
344
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
345
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
346
+ context_layer = context_layer.view(*new_context_layer_shape)
347
+ return context_layer
348
+
349
+
350
+ class BertAttOutput(nn.Module):
351
+ def __init__(self, config):
352
+ super(BertAttOutput, self).__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
354
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
355
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
356
+
357
+ def forward(self, hidden_states, input_tensor):
358
+ hidden_states = self.dense(hidden_states)
359
+ hidden_states = self.dropout(hidden_states)
360
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
361
+ return hidden_states
362
+
363
+
364
+ class BertCrossattLayer(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.att = BertAttention(config)
368
+ self.output = BertAttOutput(config)
369
+
370
+ def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
371
+ output = self.att(input_tensor, ctx_tensor, ctx_att_mask)
372
+ attention_output = self.output(output, input_tensor)
373
+ return attention_output
374
+
375
+
376
+ class BertSelfattLayer(nn.Module):
377
+ def __init__(self, config):
378
+ super(BertSelfattLayer, self).__init__()
379
+ self.self = BertAttention(config)
380
+ self.output = BertAttOutput(config)
381
+
382
+ def forward(self, input_tensor, attention_mask):
383
+ # Self attention attends to itself, thus keys and querys are the same (input_tensor).
384
+ self_output = self.self(input_tensor, input_tensor, attention_mask)
385
+ attention_output = self.output(self_output, input_tensor)
386
+ return attention_output
387
+
388
+
389
+ class BertIntermediate(nn.Module):
390
+ def __init__(self, config):
391
+ super(BertIntermediate, self).__init__()
392
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
393
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
394
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
395
+ else:
396
+ self.intermediate_act_fn = config.hidden_act
397
+
398
+ def forward(self, hidden_states):
399
+ hidden_states = self.dense(hidden_states)
400
+ hidden_states = self.intermediate_act_fn(hidden_states)
401
+ return hidden_states
402
+
403
+
404
+ class BertOutput(nn.Module):
405
+ def __init__(self, config):
406
+ super(BertOutput, self).__init__()
407
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
408
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
409
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
410
+
411
+ def forward(self, hidden_states, input_tensor):
412
+ hidden_states = self.dense(hidden_states)
413
+ hidden_states = self.dropout(hidden_states)
414
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
415
+ return hidden_states
416
+
417
+
418
+ class BertLayer(nn.Module):
419
+ def __init__(self, config):
420
+ super(BertLayer, self).__init__()
421
+ self.attention = BertSelfattLayer(config)
422
+ self.intermediate = BertIntermediate(config)
423
+ self.output = BertOutput(config)
424
+
425
+ def forward(self, hidden_states, attention_mask):
426
+ attention_output = self.attention(hidden_states, attention_mask)
427
+ intermediate_output = self.intermediate(attention_output)
428
+ layer_output = self.output(intermediate_output, attention_output)
429
+ return layer_output
430
+
431
+
432
+ """
433
+ ---------------------------------------------------------------------------------------
434
+ Above modules are copied from BERT (pytorch-transformer) with modifications.
435
+ ---------------------------------------------------------------------------------------
436
+ """
437
+
438
+
439
+ class LXRTXLayer(nn.Module):
440
+ def __init__(self, config):
441
+ super().__init__()
442
+ # The cross-attention Layer
443
+ self.visual_attention = BertCrossattLayer(config)
444
+
445
+ # Self-attention Layers
446
+ self.lang_self_att = BertSelfattLayer(config)
447
+ self.visn_self_att = BertSelfattLayer(config)
448
+
449
+ # Intermediate and Output Layers (FFNs)
450
+ self.lang_inter = BertIntermediate(config)
451
+ self.lang_output = BertOutput(config)
452
+ self.visn_inter = BertIntermediate(config)
453
+ self.visn_output = BertOutput(config)
454
+
455
+ def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
456
+ # Cross Attention
457
+ lang_att_output = self.visual_attention(lang_input, visn_input, ctx_att_mask=visn_attention_mask)
458
+ visn_att_output = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
459
+ return lang_att_output, visn_att_output
460
+
461
+ def self_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
462
+ # Self Attention
463
+ lang_att_output = self.lang_self_att(lang_input, lang_attention_mask)
464
+ visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
465
+ return lang_att_output, visn_att_output
466
+
467
+ def output_fc(self, lang_input, visn_input):
468
+ # FC layers
469
+ lang_inter_output = self.lang_inter(lang_input)
470
+ visn_inter_output = self.visn_inter(visn_input)
471
+
472
+ # Layer output
473
+ lang_output = self.lang_output(lang_inter_output, lang_input)
474
+ visn_output = self.visn_output(visn_inter_output, visn_input)
475
+ return lang_output, visn_output
476
+
477
+ def forward(self, lang_feats, lang_attention_mask,
478
+ visn_feats, visn_attention_mask):
479
+ lang_att_output = lang_feats
480
+ visn_att_output = visn_feats
481
+
482
+ lang_att_output, visn_att_output = self.cross_att(lang_att_output, lang_attention_mask,
483
+ visn_att_output, visn_attention_mask)
484
+ lang_att_output, visn_att_output = self.self_att(lang_att_output, lang_attention_mask,
485
+ visn_att_output, visn_attention_mask)
486
+ lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output)
487
+
488
+ return lang_output, visn_output
489
+
490
+
491
+ class VisualFeatEncoder(nn.Module):
492
+ def __init__(self, config):
493
+ super().__init__()
494
+ feat_dim = VISUAL_CONFIG.visual_feat_dim
495
+ pos_dim = VISUAL_CONFIG.visual_pos_dim
496
+
497
+ # Object feature encoding
498
+ self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
499
+ self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
500
+
501
+ # Box position encoding
502
+ self.box_fc = nn.Linear(pos_dim, config.hidden_size)
503
+ self.box_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
504
+
505
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
506
+
507
+ def forward(self, visn_input):
508
+ feats, boxes = visn_input
509
+
510
+ x = self.visn_fc(feats)
511
+ x = self.visn_layer_norm(x)
512
+ y = self.box_fc(boxes)
513
+ y = self.box_layer_norm(y)
514
+ output = (x + y) / 2
515
+
516
+ output = self.dropout(output)
517
+ return output
518
+
519
+
520
+ class LXRTEncoder(nn.Module):
521
+ def __init__(self, config):
522
+ super().__init__()
523
+
524
+ # Obj-level image embedding layer
525
+ self.visn_fc = VisualFeatEncoder(config)
526
+
527
+ # Number of layers
528
+ self.num_l_layers = VISUAL_CONFIG.l_layers
529
+ self.num_x_layers = VISUAL_CONFIG.x_layers
530
+ self.num_r_layers = VISUAL_CONFIG.r_layers
531
+ print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." %
532
+ (self.num_l_layers, self.num_x_layers, self.num_r_layers))
533
+
534
+ # Layers
535
+ # Using self.layer instead of self.l_layer to support loading BERT weights.
536
+ self.layer = nn.ModuleList(
537
+ [BertLayer(config) for _ in range(self.num_l_layers)]
538
+ )
539
+ self.x_layers = nn.ModuleList(
540
+ [LXRTXLayer(config) for _ in range(self.num_x_layers)]
541
+ )
542
+ self.r_layers = nn.ModuleList(
543
+ [BertLayer(config) for _ in range(self.num_r_layers)]
544
+ )
545
+
546
+ def forward(self, lang_feats, lang_attention_mask,
547
+ visn_feats, visn_attention_mask=None):
548
+ # Run visual embedding layer
549
+ # Note: Word embedding layer was executed outside this module.
550
+ # Keep this design to allow loading BERT weights.
551
+ visn_feats = self.visn_fc(visn_feats)
552
+
553
+ # Run language layers
554
+ for layer_module in self.layer:
555
+ lang_feats = layer_module(lang_feats, lang_attention_mask)
556
+
557
+ # Run relational layers
558
+ for layer_module in self.r_layers:
559
+ visn_feats = layer_module(visn_feats, visn_attention_mask)
560
+
561
+ # Run cross-modality layers
562
+ for layer_module in self.x_layers:
563
+ lang_feats, visn_feats = layer_module(lang_feats, lang_attention_mask,
564
+ visn_feats, visn_attention_mask)
565
+
566
+ return lang_feats, visn_feats
567
+
568
+
569
+ class BertPooler(nn.Module):
570
+ def __init__(self, config):
571
+ super(BertPooler, self).__init__()
572
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
573
+ self.activation = nn.Tanh()
574
+
575
+ def forward(self, hidden_states):
576
+ # We "pool" the model by simply taking the hidden state corresponding
577
+ # to the first token.
578
+ first_token_tensor = hidden_states[:, 0]
579
+ pooled_output = self.dense(first_token_tensor)
580
+ pooled_output = self.activation(pooled_output)
581
+ return pooled_output
582
+
583
+
584
+ class BertPredictionHeadTransform(nn.Module):
585
+ def __init__(self, config):
586
+ super(BertPredictionHeadTransform, self).__init__()
587
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
588
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
589
+ self.transform_act_fn = ACT2FN[config.hidden_act]
590
+ else:
591
+ self.transform_act_fn = config.hidden_act
592
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
593
+
594
+ def forward(self, hidden_states):
595
+ hidden_states = self.dense(hidden_states)
596
+ hidden_states = self.transform_act_fn(hidden_states)
597
+ hidden_states = self.LayerNorm(hidden_states)
598
+ return hidden_states
599
+
600
+
601
+ class BertLMPredictionHead(nn.Module):
602
+ def __init__(self, config, bert_model_embedding_weights):
603
+ super(BertLMPredictionHead, self).__init__()
604
+ self.transform = BertPredictionHeadTransform(config)
605
+
606
+ # The output weights are the same as the input embeddings, but there is
607
+ # an output-only bias for each token.
608
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
609
+ bert_model_embedding_weights.size(0),
610
+ bias=False)
611
+ self.decoder.weight = bert_model_embedding_weights
612
+ self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
613
+
614
+ def forward(self, hidden_states):
615
+ hidden_states = self.transform(hidden_states)
616
+ hidden_states = self.decoder(hidden_states) + self.bias
617
+ return hidden_states
618
+
619
+
620
+ class BertVisualAnswerHead(nn.Module):
621
+ def __init__(self, config, num_answers):
622
+ super().__init__()
623
+ hid_dim = config.hidden_size
624
+ self.logit_fc = nn.Sequential(
625
+ nn.Linear(hid_dim, hid_dim * 2),
626
+ GeLU(),
627
+ BertLayerNorm(hid_dim * 2, eps=1e-12),
628
+ nn.Linear(hid_dim * 2, num_answers)
629
+ )
630
+
631
+ def forward(self, hidden_states):
632
+ return self.logit_fc(hidden_states)
633
+
634
+
635
+ class BertVisualObjHead(nn.Module):
636
+ def __init__(self, config, visual_losses):
637
+ super().__init__()
638
+ self.transform = BertPredictionHeadTransform(config)
639
+
640
+ # Decide the use of visual losses
641
+ visual_losses = visual_losses.split(",")
642
+ for loss in visual_losses:
643
+ assert loss in VISUAL_CONFIG.VISUAL_LOSSES
644
+ self.visual_losses = visual_losses
645
+
646
+ # The output weights are the same as the input embeddings, but there is
647
+ # an output-only bias for each token.
648
+ self.decoder_dict = nn.ModuleDict({
649
+ key: nn.Linear(config.hidden_size, VISUAL_CONFIG.visual_loss_config[key][0])
650
+ for key in self.visual_losses
651
+ })
652
+
653
+ def forward(self, hidden_states):
654
+ hidden_states = self.transform(hidden_states)
655
+ output = {}
656
+ for key in self.visual_losses:
657
+ output[key] = self.decoder_dict[key](hidden_states)
658
+ return output
659
+
660
+
661
+ class BertPreTrainingHeads(nn.Module):
662
+ def __init__(self, config, bert_model_embedding_weights):
663
+ super(BertPreTrainingHeads, self).__init__()
664
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
665
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
666
+
667
+ def forward(self, sequence_output, pooled_output):
668
+ prediction_scores = self.predictions(sequence_output)
669
+ seq_relationship_score = self.seq_relationship(pooled_output)
670
+ return prediction_scores, seq_relationship_score
671
+
672
+
673
+ class BertPreTrainedModel(nn.Module):
674
+ """ An abstract class to handle weights initialization and
675
+ a simple interface for dowloading and loading pretrained models.
676
+ """
677
+ def __init__(self, config, *inputs, **kwargs):
678
+ super(BertPreTrainedModel, self).__init__()
679
+ if not isinstance(config, BertConfig):
680
+ raise ValueError(
681
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
682
+ "To create a model from a Google pretrained model use "
683
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
684
+ self.__class__.__name__, self.__class__.__name__
685
+ ))
686
+ self.config = config
687
+
688
+ def init_bert_weights(self, module):
689
+ """ Initialize the weights.
690
+ """
691
+ if isinstance(module, (nn.Linear, nn.Embedding)):
692
+ # Slightly different from the TF version which uses truncated_normal for initialization
693
+ # cf https://github.com/pytorch/pytorch/pull/5617
694
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
695
+ elif isinstance(module, BertLayerNorm):
696
+ module.bias.data.zero_()
697
+ module.weight.data.fill_(1.0)
698
+ if isinstance(module, nn.Linear) and module.bias is not None:
699
+ module.bias.data.zero_()
700
+
701
+ @classmethod
702
+ def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
703
+ from_tf=False, *inputs, **kwargs):
704
+ """
705
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
706
+ Download and cache the pre-trained model file if needed.
707
+
708
+ Params:
709
+ pretrained_model_name_or_path: either:
710
+ - a str with the name of a pre-trained model to load selected in the list of:
711
+ . `bert-base-uncased`
712
+ . `bert-large-uncased`
713
+ . `bert-base-cased`
714
+ . `bert-large-cased`
715
+ . `bert-base-multilingual-uncased`
716
+ . `bert-base-multilingual-cased`
717
+ . `bert-base-chinese`
718
+ - a path or url to a pretrained model archive containing:
719
+ . `bert_config.json` a configuration file for the model
720
+ . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
721
+ - a path or url to a pretrained model archive containing:
722
+ . `bert_config.json` a configuration file for the model
723
+ . `model.chkpt` a TensorFlow checkpoint
724
+ from_tf: should we load the weights from a locally saved TensorFlow checkpoint
725
+ cache_dir: an optional path to a folder in which the pre-trained models will be cached.
726
+ state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
727
+ *inputs, **kwargs: additional input for the specific Bert class
728
+ (ex: num_labels for BertForSequenceClassification)
729
+ """
730
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
731
+ archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
732
+ else:
733
+ archive_file = pretrained_model_name_or_path
734
+ # redirect to the cache, if necessary
735
+ try:
736
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
737
+ except EnvironmentError:
738
+ if pretrained_model_name_or_path == 'bert-base-uncased':
739
+ try:
740
+ print("The BERT-weight-downloading query to AWS was time-out;"
741
+ "trying to download from UNC servers")
742
+ archive_file = "https://nlp.cs.unc.edu/data/bert/bert-base-uncased.tar.gz"
743
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
744
+ except EnvironmentError:
745
+ print("The weight-downloading still crashed with link: %s, "
746
+ "please check your network connection" % archive_file)
747
+ return None
748
+ else:
749
+ logger.error(
750
+ "Model name '{}' was not found in model name list ({}). "
751
+ "We assumed '{}' was a path or url but couldn't find any file "
752
+ "associated to this path or url.".format(
753
+ pretrained_model_name_or_path,
754
+ ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
755
+ archive_file))
756
+ if resolved_archive_file == archive_file:
757
+ logger.info("loading archive file {}".format(archive_file))
758
+ else:
759
+ logger.info("loading archive file {} from cache at {}".format(
760
+ archive_file, resolved_archive_file))
761
+ tempdir = None
762
+ if os.path.isdir(resolved_archive_file) or from_tf:
763
+ serialization_dir = resolved_archive_file
764
+ else:
765
+ # Extract archive to temp dir
766
+ tempdir = tempfile.mkdtemp()
767
+ logger.info("extracting archive file {} to temp dir {}".format(
768
+ resolved_archive_file, tempdir))
769
+ with tarfile.open(resolved_archive_file, 'r:gz') as archive:
770
+ archive.extractall(tempdir)
771
+ serialization_dir = tempdir
772
+ # Load config
773
+ config_file = os.path.join(serialization_dir, CONFIG_NAME)
774
+ config = BertConfig.from_json_file(config_file)
775
+ logger.info("Model config {}".format(config))
776
+ # Instantiate model.
777
+ model = cls(config, *inputs, **kwargs)
778
+ if state_dict is None and not from_tf:
779
+ weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
780
+ state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
781
+ if tempdir:
782
+ # Clean up temp dir
783
+ shutil.rmtree(tempdir)
784
+ if from_tf:
785
+ # Directly load from a TensorFlow checkpoint
786
+ weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
787
+ return load_tf_weights_in_bert(model, weights_path)
788
+ # Load from a PyTorch state_dict
789
+ old_keys = []
790
+ new_keys = []
791
+ for key in state_dict.keys():
792
+ new_key = None
793
+ if 'gamma' in key:
794
+ new_key = key.replace('gamma', 'weight')
795
+ if 'beta' in key:
796
+ new_key = key.replace('beta', 'bias')
797
+ if new_key:
798
+ old_keys.append(key)
799
+ new_keys.append(new_key)
800
+ for old_key, new_key in zip(old_keys, new_keys):
801
+ state_dict[new_key] = state_dict.pop(old_key)
802
+
803
+ missing_keys = []
804
+ unexpected_keys = []
805
+ error_msgs = []
806
+ # copy state_dict so _load_from_state_dict can modify it
807
+ metadata = getattr(state_dict, '_metadata', None)
808
+ state_dict = state_dict.copy()
809
+ if metadata is not None:
810
+ state_dict._metadata = metadata
811
+
812
+ def load(module, prefix=''):
813
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
814
+ module._load_from_state_dict(
815
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
816
+ for name, child in module._modules.items():
817
+ if child is not None:
818
+ load(child, prefix + name + '.')
819
+ start_prefix = ''
820
+ if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
821
+ start_prefix = 'bert.'
822
+ load(model, prefix=start_prefix)
823
+ # if len(missing_keys) > 0:
824
+ # logger.info("Weights of {} not initialized from pretrained model: {}".format(
825
+ # model.__class__.__name__, missing_keys))
826
+ # if len(unexpected_keys) > 0:
827
+ # logger.info("Weights from pretrained model not used in {}: {}".format(
828
+ # model.__class__.__name__, unexpected_keys))
829
+ if len(error_msgs) > 0:
830
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
831
+ model.__class__.__name__, "\n\t".join(error_msgs)))
832
+ return model
833
+
834
+
835
+ class LXRTModel(BertPreTrainedModel):
836
+ """LXRT Model."""
837
+
838
+ def __init__(self, config):
839
+ super().__init__(config)
840
+ self.embeddings = BertEmbeddings(config)
841
+ self.encoder = LXRTEncoder(config)
842
+ self.pooler = BertPooler(config)
843
+ self.apply(self.init_bert_weights)
844
+
845
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
846
+ visual_feats=None, visual_attention_mask=None):
847
+ if attention_mask is None:
848
+ attention_mask = torch.ones_like(input_ids)
849
+ if token_type_ids is None:
850
+ token_type_ids = torch.zeros_like(input_ids)
851
+
852
+ # We create a 3D attention mask from a 2D tensor mask.
853
+ # Sizes are [batch_size, 1, 1, to_seq_length]
854
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
855
+ # this attention mask is more simple than the triangular masking of causal attention
856
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
857
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
858
+
859
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
860
+ # masked positions, this operation will create a tensor which is 0.0 for
861
+ # positions we want to attend and -10000.0 for masked positions.
862
+ # Since we are adding it to the raw scores before the softmax, this is
863
+ # effectively the same as removing these entirely.
864
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
865
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
866
+
867
+ # Process the visual attention mask
868
+ if visual_attention_mask is not None:
869
+ extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
870
+ extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
871
+ extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
872
+ else:
873
+ extended_visual_attention_mask = None
874
+
875
+ # Positional Word Embeddings
876
+ embedding_output = self.embeddings(input_ids, token_type_ids)
877
+
878
+ # Run LXRT backbone
879
+ lang_feats, visn_feats = self.encoder(
880
+ embedding_output,
881
+ extended_attention_mask,
882
+ visn_feats=visual_feats,
883
+ visn_attention_mask=extended_visual_attention_mask)
884
+ pooled_output = self.pooler(lang_feats)
885
+
886
+ return (lang_feats, visn_feats), pooled_output
887
+
888
+
889
+ class LXRTPretraining(BertPreTrainedModel):
890
+ def __init__(self,
891
+ config,
892
+ task_mask_lm=True,
893
+ task_matched=True,
894
+ task_obj_predict=True,
895
+ visual_losses='',
896
+ task_qa=True,
897
+ num_answers=2):
898
+ super().__init__(config)
899
+ # Configuration
900
+ self.config = config
901
+ self.num_answers = num_answers
902
+
903
+ # Use of pre-training tasks
904
+ self.task_mask_lm = task_mask_lm
905
+ self.task_obj_predict = task_obj_predict
906
+ self.task_matched = task_matched
907
+ self.task_qa = task_qa
908
+
909
+ # LXRT backbone
910
+ self.bert = LXRTModel(config)
911
+
912
+ # Pre-training heads
913
+ self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
914
+ if self.task_obj_predict:
915
+ self.obj_predict_head = BertVisualObjHead(config, visual_losses)
916
+ if self.task_qa:
917
+ self.answer_head = BertVisualAnswerHead(config, self.num_answers)
918
+
919
+ # Weight initialization
920
+ self.apply(self.init_bert_weights)
921
+
922
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
923
+ visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
924
+ (lang_output, visn_output), pooled_output = self.bert(
925
+ input_ids, token_type_ids, attention_mask,
926
+ visual_feats=(visual_feats, pos),
927
+ )
928
+
929
+ lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
930
+ if self.task_qa:
931
+ answer_score = self.answer_head(pooled_output)
932
+ else:
933
+ # This answer_score would not be used anywhere,
934
+ # just to keep a constant return function signature.
935
+ answer_score = pooled_output[0][0]
936
+
937
+ total_loss = 0.
938
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
939
+ losses = ()
940
+ if masked_lm_labels is not None and self.task_mask_lm:
941
+ masked_lm_loss = loss_fct(
942
+ lang_prediction_scores.view(-1, self.config.vocab_size),
943
+ masked_lm_labels.view(-1)
944
+ )
945
+ total_loss += masked_lm_loss
946
+ losses += (masked_lm_loss.detach(),)
947
+ if matched_label is not None and self.task_matched:
948
+ matched_loss = loss_fct(
949
+ cross_relationship_score.view(-1, 2),
950
+ matched_label.view(-1)
951
+ )
952
+ total_loss += matched_loss
953
+ losses += (matched_loss.detach(),)
954
+ if obj_labels is not None and self.task_obj_predict:
955
+ loss_fcts = {
956
+ 'l2': SmoothL1Loss(reduction='none'),
957
+ 'ce': CrossEntropyLoss(ignore_index=-1, reduction='none')
958
+ }
959
+ total_visn_loss = 0.
960
+ visn_prediction_scores_dict = self.obj_predict_head(visn_output)
961
+ for key in VISUAL_CONFIG.visual_losses:
962
+ label, mask_conf = obj_labels[key]
963
+ output_dim, loss_fct_name, label_shape, weight = VISUAL_CONFIG.visual_loss_config[key]
964
+ visn_loss_fct = loss_fcts[loss_fct_name]
965
+ visn_prediction_scores = visn_prediction_scores_dict[key]
966
+ visn_loss = visn_loss_fct(
967
+ visn_prediction_scores.view(-1, output_dim),
968
+ label.view(*label_shape),
969
+ )
970
+ if visn_loss.dim() > 1: # Regression Losses
971
+ visn_loss = visn_loss.mean(1)
972
+ visn_loss = (visn_loss * mask_conf.view(-1)).mean() * weight
973
+ total_visn_loss += visn_loss
974
+ losses += (visn_loss.detach(),)
975
+ total_loss += total_visn_loss
976
+ if ans is not None and self.task_qa:
977
+ answer_loss = loss_fct(
978
+ answer_score.view(-1, self.num_answers),
979
+ ans.view(-1)
980
+ )
981
+ # Since this Github version pre-trains with QA loss from the beginning,
982
+ # I exclude "*2" here to match the effect of QA losses.
983
+ # Previous: (loss *0) for 6 epochs, (loss *2) for 6 epochs. (Used 10 instead of 6 in EMNLP paper)
984
+ # Now : (loss *1) for 12 epochs
985
+ #
986
+ # * 2 # Multiply by 2 because > half of the data will not have label
987
+ total_loss += answer_loss
988
+ losses += (answer_loss.detach(),)
989
+ return total_loss, torch.stack(losses).unsqueeze(0), answer_score.detach()
990
+
991
+
992
+ class LXRTFeatureExtraction(BertPreTrainedModel):
993
+ """
994
+ BERT model for classification.
995
+ """
996
+ def __init__(self, config, mode='lxr'):
997
+ """
998
+
999
+ :param config:
1000
+ :param mode: Number of visual layers
1001
+ """
1002
+ super().__init__(config)
1003
+ self.bert = LXRTModel(config)
1004
+ self.mode = mode
1005
+ self.apply(self.init_bert_weights)
1006
+
1007
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, visual_feats=None,
1008
+ visual_attention_mask=None):
1009
+ feat_seq, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
1010
+ visual_feats=visual_feats,
1011
+ visual_attention_mask=visual_attention_mask)
1012
+ if 'x' == self.mode:
1013
+ return pooled_output
1014
+ elif 'x' in self.mode and ('l' in self.mode or 'r' in self.mode):
1015
+ return feat_seq, pooled_output
1016
+ elif 'l' in self.mode or 'r' in self.mode:
1017
+ return feat_seq
1018
+
lxmert/src/lxrt/optimization.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 project LXRT
3
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch optimization for BERT model."""
17
+
18
+ import math
19
+ import torch
20
+ from torch.optim import Optimizer
21
+ from torch.optim.optimizer import required
22
+ import logging
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def warmup_cosine(x, warmup=0.002):
27
+ if x < warmup:
28
+ return x/warmup
29
+ return 0.5 * (1.0 + torch.cos(math.pi * x))
30
+
31
+ def warmup_constant(x, warmup=0.002):
32
+ """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
33
+ Learning rate is 1. afterwards. """
34
+ if x < warmup:
35
+ return x/warmup
36
+ return 1.0
37
+
38
+ def warmup_linear(x, warmup=0.002):
39
+ """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
40
+ After `t_total`-th training step, learning rate is zero. """
41
+ if x < warmup:
42
+ return x/warmup
43
+ return max((x-1.)/(warmup-1.), 0)
44
+
45
+ SCHEDULES = {
46
+ 'warmup_cosine': warmup_cosine,
47
+ 'warmup_constant': warmup_constant,
48
+ 'warmup_linear': warmup_linear,
49
+ }
50
+
51
+
52
+ class BertAdam(Optimizer):
53
+ """Implements BERT version of Adam algorithm with weight decay fix.
54
+ Params:
55
+ lr: learning rate
56
+ warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
57
+ t_total: total number of training steps for the learning
58
+ rate schedule, -1 means constant learning rate. Default: -1
59
+ schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
60
+ b1: Adams b1. Default: 0.9
61
+ b2: Adams b2. Default: 0.999
62
+ e: Adams epsilon. Default: 1e-6
63
+ weight_decay: Weight decay. Default: 0.01
64
+ max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
65
+ """
66
+ def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
67
+ b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
68
+ max_grad_norm=1.0):
69
+ if lr is not required and lr < 0.0:
70
+ raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
71
+ if schedule not in SCHEDULES:
72
+ raise ValueError("Invalid schedule parameter: {}".format(schedule))
73
+ if not 0.0 <= warmup < 1.0 and not warmup == -1:
74
+ raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
75
+ if not 0.0 <= b1 < 1.0:
76
+ raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
77
+ if not 0.0 <= b2 < 1.0:
78
+ raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
79
+ if not e >= 0.0:
80
+ raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
81
+ defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
82
+ b1=b1, b2=b2, e=e, weight_decay=weight_decay,
83
+ max_grad_norm=max_grad_norm)
84
+ super(BertAdam, self).__init__(params, defaults)
85
+
86
+ def get_lr(self):
87
+ lr = []
88
+ for group in self.param_groups:
89
+ for p in group['params']:
90
+ state = self.state[p]
91
+ if len(state) == 0:
92
+ return [0]
93
+ if group['t_total'] != -1:
94
+ schedule_fct = SCHEDULES[group['schedule']]
95
+ lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
96
+ else:
97
+ lr_scheduled = group['lr']
98
+ lr.append(lr_scheduled)
99
+ return lr
100
+
101
+ def step(self, closure=None):
102
+ """Performs a single optimization step.
103
+
104
+ Arguments:
105
+ closure (callable, optional): A closure that reevaluates the model
106
+ and returns the loss.
107
+ """
108
+ loss = None
109
+ if closure is not None:
110
+ loss = closure()
111
+
112
+ warned_for_t_total = False
113
+
114
+ for group in self.param_groups:
115
+ for p in group['params']:
116
+ if p.grad is None:
117
+ continue
118
+ grad = p.grad.data
119
+ if grad.is_sparse:
120
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
121
+
122
+ state = self.state[p]
123
+
124
+ # State initialization
125
+ if len(state) == 0:
126
+ state['step'] = 0
127
+ # Exponential moving average of gradient values
128
+ state['next_m'] = torch.zeros_like(p.data)
129
+ # Exponential moving average of squared gradient values
130
+ state['next_v'] = torch.zeros_like(p.data)
131
+
132
+ next_m, next_v = state['next_m'], state['next_v']
133
+ beta1, beta2 = group['b1'], group['b2']
134
+
135
+ # LXRT: grad is clipped outside.
136
+ # Add grad clipping
137
+ # if group['max_grad_norm'] > 0:
138
+ # clip_grad_norm_(p, group['max_grad_norm'])
139
+
140
+ # Decay the first and second moment running average coefficient
141
+ # In-place operations to update the averages at the same time
142
+ next_m.mul_(beta1).add_(1 - beta1, grad)
143
+ next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
144
+ update = next_m / (next_v.sqrt() + group['e'])
145
+
146
+ # Just adding the square of the weights to the loss function is *not*
147
+ # the correct way of using L2 regularization/weight decay with Adam,
148
+ # since that will interact with the m and v parameters in strange ways.
149
+ #
150
+ # Instead we want to decay the weights in a manner that doesn't interact
151
+ # with the m/v parameters. This is equivalent to adding the square
152
+ # of the weights to the loss with plain (non-momentum) SGD.
153
+ if group['weight_decay'] > 0.0:
154
+ update += group['weight_decay'] * p.data
155
+
156
+ if group['t_total'] != -1:
157
+ schedule_fct = SCHEDULES[group['schedule']]
158
+ progress = state['step']/group['t_total']
159
+ lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
160
+ # warning for exceeding t_total (only active with warmup_linear
161
+ if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
162
+ logger.warning(
163
+ "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
164
+ "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
165
+ warned_for_t_total = True
166
+ # end warning
167
+ else:
168
+ lr_scheduled = group['lr']
169
+
170
+ update_with_lr = lr_scheduled * update
171
+ p.data.add_(-update_with_lr)
172
+
173
+ state['step'] += 1
174
+
175
+ # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
176
+ # No bias correction
177
+ # bias_correction1 = 1 - beta1 ** state['step']
178
+ # bias_correction2 = 1 - beta2 ** state['step']
179
+
180
+ return loss
lxmert/src/lxrt/tokenization.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ import collections
18
+ import logging
19
+ import os
20
+ import unicodedata
21
+ from io import open
22
+
23
+ from .file_utils import cached_path
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ PRETRAINED_VOCAB_ARCHIVE_MAP = {
28
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
29
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
30
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
31
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
32
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
33
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
34
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
35
+ }
36
+ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
37
+ 'bert-base-uncased': 512,
38
+ 'bert-large-uncased': 512,
39
+ 'bert-base-cased': 512,
40
+ 'bert-large-cased': 512,
41
+ 'bert-base-multilingual-uncased': 512,
42
+ 'bert-base-multilingual-cased': 512,
43
+ 'bert-base-chinese': 512,
44
+ }
45
+ VOCAB_NAME = 'vocab.txt'
46
+
47
+
48
+ def load_vocab(vocab_file):
49
+ """Loads a vocabulary file into a dictionary."""
50
+ vocab = collections.OrderedDict()
51
+ index = 0
52
+ with open(vocab_file, "r", encoding="utf-8") as reader:
53
+ while True:
54
+ token = reader.readline()
55
+ if not token:
56
+ break
57
+ token = token.strip()
58
+ vocab[token] = index
59
+ index += 1
60
+ return vocab
61
+
62
+
63
+ def whitespace_tokenize(text):
64
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
65
+ text = text.strip()
66
+ if not text:
67
+ return []
68
+ tokens = text.split()
69
+ return tokens
70
+
71
+
72
+ class BertTokenizer(object):
73
+ """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
74
+
75
+ def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
76
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
77
+ """Constructs a BertTokenizer.
78
+
79
+ Args:
80
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
81
+ do_lower_case: Whether to lower case the input
82
+ Only has an effect when do_wordpiece_only=False
83
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
84
+ max_len: An artificial maximum length to truncate tokenized sequences to;
85
+ Effective maximum length is always the minimum of this
86
+ value (if specified) and the underlying BERT model's
87
+ sequence length.
88
+ never_split: List of tokens which will never be split during tokenization.
89
+ Only has an effect when do_wordpiece_only=False
90
+ """
91
+ if not os.path.isfile(vocab_file):
92
+ raise ValueError(
93
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
94
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
95
+ self.vocab = load_vocab(vocab_file)
96
+ self.ids_to_tokens = collections.OrderedDict(
97
+ [(ids, tok) for tok, ids in self.vocab.items()])
98
+ self.do_basic_tokenize = do_basic_tokenize
99
+ if do_basic_tokenize:
100
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
101
+ never_split=never_split)
102
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
103
+ self.max_len = max_len if max_len is not None else int(1e12)
104
+
105
+ def tokenize(self, text):
106
+ if self.do_basic_tokenize:
107
+ split_tokens = []
108
+ for token in self.basic_tokenizer.tokenize(text):
109
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
110
+ split_tokens.append(sub_token)
111
+ else:
112
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
113
+ return split_tokens
114
+
115
+ def convert_tokens_to_ids(self, tokens):
116
+ """Converts a sequence of tokens into ids using the vocab."""
117
+ ids = []
118
+ for token in tokens:
119
+ ids.append(self.vocab[token])
120
+ if len(ids) > self.max_len:
121
+ logger.warning(
122
+ "Token indices sequence length is longer than the specified maximum "
123
+ " sequence length for this BERT model ({} > {}). Running this"
124
+ " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
125
+ )
126
+ return ids
127
+
128
+ def convert_ids_to_tokens(self, ids):
129
+ """Converts a sequence of ids in wordpiece tokens using the vocab."""
130
+ tokens = []
131
+ for i in ids:
132
+ tokens.append(self.ids_to_tokens[i])
133
+ return tokens
134
+
135
+ @classmethod
136
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
137
+ """
138
+ Instantiate a PreTrainedBertModel from a pre-trained model file.
139
+ Download and cache the pre-trained model file if needed.
140
+ """
141
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
142
+ vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
143
+ else:
144
+ vocab_file = pretrained_model_name_or_path
145
+ if os.path.isdir(vocab_file):
146
+ vocab_file = os.path.join(vocab_file, VOCAB_NAME)
147
+ # redirect to the cache, if necessary
148
+ try:
149
+ resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
150
+ except EnvironmentError:
151
+ logger.error(
152
+ "Model name '{}' was not found in model name list ({}). "
153
+ "We assumed '{}' was a path or url but couldn't find any file "
154
+ "associated to this path or url.".format(
155
+ pretrained_model_name_or_path,
156
+ ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
157
+ vocab_file))
158
+ return None
159
+ if resolved_vocab_file == vocab_file:
160
+ logger.info("loading vocabulary file {}".format(vocab_file))
161
+ else:
162
+ logger.info("loading vocabulary file {} from cache at {}".format(
163
+ vocab_file, resolved_vocab_file))
164
+ if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
165
+ # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
166
+ # than the number of positional embeddings
167
+ max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
168
+ kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
169
+ # Instantiate tokenizer.
170
+ tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
171
+ return tokenizer
172
+
173
+
174
+ class BasicTokenizer(object):
175
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
176
+
177
+ def __init__(self,
178
+ do_lower_case=True,
179
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
180
+ """Constructs a BasicTokenizer.
181
+
182
+ Args:
183
+ do_lower_case: Whether to lower case the input.
184
+ """
185
+ self.do_lower_case = do_lower_case
186
+ self.never_split = never_split
187
+
188
+ def tokenize(self, text):
189
+ """Tokenizes a piece of text."""
190
+ text = self._clean_text(text)
191
+ # This was added on November 1st, 2018 for the multilingual and Chinese
192
+ # models. This is also applied to the English models now, but it doesn't
193
+ # matter since the English models were not trained on any Chinese data
194
+ # and generally don't have any Chinese data in them (there are Chinese
195
+ # characters in the vocabulary because Wikipedia does have some Chinese
196
+ # words in the English Wikipedia.).
197
+ text = self._tokenize_chinese_chars(text)
198
+ orig_tokens = whitespace_tokenize(text)
199
+ split_tokens = []
200
+ for token in orig_tokens:
201
+ if self.do_lower_case and token not in self.never_split:
202
+ token = token.lower()
203
+ token = self._run_strip_accents(token)
204
+ split_tokens.extend(self._run_split_on_punc(token))
205
+
206
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
207
+ return output_tokens
208
+
209
+ def _run_strip_accents(self, text):
210
+ """Strips accents from a piece of text."""
211
+ text = unicodedata.normalize("NFD", text)
212
+ output = []
213
+ for char in text:
214
+ cat = unicodedata.category(char)
215
+ if cat == "Mn":
216
+ continue
217
+ output.append(char)
218
+ return "".join(output)
219
+
220
+ def _run_split_on_punc(self, text):
221
+ """Splits punctuation on a piece of text."""
222
+ if text in self.never_split:
223
+ return [text]
224
+ chars = list(text)
225
+ i = 0
226
+ start_new_word = True
227
+ output = []
228
+ while i < len(chars):
229
+ char = chars[i]
230
+ if _is_punctuation(char):
231
+ output.append([char])
232
+ start_new_word = True
233
+ else:
234
+ if start_new_word:
235
+ output.append([])
236
+ start_new_word = False
237
+ output[-1].append(char)
238
+ i += 1
239
+
240
+ return ["".join(x) for x in output]
241
+
242
+ def _tokenize_chinese_chars(self, text):
243
+ """Adds whitespace around any CJK character."""
244
+ output = []
245
+ for char in text:
246
+ cp = ord(char)
247
+ if self._is_chinese_char(cp):
248
+ output.append(" ")
249
+ output.append(char)
250
+ output.append(" ")
251
+ else:
252
+ output.append(char)
253
+ return "".join(output)
254
+
255
+ def _is_chinese_char(self, cp):
256
+ """Checks whether CP is the codepoint of a CJK character."""
257
+ # This defines a "chinese character" as anything in the CJK Unicode block:
258
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
259
+ #
260
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
261
+ # despite its name. The modern Korean Hangul alphabet is a different block,
262
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
263
+ # space-separated words, so they are not treated specially and handled
264
+ # like the all of the other languages.
265
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
266
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
267
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
268
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
269
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
270
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
271
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
272
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
273
+ return True
274
+
275
+ return False
276
+
277
+ def _clean_text(self, text):
278
+ """Performs invalid character removal and whitespace cleanup on text."""
279
+ output = []
280
+ for char in text:
281
+ cp = ord(char)
282
+ if cp == 0 or cp == 0xfffd or _is_control(char):
283
+ continue
284
+ if _is_whitespace(char):
285
+ output.append(" ")
286
+ else:
287
+ output.append(char)
288
+ return "".join(output)
289
+
290
+
291
+ class WordpieceTokenizer(object):
292
+ """Runs WordPiece tokenization."""
293
+
294
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
295
+ self.vocab = vocab
296
+ self.unk_token = unk_token
297
+ self.max_input_chars_per_word = max_input_chars_per_word
298
+
299
+ def tokenize(self, text):
300
+ """Tokenizes a piece of text into its word pieces.
301
+
302
+ This uses a greedy longest-match-first algorithm to perform tokenization
303
+ using the given vocabulary.
304
+
305
+ For example:
306
+ input = "unaffable"
307
+ output = ["un", "##aff", "##able"]
308
+
309
+ Args:
310
+ text: A single token or whitespace separated tokens. This should have
311
+ already been passed through `BasicTokenizer`.
312
+
313
+ Returns:
314
+ A list of wordpiece tokens.
315
+ """
316
+
317
+ output_tokens = []
318
+ for token in whitespace_tokenize(text):
319
+ chars = list(token)
320
+ if len(chars) > self.max_input_chars_per_word:
321
+ output_tokens.append(self.unk_token)
322
+ continue
323
+
324
+ is_bad = False
325
+ start = 0
326
+ sub_tokens = []
327
+ while start < len(chars):
328
+ end = len(chars)
329
+ cur_substr = None
330
+ while start < end:
331
+ substr = "".join(chars[start:end])
332
+ if start > 0:
333
+ substr = "##" + substr
334
+ if substr in self.vocab:
335
+ cur_substr = substr
336
+ break
337
+ end -= 1
338
+ if cur_substr is None:
339
+ is_bad = True
340
+ break
341
+ sub_tokens.append(cur_substr)
342
+ start = end
343
+
344
+ if is_bad:
345
+ output_tokens.append(self.unk_token)
346
+ else:
347
+ output_tokens.extend(sub_tokens)
348
+ return output_tokens
349
+
350
+
351
+ def _is_whitespace(char):
352
+ """Checks whether `chars` is a whitespace character."""
353
+ # \t, \n, and \r are technically contorl characters but we treat them
354
+ # as whitespace since they are generally considered as such.
355
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
356
+ return True
357
+ cat = unicodedata.category(char)
358
+ if cat == "Zs":
359
+ return True
360
+ return False
361
+
362
+
363
+ def _is_control(char):
364
+ """Checks whether `chars` is a control character."""
365
+ # These are technically control characters but we count them as whitespace
366
+ # characters.
367
+ if char == "\t" or char == "\n" or char == "\r":
368
+ return False
369
+ cat = unicodedata.category(char)
370
+ if cat.startswith("C"):
371
+ return True
372
+ return False
373
+
374
+
375
+ def _is_punctuation(char):
376
+ """Checks whether `chars` is a punctuation character."""
377
+ cp = ord(char)
378
+ # We treat all non-letter/number ASCII as punctuation.
379
+ # Characters such as "^", "$", and "`" are not in the Unicode
380
+ # Punctuation class but we treat them as punctuation anyways, for
381
+ # consistency.
382
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
383
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
384
+ return True
385
+ cat = unicodedata.category(char)
386
+ if cat.startswith("P"):
387
+ return True
388
+ return False
lxmert/src/modeling_frcnn.py ADDED
@@ -0,0 +1,1922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ coding=utf-8
3
+ Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
4
+ Adapted From Facebook Inc, Detectron2 && Huggingface Co.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.import copy
17
+ """
18
+ import itertools
19
+ import math
20
+ import os
21
+ from abc import ABCMeta, abstractmethod
22
+ from collections import OrderedDict, namedtuple
23
+ from typing import Dict, List, Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch import nn
28
+ from torch.nn import functional as F
29
+ from torch.nn.modules.batchnorm import BatchNorm2d
30
+ from torchvision.ops import RoIPool
31
+ from torchvision.ops.boxes import batched_nms, nms
32
+
33
+ from lxmert.lxmert.src.vqa_utils import WEIGHTS_NAME, Config, cached_path, hf_bucket_url, is_remote_url, load_checkpoint
34
+
35
+
36
+ # other:
37
+ def norm_box(boxes, raw_sizes):
38
+ if not isinstance(boxes, torch.Tensor):
39
+ normalized_boxes = boxes.copy()
40
+ else:
41
+ normalized_boxes = boxes.clone()
42
+ normalized_boxes[:, :, (0, 2)] /= raw_sizes[:, 1]
43
+ normalized_boxes[:, :, (1, 3)] /= raw_sizes[:, 0]
44
+ return normalized_boxes
45
+
46
+
47
+ def pad_list_tensors(
48
+ list_tensors,
49
+ preds_per_image,
50
+ max_detections=None,
51
+ return_tensors=None,
52
+ padding=None,
53
+ pad_value=0,
54
+ location=None,
55
+ ):
56
+ """
57
+ location will always be cpu for np tensors
58
+ """
59
+ if location is None:
60
+ location = "cpu"
61
+ assert return_tensors in {"pt", "np", None}
62
+ assert padding in {"max_detections", "max_batch", None}
63
+ new = []
64
+ if padding is None:
65
+ if return_tensors is None:
66
+ return list_tensors
67
+ elif return_tensors == "pt":
68
+ if not isinstance(list_tensors, torch.Tensor):
69
+ return torch.stack(list_tensors).to(location)
70
+ else:
71
+ return list_tensors.to(location)
72
+ else:
73
+ if not isinstance(list_tensors, list):
74
+ return np.array(list_tensors.to(location))
75
+ else:
76
+ return list_tensors.to(location)
77
+ if padding == "max_detections":
78
+ assert max_detections is not None, "specify max number of detections per batch"
79
+ elif padding == "max_batch":
80
+ max_detections = max(preds_per_image)
81
+ for i in range(len(list_tensors)):
82
+ too_small = False
83
+ tensor_i = list_tensors.pop(0)
84
+ if tensor_i.ndim < 2:
85
+ too_small = True
86
+ tensor_i = tensor_i.unsqueeze(-1)
87
+ assert isinstance(tensor_i, torch.Tensor)
88
+ tensor_i = F.pad(
89
+ input=tensor_i,
90
+ pad=(0, 0, 0, max_detections - preds_per_image[i]),
91
+ mode="constant",
92
+ value=pad_value,
93
+ )
94
+ if too_small:
95
+ tensor_i = tensor_i.squeeze(-1)
96
+ if return_tensors is None:
97
+ if location == "cpu":
98
+ tensor_i = tensor_i.cpu()
99
+ tensor_i = tensor_i.tolist()
100
+ if return_tensors == "np":
101
+ if location == "cpu":
102
+ tensor_i = tensor_i.cpu()
103
+ tensor_i = tensor_i.numpy()
104
+ else:
105
+ if location == "cpu":
106
+ tensor_i = tensor_i.cpu()
107
+ new.append(tensor_i)
108
+ if return_tensors == "np":
109
+ return np.stack(new, axis=0)
110
+ elif return_tensors == "pt" and not isinstance(new, torch.Tensor):
111
+ return torch.stack(new, dim=0)
112
+ else:
113
+ return list_tensors
114
+
115
+
116
+ def do_nms(boxes, scores, image_shape, score_thresh, nms_thresh, mind, maxd):
117
+ scores = scores[:, :-1]
118
+ num_bbox_reg_classes = boxes.shape[1] // 4
119
+ # Convert to Boxes to use the `clip` function ...
120
+ boxes = boxes.reshape(-1, 4)
121
+ _clip_box(boxes, image_shape)
122
+ boxes = boxes.view(-1, num_bbox_reg_classes, 4) # R x C x 4
123
+
124
+ # Select max scores
125
+ max_scores, max_classes = scores.max(1) # R x C --> R
126
+ num_objs = boxes.size(0)
127
+ boxes = boxes.view(-1, 4)
128
+ idxs = torch.arange(num_objs).to(boxes.device) * num_bbox_reg_classes + max_classes
129
+ max_boxes = boxes[idxs] # Select max boxes according to the max scores.
130
+
131
+ # Apply NMS
132
+ keep = nms(max_boxes, max_scores, nms_thresh)
133
+ keep = keep[:maxd]
134
+ if keep.shape[-1] >= mind and keep.shape[-1] <= maxd:
135
+ max_boxes, max_scores = max_boxes[keep], max_scores[keep]
136
+ classes = max_classes[keep]
137
+ return max_boxes, max_scores, classes, keep
138
+ else:
139
+ return None
140
+
141
+
142
+ # Helper Functions
143
+ def _clip_box(tensor, box_size: Tuple[int, int]):
144
+ assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
145
+ h, w = box_size
146
+ tensor[:, 0].clamp_(min=0, max=w)
147
+ tensor[:, 1].clamp_(min=0, max=h)
148
+ tensor[:, 2].clamp_(min=0, max=w)
149
+ tensor[:, 3].clamp_(min=0, max=h)
150
+
151
+
152
+ def _nonempty_boxes(box, threshold: float = 0.0) -> torch.Tensor:
153
+ widths = box[:, 2] - box[:, 0]
154
+ heights = box[:, 3] - box[:, 1]
155
+ keep = (widths > threshold) & (heights > threshold)
156
+ return keep
157
+
158
+
159
+ def get_norm(norm, out_channels):
160
+ if isinstance(norm, str):
161
+ if len(norm) == 0:
162
+ return None
163
+ norm = {
164
+ "BN": BatchNorm2d,
165
+ "GN": lambda channels: nn.GroupNorm(32, channels),
166
+ "nnSyncBN": nn.SyncBatchNorm, # keep for debugging
167
+ "": lambda x: x,
168
+ }[norm]
169
+ return norm(out_channels)
170
+
171
+
172
+ def _create_grid_offsets(size: List[int], stride: int, offset: float, device):
173
+
174
+ grid_height, grid_width = size
175
+ shifts_x = torch.arange(
176
+ offset * stride,
177
+ grid_width * stride,
178
+ step=stride,
179
+ dtype=torch.float32,
180
+ device=device,
181
+ )
182
+ shifts_y = torch.arange(
183
+ offset * stride,
184
+ grid_height * stride,
185
+ step=stride,
186
+ dtype=torch.float32,
187
+ device=device,
188
+ )
189
+
190
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
191
+ shift_x = shift_x.reshape(-1)
192
+ shift_y = shift_y.reshape(-1)
193
+ return shift_x, shift_y
194
+
195
+
196
+ def build_backbone(cfg):
197
+ input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))
198
+ norm = cfg.RESNETS.NORM
199
+ stem = BasicStem(
200
+ in_channels=input_shape.channels,
201
+ out_channels=cfg.RESNETS.STEM_OUT_CHANNELS,
202
+ norm=norm,
203
+ caffe_maxpool=cfg.MODEL.MAX_POOL,
204
+ )
205
+ freeze_at = cfg.BACKBONE.FREEZE_AT
206
+
207
+ if freeze_at >= 1:
208
+ for p in stem.parameters():
209
+ p.requires_grad = False
210
+
211
+ out_features = cfg.RESNETS.OUT_FEATURES
212
+ depth = cfg.RESNETS.DEPTH
213
+ num_groups = cfg.RESNETS.NUM_GROUPS
214
+ width_per_group = cfg.RESNETS.WIDTH_PER_GROUP
215
+ bottleneck_channels = num_groups * width_per_group
216
+ in_channels = cfg.RESNETS.STEM_OUT_CHANNELS
217
+ out_channels = cfg.RESNETS.RES2_OUT_CHANNELS
218
+ stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1
219
+ res5_dilation = cfg.RESNETS.RES5_DILATION
220
+ assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
221
+
222
+ num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
223
+
224
+ stages = []
225
+ out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features]
226
+ max_stage_idx = max(out_stage_idx)
227
+ for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
228
+ dilation = res5_dilation if stage_idx == 5 else 1
229
+ first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
230
+ stage_kargs = {
231
+ "num_blocks": num_blocks_per_stage[idx],
232
+ "first_stride": first_stride,
233
+ "in_channels": in_channels,
234
+ "bottleneck_channels": bottleneck_channels,
235
+ "out_channels": out_channels,
236
+ "num_groups": num_groups,
237
+ "norm": norm,
238
+ "stride_in_1x1": stride_in_1x1,
239
+ "dilation": dilation,
240
+ }
241
+
242
+ stage_kargs["block_class"] = BottleneckBlock
243
+ blocks = ResNet.make_stage(**stage_kargs)
244
+ in_channels = out_channels
245
+ out_channels *= 2
246
+ bottleneck_channels *= 2
247
+
248
+ if freeze_at >= stage_idx:
249
+ for block in blocks:
250
+ block.freeze()
251
+ stages.append(blocks)
252
+
253
+ return ResNet(stem, stages, out_features=out_features)
254
+
255
+
256
+ def find_top_rpn_proposals(
257
+ proposals,
258
+ pred_objectness_logits,
259
+ images,
260
+ image_sizes,
261
+ nms_thresh,
262
+ pre_nms_topk,
263
+ post_nms_topk,
264
+ min_box_side_len,
265
+ training,
266
+ ):
267
+ """Args:
268
+ proposals (list[Tensor]): (L, N, Hi*Wi*A, 4).
269
+ pred_objectness_logits: tensors of length L.
270
+ nms_thresh (float): IoU threshold to use for NMS
271
+ pre_nms_topk (int): before nms
272
+ post_nms_topk (int): after nms
273
+ min_box_side_len (float): minimum proposal box side
274
+ training (bool): True if proposals are to be used in training,
275
+ Returns:
276
+ results (List[Dict]): stores post_nms_topk object proposals for image i.
277
+ """
278
+ num_images = len(images)
279
+ device = proposals[0].device
280
+
281
+ # 1. Select top-k anchor for every level and every image
282
+ topk_scores = [] # #lvl Tensor, each of shape N x topk
283
+ topk_proposals = []
284
+ level_ids = [] # #lvl Tensor, each of shape (topk,)
285
+ batch_idx = torch.arange(num_images, device=device)
286
+ for level_id, proposals_i, logits_i in zip(itertools.count(), proposals, pred_objectness_logits):
287
+ Hi_Wi_A = logits_i.shape[1]
288
+ num_proposals_i = min(pre_nms_topk, Hi_Wi_A)
289
+
290
+ # sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812)
291
+ # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
292
+ logits_i, idx = logits_i.sort(descending=True, dim=1)
293
+ topk_scores_i = logits_i[batch_idx, :num_proposals_i]
294
+ topk_idx = idx[batch_idx, :num_proposals_i]
295
+
296
+ # each is N x topk
297
+ topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4
298
+
299
+ topk_proposals.append(topk_proposals_i)
300
+ topk_scores.append(topk_scores_i)
301
+ level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device))
302
+
303
+ # 2. Concat all levels together
304
+ topk_scores = torch.cat(topk_scores, dim=1)
305
+ topk_proposals = torch.cat(topk_proposals, dim=1)
306
+ level_ids = torch.cat(level_ids, dim=0)
307
+
308
+ # if I change to batched_nms, I wonder if this will make a difference
309
+ # 3. For each image, run a per-level NMS, and choose topk results.
310
+ results = []
311
+ for n, image_size in enumerate(image_sizes):
312
+ boxes = topk_proposals[n]
313
+ scores_per_img = topk_scores[n]
314
+ # I will have to take a look at the boxes clip method
315
+ _clip_box(boxes, image_size)
316
+ # filter empty boxes
317
+ keep = _nonempty_boxes(boxes, threshold=min_box_side_len)
318
+ lvl = level_ids
319
+ if keep.sum().item() != len(boxes):
320
+ boxes, scores_per_img, lvl = (
321
+ boxes[keep],
322
+ scores_per_img[keep],
323
+ level_ids[keep],
324
+ )
325
+
326
+ keep = batched_nms(boxes, scores_per_img, lvl, nms_thresh)
327
+ keep = keep[:post_nms_topk]
328
+
329
+ res = (boxes[keep], scores_per_img[keep])
330
+ results.append(res)
331
+
332
+ # I wonder if it would be possible for me to pad all these things.
333
+ return results
334
+
335
+
336
+ def subsample_labels(labels, num_samples, positive_fraction, bg_label):
337
+ """
338
+ Returns:
339
+ pos_idx, neg_idx (Tensor):
340
+ 1D vector of indices. The total length of both is `num_samples` or fewer.
341
+ """
342
+ positive = torch.nonzero((labels != -1) & (labels != bg_label)).squeeze(1)
343
+ negative = torch.nonzero(labels == bg_label).squeeze(1)
344
+
345
+ num_pos = int(num_samples * positive_fraction)
346
+ # protect against not enough positive examples
347
+ num_pos = min(positive.numel(), num_pos)
348
+ num_neg = num_samples - num_pos
349
+ # protect against not enough negative examples
350
+ num_neg = min(negative.numel(), num_neg)
351
+
352
+ # randomly select positive and negative examples
353
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
354
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
355
+
356
+ pos_idx = positive[perm1]
357
+ neg_idx = negative[perm2]
358
+ return pos_idx, neg_idx
359
+
360
+
361
+ def add_ground_truth_to_proposals(gt_boxes, proposals):
362
+ raise NotImplementedError()
363
+
364
+
365
+ def add_ground_truth_to_proposals_single_image(gt_boxes, proposals):
366
+ raise NotImplementedError()
367
+
368
+
369
+ def _fmt_box_list(box_tensor, batch_index: int):
370
+ repeated_index = torch.full(
371
+ (len(box_tensor), 1),
372
+ batch_index,
373
+ dtype=box_tensor.dtype,
374
+ device=box_tensor.device,
375
+ )
376
+ return torch.cat((repeated_index, box_tensor), dim=1)
377
+
378
+
379
+ def convert_boxes_to_pooler_format(box_lists: List[torch.Tensor]):
380
+ pooler_fmt_boxes = torch.cat(
381
+ [_fmt_box_list(box_list, i) for i, box_list in enumerate(box_lists)],
382
+ dim=0,
383
+ )
384
+ return pooler_fmt_boxes
385
+
386
+
387
+ def assign_boxes_to_levels(
388
+ box_lists: List[torch.Tensor],
389
+ min_level: int,
390
+ max_level: int,
391
+ canonical_box_size: int,
392
+ canonical_level: int,
393
+ ):
394
+
395
+ box_sizes = torch.sqrt(torch.cat([boxes.area() for boxes in box_lists]))
396
+ # Eqn.(1) in FPN paper
397
+ level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
398
+ # clamp level to (min, max), in case the box size is too large or too small
399
+ # for the available feature maps
400
+ level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
401
+ return level_assignments.to(torch.int64) - min_level
402
+
403
+
404
+ # Helper Classes
405
+ class _NewEmptyTensorOp(torch.autograd.Function):
406
+ @staticmethod
407
+ def forward(ctx, x, new_shape):
408
+ ctx.shape = x.shape
409
+ return x.new_empty(new_shape)
410
+
411
+ @staticmethod
412
+ def backward(ctx, grad):
413
+ shape = ctx.shape
414
+ return _NewEmptyTensorOp.apply(grad, shape), None
415
+
416
+
417
+ class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
418
+ def __new__(cls, *, channels=None, height=None, width=None, stride=None):
419
+ return super().__new__(cls, channels, height, width, stride)
420
+
421
+
422
+ class Box2BoxTransform(object):
423
+ """
424
+ This R-CNN transformation scales the box's width and height
425
+ by exp(dw), exp(dh) and shifts a box's center by the offset
426
+ (dx * width, dy * height).
427
+ """
428
+
429
+ def __init__(self, weights: Tuple[float, float, float, float], scale_clamp: float = None):
430
+ """
431
+ Args:
432
+ weights (4-element tuple): Scaling factors that are applied to the
433
+ (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set
434
+ such that the deltas have unit variance; now they are treated as
435
+ hyperparameters of the system.
436
+ scale_clamp (float): When predicting deltas, the predicted box scaling
437
+ factors (dw and dh) are clamped such that they are <= scale_clamp.
438
+ """
439
+ self.weights = weights
440
+ if scale_clamp is not None:
441
+ self.scale_clamp = scale_clamp
442
+ else:
443
+ """
444
+ Value for clamping large dw and dh predictions.
445
+ The heuristic is that we clamp such that dw and dh are no larger
446
+ than what would transform a 16px box into a 1000px box
447
+ (based on a small anchor, 16px, and a typical image size, 1000px).
448
+ """
449
+ self.scale_clamp = math.log(1000.0 / 16)
450
+
451
+ def get_deltas(self, src_boxes, target_boxes):
452
+ """
453
+ Get box regression transformation deltas (dx, dy, dw, dh) that can be used
454
+ to transform the `src_boxes` into the `target_boxes`. That is, the relation
455
+ ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless
456
+ any delta is too large and is clamped).
457
+ Args:
458
+ src_boxes (Tensor): source boxes, e.g., object proposals
459
+ target_boxes (Tensor): target of the transformation, e.g., ground-truth
460
+ boxes.
461
+ """
462
+ assert isinstance(src_boxes, torch.Tensor), type(src_boxes)
463
+ assert isinstance(target_boxes, torch.Tensor), type(target_boxes)
464
+
465
+ src_widths = src_boxes[:, 2] - src_boxes[:, 0]
466
+ src_heights = src_boxes[:, 3] - src_boxes[:, 1]
467
+ src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths
468
+ src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights
469
+
470
+ target_widths = target_boxes[:, 2] - target_boxes[:, 0]
471
+ target_heights = target_boxes[:, 3] - target_boxes[:, 1]
472
+ target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths
473
+ target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights
474
+
475
+ wx, wy, ww, wh = self.weights
476
+ dx = wx * (target_ctr_x - src_ctr_x) / src_widths
477
+ dy = wy * (target_ctr_y - src_ctr_y) / src_heights
478
+ dw = ww * torch.log(target_widths / src_widths)
479
+ dh = wh * torch.log(target_heights / src_heights)
480
+
481
+ deltas = torch.stack((dx, dy, dw, dh), dim=1)
482
+ assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!"
483
+ return deltas
484
+
485
+ def apply_deltas(self, deltas, boxes):
486
+ """
487
+ Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
488
+ Args:
489
+ deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
490
+ deltas[i] represents k potentially different class-specific
491
+ box transformations for the single box boxes[i].
492
+ boxes (Tensor): boxes to transform, of shape (N, 4)
493
+ """
494
+ boxes = boxes.to(deltas.dtype)
495
+
496
+ widths = boxes[:, 2] - boxes[:, 0]
497
+ heights = boxes[:, 3] - boxes[:, 1]
498
+ ctr_x = boxes[:, 0] + 0.5 * widths
499
+ ctr_y = boxes[:, 1] + 0.5 * heights
500
+
501
+ wx, wy, ww, wh = self.weights
502
+ dx = deltas[:, 0::4] / wx
503
+ dy = deltas[:, 1::4] / wy
504
+ dw = deltas[:, 2::4] / ww
505
+ dh = deltas[:, 3::4] / wh
506
+
507
+ # Prevent sending too large values into torch.exp()
508
+ dw = torch.clamp(dw, max=self.scale_clamp)
509
+ dh = torch.clamp(dh, max=self.scale_clamp)
510
+
511
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
512
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
513
+ pred_w = torch.exp(dw) * widths[:, None]
514
+ pred_h = torch.exp(dh) * heights[:, None]
515
+
516
+ pred_boxes = torch.zeros_like(deltas)
517
+ pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
518
+ pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
519
+ pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
520
+ pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
521
+ return pred_boxes
522
+
523
+
524
+ class Matcher(object):
525
+ """
526
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
527
+ element. Each predicted element will have exactly zero or one matches; each
528
+ ground-truth element may be matched to zero or more predicted elements.
529
+ The matching is determined by the MxN match_quality_matrix, that characterizes
530
+ how well each (ground-truth, prediction)-pair match each other. For example,
531
+ if the elements are boxes, this matrix may contain box intersection-over-union
532
+ overlap values.
533
+ The matcher returns (a) a vector of length N containing the index of the
534
+ ground-truth element m in [0, M) that matches to prediction n in [0, N).
535
+ (b) a vector of length N containing the labels for each prediction.
536
+ """
537
+
538
+ def __init__(
539
+ self,
540
+ thresholds: List[float],
541
+ labels: List[int],
542
+ allow_low_quality_matches: bool = False,
543
+ ):
544
+ """
545
+ Args:
546
+ thresholds (list): a list of thresholds used to stratify predictions
547
+ into levels.
548
+ labels (list): a list of values to label predictions belonging at
549
+ each level. A label can be one of {-1, 0, 1} signifying
550
+ {ignore, negative class, positive class}, respectively.
551
+ allow_low_quality_matches (bool): if True, produce additional matches or predictions with maximum match quality lower than high_threshold.
552
+ For example, thresholds = [0.3, 0.5] labels = [0, -1, 1] All predictions with iou < 0.3 will be marked with 0 and
553
+ thus will be considered as false positives while training. All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
554
+ thus will be ignored. All predictions with 0.5 <= iou will be marked with 1 and thus will be considered as true positives.
555
+ """
556
+ thresholds = thresholds[:]
557
+ assert thresholds[0] > 0
558
+ thresholds.insert(0, -float("inf"))
559
+ thresholds.append(float("inf"))
560
+ assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])])
561
+ assert all([label_i in [-1, 0, 1] for label_i in labels])
562
+ assert len(labels) == len(thresholds) - 1
563
+ self.thresholds = thresholds
564
+ self.labels = labels
565
+ self.allow_low_quality_matches = allow_low_quality_matches
566
+
567
+ def __call__(self, match_quality_matrix):
568
+ """
569
+ Args:
570
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted
571
+ elements. All elements must be >= 0 (due to the us of `torch.nonzero` for selecting indices in :meth:`set_low_quality_matches_`).
572
+ Returns:
573
+ matches (Tensor[int64]): a vector of length N, where matches[i] is a matched ground-truth index in [0, M)
574
+ match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates true or false positive or ignored
575
+ """
576
+ assert match_quality_matrix.dim() == 2
577
+ if match_quality_matrix.numel() == 0:
578
+ default_matches = match_quality_matrix.new_full((match_quality_matrix.size(1),), 0, dtype=torch.int64)
579
+ # When no gt boxes exist, we define IOU = 0 and therefore set labels
580
+ # to `self.labels[0]`, which usually defaults to background class 0
581
+ # To choose to ignore instead,
582
+ # can make labels=[-1,0,-1,1] + set appropriate thresholds
583
+ default_match_labels = match_quality_matrix.new_full(
584
+ (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
585
+ )
586
+ return default_matches, default_match_labels
587
+
588
+ assert torch.all(match_quality_matrix >= 0)
589
+
590
+ # match_quality_matrix is M (gt) x N (predicted)
591
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
592
+ matched_vals, matches = match_quality_matrix.max(dim=0)
593
+
594
+ match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
595
+
596
+ for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
597
+ low_high = (matched_vals >= low) & (matched_vals < high)
598
+ match_labels[low_high] = l
599
+
600
+ if self.allow_low_quality_matches:
601
+ self.set_low_quality_matches_(match_labels, match_quality_matrix)
602
+
603
+ return matches, match_labels
604
+
605
+ def set_low_quality_matches_(self, match_labels, match_quality_matrix):
606
+ """
607
+ Produce additional matches for predictions that have only low-quality matches.
608
+ Specifically, for each ground-truth G find the set of predictions that have
609
+ maximum overlap with it (including ties); for each prediction in that set, if
610
+ it is unmatched, then match it to the ground-truth G.
611
+ This function implements the RPN assignment case (i)
612
+ in Sec. 3.1.2 of Faster R-CNN.
613
+ """
614
+ # For each gt, find the prediction with which it has highest quality
615
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
616
+ # Find the highest quality match available, even if it is low, including ties.
617
+ # Note that the matches qualities must be positive due to the use of
618
+ # `torch.nonzero`.
619
+ of_quality_inds = match_quality_matrix == highest_quality_foreach_gt[:, None]
620
+ if of_quality_inds.dim() == 0:
621
+ (_, pred_inds_with_highest_quality) = of_quality_inds.unsqueeze(0).nonzero().unbind(1)
622
+ else:
623
+ (_, pred_inds_with_highest_quality) = of_quality_inds.nonzero().unbind(1)
624
+ match_labels[pred_inds_with_highest_quality] = 1
625
+
626
+
627
+ class RPNOutputs(object):
628
+ def __init__(
629
+ self,
630
+ box2box_transform,
631
+ anchor_matcher,
632
+ batch_size_per_image,
633
+ positive_fraction,
634
+ images,
635
+ pred_objectness_logits,
636
+ pred_anchor_deltas,
637
+ anchors,
638
+ boundary_threshold=0,
639
+ gt_boxes=None,
640
+ smooth_l1_beta=0.0,
641
+ ):
642
+ """
643
+ Args:
644
+ box2box_transform (Box2BoxTransform): :class:`Box2BoxTransform` instance for anchor-proposal transformations.
645
+ anchor_matcher (Matcher): :class:`Matcher` instance for matching anchors to ground-truth boxes; used to determine training labels.
646
+ batch_size_per_image (int): number of proposals to sample when training
647
+ positive_fraction (float): target fraction of sampled proposals that should be positive
648
+ images (ImageList): :class:`ImageList` instance representing N input images
649
+ pred_objectness_logits (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A, Hi, W)
650
+ pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape (N, A*4, Hi, Wi)
651
+ anchors (list[torch.Tensor]): nested list of boxes. anchors[i][j] at (n, l) stores anchor array for feature map l
652
+ boundary_threshold (int): if >= 0, then anchors that extend beyond the image boundary by more than boundary_thresh are not used in training.
653
+ gt_boxes (list[Boxes], optional): A list of N elements.
654
+ smooth_l1_beta (float): The transition point between L1 and L2 lossn. When set to 0, the loss becomes L1. When +inf, it is ignored
655
+ """
656
+ self.box2box_transform = box2box_transform
657
+ self.anchor_matcher = anchor_matcher
658
+ self.batch_size_per_image = batch_size_per_image
659
+ self.positive_fraction = positive_fraction
660
+ self.pred_objectness_logits = pred_objectness_logits
661
+ self.pred_anchor_deltas = pred_anchor_deltas
662
+
663
+ self.anchors = anchors
664
+ self.gt_boxes = gt_boxes
665
+ self.num_feature_maps = len(pred_objectness_logits)
666
+ self.num_images = len(images)
667
+ self.boundary_threshold = boundary_threshold
668
+ self.smooth_l1_beta = smooth_l1_beta
669
+
670
+ def _get_ground_truth(self):
671
+ raise NotImplementedError()
672
+
673
+ def predict_proposals(self):
674
+ # pred_anchor_deltas: (L, N, ? Hi, Wi)
675
+ # anchors:(N, L, -1, B)
676
+ # here we loop over specific feature map, NOT images
677
+ proposals = []
678
+ anchors = self.anchors.transpose(0, 1)
679
+ for anchors_i, pred_anchor_deltas_i in zip(anchors, self.pred_anchor_deltas):
680
+ B = anchors_i.size(-1)
681
+ N, _, Hi, Wi = pred_anchor_deltas_i.shape
682
+ anchors_i = anchors_i.flatten(start_dim=0, end_dim=1)
683
+ pred_anchor_deltas_i = pred_anchor_deltas_i.view(N, -1, B, Hi, Wi).permute(0, 3, 4, 1, 2).reshape(-1, B)
684
+ proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i)
685
+ # Append feature map proposals with shape (N, Hi*Wi*A, B)
686
+ proposals.append(proposals_i.view(N, -1, B))
687
+ proposals = torch.stack(proposals)
688
+ return proposals
689
+
690
+ def predict_objectness_logits(self):
691
+ """
692
+ Returns:
693
+ pred_objectness_logits (list[Tensor]) -> (N, Hi*Wi*A).
694
+ """
695
+ pred_objectness_logits = [
696
+ # Reshape: (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
697
+ score.permute(0, 2, 3, 1).reshape(self.num_images, -1)
698
+ for score in self.pred_objectness_logits
699
+ ]
700
+ return pred_objectness_logits
701
+
702
+
703
+ # Main Classes
704
+ class Conv2d(torch.nn.Conv2d):
705
+ def __init__(self, *args, **kwargs):
706
+ norm = kwargs.pop("norm", None)
707
+ activation = kwargs.pop("activation", None)
708
+ super().__init__(*args, **kwargs)
709
+
710
+ self.norm = norm
711
+ self.activation = activation
712
+
713
+ def forward(self, x):
714
+ if x.numel() == 0 and self.training:
715
+ assert not isinstance(self.norm, torch.nn.SyncBatchNorm)
716
+ if x.numel() == 0:
717
+ assert not isinstance(self.norm, torch.nn.GroupNorm)
718
+ output_shape = [
719
+ (i + 2 * p - (di * (k - 1) + 1)) // s + 1
720
+ for i, p, di, k, s in zip(
721
+ x.shape[-2:],
722
+ self.padding,
723
+ self.dilation,
724
+ self.kernel_size,
725
+ self.stride,
726
+ )
727
+ ]
728
+ output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
729
+ empty = _NewEmptyTensorOp.apply(x, output_shape)
730
+ if self.training:
731
+ _dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
732
+ return empty + _dummy
733
+ else:
734
+ return empty
735
+
736
+ x = super().forward(x)
737
+ if self.norm is not None:
738
+ x = self.norm(x)
739
+ if self.activation is not None:
740
+ x = self.activation(x)
741
+ return x
742
+
743
+
744
+ class LastLevelMaxPool(nn.Module):
745
+ """
746
+ This module is used in the original FPN to generate a downsampled P6 feature from P5.
747
+ """
748
+
749
+ def __init__(self):
750
+ super().__init__()
751
+ self.num_levels = 1
752
+ self.in_feature = "p5"
753
+
754
+ def forward(self, x):
755
+ return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
756
+
757
+
758
+ class LastLevelP6P7(nn.Module):
759
+ """
760
+ This module is used in RetinaNet to generate extra layers, P6 and P7 from C5 feature.
761
+ """
762
+
763
+ def __init__(self, in_channels, out_channels):
764
+ super().__init__()
765
+ self.num_levels = 2
766
+ self.in_feature = "res5"
767
+ self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
768
+ self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
769
+
770
+ def forward(self, c5):
771
+ p6 = self.p6(c5)
772
+ p7 = self.p7(F.relu(p6))
773
+ return [p6, p7]
774
+
775
+
776
+ class BasicStem(nn.Module):
777
+ def __init__(self, in_channels=3, out_channels=64, norm="BN", caffe_maxpool=False):
778
+ super().__init__()
779
+ self.conv1 = Conv2d(
780
+ in_channels,
781
+ out_channels,
782
+ kernel_size=7,
783
+ stride=2,
784
+ padding=3,
785
+ bias=False,
786
+ norm=get_norm(norm, out_channels),
787
+ )
788
+ self.caffe_maxpool = caffe_maxpool
789
+ # use pad 1 instead of pad zero
790
+
791
+ def forward(self, x):
792
+ x = self.conv1(x)
793
+ x = F.relu_(x)
794
+ if self.caffe_maxpool:
795
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=0, ceil_mode=True)
796
+ else:
797
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
798
+ return x
799
+
800
+ @property
801
+ def out_channels(self):
802
+ return self.conv1.out_channels
803
+
804
+ @property
805
+ def stride(self):
806
+ return 4 # = stride 2 conv -> stride 2 max pool
807
+
808
+
809
+ class ResNetBlockBase(nn.Module):
810
+ def __init__(self, in_channels, out_channels, stride):
811
+ super().__init__()
812
+ self.in_channels = in_channels
813
+ self.out_channels = out_channels
814
+ self.stride = stride
815
+
816
+ def freeze(self):
817
+ for p in self.parameters():
818
+ p.requires_grad = False
819
+ return self
820
+
821
+
822
+ class BottleneckBlock(ResNetBlockBase):
823
+ def __init__(
824
+ self,
825
+ in_channels,
826
+ out_channels,
827
+ bottleneck_channels,
828
+ stride=1,
829
+ num_groups=1,
830
+ norm="BN",
831
+ stride_in_1x1=False,
832
+ dilation=1,
833
+ ):
834
+ super().__init__(in_channels, out_channels, stride)
835
+
836
+ if in_channels != out_channels:
837
+ self.shortcut = Conv2d(
838
+ in_channels,
839
+ out_channels,
840
+ kernel_size=1,
841
+ stride=stride,
842
+ bias=False,
843
+ norm=get_norm(norm, out_channels),
844
+ )
845
+ else:
846
+ self.shortcut = None
847
+
848
+ # The original MSRA ResNet models have stride in the first 1x1 conv
849
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
850
+ # stride in the 3x3 conv
851
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
852
+
853
+ self.conv1 = Conv2d(
854
+ in_channels,
855
+ bottleneck_channels,
856
+ kernel_size=1,
857
+ stride=stride_1x1,
858
+ bias=False,
859
+ norm=get_norm(norm, bottleneck_channels),
860
+ )
861
+
862
+ self.conv2 = Conv2d(
863
+ bottleneck_channels,
864
+ bottleneck_channels,
865
+ kernel_size=3,
866
+ stride=stride_3x3,
867
+ padding=1 * dilation,
868
+ bias=False,
869
+ groups=num_groups,
870
+ dilation=dilation,
871
+ norm=get_norm(norm, bottleneck_channels),
872
+ )
873
+
874
+ self.conv3 = Conv2d(
875
+ bottleneck_channels,
876
+ out_channels,
877
+ kernel_size=1,
878
+ bias=False,
879
+ norm=get_norm(norm, out_channels),
880
+ )
881
+
882
+ def forward(self, x):
883
+ out = self.conv1(x)
884
+ out = F.relu_(out)
885
+
886
+ out = self.conv2(out)
887
+ out = F.relu_(out)
888
+
889
+ out = self.conv3(out)
890
+
891
+ if self.shortcut is not None:
892
+ shortcut = self.shortcut(x)
893
+ else:
894
+ shortcut = x
895
+
896
+ out += shortcut
897
+ out = F.relu_(out)
898
+ return out
899
+
900
+
901
+ class Backbone(nn.Module, metaclass=ABCMeta):
902
+ def __init__(self):
903
+ super().__init__()
904
+
905
+ @abstractmethod
906
+ def forward(self):
907
+ pass
908
+
909
+ @property
910
+ def size_divisibility(self):
911
+ """
912
+ Some backbones require the input height and width to be divisible by a specific integer. This is
913
+ typically true for encoder / decoder type networks with lateral connection (e.g., FPN) for which feature maps need to match
914
+ dimension in the "bottom up" and "top down" paths. Set to 0 if no specific input size divisibility is required.
915
+ """
916
+ return 0
917
+
918
+ def output_shape(self):
919
+ return {
920
+ name: ShapeSpec(
921
+ channels=self._out_feature_channels[name],
922
+ stride=self._out_feature_strides[name],
923
+ )
924
+ for name in self._out_features
925
+ }
926
+
927
+ @property
928
+ def out_features(self):
929
+ """deprecated"""
930
+ return self._out_features
931
+
932
+ @property
933
+ def out_feature_strides(self):
934
+ """deprecated"""
935
+ return {f: self._out_feature_strides[f] for f in self._out_features}
936
+
937
+ @property
938
+ def out_feature_channels(self):
939
+ """deprecated"""
940
+ return {f: self._out_feature_channels[f] for f in self._out_features}
941
+
942
+
943
+ class ResNet(Backbone):
944
+ def __init__(self, stem, stages, num_classes=None, out_features=None):
945
+ """
946
+ Args:
947
+ stem (nn.Module): a stem module
948
+ stages (list[list[ResNetBlock]]): several (typically 4) stages, each contains multiple :class:`ResNetBlockBase`.
949
+ num_classes (None or int): if None, will not perform classification.
950
+ out_features (list[str]): name of the layers whose outputs should be returned in forward. Can be anything in:
951
+ "stem", "linear", or "res2" ... If None, will return the output of the last layer.
952
+ """
953
+ super(ResNet, self).__init__()
954
+ self.stem = stem
955
+ self.num_classes = num_classes
956
+
957
+ current_stride = self.stem.stride
958
+ self._out_feature_strides = {"stem": current_stride}
959
+ self._out_feature_channels = {"stem": self.stem.out_channels}
960
+
961
+ self.stages_and_names = []
962
+ for i, blocks in enumerate(stages):
963
+ for block in blocks:
964
+ assert isinstance(block, ResNetBlockBase), block
965
+ curr_channels = block.out_channels
966
+ stage = nn.Sequential(*blocks)
967
+ name = "res" + str(i + 2)
968
+ self.add_module(name, stage)
969
+ self.stages_and_names.append((stage, name))
970
+ self._out_feature_strides[name] = current_stride = int(
971
+ current_stride * np.prod([k.stride for k in blocks])
972
+ )
973
+ self._out_feature_channels[name] = blocks[-1].out_channels
974
+
975
+ if num_classes is not None:
976
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
977
+ self.linear = nn.Linear(curr_channels, num_classes)
978
+
979
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
980
+ # "The 1000-way fully-connected layer is initialized by
981
+ # drawing weights from a zero-mean Gaussian with std of 0.01."
982
+ nn.init.normal_(self.linear.weight, stddev=0.01)
983
+ name = "linear"
984
+
985
+ if out_features is None:
986
+ out_features = [name]
987
+ self._out_features = out_features
988
+ assert len(self._out_features)
989
+ children = [x[0] for x in self.named_children()]
990
+ for out_feature in self._out_features:
991
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
992
+
993
+ def forward(self, x):
994
+ outputs = {}
995
+ x = self.stem(x)
996
+ if "stem" in self._out_features:
997
+ outputs["stem"] = x
998
+ for stage, name in self.stages_and_names:
999
+ x = stage(x)
1000
+ if name in self._out_features:
1001
+ outputs[name] = x
1002
+ if self.num_classes is not None:
1003
+ x = self.avgpool(x)
1004
+ x = self.linear(x)
1005
+ if "linear" in self._out_features:
1006
+ outputs["linear"] = x
1007
+ return outputs
1008
+
1009
+ def output_shape(self):
1010
+ return {
1011
+ name: ShapeSpec(
1012
+ channels=self._out_feature_channels[name],
1013
+ stride=self._out_feature_strides[name],
1014
+ )
1015
+ for name in self._out_features
1016
+ }
1017
+
1018
+ @staticmethod
1019
+ def make_stage(
1020
+ block_class,
1021
+ num_blocks,
1022
+ first_stride=None,
1023
+ *,
1024
+ in_channels,
1025
+ out_channels,
1026
+ **kwargs,
1027
+ ):
1028
+ """
1029
+ Usually, layers that produce the same feature map spatial size
1030
+ are defined as one "stage".
1031
+ Under such definition, stride_per_block[1:] should all be 1.
1032
+ """
1033
+ if first_stride is not None:
1034
+ assert "stride" not in kwargs and "stride_per_block" not in kwargs
1035
+ kwargs["stride_per_block"] = [first_stride] + [1] * (num_blocks - 1)
1036
+ blocks = []
1037
+ for i in range(num_blocks):
1038
+ curr_kwargs = {}
1039
+ for k, v in kwargs.items():
1040
+ if k.endswith("_per_block"):
1041
+ assert len(v) == num_blocks, (
1042
+ f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
1043
+ )
1044
+ newk = k[: -len("_per_block")]
1045
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
1046
+ curr_kwargs[newk] = v[i]
1047
+ else:
1048
+ curr_kwargs[k] = v
1049
+
1050
+ blocks.append(block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs))
1051
+ in_channels = out_channels
1052
+
1053
+ return blocks
1054
+
1055
+
1056
+ class ROIPooler(nn.Module):
1057
+ """
1058
+ Region of interest feature map pooler that supports pooling from one or more
1059
+ feature maps.
1060
+ """
1061
+
1062
+ def __init__(
1063
+ self,
1064
+ output_size,
1065
+ scales,
1066
+ sampling_ratio,
1067
+ canonical_box_size=224,
1068
+ canonical_level=4,
1069
+ ):
1070
+ super().__init__()
1071
+ # assumption that stride is a power of 2.
1072
+ min_level = -math.log2(scales[0])
1073
+ max_level = -math.log2(scales[-1])
1074
+
1075
+ # a bunch of testing
1076
+ assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
1077
+ assert len(scales) == max_level - min_level + 1, "not pyramid"
1078
+ assert 0 < min_level and min_level <= max_level
1079
+ if isinstance(output_size, int):
1080
+ output_size = (output_size, output_size)
1081
+ assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
1082
+ if len(scales) > 1:
1083
+ assert min_level <= canonical_level and canonical_level <= max_level
1084
+ assert canonical_box_size > 0
1085
+
1086
+ self.output_size = output_size
1087
+ self.min_level = int(min_level)
1088
+ self.max_level = int(max_level)
1089
+ self.level_poolers = nn.ModuleList(RoIPool(output_size, spatial_scale=scale) for scale in scales)
1090
+ self.canonical_level = canonical_level
1091
+ self.canonical_box_size = canonical_box_size
1092
+
1093
+ def forward(self, feature_maps, boxes):
1094
+ """
1095
+ Args:
1096
+ feature_maps: List[torch.Tensor(N,C,W,H)]
1097
+ box_lists: list[torch.Tensor])
1098
+ Returns:
1099
+ A tensor of shape(N*B, Channels, output_size, output_size)
1100
+ """
1101
+ x = [v for v in feature_maps.values()]
1102
+ num_level_assignments = len(self.level_poolers)
1103
+ assert len(x) == num_level_assignments and len(boxes) == x[0].size(0)
1104
+
1105
+ pooler_fmt_boxes = convert_boxes_to_pooler_format(boxes)
1106
+
1107
+ if num_level_assignments == 1:
1108
+ return self.level_poolers[0](x[0], pooler_fmt_boxes)
1109
+
1110
+ level_assignments = assign_boxes_to_levels(
1111
+ boxes,
1112
+ self.min_level,
1113
+ self.max_level,
1114
+ self.canonical_box_size,
1115
+ self.canonical_level,
1116
+ )
1117
+
1118
+ num_boxes = len(pooler_fmt_boxes)
1119
+ num_channels = x[0].shape[1]
1120
+ output_size = self.output_size[0]
1121
+
1122
+ dtype, device = x[0].dtype, x[0].device
1123
+ output = torch.zeros(
1124
+ (num_boxes, num_channels, output_size, output_size),
1125
+ dtype=dtype,
1126
+ device=device,
1127
+ )
1128
+
1129
+ for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
1130
+ inds = torch.nonzero(level_assignments == level).squeeze(1)
1131
+ pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
1132
+ output[inds] = pooler(x_level, pooler_fmt_boxes_level)
1133
+
1134
+ return output
1135
+
1136
+
1137
+ class ROIOutputs(object):
1138
+ def __init__(self, cfg, training=False):
1139
+ self.smooth_l1_beta = cfg.ROI_BOX_HEAD.SMOOTH_L1_BETA
1140
+ self.box2box_transform = Box2BoxTransform(weights=cfg.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)
1141
+ self.training = training
1142
+ self.score_thresh = cfg.ROI_HEADS.SCORE_THRESH_TEST
1143
+ self.min_detections = cfg.MIN_DETECTIONS
1144
+ self.max_detections = cfg.MAX_DETECTIONS
1145
+
1146
+ nms_thresh = cfg.ROI_HEADS.NMS_THRESH_TEST
1147
+ if not isinstance(nms_thresh, list):
1148
+ nms_thresh = [nms_thresh]
1149
+ self.nms_thresh = nms_thresh
1150
+
1151
+ def _predict_boxes(self, proposals, box_deltas, preds_per_image):
1152
+ num_pred = box_deltas.size(0)
1153
+ B = proposals[0].size(-1)
1154
+ K = box_deltas.size(-1) // B
1155
+ box_deltas = box_deltas.view(num_pred * K, B)
1156
+ proposals = torch.cat(proposals, dim=0).unsqueeze(-2).expand(num_pred, K, B)
1157
+ proposals = proposals.reshape(-1, B)
1158
+ boxes = self.box2box_transform.apply_deltas(box_deltas, proposals)
1159
+ return boxes.view(num_pred, K * B).split(preds_per_image, dim=0)
1160
+
1161
+ def _predict_objs(self, obj_logits, preds_per_image):
1162
+ probs = F.softmax(obj_logits, dim=-1)
1163
+ probs = probs.split(preds_per_image, dim=0)
1164
+ return probs
1165
+
1166
+ def _predict_attrs(self, attr_logits, preds_per_image):
1167
+ attr_logits = attr_logits[..., :-1].softmax(-1)
1168
+ attr_probs, attrs = attr_logits.max(-1)
1169
+ return attr_probs.split(preds_per_image, dim=0), attrs.split(preds_per_image, dim=0)
1170
+
1171
+ @torch.no_grad()
1172
+ def inference(
1173
+ self,
1174
+ obj_logits,
1175
+ attr_logits,
1176
+ box_deltas,
1177
+ pred_boxes,
1178
+ features,
1179
+ sizes,
1180
+ scales=None,
1181
+ ):
1182
+ # only the pred boxes is the
1183
+ preds_per_image = [p.size(0) for p in pred_boxes]
1184
+ boxes_all = self._predict_boxes(pred_boxes, box_deltas, preds_per_image)
1185
+ obj_scores_all = self._predict_objs(obj_logits, preds_per_image) # list of length N
1186
+ attr_probs_all, attrs_all = self._predict_attrs(attr_logits, preds_per_image)
1187
+ features = features.split(preds_per_image, dim=0)
1188
+
1189
+ # fun for each image too, also I can experiment and do multiple images
1190
+ final_results = []
1191
+ zipped = zip(boxes_all, obj_scores_all, attr_probs_all, attrs_all, sizes)
1192
+ for i, (boxes, obj_scores, attr_probs, attrs, size) in enumerate(zipped):
1193
+ for nms_t in self.nms_thresh:
1194
+ outputs = do_nms(
1195
+ boxes,
1196
+ obj_scores,
1197
+ size,
1198
+ self.score_thresh,
1199
+ nms_t,
1200
+ self.min_detections,
1201
+ self.max_detections,
1202
+ )
1203
+ if outputs is not None:
1204
+ max_boxes, max_scores, classes, ids = outputs
1205
+ break
1206
+
1207
+ if scales is not None:
1208
+ scale_yx = scales[i]
1209
+ max_boxes[:, 0::2] *= scale_yx[1]
1210
+ max_boxes[:, 1::2] *= scale_yx[0]
1211
+
1212
+ final_results.append(
1213
+ (
1214
+ max_boxes,
1215
+ classes,
1216
+ max_scores,
1217
+ attrs[ids],
1218
+ attr_probs[ids],
1219
+ features[i][ids],
1220
+ )
1221
+ )
1222
+ boxes, classes, class_probs, attrs, attr_probs, roi_features = map(list, zip(*final_results))
1223
+ return boxes, classes, class_probs, attrs, attr_probs, roi_features
1224
+
1225
+ def training(self, obj_logits, attr_logits, box_deltas, pred_boxes, features, sizes):
1226
+ pass
1227
+
1228
+ def __call__(
1229
+ self,
1230
+ obj_logits,
1231
+ attr_logits,
1232
+ box_deltas,
1233
+ pred_boxes,
1234
+ features,
1235
+ sizes,
1236
+ scales=None,
1237
+ ):
1238
+ if self.training:
1239
+ raise NotImplementedError()
1240
+ return self.inference(
1241
+ obj_logits,
1242
+ attr_logits,
1243
+ box_deltas,
1244
+ pred_boxes,
1245
+ features,
1246
+ sizes,
1247
+ scales=scales,
1248
+ )
1249
+
1250
+
1251
+ class Res5ROIHeads(nn.Module):
1252
+ """
1253
+ ROIHeads perform all per-region computation in an R-CNN.
1254
+ It contains logic of cropping the regions, extract per-region features
1255
+ (by the res-5 block in this case), and make per-region predictions.
1256
+ """
1257
+
1258
+ def __init__(self, cfg, input_shape):
1259
+ super().__init__()
1260
+ self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE
1261
+ self.positive_sample_fraction = cfg.ROI_HEADS.POSITIVE_FRACTION
1262
+ self.in_features = cfg.ROI_HEADS.IN_FEATURES
1263
+ self.num_classes = cfg.ROI_HEADS.NUM_CLASSES
1264
+ self.proposal_append_gt = cfg.ROI_HEADS.PROPOSAL_APPEND_GT
1265
+ self.feature_strides = {k: v.stride for k, v in input_shape.items()}
1266
+ self.feature_channels = {k: v.channels for k, v in input_shape.items()}
1267
+ self.cls_agnostic_bbox_reg = cfg.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
1268
+ self.stage_channel_factor = 2 ** 3 # res5 is 8x res2
1269
+ self.out_channels = cfg.RESNETS.RES2_OUT_CHANNELS * self.stage_channel_factor
1270
+
1271
+ # self.proposal_matcher = Matcher(
1272
+ # cfg.ROI_HEADS.IOU_THRESHOLDS,
1273
+ # cfg.ROI_HEADS.IOU_LABELS,
1274
+ # allow_low_quality_matches=False,
1275
+ # )
1276
+
1277
+ pooler_resolution = cfg.ROI_BOX_HEAD.POOLER_RESOLUTION
1278
+ pooler_scales = (1.0 / self.feature_strides[self.in_features[0]],)
1279
+ sampling_ratio = cfg.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO
1280
+ res5_halve = cfg.ROI_BOX_HEAD.RES5HALVE
1281
+ use_attr = cfg.ROI_BOX_HEAD.ATTR
1282
+ num_attrs = cfg.ROI_BOX_HEAD.NUM_ATTRS
1283
+
1284
+ self.pooler = ROIPooler(
1285
+ output_size=pooler_resolution,
1286
+ scales=pooler_scales,
1287
+ sampling_ratio=sampling_ratio,
1288
+ )
1289
+
1290
+ self.res5 = self._build_res5_block(cfg)
1291
+ if not res5_halve:
1292
+ """
1293
+ Modifications for VG in RoI heads:
1294
+ 1. Change the stride of conv1 and shortcut in Res5.Block1 from 2 to 1
1295
+ 2. Modifying all conv2 with (padding: 1 --> 2) and (dilation: 1 --> 2)
1296
+ """
1297
+ self.res5[0].conv1.stride = (1, 1)
1298
+ self.res5[0].shortcut.stride = (1, 1)
1299
+ for i in range(3):
1300
+ self.res5[i].conv2.padding = (2, 2)
1301
+ self.res5[i].conv2.dilation = (2, 2)
1302
+
1303
+ self.box_predictor = FastRCNNOutputLayers(
1304
+ self.out_channels,
1305
+ self.num_classes,
1306
+ self.cls_agnostic_bbox_reg,
1307
+ use_attr=use_attr,
1308
+ num_attrs=num_attrs,
1309
+ )
1310
+
1311
+ def _build_res5_block(self, cfg):
1312
+ stage_channel_factor = self.stage_channel_factor # res5 is 8x res2
1313
+ num_groups = cfg.RESNETS.NUM_GROUPS
1314
+ width_per_group = cfg.RESNETS.WIDTH_PER_GROUP
1315
+ bottleneck_channels = num_groups * width_per_group * stage_channel_factor
1316
+ out_channels = self.out_channels
1317
+ stride_in_1x1 = cfg.RESNETS.STRIDE_IN_1X1
1318
+ norm = cfg.RESNETS.NORM
1319
+
1320
+ blocks = ResNet.make_stage(
1321
+ BottleneckBlock,
1322
+ 3,
1323
+ first_stride=2,
1324
+ in_channels=out_channels // 2,
1325
+ bottleneck_channels=bottleneck_channels,
1326
+ out_channels=out_channels,
1327
+ num_groups=num_groups,
1328
+ norm=norm,
1329
+ stride_in_1x1=stride_in_1x1,
1330
+ )
1331
+ return nn.Sequential(*blocks)
1332
+
1333
+ def _shared_roi_transform(self, features, boxes):
1334
+ x = self.pooler(features, boxes)
1335
+ return self.res5(x)
1336
+
1337
+ def forward(self, features, proposal_boxes, gt_boxes=None):
1338
+ if self.training:
1339
+ """
1340
+ see https://github.com/airsplay/py-bottom-up-attention/\
1341
+ blob/master/detectron2/modeling/roi_heads/roi_heads.py
1342
+ """
1343
+ raise NotImplementedError()
1344
+
1345
+ assert not proposal_boxes[0].requires_grad
1346
+ box_features = self._shared_roi_transform(features, proposal_boxes)
1347
+ feature_pooled = box_features.mean(dim=[2, 3]) # pooled to 1x1
1348
+ obj_logits, attr_logits, pred_proposal_deltas = self.box_predictor(feature_pooled)
1349
+ return obj_logits, attr_logits, pred_proposal_deltas, feature_pooled
1350
+
1351
+
1352
+ class AnchorGenerator(nn.Module):
1353
+ """
1354
+ For a set of image sizes and feature maps, computes a set of anchors.
1355
+ """
1356
+
1357
+ def __init__(self, cfg, input_shape: List[ShapeSpec]):
1358
+ super().__init__()
1359
+ sizes = cfg.ANCHOR_GENERATOR.SIZES
1360
+ aspect_ratios = cfg.ANCHOR_GENERATOR.ASPECT_RATIOS
1361
+ self.strides = [x.stride for x in input_shape]
1362
+ self.offset = cfg.ANCHOR_GENERATOR.OFFSET
1363
+ assert 0.0 <= self.offset < 1.0, self.offset
1364
+
1365
+ """
1366
+ sizes (list[list[int]]): sizes[i] is the list of anchor sizes for feat map i
1367
+ 1. given in absolute lengths in units of the input image;
1368
+ 2. they do not dynamically scale if the input image size changes.
1369
+ aspect_ratios (list[list[float]])
1370
+ strides (list[int]): stride of each input feature.
1371
+ """
1372
+
1373
+ self.num_features = len(self.strides)
1374
+ self.cell_anchors = nn.ParameterList(self._calculate_anchors(sizes, aspect_ratios))
1375
+ self._spacial_feat_dim = 4
1376
+
1377
+ def _calculate_anchors(self, sizes, aspect_ratios):
1378
+ # If one size (or aspect ratio) is specified and there are multiple feature
1379
+ # maps, then we "broadcast" anchors of that single size (or aspect ratio)
1380
+ if len(sizes) == 1:
1381
+ sizes *= self.num_features
1382
+ if len(aspect_ratios) == 1:
1383
+ aspect_ratios *= self.num_features
1384
+ assert self.num_features == len(sizes)
1385
+ assert self.num_features == len(aspect_ratios)
1386
+
1387
+ cell_anchors = [self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios)]
1388
+
1389
+ return cell_anchors
1390
+
1391
+ @property
1392
+ def box_dim(self):
1393
+ return self._spacial_feat_dim
1394
+
1395
+ @property
1396
+ def num_cell_anchors(self):
1397
+ """
1398
+ Returns:
1399
+ list[int]: Each int is the number of anchors at every pixel location, on that feature map.
1400
+ """
1401
+ return [len(cell_anchors) for cell_anchors in self.cell_anchors]
1402
+
1403
+ def grid_anchors(self, grid_sizes):
1404
+ anchors = []
1405
+ for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
1406
+ shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
1407
+ shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
1408
+
1409
+ anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
1410
+
1411
+ return anchors
1412
+
1413
+ def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)):
1414
+ """
1415
+ anchors are continuous geometric rectangles
1416
+ centered on one feature map point sample.
1417
+ We can later build the set of anchors
1418
+ for the entire feature map by tiling these tensors
1419
+ """
1420
+
1421
+ anchors = []
1422
+ for size in sizes:
1423
+ area = size ** 2.0
1424
+ for aspect_ratio in aspect_ratios:
1425
+ w = math.sqrt(area / aspect_ratio)
1426
+ h = aspect_ratio * w
1427
+ x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
1428
+ anchors.append([x0, y0, x1, y1])
1429
+ return nn.Parameter(torch.Tensor(anchors))
1430
+
1431
+ def forward(self, features):
1432
+ """
1433
+ Args:
1434
+ features List[torch.Tensor]: list of feature maps on which to generate anchors.
1435
+ Returns:
1436
+ torch.Tensor: a list of #image elements.
1437
+ """
1438
+ num_images = features[0].size(0)
1439
+ grid_sizes = [feature_map.shape[-2:] for feature_map in features]
1440
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
1441
+ anchors_over_all_feature_maps = torch.stack(anchors_over_all_feature_maps)
1442
+ return anchors_over_all_feature_maps.unsqueeze(0).repeat_interleave(num_images, dim=0)
1443
+
1444
+
1445
+ class RPNHead(nn.Module):
1446
+ """
1447
+ RPN classification and regression heads. Uses a 3x3 conv to produce a shared
1448
+ hidden state from which one 1x1 conv predicts objectness logits for each anchor
1449
+ and a second 1x1 conv predicts bounding-box deltas specifying how to deform
1450
+ each anchor into an object proposal.
1451
+ """
1452
+
1453
+ def __init__(self, cfg, input_shape: List[ShapeSpec]):
1454
+ super().__init__()
1455
+
1456
+ # Standard RPN is shared across levels:
1457
+ in_channels = [s.channels for s in input_shape]
1458
+ assert len(set(in_channels)) == 1, "Each level must have the same channel!"
1459
+ in_channels = in_channels[0]
1460
+
1461
+ anchor_generator = AnchorGenerator(cfg, input_shape)
1462
+ num_cell_anchors = anchor_generator.num_cell_anchors
1463
+ box_dim = anchor_generator.box_dim
1464
+ assert len(set(num_cell_anchors)) == 1, "Each level must have the same number of cell anchors"
1465
+ num_cell_anchors = num_cell_anchors[0]
1466
+
1467
+ if cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS == -1:
1468
+ hid_channels = in_channels
1469
+ else:
1470
+ hid_channels = cfg.PROPOSAL_GENERATOR.HIDDEN_CHANNELS
1471
+ # Modifications for VG in RPN (modeling/proposal_generator/rpn.py)
1472
+ # Use hidden dim instead fo the same dim as Res4 (in_channels)
1473
+
1474
+ # 3x3 conv for the hidden representation
1475
+ self.conv = nn.Conv2d(in_channels, hid_channels, kernel_size=3, stride=1, padding=1)
1476
+ # 1x1 conv for predicting objectness logits
1477
+ self.objectness_logits = nn.Conv2d(hid_channels, num_cell_anchors, kernel_size=1, stride=1)
1478
+ # 1x1 conv for predicting box2box transform deltas
1479
+ self.anchor_deltas = nn.Conv2d(hid_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1)
1480
+
1481
+ for layer in [self.conv, self.objectness_logits, self.anchor_deltas]:
1482
+ nn.init.normal_(layer.weight, std=0.01)
1483
+ nn.init.constant_(layer.bias, 0)
1484
+
1485
+ def forward(self, features):
1486
+ """
1487
+ Args:
1488
+ features (list[Tensor]): list of feature maps
1489
+ """
1490
+ pred_objectness_logits = []
1491
+ pred_anchor_deltas = []
1492
+ for x in features:
1493
+ t = F.relu(self.conv(x))
1494
+ pred_objectness_logits.append(self.objectness_logits(t))
1495
+ pred_anchor_deltas.append(self.anchor_deltas(t))
1496
+ return pred_objectness_logits, pred_anchor_deltas
1497
+
1498
+
1499
+ class RPN(nn.Module):
1500
+ """
1501
+ Region Proposal Network, introduced by the Faster R-CNN paper.
1502
+ """
1503
+
1504
+ def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
1505
+ super().__init__()
1506
+
1507
+ self.min_box_side_len = cfg.PROPOSAL_GENERATOR.MIN_SIZE
1508
+ self.in_features = cfg.RPN.IN_FEATURES
1509
+ self.nms_thresh = cfg.RPN.NMS_THRESH
1510
+ self.batch_size_per_image = cfg.RPN.BATCH_SIZE_PER_IMAGE
1511
+ self.positive_fraction = cfg.RPN.POSITIVE_FRACTION
1512
+ self.smooth_l1_beta = cfg.RPN.SMOOTH_L1_BETA
1513
+ self.loss_weight = cfg.RPN.LOSS_WEIGHT
1514
+
1515
+ self.pre_nms_topk = {
1516
+ True: cfg.RPN.PRE_NMS_TOPK_TRAIN,
1517
+ False: cfg.RPN.PRE_NMS_TOPK_TEST,
1518
+ }
1519
+ self.post_nms_topk = {
1520
+ True: cfg.RPN.POST_NMS_TOPK_TRAIN,
1521
+ False: cfg.RPN.POST_NMS_TOPK_TEST,
1522
+ }
1523
+ self.boundary_threshold = cfg.RPN.BOUNDARY_THRESH
1524
+
1525
+ self.anchor_generator = AnchorGenerator(cfg, [input_shape[f] for f in self.in_features])
1526
+ self.box2box_transform = Box2BoxTransform(weights=cfg.RPN.BBOX_REG_WEIGHTS)
1527
+ self.anchor_matcher = Matcher(
1528
+ cfg.RPN.IOU_THRESHOLDS,
1529
+ cfg.RPN.IOU_LABELS,
1530
+ allow_low_quality_matches=True,
1531
+ )
1532
+ self.rpn_head = RPNHead(cfg, [input_shape[f] for f in self.in_features])
1533
+
1534
+ def training(self, images, image_shapes, features, gt_boxes):
1535
+ pass
1536
+
1537
+ def inference(self, outputs, images, image_shapes, features, gt_boxes=None):
1538
+ outputs = find_top_rpn_proposals(
1539
+ outputs.predict_proposals(),
1540
+ outputs.predict_objectness_logits(),
1541
+ images,
1542
+ image_shapes,
1543
+ self.nms_thresh,
1544
+ self.pre_nms_topk[self.training],
1545
+ self.post_nms_topk[self.training],
1546
+ self.min_box_side_len,
1547
+ self.training,
1548
+ )
1549
+
1550
+ results = []
1551
+ for img in outputs:
1552
+ im_boxes, img_box_logits = img
1553
+ img_box_logits, inds = img_box_logits.sort(descending=True)
1554
+ im_boxes = im_boxes[inds]
1555
+ results.append((im_boxes, img_box_logits))
1556
+
1557
+ (proposal_boxes, logits) = tuple(map(list, zip(*results)))
1558
+ return proposal_boxes, logits
1559
+
1560
+ def forward(self, images, image_shapes, features, gt_boxes=None):
1561
+ """
1562
+ Args:
1563
+ images (torch.Tensor): input images of length `N`
1564
+ features (dict[str: Tensor])
1565
+ gt_instances
1566
+ """
1567
+ # features is dict, key = block level, v = feature_map
1568
+ features = [features[f] for f in self.in_features]
1569
+ pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
1570
+ anchors = self.anchor_generator(features)
1571
+ outputs = RPNOutputs(
1572
+ self.box2box_transform,
1573
+ self.anchor_matcher,
1574
+ self.batch_size_per_image,
1575
+ self.positive_fraction,
1576
+ images,
1577
+ pred_objectness_logits,
1578
+ pred_anchor_deltas,
1579
+ anchors,
1580
+ self.boundary_threshold,
1581
+ gt_boxes,
1582
+ self.smooth_l1_beta,
1583
+ )
1584
+ # For RPN-only models, the proposals are the final output
1585
+
1586
+ if self.training:
1587
+ raise NotImplementedError()
1588
+ return self.training(outputs, images, image_shapes, features, gt_boxes)
1589
+ else:
1590
+ return self.inference(outputs, images, image_shapes, features, gt_boxes)
1591
+
1592
+
1593
+ class FastRCNNOutputLayers(nn.Module):
1594
+ """
1595
+ Two linear layers for predicting Fast R-CNN outputs:
1596
+ (1) proposal-to-detection box regression deltas
1597
+ (2) classification scores
1598
+ """
1599
+
1600
+ def __init__(
1601
+ self,
1602
+ input_size,
1603
+ num_classes,
1604
+ cls_agnostic_bbox_reg,
1605
+ box_dim=4,
1606
+ use_attr=False,
1607
+ num_attrs=-1,
1608
+ ):
1609
+ """
1610
+ Args:
1611
+ input_size (int): channels, or (channels, height, width)
1612
+ num_classes (int)
1613
+ cls_agnostic_bbox_reg (bool)
1614
+ box_dim (int)
1615
+ """
1616
+ super().__init__()
1617
+
1618
+ if not isinstance(input_size, int):
1619
+ input_size = np.prod(input_size)
1620
+
1621
+ # (do + 1 for background class)
1622
+ self.cls_score = nn.Linear(input_size, num_classes + 1)
1623
+ num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
1624
+ self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim)
1625
+
1626
+ self.use_attr = use_attr
1627
+ if use_attr:
1628
+ """
1629
+ Modifications for VG in RoI heads
1630
+ Embedding: {num_classes + 1} --> {input_size // 8}
1631
+ Linear: {input_size + input_size // 8} --> {input_size // 4}
1632
+ Linear: {input_size // 4} --> {num_attrs + 1}
1633
+ """
1634
+ self.cls_embedding = nn.Embedding(num_classes + 1, input_size // 8)
1635
+ self.fc_attr = nn.Linear(input_size + input_size // 8, input_size // 4)
1636
+ self.attr_score = nn.Linear(input_size // 4, num_attrs + 1)
1637
+
1638
+ nn.init.normal_(self.cls_score.weight, std=0.01)
1639
+ nn.init.normal_(self.bbox_pred.weight, std=0.001)
1640
+ for item in [self.cls_score, self.bbox_pred]:
1641
+ nn.init.constant_(item.bias, 0)
1642
+
1643
+ def forward(self, roi_features):
1644
+ if roi_features.dim() > 2:
1645
+ roi_features = torch.flatten(roi_features, start_dim=1)
1646
+ scores = self.cls_score(roi_features)
1647
+ proposal_deltas = self.bbox_pred(roi_features)
1648
+ if self.use_attr:
1649
+ _, max_class = scores.max(-1) # [b, c] --> [b]
1650
+ cls_emb = self.cls_embedding(max_class) # [b] --> [b, 256]
1651
+ roi_features = torch.cat([roi_features, cls_emb], -1) # [b, 2048] + [b, 256] --> [b, 2304]
1652
+ roi_features = self.fc_attr(roi_features)
1653
+ roi_features = F.relu(roi_features)
1654
+ attr_scores = self.attr_score(roi_features)
1655
+ return scores, attr_scores, proposal_deltas
1656
+ else:
1657
+ return scores, proposal_deltas
1658
+
1659
+
1660
+ class GeneralizedRCNN(nn.Module):
1661
+ def __init__(self, cfg):
1662
+ super().__init__()
1663
+
1664
+ self.device = torch.device(cfg.MODEL.DEVICE)
1665
+ self.backbone = build_backbone(cfg)
1666
+ self.proposal_generator = RPN(cfg, self.backbone.output_shape())
1667
+ self.roi_heads = Res5ROIHeads(cfg, self.backbone.output_shape())
1668
+ self.roi_outputs = ROIOutputs(cfg)
1669
+ self.to(self.device)
1670
+
1671
+ @classmethod
1672
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1673
+ config = kwargs.pop("config", None)
1674
+ state_dict = kwargs.pop("state_dict", None)
1675
+ cache_dir = kwargs.pop("cache_dir", None)
1676
+ from_tf = kwargs.pop("from_tf", False)
1677
+ force_download = kwargs.pop("force_download", False)
1678
+ resume_download = kwargs.pop("resume_download", False)
1679
+ proxies = kwargs.pop("proxies", None)
1680
+ local_files_only = kwargs.pop("local_files_only", False)
1681
+ use_cdn = kwargs.pop("use_cdn", True)
1682
+
1683
+ # Load config if we don't provide a configuration
1684
+ if not isinstance(config, Config):
1685
+ config_path = config if config is not None else pretrained_model_name_or_path
1686
+ # try:
1687
+ config = Config.from_pretrained(
1688
+ config_path,
1689
+ cache_dir=cache_dir,
1690
+ force_download=force_download,
1691
+ resume_download=resume_download,
1692
+ proxies=proxies,
1693
+ local_files_only=local_files_only,
1694
+ )
1695
+
1696
+ # Load model
1697
+ if pretrained_model_name_or_path is not None:
1698
+ if os.path.isdir(pretrained_model_name_or_path):
1699
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
1700
+ # Load from a PyTorch checkpoint
1701
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
1702
+ else:
1703
+ raise EnvironmentError(
1704
+ "Error no file named {} found in directory {} ".format(
1705
+ WEIGHTS_NAME,
1706
+ pretrained_model_name_or_path,
1707
+ )
1708
+ )
1709
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
1710
+ archive_file = pretrained_model_name_or_path
1711
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
1712
+ assert (
1713
+ from_tf
1714
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
1715
+ pretrained_model_name_or_path + ".index"
1716
+ )
1717
+ archive_file = pretrained_model_name_or_path + ".index"
1718
+ else:
1719
+ archive_file = hf_bucket_url(
1720
+ pretrained_model_name_or_path,
1721
+ filename=WEIGHTS_NAME,
1722
+ use_cdn=use_cdn,
1723
+ )
1724
+
1725
+ try:
1726
+ # Load from URL or cache if already cached
1727
+ resolved_archive_file = cached_path(
1728
+ archive_file,
1729
+ cache_dir=cache_dir,
1730
+ force_download=force_download,
1731
+ proxies=proxies,
1732
+ resume_download=resume_download,
1733
+ local_files_only=local_files_only,
1734
+ )
1735
+ if resolved_archive_file is None:
1736
+ raise EnvironmentError
1737
+ except EnvironmentError:
1738
+ msg = f"Can't load weights for '{pretrained_model_name_or_path}'."
1739
+ raise EnvironmentError(msg)
1740
+
1741
+ if resolved_archive_file == archive_file:
1742
+ print("loading weights file {}".format(archive_file))
1743
+ else:
1744
+ print("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
1745
+ else:
1746
+ resolved_archive_file = None
1747
+
1748
+ # Instantiate model.
1749
+ model = cls(config)
1750
+
1751
+ if state_dict is None:
1752
+ try:
1753
+ try:
1754
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
1755
+ except Exception:
1756
+ state_dict = load_checkpoint(resolved_archive_file)
1757
+
1758
+ except Exception:
1759
+ raise OSError(
1760
+ "Unable to load weights from pytorch checkpoint file. "
1761
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
1762
+ )
1763
+
1764
+ missing_keys = []
1765
+ unexpected_keys = []
1766
+ error_msgs = []
1767
+
1768
+ # Convert old format to new format if needed from a PyTorch state_dict
1769
+ old_keys = []
1770
+ new_keys = []
1771
+ for key in state_dict.keys():
1772
+ new_key = None
1773
+ if "gamma" in key:
1774
+ new_key = key.replace("gamma", "weight")
1775
+ if "beta" in key:
1776
+ new_key = key.replace("beta", "bias")
1777
+ if new_key:
1778
+ old_keys.append(key)
1779
+ new_keys.append(new_key)
1780
+ for old_key, new_key in zip(old_keys, new_keys):
1781
+ state_dict[new_key] = state_dict.pop(old_key)
1782
+
1783
+ # copy state_dict so _load_from_state_dict can modify it
1784
+ metadata = getattr(state_dict, "_metadata", None)
1785
+ state_dict = state_dict.copy()
1786
+ if metadata is not None:
1787
+ state_dict._metadata = metadata
1788
+
1789
+ model_to_load = model
1790
+ model_to_load.load_state_dict(state_dict)
1791
+
1792
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
1793
+ base_model_state_dict = model_to_load.state_dict().keys()
1794
+ head_model_state_dict_without_base_prefix = [
1795
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
1796
+ ]
1797
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
1798
+
1799
+ if len(unexpected_keys) > 0:
1800
+ print(
1801
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
1802
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
1803
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
1804
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
1805
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
1806
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
1807
+ )
1808
+ else:
1809
+ print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1810
+ if len(missing_keys) > 0:
1811
+ print(
1812
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1813
+ f"and are newly initialized: {missing_keys}\n"
1814
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1815
+ )
1816
+ else:
1817
+ print(
1818
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
1819
+ f"If your task is similar to the task the model of the checkpoint was trained on, "
1820
+ f"you can already use {model.__class__.__name__} for predictions without further training."
1821
+ )
1822
+ if len(error_msgs) > 0:
1823
+ raise RuntimeError(
1824
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
1825
+ model.__class__.__name__, "\n\t".join(error_msgs)
1826
+ )
1827
+ )
1828
+ # Set model in evaluation mode to deactivate DropOut modules by default
1829
+ model.eval()
1830
+
1831
+ return model
1832
+
1833
+ def forward(
1834
+ self,
1835
+ images,
1836
+ image_shapes,
1837
+ gt_boxes=None,
1838
+ proposals=None,
1839
+ scales_yx=None,
1840
+ **kwargs,
1841
+ ):
1842
+ """
1843
+ kwargs:
1844
+ max_detections (int), return_tensors {"np", "pt", None}, padding {None,
1845
+ "max_detections"}, pad_value (int), location = {"cuda", "cpu"}
1846
+ """
1847
+ if self.training:
1848
+ raise NotImplementedError()
1849
+ return self.inference(
1850
+ images=images,
1851
+ image_shapes=image_shapes,
1852
+ gt_boxes=gt_boxes,
1853
+ proposals=proposals,
1854
+ scales_yx=scales_yx,
1855
+ **kwargs,
1856
+ )
1857
+
1858
+ @torch.no_grad()
1859
+ def inference(
1860
+ self,
1861
+ images,
1862
+ image_shapes,
1863
+ gt_boxes=None,
1864
+ proposals=None,
1865
+ scales_yx=None,
1866
+ **kwargs,
1867
+ ):
1868
+ # run images through backbone
1869
+ original_sizes = image_shapes * scales_yx
1870
+ features = self.backbone(images)
1871
+
1872
+ # generate proposals if none are available
1873
+ if proposals is None:
1874
+ proposal_boxes, _ = self.proposal_generator(images, image_shapes, features, gt_boxes)
1875
+ else:
1876
+ assert proposals is not None
1877
+
1878
+ # pool object features from either gt_boxes, or from proposals
1879
+ obj_logits, attr_logits, box_deltas, feature_pooled = self.roi_heads(features, proposal_boxes, gt_boxes)
1880
+
1881
+ # prepare FRCNN Outputs and select top proposals
1882
+ boxes, classes, class_probs, attrs, attr_probs, roi_features = self.roi_outputs(
1883
+ obj_logits=obj_logits,
1884
+ attr_logits=attr_logits,
1885
+ box_deltas=box_deltas,
1886
+ pred_boxes=proposal_boxes,
1887
+ features=feature_pooled,
1888
+ sizes=image_shapes,
1889
+ scales=scales_yx,
1890
+ )
1891
+
1892
+ # will we pad???
1893
+ subset_kwargs = {
1894
+ "max_detections": kwargs.get("max_detections", None),
1895
+ "return_tensors": kwargs.get("return_tensors", None),
1896
+ "pad_value": kwargs.get("pad_value", 0),
1897
+ "padding": kwargs.get("padding", None),
1898
+ }
1899
+ preds_per_image = torch.tensor([p.size(0) for p in boxes])
1900
+ boxes = pad_list_tensors(boxes, preds_per_image, **subset_kwargs)
1901
+ classes = pad_list_tensors(classes, preds_per_image, **subset_kwargs)
1902
+ class_probs = pad_list_tensors(class_probs, preds_per_image, **subset_kwargs)
1903
+ attrs = pad_list_tensors(attrs, preds_per_image, **subset_kwargs)
1904
+ attr_probs = pad_list_tensors(attr_probs, preds_per_image, **subset_kwargs)
1905
+ roi_features = pad_list_tensors(roi_features, preds_per_image, **subset_kwargs)
1906
+ subset_kwargs["padding"] = None
1907
+ preds_per_image = pad_list_tensors(preds_per_image, None, **subset_kwargs)
1908
+ sizes = pad_list_tensors(image_shapes, None, **subset_kwargs)
1909
+ normalized_boxes = norm_box(boxes, original_sizes)
1910
+ return OrderedDict(
1911
+ {
1912
+ "obj_ids": classes,
1913
+ "obj_probs": class_probs,
1914
+ "attr_ids": attrs,
1915
+ "attr_probs": attr_probs,
1916
+ "boxes": boxes,
1917
+ "sizes": sizes,
1918
+ "preds_per_image": preds_per_image,
1919
+ "roi_features": roi_features,
1920
+ "normalized_boxes": normalized_boxes,
1921
+ }
1922
+ )
lxmert/src/param.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyleft 2019 project LXRT.
3
+
4
+ import argparse
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def get_optimizer(optim):
12
+ # Bind the optimizer
13
+ if optim == 'rms':
14
+ print("Optimizer: Using RMSProp")
15
+ optimizer = torch.optim.RMSprop
16
+ elif optim == 'adam':
17
+ print("Optimizer: Using Adam")
18
+ optimizer = torch.optim.Adam
19
+ elif optim == 'adamax':
20
+ print("Optimizer: Using Adamax")
21
+ optimizer = torch.optim.Adamax
22
+ elif optim == 'sgd':
23
+ print("Optimizer: sgd")
24
+ optimizer = torch.optim.SGD
25
+ elif 'bert' in optim:
26
+ optimizer = 'bert' # The bert optimizer will be bind later.
27
+ else:
28
+ assert False, "Please add your optimizer %s in the list." % optim
29
+
30
+ return optimizer
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser()
35
+
36
+ # Data Splits
37
+ parser.add_argument("--train", default='train')
38
+ parser.add_argument("--valid", default='valid')
39
+ parser.add_argument("--test", default=None)
40
+
41
+ # Training Hyper-parameters
42
+ parser.add_argument('--batchSize', dest='batch_size', type=int, default=256)
43
+ parser.add_argument('--optim', default='bert')
44
+ parser.add_argument('--lr', type=float, default=1e-4)
45
+ parser.add_argument('--epochs', type=int, default=10)
46
+ parser.add_argument('--dropout', type=float, default=0.1)
47
+ parser.add_argument('--seed', type=int, default=9595, help='random seed')
48
+
49
+ # Debugging
50
+ parser.add_argument('--output', type=str, default='snap/test')
51
+ parser.add_argument("--fast", action='store_const', default=False, const=True)
52
+ parser.add_argument("--tiny", action='store_const', default=False, const=True)
53
+ parser.add_argument("--tqdm", action='store_const', default=False, const=True)
54
+
55
+ # Model Loading
56
+ parser.add_argument('--load', type=str, default=None,
57
+ help='Load the model (usually the fine-tuned model).')
58
+ parser.add_argument('--loadLXMERT', dest='load_lxmert', type=str, default=None,
59
+ help='Load the pre-trained lxmert model.')
60
+ parser.add_argument('--loadLXMERTQA', dest='load_lxmert_qa', type=str, default=None,
61
+ help='Load the pre-trained lxmert model with QA answer head.')
62
+ parser.add_argument("--fromScratch", dest='from_scratch', action='store_const', default=False, const=True,
63
+ help='If none of the --load, --loadLXMERT, --loadLXMERTQA is set, '
64
+ 'the model would be trained from scratch. If --fromScratch is'
65
+ ' not specified, the model would load BERT-pre-trained weights by'
66
+ ' default. ')
67
+
68
+ # Optimization
69
+ parser.add_argument("--mceLoss", dest='mce_loss', action='store_const', default=False, const=True)
70
+
71
+ # LXRT Model Config
72
+ # Note: LXRT = L, X, R (three encoders), Transformer
73
+ parser.add_argument("--llayers", default=9, type=int, help='Number of Language layers')
74
+ parser.add_argument("--xlayers", default=5, type=int, help='Number of CROSS-modality layers.')
75
+ parser.add_argument("--rlayers", default=5, type=int, help='Number of object Relationship layers.')
76
+
77
+ # lxmert Pre-training Config
78
+ parser.add_argument("--taskMatched", dest='task_matched', action='store_const', default=False, const=True)
79
+ parser.add_argument("--taskMaskLM", dest='task_mask_lm', action='store_const', default=False, const=True)
80
+ parser.add_argument("--taskObjPredict", dest='task_obj_predict', action='store_const', default=False, const=True)
81
+ parser.add_argument("--taskQA", dest='task_qa', action='store_const', default=False, const=True)
82
+ parser.add_argument("--visualLosses", dest='visual_losses', default='obj,attr,feat', type=str)
83
+ parser.add_argument("--qaSets", dest='qa_sets', default=None, type=str)
84
+ parser.add_argument("--wordMaskRate", dest='word_mask_rate', default=0.15, type=float)
85
+ parser.add_argument("--objMaskRate", dest='obj_mask_rate', default=0.15, type=float)
86
+
87
+ # Training configuration
88
+ parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
89
+ parser.add_argument("--numWorkers", dest='num_workers', default=0)
90
+
91
+
92
+ # perturbation configuration
93
+ parser.add_argument('--method', type=str,
94
+ default='ours_no_lrp',
95
+ choices=['ours_with_lrp', 'rollout', 'partial_lrp', 'transformer_att',
96
+ 'raw_attn', 'attn_gradcam', 'ours_with_lrp_no_normalization', 'ours_no_lrp',
97
+ 'ours_no_lrp_no_norm', 'ablation_no_aggregation', 'ablation_no_self_in_10'],
98
+ help='')
99
+ parser.add_argument('--num-samples', type=int,
100
+ default=10000,
101
+ help='')
102
+ parser.add_argument('--is-positive-pert', type=bool,
103
+ default=False,
104
+ help='')
105
+ parser.add_argument('--is-text-pert', type=bool,
106
+ default=False,
107
+ help='')
108
+ parser.add_argument('--COCO_path', type=str,
109
+ default='',
110
+ help='path to COCO 2014 validation set')
111
+
112
+ # Parse the arguments.
113
+ args = parser.parse_args()
114
+
115
+ # Bind optimizer class.
116
+ args.optimizer = get_optimizer(args.optim)
117
+
118
+ # Set seeds
119
+ torch.manual_seed(args.seed)
120
+ random.seed(args.seed)
121
+ np.random.seed(args.seed)
122
+
123
+ return args
124
+
125
+
126
+ args = parse_args()
lxmert/src/pretrain/__init__.py ADDED
File without changes
lxmert/src/pretrain/lxmert_data.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyleft 2019 project LXRT.
3
+
4
+ from collections import defaultdict
5
+ import json
6
+ import random
7
+
8
+ import numpy as np
9
+ from torch.utils.data import Dataset
10
+
11
+ from param import args
12
+ from pretrain.qa_answer_table import AnswerTable
13
+ from utils import load_obj_tsv
14
+
15
+ TINY_IMG_NUM = 500
16
+ FAST_IMG_NUM = 5000
17
+
18
+ Split2ImgFeatPath = {
19
+ 'mscoco_train': 'data/mscoco_imgfeat/train2014_obj36.tsv',
20
+ 'mscoco_minival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
21
+ 'mscoco_nominival': 'data/mscoco_imgfeat/val2014_obj36.tsv',
22
+ 'vgnococo': 'data/vg_gqa_imgfeat/vg_gqa_obj36.tsv',
23
+ }
24
+
25
+
26
+ class InputExample(object):
27
+ """A single training/test example for the language model."""
28
+ def __init__(self, uid, sent, visual_feats=None,
29
+ obj_labels=None, attr_labels=None,
30
+ is_matched=None, label=None):
31
+ self.uid = uid
32
+ self.sent = sent
33
+ self.visual_feats = visual_feats
34
+ self.obj_labels = obj_labels
35
+ self.attr_labels = attr_labels
36
+ self.is_matched = is_matched # whether the visual and obj matched
37
+ self.label = label
38
+
39
+
40
+ class LXMERTDataset:
41
+ def __init__(self, splits: str, qa_sets=None):
42
+ """
43
+ :param splits: The data sources to be loaded
44
+ :param qa_sets: if None, no action
45
+ o.w., only takes the answers appearing in these dsets
46
+ and remove all unlabeled data (MSCOCO captions)
47
+ """
48
+ self.name = splits
49
+ self.sources = splits.split(',')
50
+
51
+ # Loading datasets to data
52
+ self.data = []
53
+ for source in self.sources:
54
+ self.data.extend(json.load(open("data/lxmert/%s.json" % source)))
55
+ print("Load %d data from %s" % (len(self.data), self.name))
56
+
57
+ # Create answer table according to the qa_sets
58
+ self.answer_table = AnswerTable(qa_sets)
59
+ print("Load an answer table of size %d." % (len(self.answer_table.ans2id_map())))
60
+
61
+ # Modify the answers
62
+ for datum in self.data:
63
+ labelf = datum['labelf']
64
+ for cat, labels in labelf.items():
65
+ for label in labels:
66
+ for ans in list(label.keys()):
67
+ new_ans = self.answer_table.convert_ans(ans)
68
+ if self.answer_table.used(new_ans):
69
+ if ans != new_ans:
70
+ label[new_ans] = label.pop(ans)
71
+ else:
72
+ label.pop(ans)
73
+
74
+ def __len__(self):
75
+ return len(self.data)
76
+
77
+
78
+ def make_uid(img_id, dset, sent_idx):
79
+ return "%s_%s_%03d" % (img_id, dset, sent_idx),
80
+
81
+
82
+ """
83
+ Example in obj tsv:
84
+ FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
85
+ "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
86
+ """
87
+ class LXMERTTorchDataset(Dataset):
88
+ def __init__(self, dataset: LXMERTDataset, topk=-1):
89
+ super().__init__()
90
+ self.raw_dataset = dataset
91
+ self.task_matched = args.task_matched
92
+
93
+ if args.tiny:
94
+ topk = TINY_IMG_NUM
95
+ elif args.fast:
96
+ topk = FAST_IMG_NUM
97
+
98
+ # Load the dataset
99
+ img_data = []
100
+ for source in self.raw_dataset.sources:
101
+ img_data.extend(load_obj_tsv(Split2ImgFeatPath[source], topk))
102
+
103
+ self.imgid2img = {}
104
+ for img_datum in img_data:
105
+ self.imgid2img[img_datum['img_id']] = img_datum
106
+
107
+ # Filter out the dataset
108
+ used_data = []
109
+ for datum in self.raw_dataset.data:
110
+ if datum['img_id'] in self.imgid2img:
111
+ used_data.append(datum)
112
+
113
+ # Flatten the dataset (into one sent + one image entries)
114
+ self.data = []
115
+ for datum in used_data:
116
+ sentf = datum['sentf']
117
+ for sents_cat, sents in sentf.items():
118
+ if sents_cat in datum['labelf']:
119
+ labels = datum['labelf'][sents_cat]
120
+ else:
121
+ labels = None
122
+ for sent_idx, sent in enumerate(sents):
123
+ new_datum = {
124
+ 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
125
+ 'img_id': datum['img_id'],
126
+ 'sent': sent
127
+ }
128
+ if labels is not None:
129
+ new_datum['label'] = labels[sent_idx]
130
+ self.data.append(new_datum)
131
+ print("Use %d data in torch dataset" % (len(self.data)))
132
+
133
+ def __len__(self):
134
+ return len(self.data)
135
+
136
+ def random_feat(self):
137
+ """Get a random obj feat from the dataset."""
138
+ datum = self.data[random.randint(0, len(self.data)-1)]
139
+ img_id = datum['img_id']
140
+ img_info = self.imgid2img[img_id]
141
+ feat = img_info['features'][random.randint(0, 35)]
142
+ return feat
143
+
144
+ def __getitem__(self, item: int):
145
+ datum = self.data[item]
146
+
147
+ uid = datum['uid']
148
+ img_id = datum['img_id']
149
+
150
+ # Get image info
151
+ img_info = self.imgid2img[img_id]
152
+ obj_num = img_info['num_boxes']
153
+ feats = img_info['features'].copy()
154
+ boxes = img_info['boxes'].copy()
155
+ obj_labels = img_info['objects_id'].copy()
156
+ obj_confs = img_info['objects_conf'].copy()
157
+ attr_labels = img_info['attrs_id'].copy()
158
+ attr_confs = img_info['attrs_conf'].copy()
159
+ assert obj_num == len(boxes) == len(feats)
160
+
161
+ # Normalize the boxes (to 0 ~ 1)
162
+ img_h, img_w = img_info['img_h'], img_info['img_w']
163
+ boxes = boxes.copy()
164
+ boxes[:, (0, 2)] /= img_w
165
+ boxes[:, (1, 3)] /= img_h
166
+ np.testing.assert_array_less(boxes, 1+1e-5)
167
+ np.testing.assert_array_less(-boxes, 0+1e-5)
168
+
169
+ # If calculating the matched loss, replace the sentence with an sentence
170
+ # corresponding to other image.
171
+ is_matched = 1
172
+ sent = datum['sent']
173
+ if self.task_matched:
174
+ if random.random() < 0.5:
175
+ is_matched = 0
176
+ other_datum = self.data[random.randint(0, len(self.data)-1)]
177
+ while other_datum['img_id'] == img_id:
178
+ other_datum = self.data[random.randint(0, len(self.data)-1)]
179
+ sent = other_datum['sent']
180
+
181
+ # Label, convert answer to id
182
+ if 'label' in datum:
183
+ label = datum['label'].copy()
184
+ for ans in list(label.keys()):
185
+ label[self.raw_dataset.answer_table.ans2id(ans)] = label.pop(ans)
186
+ else:
187
+ label = None
188
+
189
+ # Create target
190
+ example = InputExample(
191
+ uid, sent, (feats, boxes),
192
+ (obj_labels, obj_confs), (attr_labels, attr_confs),
193
+ is_matched, label
194
+ )
195
+ return example
196
+
197
+
198
+ class LXMERTEvaluator:
199
+ def __init__(self, dataset: LXMERTDataset):
200
+ self.raw_dataset = dataset
201
+
202
+ # Create QA Eval Data
203
+ self.data = []
204
+ for datum in self.raw_dataset.data:
205
+ sentf = datum['sentf']
206
+ for sents_cat, sents in sentf.items():
207
+ if sents_cat in datum['labelf']: # A labeled dataset
208
+ labels = datum['labelf'][sents_cat]
209
+ for sent_idx, sent in enumerate(sents):
210
+ new_datum = {
211
+ 'uid': make_uid(datum['img_id'], sents_cat, sent_idx),
212
+ 'img_id': datum['img_id'],
213
+ 'sent': sent,
214
+ 'dset': sents_cat,
215
+ 'label': labels[sent_idx]
216
+ }
217
+ self.data.append(new_datum)
218
+
219
+ # uid2datum
220
+ self.uid2datum = {}
221
+ for datum in self.data:
222
+ self.uid2datum[datum['uid']] = datum
223
+
224
+ def evaluate(self, uid2ans: dict, pprint=False):
225
+ score = 0.
226
+ cnt = 0
227
+ dset2score = defaultdict(lambda: 0.)
228
+ dset2cnt = defaultdict(lambda: 0)
229
+ for uid, ans in uid2ans.items():
230
+ if uid not in self.uid2datum: # Not a labeled data
231
+ continue
232
+ datum = self.uid2datum[uid]
233
+ label = datum['label']
234
+ dset = datum['dset']
235
+ if ans in label:
236
+ score += label[ans]
237
+ dset2score[dset] += label[ans]
238
+ cnt += 1
239
+ dset2cnt[dset] += 1
240
+ accu = score / cnt
241
+ dset2accu = {}
242
+ for dset in dset2cnt:
243
+ dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
244
+
245
+ if pprint:
246
+ accu_str = "Overall Accu %0.4f, " % (accu)
247
+ sorted_keys = sorted(dset2accu.keys())
248
+ for key in sorted_keys:
249
+ accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
250
+ print(accu_str)
251
+
252
+ return accu, dset2accu
253
+
254
+ def dump_result(self, uid2ans: dict, path):
255
+ raise NotImplemented
lxmert/src/pretrain/lxmert_pretrain.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyleft 2019 project LXRT.
3
+
4
+ import collections
5
+ import os
6
+ import random
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.data import DataLoader
13
+
14
+ from param import args
15
+ from pretrain.lxmert_data import InputExample, LXMERTDataset, LXMERTTorchDataset, LXMERTEvaluator
16
+ from lxrt.entry import set_visual_config
17
+ from lxrt.tokenization import BertTokenizer
18
+ from lxrt.modeling import LXRTPretraining
19
+
20
+ DataTuple = collections.namedtuple("DataTuple", 'dataset torchdset loader evaluator')
21
+
22
+
23
+ def get_tuple(splits: str, bs: int, shuffle=False, drop_last=False, topk=-1) -> DataTuple:
24
+ # Decide which QA datasets would be used in pre-training.
25
+ # Options: vqa, gqa, visual7w
26
+ # Note: visual7w is a part of vgqa, we take the name here.
27
+ qa_sets = args.qa_sets
28
+ if qa_sets is not None:
29
+ qa_sets = set(qa_set.lower().strip() for qa_set in qa_sets.split(","))
30
+
31
+ # Build dataset, data loader, and evaluator.
32
+ dset = LXMERTDataset(splits, qa_sets=qa_sets)
33
+ tset = LXMERTTorchDataset(dset, topk)
34
+ data_loader = DataLoader(
35
+ tset, batch_size=bs,
36
+ shuffle=shuffle, num_workers=args.num_workers,
37
+ collate_fn=lambda x: x,
38
+ drop_last=drop_last, pin_memory=True
39
+ )
40
+ evaluator = LXMERTEvaluator(dset)
41
+ print()
42
+
43
+ return DataTuple(dataset=dset, torchdset=tset, loader=data_loader, evaluator=evaluator)
44
+
45
+
46
+ train_tuple = get_tuple(args.train, args.batch_size, shuffle=True, drop_last=True)
47
+ valid_batch_size = 2048 if args.multiGPU else 512
48
+ valid_tuple = get_tuple(args.valid, valid_batch_size, shuffle=False, drop_last=False, topk=5000)
49
+
50
+
51
+ class InputFeatures(object):
52
+ """A single set of features of data."""
53
+
54
+ def __init__(self,
55
+ input_ids, input_mask, segment_ids, lm_label_ids,
56
+ visual_feats, obj_labels,
57
+ is_matched, ans):
58
+ self.input_ids = input_ids
59
+ self.input_mask = input_mask
60
+ self.segment_ids = segment_ids
61
+ self.lm_label_ids = lm_label_ids
62
+
63
+ self.visual_feats = visual_feats
64
+ self.obj_labels = obj_labels
65
+
66
+ self.is_matched = is_matched
67
+
68
+ self.ans = ans
69
+
70
+
71
+ def random_word(tokens, tokenizer):
72
+ """
73
+ Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
74
+ :param tokens: list of str, tokenized sentence.
75
+ :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
76
+ :return: (list of str, list of int), masked tokens and related labels for LM prediction
77
+ """
78
+ output_label = []
79
+
80
+ for i, token in enumerate(tokens):
81
+ prob = random.random()
82
+ # mask token with probability
83
+ ratio = args.word_mask_rate
84
+ if prob < ratio:
85
+ prob /= ratio
86
+
87
+ # 80% randomly change token to mask token
88
+ if prob < 0.8:
89
+ tokens[i] = "[MASK]"
90
+
91
+ # 10% randomly change token to random token
92
+ elif prob < 0.9:
93
+ tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
94
+
95
+ # -> rest 10% randomly keep current token
96
+
97
+ # append current token to output (we will predict these later)
98
+ try:
99
+ output_label.append(tokenizer.vocab[token])
100
+ except KeyError:
101
+ # For unknown words (should not occur with BPE vocab)
102
+ output_label.append(tokenizer.vocab["[UNK]"])
103
+ else:
104
+ # no masking token (will be ignored by loss function later)
105
+ output_label.append(-1)
106
+
107
+ return tokens, output_label
108
+
109
+
110
+ def random_feat(feats):
111
+ mask_feats = feats.copy()
112
+ feat_mask = np.zeros(len(feats), dtype=np.float32)
113
+ for i in range(len(feats)):
114
+ prob = random.random()
115
+ # mask token with probability
116
+ if prob < args.obj_mask_rate:
117
+ prob /= args.obj_mask_rate
118
+
119
+ # 80% randomly change token to zero feat
120
+ if prob < 0.8:
121
+ mask_feats[i, :] = 0.
122
+
123
+ # 10% randomly change token to random feat
124
+ elif prob < 0.9:
125
+ mask_feats[i, :] = train_tuple.torchdset.random_feat()
126
+ # -> rest 10% randomly keep current feat
127
+
128
+ # Need to predict this feat
129
+ feat_mask[i] = 1.
130
+
131
+ return mask_feats, feat_mask
132
+
133
+
134
+ def convert_example_to_features(example: InputExample, max_seq_length, tokenizer)->InputFeatures:
135
+ """
136
+ Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
137
+ IDs, LM labels, input_mask, CLS and SEP tokens etc.
138
+ :param example: InputExample, containing sentence input as strings and is_next label
139
+ :param max_seq_length: int, maximum length of sequence.
140
+ :param tokenizer: Tokenizer
141
+ :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
142
+ """
143
+ tokens = tokenizer.tokenize(example.sent.strip())
144
+
145
+ # Account for [CLS] and [SEP] with "- 2"
146
+ if len(tokens) > max_seq_length - 2:
147
+ tokens = tokens[:(max_seq_length - 2)]
148
+
149
+ # Ge random words
150
+ masked_tokens, masked_label = random_word(tokens, tokenizer)
151
+
152
+ # concatenate lm labels and account for CLS, SEP, SEP
153
+ masked_tokens = ['[CLS]'] + masked_tokens + ['[SEP]']
154
+ input_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
155
+
156
+ # Mask & Segment Word
157
+ lm_label_ids = ([-1] + masked_label + [-1])
158
+ input_mask = [1] * len(input_ids)
159
+ segment_ids = [0] * len(input_ids)
160
+
161
+ # Zero-pad up to the sequence length.
162
+ while len(input_ids) < max_seq_length:
163
+ input_ids.append(0)
164
+ input_mask.append(0)
165
+ segment_ids.append(0)
166
+ lm_label_ids.append(-1)
167
+
168
+ assert len(input_ids) == max_seq_length
169
+ assert len(input_mask) == max_seq_length
170
+ assert len(segment_ids) == max_seq_length
171
+ assert len(lm_label_ids) == max_seq_length
172
+
173
+ feat, boxes = example.visual_feats
174
+ obj_labels, obj_confs = example.obj_labels
175
+ attr_labels, attr_confs = example.attr_labels
176
+
177
+ # Mask Image Features:
178
+ masked_feat, feat_mask = random_feat(feat)
179
+
180
+ # QA answer label
181
+ if example.label is None or len(example.label) == 0 or example.is_matched != 1:
182
+ # 1. No label 2. Label is pruned 3. unmatched visual + language pair
183
+ ans = -1
184
+ else:
185
+ keys, values = zip(*example.label.items())
186
+ if len(keys) == 1:
187
+ ans = keys[0]
188
+ else:
189
+ value_sum = sum(values)
190
+ prob = [value / value_sum for value in values]
191
+ choice = np.random.multinomial(1, prob).argmax()
192
+ ans = keys[choice]
193
+
194
+ features = InputFeatures(
195
+ input_ids=input_ids,
196
+ input_mask=input_mask,
197
+ segment_ids=segment_ids,
198
+ lm_label_ids=lm_label_ids,
199
+ visual_feats=(masked_feat, boxes),
200
+ obj_labels={
201
+ 'obj': (obj_labels, obj_confs),
202
+ 'attr': (attr_labels, attr_confs),
203
+ 'feat': (feat, feat_mask),
204
+ },
205
+ is_matched=example.is_matched,
206
+ ans=ans,
207
+ )
208
+ return features
209
+
210
+
211
+ LOSSES_NAME = ('Mask_LM', 'Matched', 'Obj', 'Attr', 'Feat', 'QA')
212
+
213
+
214
+ class LXMERT:
215
+ def __init__(self, max_seq_length):
216
+ super().__init__()
217
+ self.max_seq_length = max_seq_length
218
+
219
+ self.tokenizer = BertTokenizer.from_pretrained(
220
+ "bert-base-uncased",
221
+ do_lower_case=True
222
+ )
223
+
224
+ # Build model
225
+ set_visual_config(args)
226
+ self.model = LXRTPretraining.from_pretrained(
227
+ "bert-base-uncased",
228
+ task_mask_lm=args.task_mask_lm,
229
+ task_obj_predict=args.task_obj_predict,
230
+ task_matched=args.task_matched,
231
+ task_qa=args.task_qa,
232
+ visual_losses=args.visual_losses,
233
+ num_answers=train_tuple.dataset.answer_table.num_answers
234
+ )
235
+
236
+ # Weight initialization and loading
237
+ if args.from_scratch:
238
+ print("Train from Scratch: re-initialize all BERT weights.")
239
+ self.model.apply(self.model.init_bert_weights)
240
+ if args.load is not None:
241
+ self.load(args.load)
242
+ if args.load_lxmert is not None:
243
+ # Load lxmert would not load the answer head.
244
+ self.load_lxmert(args.load_lxmert)
245
+
246
+ # GPU Options
247
+ self.model = self.model.cuda()
248
+ if args.multiGPU:
249
+ self.model = nn.DataParallel(self.model)
250
+
251
+ def forward(self, examples):
252
+ train_features = [convert_example_to_features(example, self.max_seq_length, self.tokenizer)
253
+ for example in examples]
254
+
255
+ # language Inputs
256
+ input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
257
+ input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
258
+ segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
259
+
260
+ # Visual Inputs
261
+ feats = torch.from_numpy(np.stack([f.visual_feats[0] for f in train_features])).cuda()
262
+ pos = torch.from_numpy(np.stack([f.visual_feats[1] for f in train_features])).cuda()
263
+
264
+ # Language Prediction
265
+ lm_labels = torch.tensor([f.lm_label_ids for f in train_features], dtype=torch.long).cuda()
266
+
267
+ # Visual Prediction
268
+ obj_labels = {}
269
+ for key in ('obj', 'attr', 'feat'):
270
+ visn_labels = torch.from_numpy(np.stack([f.obj_labels[key][0] for f in train_features])).cuda()
271
+ visn_mask = torch.from_numpy(np.stack([f.obj_labels[key][1] for f in train_features])).cuda()
272
+ assert visn_labels.size(0) == visn_mask.size(0) and visn_labels.size(1) == visn_mask.size(1)
273
+ obj_labels[key] = (visn_labels, visn_mask)
274
+
275
+ # Joint Prediction
276
+ matched_labels = torch.tensor([f.is_matched for f in train_features], dtype=torch.long).cuda()
277
+ ans = torch.from_numpy(np.stack([f.ans for f in train_features])).cuda()
278
+
279
+ """
280
+ forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
281
+ visual_feats=None, pos=None, obj_labels=None, matched_label=None, ans=None):
282
+ """
283
+ loss, losses, ans_logit = self.model(
284
+ input_ids, segment_ids, input_mask, lm_labels,
285
+ feats, pos, obj_labels, matched_labels, ans
286
+ )
287
+ return loss, losses.detach().cpu(), ans_logit
288
+
289
+ def train_batch(self, optim, batch):
290
+ optim.zero_grad()
291
+ loss, losses, ans_logit = self.forward(batch)
292
+ if args.multiGPU:
293
+ loss = loss.mean()
294
+ losses = losses.mean(0)
295
+ loss.backward()
296
+ nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
297
+ optim.step()
298
+
299
+ return loss.item(), losses.cpu().numpy(), ans_logit
300
+
301
+ def valid_batch(self, batch):
302
+ with torch.no_grad():
303
+ loss, losses, ans_logit = self.forward(batch)
304
+ if args.multiGPU:
305
+ loss = loss.mean()
306
+ losses = losses.mean(0)
307
+ return loss.item(), losses.cpu().numpy(), ans_logit
308
+
309
+ def train(self, train_tuple: DataTuple, eval_tuple: DataTuple):
310
+ train_ld = train_tuple.loader
311
+
312
+ # Optimizer
313
+ from lxrt.optimization import BertAdam
314
+ batch_per_epoch = len(train_ld)
315
+ t_total = int(batch_per_epoch * args.epochs)
316
+ warmup_ratio = 0.05
317
+ warmup_iters = int(t_total * warmup_ratio)
318
+ print("Batch per epoch: %d" % batch_per_epoch)
319
+ print("Total Iters: %d" % t_total)
320
+ print("Warm up Iters: %d" % warmup_iters)
321
+ optim = BertAdam(self.model.parameters(), lr=args.lr, warmup=warmup_ratio, t_total=t_total)
322
+
323
+ # Train
324
+ best_eval_loss = 9595.
325
+ for epoch in range(args.epochs):
326
+ # Train
327
+ self.model.train()
328
+ total_loss = 0.
329
+ total_losses = 0.
330
+ uid2ans = {}
331
+ for batch in tqdm(train_ld, total=len(train_ld)):
332
+ loss, losses, logit = self.train_batch(optim, batch)
333
+ total_loss += loss
334
+ total_losses += losses
335
+
336
+ if args.task_qa:
337
+ score, label = logit.max(1)
338
+ for datum, l in zip(batch, label.cpu().numpy()):
339
+ uid = datum.uid
340
+ ans = train_tuple.dataset.answer_table.id2ans(l)
341
+ uid2ans[uid] = ans
342
+
343
+ print("The training loss for Epoch %d is %0.4f" % (epoch, total_loss / batch_per_epoch))
344
+ losses_str = "The losses are "
345
+ for name, loss in zip(LOSSES_NAME, total_losses):
346
+ losses_str += "%s: %0.4f " % (name, loss / batch_per_epoch)
347
+ print(losses_str)
348
+ if args.task_qa:
349
+ train_tuple.evaluator.evaluate(uid2ans, pprint=True)
350
+
351
+ # Eval
352
+ avg_eval_loss = self.evaluate_epoch(eval_tuple, iters=-1)
353
+
354
+ # Save
355
+ if avg_eval_loss < best_eval_loss:
356
+ best_eval_loss = avg_eval_loss
357
+ self.save("BEST_EVAL_LOSS")
358
+ self.save("Epoch%02d" % (epoch+1))
359
+
360
+ def evaluate_epoch(self, eval_tuple: DataTuple, iters: int=-1):
361
+ self.model.eval()
362
+ eval_ld = eval_tuple.loader
363
+ total_loss = 0.
364
+ total_losses = 0.
365
+ uid2ans = {}
366
+ for i, batch in enumerate(eval_ld):
367
+ loss, losses, logit = self.valid_batch(batch)
368
+ total_loss += loss
369
+ total_losses += losses
370
+ if args.task_qa:
371
+ score, label = logit.max(1)
372
+ for datum, l in zip(batch, label.cpu().numpy()):
373
+ uid = datum.uid
374
+ ans = train_tuple.dataset.answer_table.id2ans(l)
375
+ uid2ans[uid] = ans
376
+ if i == iters:
377
+ break
378
+
379
+ print("The valid loss is %0.4f" % (total_loss / len(eval_ld)))
380
+ losses_str = "The losses are "
381
+ for name, loss in zip(LOSSES_NAME, total_losses / len(eval_ld)):
382
+ losses_str += "%s: %0.4f " % (name, loss)
383
+ print(losses_str)
384
+
385
+ if args.task_qa:
386
+ eval_tuple.evaluator.evaluate(uid2ans, pprint=True)
387
+
388
+ return total_loss / len(eval_ld)
389
+
390
+ def save(self, name):
391
+ torch.save(self.model.state_dict(),
392
+ os.path.join(args.output, "%s_LXRT.pth" % name))
393
+
394
+ def load(self, path):
395
+ print("Load BERT extractor from %s" % path)
396
+ state_dict = torch.load("%s_LXRT.pth" % path)
397
+ self.model.load_state_dict(state_dict)
398
+
399
+ def load_lxmert(self, path):
400
+ print("Load lxmert model from %s" % path)
401
+ state_dict = torch.load("%s_LXRT.pth" % path)
402
+
403
+ # Do not load any answer head
404
+ for key in list(state_dict.keys()):
405
+ if 'answer' in key:
406
+ state_dict.pop(key)
407
+
408
+ # Change Multi GPU to single GPU
409
+ new_state_dict = {}
410
+ for key, value in state_dict.items():
411
+ if key.startswith("module."):
412
+ new_state_dict[key[len("module."):]] = value
413
+ state_dict = new_state_dict
414
+
415
+ load_keys = set(state_dict.keys())
416
+ model_keys = set(self.model.state_dict().keys())
417
+ print()
418
+ print("Keys in loaded but not in model:")
419
+ for key in sorted(load_keys.difference(model_keys)):
420
+ print(key)
421
+ print()
422
+ print("Keys in model but not in loaded:")
423
+ for key in sorted(model_keys.difference(load_keys)):
424
+ print(key)
425
+ print()
426
+
427
+ self.model.load_state_dict(state_dict, strict=False)
428
+
429
+
430
+ if __name__ == "__main__":
431
+
432
+ lxmert = LXMERT(max_seq_length=20)
433
+
434
+
435
+ lxmert.train(train_tuple, valid_tuple)
lxmert/src/pretrain/qa_answer_table.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyleft 2019 project LXRT.
3
+
4
+ import json
5
+ import torch
6
+
7
+
8
+ class AnswerTable:
9
+ ANS_CONVERT = {
10
+ "a man": "man",
11
+ "the man": "man",
12
+ "a woman": "woman",
13
+ "the woman": "woman",
14
+ 'one': '1',
15
+ 'two': '2',
16
+ 'three': '3',
17
+ 'four': '4',
18
+ 'five': '5',
19
+ 'six': '6',
20
+ 'seven': '7',
21
+ 'eight': '8',
22
+ 'nine': '9',
23
+ 'ten': '10',
24
+ 'grey': 'gray',
25
+ }
26
+
27
+ def __init__(self, dsets=None):
28
+ self.all_ans = json.load(open("data/lxmert/all_ans.json"))
29
+ if dsets is not None:
30
+ dsets = set(dsets)
31
+ # If the answer is used in the dsets
32
+ self.anss = [ans['ans'] for ans in self.all_ans if
33
+ len(set(ans['dsets']) & dsets) > 0]
34
+ else:
35
+ self.anss = [ans['ans'] for ans in self.all_ans]
36
+ self.ans_set = set(self.anss)
37
+
38
+ self._id2ans_map = self.anss
39
+ self._ans2id_map = {ans: ans_id for ans_id, ans in enumerate(self.anss)}
40
+
41
+ assert len(self._id2ans_map) == len(self._ans2id_map)
42
+ for ans_id, ans in enumerate(self._id2ans_map):
43
+ assert self._ans2id_map[ans] == ans_id
44
+
45
+ def convert_ans(self, ans):
46
+ if len(ans) == 0:
47
+ return ""
48
+ ans = ans.lower()
49
+ if ans[-1] == '.':
50
+ ans = ans[:-1].strip()
51
+ if ans.startswith("a "):
52
+ ans = ans[2:].strip()
53
+ if ans.startswith("an "):
54
+ ans = ans[3:].strip()
55
+ if ans.startswith("the "):
56
+ ans = ans[4:].strip()
57
+ if ans in self.ANS_CONVERT:
58
+ ans = self.ANS_CONVERT[ans]
59
+ return ans
60
+
61
+ def ans2id(self, ans):
62
+ return self._ans2id_map[ans]
63
+
64
+ def id2ans(self, ans_id):
65
+ return self._id2ans_map[ans_id]
66
+
67
+ def ans2id_map(self):
68
+ return self._ans2id_map.copy()
69
+
70
+ def id2ans_map(self):
71
+ return self._id2ans_map.copy()
72
+
73
+ def used(self, ans):
74
+ return ans in self.ans_set
75
+
76
+ def all_answers(self):
77
+ return self.anss.copy()
78
+
79
+ @property
80
+ def num_answers(self):
81
+ return len(self.anss)
82
+
83
+
84
+ def load_lxmert_qa(path, model, label2ans):
85
+ """
86
+ Load model weights from lxmert pre-training.
87
+ The answers in the fine-tuned QA task (indicated by label2ans)
88
+ would also be properly initialized with lxmert pre-trained
89
+ QA heads.
90
+
91
+ :param path: Path to lxmert snapshot.
92
+ :param model: LXRT model instance.
93
+ :param label2ans: The label2ans dict of fine-tuned QA datasets, like
94
+ {0: 'cat', 1: 'dog', ...}
95
+ :return:
96
+ """
97
+ print("Load QA pre-trained lxmert from %s " % path)
98
+ loaded_state_dict = torch.load("%s_LXRT.pth" % path)
99
+ model_state_dict = model.state_dict()
100
+
101
+ # Handle Multi-GPU pre-training --> Single GPU fine-tuning
102
+ for key in list(loaded_state_dict.keys()):
103
+ loaded_state_dict[key.replace("module.", '')] = loaded_state_dict.pop(key)
104
+
105
+ # Isolate bert model
106
+ bert_state_dict = {}
107
+ for key, value in loaded_state_dict.items():
108
+ if key.startswith('bert.'):
109
+ bert_state_dict[key] = value
110
+
111
+ # Isolate answer head
112
+ answer_state_dict = {}
113
+ for key, value in loaded_state_dict.items():
114
+ if key.startswith("answer_head."):
115
+ answer_state_dict[key.replace('answer_head.', '')] = value
116
+
117
+ # Do surgery on answer state dict
118
+ ans_weight = answer_state_dict['logit_fc.3.weight']
119
+ ans_bias = answer_state_dict['logit_fc.3.bias']
120
+ import copy
121
+ new_answer_weight = copy.deepcopy(model_state_dict['logit_fc.3.weight'])
122
+ new_answer_bias = copy.deepcopy(model_state_dict['logit_fc.3.bias'])
123
+ answer_table = AnswerTable()
124
+ loaded = 0
125
+ unload = 0
126
+ if type(label2ans) is list:
127
+ label2ans = {label: ans for label, ans in enumerate(label2ans)}
128
+ for label, ans in label2ans.items():
129
+ new_ans = answer_table.convert_ans(ans)
130
+ if answer_table.used(new_ans):
131
+ ans_id_9500 = answer_table.ans2id(new_ans)
132
+ new_answer_weight[label] = ans_weight[ans_id_9500]
133
+ new_answer_bias[label] = ans_bias[ans_id_9500]
134
+ loaded += 1
135
+ else:
136
+ new_answer_weight[label] = 0.
137
+ new_answer_bias[label] = 0.
138
+ unload += 1
139
+ print("Loaded %d answers from LXRTQA pre-training and %d not" % (loaded, unload))
140
+ print()
141
+ answer_state_dict['logit_fc.3.weight'] = new_answer_weight
142
+ answer_state_dict['logit_fc.3.bias'] = new_answer_bias
143
+
144
+ # Load Bert Weights
145
+ bert_model_keys = set(model.lxrt_encoder.model.state_dict().keys())
146
+ bert_loaded_keys = set(bert_state_dict.keys())
147
+ assert len(bert_model_keys - bert_loaded_keys) == 0
148
+ model.lxrt_encoder.model.load_state_dict(bert_state_dict, strict=False)
149
+
150
+ # Load Answer Logic FC Weights
151
+ model_keys = set(model.state_dict().keys())
152
+ ans_loaded_keys = set(answer_state_dict.keys())
153
+ assert len(ans_loaded_keys - model_keys) == 0
154
+
155
+ model.load_state_dict(answer_state_dict, strict=False)
156
+
157
+
158
+
lxmert/src/processing_image.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ coding=utf-8
3
+ Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal
4
+ Adapted From Facebook Inc, Detectron2
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.import copy
17
+ """
18
+ import sys
19
+ from typing import Tuple
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from PIL import Image
25
+
26
+ from lxmert.lxmert.src.vqa_utils import img_tensorize
27
+
28
+
29
+ class ResizeShortestEdge:
30
+ def __init__(self, short_edge_length, max_size=sys.maxsize):
31
+ """
32
+ Args:
33
+ short_edge_length (list[min, max])
34
+ max_size (int): maximum allowed longest edge length.
35
+ """
36
+ self.interp_method = "bilinear"
37
+ self.max_size = max_size
38
+ self.short_edge_length = short_edge_length
39
+
40
+ def __call__(self, imgs):
41
+ img_augs = []
42
+ for img in imgs:
43
+ h, w = img.shape[:2]
44
+ # later: provide list and randomly choose index for resize
45
+ size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
46
+ if size == 0:
47
+ return img
48
+ scale = size * 1.0 / min(h, w)
49
+ if h < w:
50
+ newh, neww = size, scale * w
51
+ else:
52
+ newh, neww = scale * h, size
53
+ if max(newh, neww) > self.max_size:
54
+ scale = self.max_size * 1.0 / max(newh, neww)
55
+ newh = newh * scale
56
+ neww = neww * scale
57
+ neww = int(neww + 0.5)
58
+ newh = int(newh + 0.5)
59
+
60
+ if img.dtype == np.uint8:
61
+ pil_image = Image.fromarray(img)
62
+ pil_image = pil_image.resize((neww, newh), Image.BILINEAR)
63
+ img = np.asarray(pil_image)
64
+ else:
65
+ img = img.permute(2, 0, 1).unsqueeze(0) # 3, 0, 1) # hw(c) -> nchw
66
+ img = F.interpolate(img, (newh, neww), mode=self.interp_method, align_corners=False).squeeze(0)
67
+ img_augs.append(img)
68
+
69
+ return img_augs
70
+
71
+
72
+ class Preprocess:
73
+ def __init__(self, cfg):
74
+ self.aug = ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
75
+ self.input_format = cfg.INPUT.FORMAT
76
+ self.size_divisibility = cfg.SIZE_DIVISIBILITY
77
+ self.pad_value = cfg.PAD_VALUE
78
+ self.max_image_size = cfg.INPUT.MAX_SIZE_TEST
79
+ self.device = cfg.MODEL.DEVICE
80
+ self.pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
81
+ self.pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
82
+ self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
83
+
84
+ def pad(self, images):
85
+ max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
86
+ image_sizes = [im.shape[-2:] for im in images]
87
+ images = [
88
+ F.pad(
89
+ im,
90
+ [0, max_size[-1] - size[1], 0, max_size[-2] - size[0]],
91
+ value=self.pad_value,
92
+ )
93
+ for size, im in zip(image_sizes, images)
94
+ ]
95
+
96
+ return torch.stack(images), torch.tensor(image_sizes)
97
+
98
+ def __call__(self, images, single_image=False):
99
+ with torch.no_grad():
100
+ if not isinstance(images, list):
101
+ images = [images]
102
+ if single_image:
103
+ assert len(images) == 1
104
+ for i in range(len(images)):
105
+ if isinstance(images[i], torch.Tensor):
106
+ images.insert(i, images.pop(i).to(self.device).float())
107
+ elif not isinstance(images[i], torch.Tensor):
108
+ images.insert(
109
+ i,
110
+ torch.as_tensor(img_tensorize(images.pop(i), input_format=self.input_format))
111
+ .to(self.device)
112
+ .float(),
113
+ )
114
+ # resize smallest edge
115
+ raw_sizes = torch.tensor([im.shape[:2] for im in images])
116
+ images = self.aug(images)
117
+ # transpose images and convert to torch tensors
118
+ # images = [torch.as_tensor(i.astype("float32")).permute(2, 0, 1).to(self.device) for i in images]
119
+ # now normalize before pad to avoid useless arithmetic
120
+ images = [self.normalizer(x) for x in images]
121
+ # now pad them to do the following operations
122
+ images, sizes = self.pad(images)
123
+ # Normalize
124
+
125
+ if self.size_divisibility > 0:
126
+ raise NotImplementedError()
127
+ # pad
128
+ scales_yx = torch.true_divide(raw_sizes, sizes)
129
+ if single_image:
130
+ return images[0], sizes[0], scales_yx[0]
131
+ else:
132
+ return images, sizes, scales_yx
133
+
134
+
135
+ def _scale_box(boxes, scale_yx):
136
+ boxes[:, 0::2] *= scale_yx[:, 1]
137
+ boxes[:, 1::2] *= scale_yx[:, 0]
138
+ return boxes
139
+
140
+
141
+ def _clip_box(tensor, box_size: Tuple[int, int]):
142
+ assert torch.isfinite(tensor).all(), "Box tensor contains infinite or NaN!"
143
+ h, w = box_size
144
+ tensor[:, 0].clamp_(min=0, max=w)
145
+ tensor[:, 1].clamp_(min=0, max=h)
146
+ tensor[:, 2].clamp_(min=0, max=w)
147
+ tensor[:, 3].clamp_(min=0, max=h)
lxmert/src/tasks/__init__.py ADDED
File without changes
lxmert/src/tasks/gqa.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyleft 2019 project LXRT.
3
+
4
+ import os
5
+ import collections
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ import torch.nn as nn
10
+ from torch.utils.data.dataloader import DataLoader
11
+
12
+ from param import args
13
+ from pretrain.qa_answer_table import load_lxmert_qa
14
+ from tasks.gqa_model import GQAModel
15
+ from tasks.gqa_data import GQADataset, GQATorchDataset, GQAEvaluator
16
+
17
+
18
+ DataTuple = collections.namedtuple("DataTuple", 'dataset loader evaluator')
19
+
20
+
21
+ def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
22
+ dset = GQADataset(splits)
23
+ tset = GQATorchDataset(dset)
24
+ evaluator = GQAEvaluator(dset)
25
+ data_loader = DataLoader(
26
+ tset, batch_size=bs,
27
+ shuffle=shuffle, num_workers=args.num_workers,
28
+ drop_last=drop_last, pin_memory=True
29
+ )
30
+
31
+ return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
32
+
33
+
34
+ class GQA:
35
+ def __init__(self):
36
+ self.train_tuple = get_tuple(
37
+ args.train, bs=args.batch_size, shuffle=True, drop_last=True
38
+ )
39
+ if args.valid != "":
40
+ valid_bsize = 2048 if args.multiGPU else 512
41
+ self.valid_tuple = get_tuple(
42
+ args.valid, bs=valid_bsize,
43
+ shuffle=False, drop_last=False
44
+ )
45
+ else:
46
+ self.valid_tuple = None
47
+
48
+ self.model = GQAModel(self.train_tuple.dataset.num_answers)
49
+
50
+ # Load pre-trained weights
51
+ if args.load_lxmert is not None:
52
+ self.model.lxrt_encoder.load(args.load_lxmert)
53
+ if args.load_lxmert_qa is not None:
54
+ load_lxmert_qa(args.load_lxmert_qa, self.model,
55
+ label2ans=self.train_tuple.dataset.label2ans)
56
+
57
+ # GPU options
58
+ self.model = self.model.cuda()
59
+ if args.multiGPU:
60
+ self.model.lxrt_encoder.multi_gpu()
61
+
62
+ # Losses and optimizer
63
+ self.bce_loss = nn.BCEWithLogitsLoss()
64
+ self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
65
+ if 'bert' in args.optim:
66
+ batch_per_epoch = len(self.train_tuple.loader)
67
+ t_total = int(batch_per_epoch * args.epochs)
68
+ print("Total Iters: %d" % t_total)
69
+ from lxrt.optimization import BertAdam
70
+ self.optim = BertAdam(list(self.model.parameters()),
71
+ lr=args.lr,
72
+ warmup=0.1,
73
+ t_total=t_total)
74
+ else:
75
+ self.optim = args.optimizer(list(self.model.parameters()), args.lr)
76
+
77
+ self.output = args.output
78
+ os.makedirs(self.output, exist_ok=True)
79
+
80
+ def train(self, train_tuple, eval_tuple):
81
+ dset, loader, evaluator = train_tuple
82
+ iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
83
+
84
+ best_valid = 0.
85
+ for epoch in range(args.epochs):
86
+ quesid2ans = {}
87
+ for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):
88
+
89
+ self.model.train()
90
+ self.optim.zero_grad()
91
+
92
+ feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
93
+ logit = self.model(feats, boxes, sent)
94
+ assert logit.dim() == target.dim() == 2
95
+ if args.mce_loss:
96
+ max_value, target = target.max(1)
97
+ loss = self.mce_loss(logit, target) * logit.size(1)
98
+ else:
99
+ loss = self.bce_loss(logit, target)
100
+ loss = loss * logit.size(1)
101
+
102
+ loss.backward()
103
+ nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
104
+ self.optim.step()
105
+
106
+ score, label = logit.max(1)
107
+ for qid, l in zip(ques_id, label.cpu().numpy()):
108
+ ans = dset.label2ans[l]
109
+ quesid2ans[qid] = ans
110
+
111
+ log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)
112
+
113
+ if self.valid_tuple is not None: # Do Validation
114
+ valid_score = self.evaluate(eval_tuple)
115
+ if valid_score > best_valid:
116
+ best_valid = valid_score
117
+ self.save("BEST")
118
+
119
+ log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
120
+ "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
121
+
122
+ print(log_str, end='')
123
+
124
+ with open(self.output + "/log.log", 'a') as f:
125
+ f.write(log_str)
126
+ f.flush()
127
+
128
+ self.save("LAST")
129
+
130
+ def predict(self, eval_tuple: DataTuple, dump=None):
131
+ self.model.eval()
132
+ dset, loader, evaluator = eval_tuple
133
+ quesid2ans = {}
134
+ for i, datum_tuple in enumerate(loader):
135
+ ques_id, feats, boxes, sent = datum_tuple[:4] # avoid handling target
136
+ with torch.no_grad():
137
+ feats, boxes = feats.cuda(), boxes.cuda()
138
+ logit = self.model(feats, boxes, sent)
139
+ score, label = logit.max(1)
140
+ for qid, l in zip(ques_id, label.cpu().numpy()):
141
+ ans = dset.label2ans[l]
142
+ quesid2ans[qid] = ans
143
+ if dump is not None:
144
+ evaluator.dump_result(quesid2ans, dump)
145
+ return quesid2ans
146
+
147
+ def evaluate(self, eval_tuple: DataTuple, dump=None):
148
+ dset, loader, evaluator = eval_tuple
149
+ quesid2ans = self.predict(eval_tuple, dump)
150
+ return evaluator.evaluate(quesid2ans)
151
+
152
+ @staticmethod
153
+ def oracle_score(data_tuple):
154
+ dset, loader, evaluator = data_tuple
155
+ quesid2ans = {}
156
+ for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
157
+ _, label = target.max(1)
158
+ for qid, l in zip(ques_id, label.cpu().numpy()):
159
+ ans = dset.label2ans[l]
160
+ quesid2ans[qid] = ans
161
+ return evaluator.evaluate(quesid2ans)
162
+
163
+ def save(self, name):
164
+ torch.save(self.model.state_dict(),
165
+ os.path.join(self.output, "%s.pth" % name))
166
+
167
+ def load(self, path):
168
+ print("Load model from %s" % path)
169
+ state_dict = torch.load("%s.pth" % path)
170
+ for key in list(state_dict.keys()):
171
+ if '.module' in key:
172
+ state_dict[key.replace('.module', '')] = state_dict.pop(key)
173
+ self.model.load_state_dict(state_dict, strict=False)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ # Build Class
178
+ gqa = GQA()
179
+
180
+ # Load Model
181
+ if args.load is not None:
182
+ gqa.load(args.load)
183
+
184
+ # Test or Train
185
+ if args.test is not None:
186
+ args.fast = args.tiny = False # Always loading all data in test
187
+ if 'submit' in args.test:
188
+ gqa.predict(
189
+ get_tuple(args.test, bs=args.batch_size,
190
+ shuffle=False, drop_last=False),
191
+ dump=os.path.join(args.output, 'submit_predict.json')
192
+ )
193
+ if 'testdev' in args.test:
194
+ result = gqa.evaluate(
195
+ get_tuple('testdev', bs=args.batch_size,
196
+ shuffle=False, drop_last=False),
197
+ dump=os.path.join(args.output, 'testdev_predict.json')
198
+ )
199
+ print(result)
200
+ else:
201
+ # print("Train Oracle: %0.2f" % (gqa.oracle_score(gqa.train_tuple) * 100))
202
+ print('Splits in Train data:', gqa.train_tuple.dataset.splits)
203
+ if gqa.valid_tuple is not None:
204
+ print('Splits in Valid data:', gqa.valid_tuple.dataset.splits)
205
+ print("Valid Oracle: %0.2f" % (gqa.oracle_score(gqa.valid_tuple) * 100))
206
+ else:
207
+ print("DO NOT USE VALIDATION")
208
+ gqa.train(gqa.train_tuple, gqa.valid_tuple)
209
+
210
+