JustinLin610 commited on
Commit
ce922b3
1 Parent(s): 199a06c

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +21 -7
  3. app.py +156 -0
  4. checkpoints.md +10 -0
  5. colab.md +7 -0
  6. criterions/__init__.py +2 -0
  7. criterions/label_smoothed_cross_entropy.py +343 -0
  8. criterions/scst_loss.py +280 -0
  9. data/__init__.py +0 -0
  10. data/data_utils.py +601 -0
  11. data/file_dataset.py +102 -0
  12. data/mm_data/__init__.py +0 -0
  13. data/mm_data/caption_dataset.py +154 -0
  14. data/mm_data/refcoco_dataset.py +168 -0
  15. data/ofa_dataset.py +74 -0
  16. datasets.md +7 -0
  17. evaluate.py +152 -0
  18. models/__init__.py +1 -0
  19. models/ofa/__init__.py +1 -0
  20. models/ofa/ofa.py +410 -0
  21. models/ofa/resnet.py +225 -0
  22. models/ofa/unify_multihead_attention.py +518 -0
  23. models/ofa/unify_transformer.py +1510 -0
  24. models/ofa/unify_transformer_layer.py +542 -0
  25. models/search.py +814 -0
  26. models/sequence_generator.py +1053 -0
  27. ofa_module/__init__.py +5 -0
  28. pokemons.jpg +0 -0
  29. requirements.txt +5 -0
  30. run_scripts/caption/coco_eval.py +42 -0
  31. run_scripts/caption/evaluate_caption.sh +29 -0
  32. run_scripts/caption/train_caption_stage1.sh +104 -0
  33. run_scripts/caption/train_caption_stage2.sh +101 -0
  34. run_scripts/refcoco/evaluate_refcoco.sh +197 -0
  35. spaces.md +4 -0
  36. tasks/__init__.py +2 -0
  37. tasks/mm_tasks/__init__.py +2 -0
  38. tasks/mm_tasks/caption.py +249 -0
  39. tasks/mm_tasks/refcoco.py +161 -0
  40. tasks/ofa_task.py +338 -0
  41. train.py +523 -0
  42. trainer.py +1531 -0
  43. utils/BPE/__init__.py +0 -0
  44. utils/BPE/dict.txt +0 -0
  45. utils/BPE/encoder.json +0 -0
  46. utils/BPE/vocab.bpe +0 -0
  47. utils/__init__.py +0 -0
  48. utils/checkpoint_utils.py +875 -0
  49. utils/cider/pyciderevalcap/__init__.py +1 -0
  50. utils/cider/pyciderevalcap/cider/__init__.py +1 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 1999-2022 Alibaba Group Holding Ltd.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,26 @@
1
  ---
2
- title: OFA Visual_Grounding
3
- emoji: 😻
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
  app_file: app.py
8
- pinned: false
9
- license: apache-2.0
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
1
  ---
2
+ title: OFA-Image_Caption
3
+ emoji: 🖼
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
  app_file: app.py
8
+ pinned: true
 
9
  ---
10
+ # Configuration
11
+ `title`: _string_
12
+ OFA Image Caption
13
+ `emoji`: _string_
14
+ 🖼
15
+ `colorFrom`: _string_
16
+ red
17
+ `colorTo`: _string_
18
+ indigo
19
+ `sdk`: _string_
20
+ gradio
21
+ `app_file`: _string_
22
+ app.py
23
 
24
+
25
+ `pinned`: _boolean_
26
+ true
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;'
4
+ 'pip install --use-feature=in-tree-build ./; cd ..')
5
+ os.system('ls -l')
6
+
7
+ import torch
8
+ import numpy as np
9
+ from fairseq import utils, tasks
10
+ from fairseq import checkpoint_utils
11
+ from utils.eval_utils import eval_step
12
+ from tasks.mm_tasks.refcoco import RefcocoTask
13
+ from models.ofa import OFAModel
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ import cv2
17
+ import gradio as gr
18
+
19
+ # Register refcoco task
20
+ tasks.register_task('refcoco', RefcocoTask)
21
+
22
+ # turn on cuda if GPU is available
23
+ use_cuda = torch.cuda.is_available()
24
+ # use fp16 only when GPU is available
25
+ use_fp16 = False
26
+
27
+ os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/refcocog_large_best.pt; '
28
+ 'mkdir -p checkpoints; mv refcocog_large_best.pt checkpoints/refcocog.pt')
29
+
30
+ # Load pretrained ckpt & config
31
+ overrides = {"bpe_dir": "utils/BPE", "eval_cider": False, "beam": 5,
32
+ "max_len_b": 16, "no_repeat_ngram_size": 3, "seed": 7}
33
+ models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
34
+ utils.split_paths('checkpoints/refcocog.pt'),
35
+ arg_overrides=overrides
36
+ )
37
+
38
+ cfg.common.seed = 7
39
+ cfg.generation.beam = 5
40
+ cfg.generation.min_len = 4
41
+ cfg.generation.max_len_a = 0
42
+ cfg.generation.max_len_b = 4
43
+ cfg.generation.no_repeat_ngram_size = 3
44
+
45
+ # Fix seed for stochastic decoding
46
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
47
+ np.random.seed(cfg.common.seed)
48
+ utils.set_torch_seed(cfg.common.seed)
49
+
50
+ # Move models to GPU
51
+ for model in models:
52
+ model.eval()
53
+ if use_fp16:
54
+ model.half()
55
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
56
+ model.cuda()
57
+ model.prepare_for_inference_(cfg)
58
+
59
+ # Initialize generator
60
+ generator = task.build_generator(models, cfg.generation)
61
+
62
+ mean = [0.5, 0.5, 0.5]
63
+ std = [0.5, 0.5, 0.5]
64
+
65
+ patch_resize_transform = transforms.Compose([
66
+ lambda image: image.convert("RGB"),
67
+ transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=mean, std=std),
70
+ ])
71
+
72
+ # Text preprocess
73
+ bos_item = torch.LongTensor([task.src_dict.bos()])
74
+ eos_item = torch.LongTensor([task.src_dict.eos()])
75
+ pad_idx = task.src_dict.pad()
76
+
77
+
78
+ def encode_text(text, length=None, append_bos=False, append_eos=False):
79
+ s = task.tgt_dict.encode_line(
80
+ line=task.bpe.encode(text),
81
+ add_if_not_exist=False,
82
+ append_eos=False
83
+ ).long()
84
+ if length is not None:
85
+ s = s[:length]
86
+ if append_bos:
87
+ s = torch.cat([bos_item, s])
88
+ if append_eos:
89
+ s = torch.cat([s, eos_item])
90
+ return s
91
+
92
+
93
+ patch_image_size = cfg.task.patch_image_size
94
+
95
+
96
+ def construct_sample(image: Image, text: str):
97
+ w, h = image.size
98
+ w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0)
99
+ h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
100
+ patch_image = patch_resize_transform(image).unsqueeze(0)
101
+ patch_mask = torch.tensor([True])
102
+ src_text = encode_text(' which region does the text " {} " describe?'.format(text), append_bos=True,
103
+ append_eos=True).unsqueeze(0)
104
+ src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
105
+ sample = {
106
+ "id": np.array(['42']),
107
+ "net_input": {
108
+ "src_tokens": src_text,
109
+ "src_lengths": src_length,
110
+ "patch_images": patch_image,
111
+ "patch_masks": patch_mask,
112
+ },
113
+ "w_resize_ratios": w_resize_ratio,
114
+ "h_resize_ratios": h_resize_ratio,
115
+ "region_coords": torch.randn(1, 4)
116
+ }
117
+ return sample
118
+
119
+
120
+ # Function to turn FP32 to FP16
121
+ def apply_half(t):
122
+ if t.dtype is torch.float32:
123
+ return t.to(dtype=torch.half)
124
+ return t
125
+
126
+
127
+ # Function for visual grounding
128
+ def visual_grounding(Image, Text):
129
+ sample = construct_sample(Image, Text)
130
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
131
+ sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
132
+ with torch.no_grad():
133
+ result, scores = eval_step(task, generator, models, sample)
134
+ img = cv2.cvtColor(np.asarray(Image), cv2.COLOR_RGB2BGR)
135
+ cv2.rectangle(
136
+ img,
137
+ (int(result[0]["box"][0]), int(result[0]["box"][1])),
138
+ (int(result[0]["box"][2]), int(result[0]["box"][3])),
139
+ (0, 255, 0),
140
+ 3
141
+ )
142
+ return img
143
+
144
+
145
+ title = "OFA-Visual_Grounding"
146
+ description = "Gradio Demo for OFA-Visual_Grounding. Upload your own image or click any one of the examples, " \
147
+ "and write a description about a certain object. " \
148
+ "Then click \"Submit\" and wait for the result of grounding. "
149
+ article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
150
+ "Repo</a></p> "
151
+ examples = [['pokemons.jpg', 'a blue turtle-like pokemon with round head']]
152
+ io = gr.Interface(fn=visual_grounding, inputs=[gr.inputs.Image(type='pil'), "textbox"],
153
+ outputs=gr.outputs.Image(type='numpy'),
154
+ title=title, description=description, article=article, examples=examples,
155
+ allow_flagging=False, allow_screenshot=False)
156
+ io.launch(enable_queue=True, cache_examples=True)
checkpoints.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Checkpoints
2
+
3
+ We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks.
4
+
5
+ ## Pretraining
6
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a>
7
+
8
+ ## Finetuning
9
+
10
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
colab.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ # Colab Notebooks
2
+
3
+ We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below.
4
+
5
+ [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing) [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
6
+
7
+ [colab]: <https://colab.research.google.com/assets/colab-badge.svg>
criterions/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .scst_loss import ScstRewardCriterion
2
+ from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
criterions/label_smoothed_cross_entropy.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+
19
+ @dataclass
20
+ class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
21
+ label_smoothing: float = field(
22
+ default=0.0,
23
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
24
+ )
25
+ report_accuracy: bool = field(
26
+ default=False,
27
+ metadata={"help": "report accuracy metric"},
28
+ )
29
+ ignore_prefix_size: int = field(
30
+ default=0,
31
+ metadata={"help": "Ignore first N tokens"},
32
+ )
33
+ ignore_eos: bool = field(
34
+ default=False,
35
+ metadata={"help": "Ignore eos token"},
36
+ )
37
+ sentence_avg: bool = II("optimization.sentence_avg")
38
+ drop_worst_ratio: float = field(
39
+ default=0.0,
40
+ metadata={"help": "ratio for discarding bad samples"},
41
+ )
42
+ drop_worst_after: int = field(
43
+ default=0,
44
+ metadata={"help": "steps for discarding bad samples"},
45
+ )
46
+ use_rdrop: bool = field(
47
+ default=False, metadata={"help": "use R-Drop"}
48
+ )
49
+ reg_alpha: float = field(
50
+ default=1.0, metadata={"help": "weight for R-Drop"}
51
+ )
52
+ sample_patch_num: int = field(
53
+ default=196, metadata={"help": "sample patchs for v1"}
54
+ )
55
+ constraint_range: Optional[str] = field(
56
+ default=None,
57
+ metadata={"help": "constraint range"}
58
+ )
59
+
60
+
61
+ def construct_rdrop_sample(x):
62
+ if isinstance(x, dict):
63
+ for key in x:
64
+ x[key] = construct_rdrop_sample(x[key])
65
+ return x
66
+ elif isinstance(x, torch.Tensor):
67
+ return x.repeat(2, *([1] * (x.dim()-1)))
68
+ elif isinstance(x, int):
69
+ return x * 2
70
+ elif isinstance(x, np.ndarray):
71
+ return x.repeat(2)
72
+ else:
73
+ raise NotImplementedError
74
+
75
+
76
+ def kl_loss(p, q):
77
+ p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
78
+ q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
79
+ loss = (p_loss + q_loss) / 2
80
+ return loss
81
+
82
+
83
+ def label_smoothed_nll_loss(
84
+ lprobs, target, epsilon, update_num, reduce=True,
85
+ drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
86
+ constraint_masks=None, constraint_start=None, constraint_end=None
87
+ ):
88
+ if target.dim() == lprobs.dim() - 1:
89
+ target = target.unsqueeze(-1)
90
+ nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
91
+ if constraint_masks is not None:
92
+ smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
93
+ eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
94
+ elif constraint_start is not None and constraint_end is not None:
95
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
96
+ smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
97
+ eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
98
+ else:
99
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
100
+ eps_i = epsilon / (lprobs.size(-1) - 1)
101
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
102
+ if drop_worst_ratio > 0 and update_num > drop_worst_after:
103
+ if use_rdrop:
104
+ true_batch_size = loss.size(0) // 2
105
+ _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
106
+ loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
107
+ nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
108
+ lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
109
+ else:
110
+ loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
111
+ nll_loss = nll_loss[indices]
112
+ lprobs = lprobs[indices]
113
+
114
+ ntokens = loss.numel()
115
+ nll_loss = nll_loss.sum()
116
+ loss = loss.sum()
117
+ if use_rdrop:
118
+ true_batch_size = lprobs.size(0) // 2
119
+ p = lprobs[:true_batch_size]
120
+ q = lprobs[true_batch_size:]
121
+ if constraint_start is not None and constraint_end is not None:
122
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
123
+ p = p[:, constraint_range]
124
+ q = q[:, constraint_range]
125
+ loss += kl_loss(p, q) * reg_alpha
126
+
127
+ return loss, nll_loss, ntokens
128
+
129
+
130
+ @register_criterion(
131
+ "ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
132
+ )
133
+ class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
134
+ def __init__(
135
+ self,
136
+ task,
137
+ sentence_avg,
138
+ label_smoothing,
139
+ ignore_prefix_size=0,
140
+ ignore_eos=False,
141
+ report_accuracy=False,
142
+ drop_worst_ratio=0,
143
+ drop_worst_after=0,
144
+ use_rdrop=False,
145
+ reg_alpha=1.0,
146
+ sample_patch_num=196,
147
+ constraint_range=None
148
+ ):
149
+ super().__init__(task)
150
+ self.sentence_avg = sentence_avg
151
+ self.eps = label_smoothing
152
+ self.ignore_prefix_size = ignore_prefix_size
153
+ self.ignore_eos = ignore_eos
154
+ self.report_accuracy = report_accuracy
155
+ self.drop_worst_ratio = drop_worst_ratio
156
+ self.drop_worst_after = drop_worst_after
157
+ self.use_rdrop = use_rdrop
158
+ self.reg_alpha = reg_alpha
159
+ self.sample_patch_num = sample_patch_num
160
+
161
+ self.constraint_start = None
162
+ self.constraint_end = None
163
+ if constraint_range is not None:
164
+ constraint_start, constraint_end = constraint_range.split(',')
165
+ self.constraint_start = int(constraint_start)
166
+ self.constraint_end = int(constraint_end)
167
+
168
+ def forward(self, model, sample, update_num=0, reduce=True):
169
+ """Compute the loss for the given sample.
170
+
171
+ Returns a tuple with three elements:
172
+ 1) the loss
173
+ 2) the sample size, which is used as the denominator for the gradient
174
+ 3) logging outputs to display while training
175
+ """
176
+ if isinstance(sample, list):
177
+ if self.sample_patch_num > 0:
178
+ sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
179
+ loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
180
+ loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
181
+ loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
182
+ sample_size = 1
183
+ logging_output = {
184
+ "loss": loss.data,
185
+ "loss_v1": loss_v1.data,
186
+ "loss_v2": loss_v2.data,
187
+ "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
188
+ "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
189
+ "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
190
+ "sample_size": 1,
191
+ "sample_size_v1": sample_size_v1,
192
+ "sample_size_v2": sample_size_v2,
193
+ }
194
+ return loss, sample_size, logging_output
195
+
196
+ if self.use_rdrop:
197
+ construct_rdrop_sample(sample)
198
+
199
+ net_output = model(**sample["net_input"])
200
+ loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
201
+ sample_size = (
202
+ sample["target"].size(0) if self.sentence_avg else ntokens
203
+ )
204
+ logging_output = {
205
+ "loss": loss.data,
206
+ "nll_loss": nll_loss.data,
207
+ "ntokens": sample["ntokens"],
208
+ "nsentences": sample["nsentences"],
209
+ "sample_size": sample_size,
210
+ }
211
+ if self.report_accuracy:
212
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
213
+ logging_output["n_correct"] = utils.item(n_correct.data)
214
+ logging_output["total"] = utils.item(total.data)
215
+ return loss, sample_size, logging_output
216
+
217
+ def get_lprobs_and_target(self, model, net_output, sample):
218
+ conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
219
+ constraint_masks = None
220
+ if "constraint_masks" in sample and sample["constraint_masks"] is not None:
221
+ constraint_masks = sample["constraint_masks"]
222
+ net_output[0].masked_fill_(~constraint_masks, -math.inf)
223
+ if self.constraint_start is not None and self.constraint_end is not None:
224
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
225
+ net_output[0][:, :, self.constraint_end:] = -math.inf
226
+ lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
227
+ target = model.get_targets(sample, net_output)
228
+ if self.ignore_prefix_size > 0:
229
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
230
+ target = target[:, self.ignore_prefix_size :].contiguous()
231
+ if constraint_masks is not None:
232
+ constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
233
+ if self.ignore_eos:
234
+ bsz, seq_len, embed_dim = lprobs.size()
235
+ eos_indices = target.eq(self.task.tgt_dict.eos())
236
+ lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
237
+ target = target[~eos_indices].reshape(bsz, seq_len-1)
238
+ if constraint_masks is not None:
239
+ constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
240
+ if constraint_masks is not None:
241
+ constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
242
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
243
+
244
+ def compute_loss(self, model, net_output, sample, update_num, reduce=True):
245
+ lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
246
+ if constraint_masks is not None:
247
+ constraint_masks = constraint_masks[target != self.padding_idx]
248
+ lprobs = lprobs[target != self.padding_idx]
249
+ target = target[target != self.padding_idx]
250
+ loss, nll_loss, ntokens = label_smoothed_nll_loss(
251
+ lprobs,
252
+ target,
253
+ self.eps,
254
+ update_num,
255
+ reduce=reduce,
256
+ drop_worst_ratio=self.drop_worst_ratio,
257
+ drop_worst_after=self.drop_worst_after,
258
+ use_rdrop=self.use_rdrop,
259
+ reg_alpha=self.reg_alpha,
260
+ constraint_masks=constraint_masks,
261
+ constraint_start=self.constraint_start,
262
+ constraint_end=self.constraint_end
263
+ )
264
+ return loss, nll_loss, ntokens
265
+
266
+ def compute_accuracy(self, model, net_output, sample):
267
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
268
+ mask = target.ne(self.padding_idx)
269
+ n_correct = torch.sum(
270
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
271
+ )
272
+ total = torch.sum(mask)
273
+ return n_correct, total
274
+
275
+ @classmethod
276
+ def reduce_metrics(cls, logging_outputs) -> None:
277
+ """Aggregate logging outputs from data parallel training."""
278
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
279
+ loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
280
+ loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
281
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
282
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
283
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
284
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
285
+ sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
286
+ sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
287
+
288
+ metrics.log_scalar(
289
+ "loss", loss_sum / sample_size, sample_size, round=3
290
+ )
291
+ metrics.log_scalar(
292
+ "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
293
+ )
294
+ metrics.log_scalar(
295
+ "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
296
+ )
297
+ metrics.log_scalar(
298
+ "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
299
+ )
300
+ metrics.log_derived(
301
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
302
+ )
303
+
304
+ metrics.log_scalar(
305
+ "ntokens", ntokens, 1, round=3
306
+ )
307
+ metrics.log_scalar(
308
+ "nsentences", nsentences, 1, round=3
309
+ )
310
+ metrics.log_scalar(
311
+ "sample_size", sample_size, 1, round=3
312
+ )
313
+ metrics.log_scalar(
314
+ "sample_size_v1", sample_size_v1, 1, round=3
315
+ )
316
+ metrics.log_scalar(
317
+ "sample_size_v2", sample_size_v2, 1, round=3
318
+ )
319
+
320
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
321
+ if total > 0:
322
+ metrics.log_scalar("total", total)
323
+ n_correct = utils.item(
324
+ sum(log.get("n_correct", 0) for log in logging_outputs)
325
+ )
326
+ metrics.log_scalar("n_correct", n_correct)
327
+ metrics.log_derived(
328
+ "accuracy",
329
+ lambda meters: round(
330
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
331
+ )
332
+ if meters["total"].sum > 0
333
+ else float("nan"),
334
+ )
335
+
336
+ @staticmethod
337
+ def logging_outputs_can_be_summed() -> bool:
338
+ """
339
+ Whether the logging outputs returned by `forward` can be summed
340
+ across workers prior to calling `reduce_metrics`. Setting this
341
+ to True will improves distributed training speed.
342
+ """
343
+ return True
criterions/scst_loss.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import string
8
+ from dataclasses import dataclass, field
9
+ from collections import OrderedDict
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+ from data import data_utils
19
+ from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
20
+
21
+
22
+ def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
23
+ loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
24
+ if ignore_index is not None:
25
+ pad_mask = target.eq(ignore_index)
26
+ loss.masked_fill_(pad_mask, 0.0)
27
+ ntokens = (~pad_mask).sum()
28
+ else:
29
+ loss = loss.squeeze(-1)
30
+ ntokens = target.numel()
31
+ if reduce:
32
+ loss = loss.sum()
33
+ return loss, ntokens
34
+
35
+ @dataclass
36
+ class ScstRewardCriterionConfig(FairseqDataclass):
37
+ scst_cider_cached_tokens: str = field(
38
+ default="coco-train-words.p",
39
+ metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
40
+ )
41
+ ignore_prefix_size: int = field(
42
+ default=0,
43
+ metadata={"help": "Ignore first N tokens"},
44
+ )
45
+ sentence_avg: bool = II("optimization.sentence_avg")
46
+ constraint_range: Optional[str] = field(
47
+ default=None,
48
+ metadata={"help": "constraint range"}
49
+ )
50
+
51
+
52
+ @register_criterion(
53
+ "scst_reward_criterion", dataclass=ScstRewardCriterionConfig
54
+ )
55
+ class ScstRewardCriterion(FairseqCriterion):
56
+ CIDER_REWARD_WEIGHT = 1
57
+
58
+ def __init__(
59
+ self,
60
+ task,
61
+ scst_cider_cached_tokens,
62
+ sentence_avg,
63
+ ignore_prefix_size=0,
64
+ constraint_range=None
65
+ ):
66
+ super().__init__(task)
67
+ self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
68
+ self.sentence_avg = sentence_avg
69
+ self.ignore_prefix_size = ignore_prefix_size
70
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
71
+
72
+ self.constraint_start = None
73
+ self.constraint_end = None
74
+ if constraint_range is not None:
75
+ constraint_start, constraint_end = constraint_range.split(',')
76
+ self.constraint_start = int(constraint_start)
77
+ self.constraint_end = int(constraint_end)
78
+
79
+ def forward(self, model, sample, update_num=0, reduce=True):
80
+ """Compute the loss for the given sample.
81
+
82
+ Returns a tuple with three elements:
83
+ 1) the loss
84
+ 2) the sample size, which is used as the denominator for the gradient
85
+ 3) logging outputs to display while training
86
+ """
87
+ loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
88
+
89
+ sample_size = (
90
+ nsentences if self.sentence_avg else ntokens
91
+ )
92
+ logging_output = {
93
+ "loss": loss.data,
94
+ "score": score,
95
+ "ntokens": ntokens,
96
+ "nsentences": nsentences,
97
+ "sample_size": sample_size,
98
+ }
99
+ return loss, sample_size, logging_output
100
+
101
+ def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
102
+ '''
103
+ gen_res: generated captions, list of str
104
+ gt_idx: list of int, of the same length as gen_res
105
+ gt_res: ground truth captions, list of list of str.
106
+ gen_res[i] corresponds to gt_res[gt_idx[i]]
107
+ Each image can have multiple ground truth captions
108
+ '''
109
+ gen_res_size = len(gen_res)
110
+
111
+ res = OrderedDict()
112
+ for i in range(gen_res_size):
113
+ res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
114
+
115
+ gts = OrderedDict()
116
+ gt_res_ = [
117
+ [self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
118
+ for i in range(len(gt_res))
119
+ ]
120
+ for i in range(gen_res_size):
121
+ gts[i] = gt_res_[gt_idx[i]]
122
+
123
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
124
+ _, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
125
+ scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
126
+ return scores
127
+
128
+ @classmethod
129
+ def _wrap_sentence(self, s):
130
+ # ensure the sentence ends with <eos> token
131
+ # in order to keep consisitent with cider_cached_tokens
132
+ r = s.strip()
133
+ if r.endswith('.'):
134
+ r = r[:-1]
135
+ r += ' <eos>'
136
+ return r
137
+
138
+ def get_generator_out(self, model, sample):
139
+ def decode(toks):
140
+ hypo = toks.int().cpu()
141
+ hypo_str = self.task.tgt_dict.string(hypo)
142
+ hypo_str = self.task.bpe.decode(hypo_str).strip()
143
+ return hypo, hypo_str
144
+
145
+ model.eval()
146
+ with torch.no_grad():
147
+ self.task.scst_generator.model.eval()
148
+ gen_out = self.task.scst_generator.generate([model], sample)
149
+
150
+ gen_target = []
151
+ gen_res = []
152
+ gt_res = []
153
+ for i in range(len(gen_out)):
154
+ for j in range(len(gen_out[i])):
155
+ hypo, hypo_str = decode(gen_out[i][j]["tokens"])
156
+ gen_target.append(hypo)
157
+ gen_res.append(hypo_str)
158
+ gt_res.append(
159
+ decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
160
+ )
161
+
162
+ return gen_target, gen_res, gt_res
163
+
164
+ def get_reward_and_scores(self, gen_res, gt_res, device):
165
+ batch_size = len(gt_res)
166
+ gen_res_size = len(gen_res)
167
+ seq_per_img = gen_res_size // batch_size
168
+
169
+ gt_idx = [i // seq_per_img for i in range(gen_res_size)]
170
+ scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
171
+ sc_ = scores.reshape(batch_size, seq_per_img)
172
+ baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
173
+ # sample - baseline
174
+ reward = scores.reshape(batch_size, seq_per_img)
175
+ reward = reward - baseline
176
+ reward = reward.reshape(gen_res_size)
177
+ reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
178
+
179
+ return reward, scores
180
+
181
+ def get_net_output(self, model, sample, gen_target):
182
+ def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
183
+ return data_utils.collate_tokens(
184
+ sample_list,
185
+ pad_idx=self.padding_idx,
186
+ eos_idx=eos,
187
+ left_pad=False,
188
+ move_eos_to_beginning=move_eos_to_beginning,
189
+ )
190
+
191
+ batch_size = len(sample["target"])
192
+ gen_target_size = len(gen_target)
193
+ seq_per_img = gen_target_size // batch_size
194
+
195
+ model.train()
196
+ sample_src_tokens = torch.repeat_interleave(
197
+ sample['net_input']['src_tokens'], seq_per_img, dim=0
198
+ )
199
+ sample_src_lengths = torch.repeat_interleave(
200
+ sample['net_input']['src_lengths'], seq_per_img, dim=0
201
+ )
202
+ sample_patch_images = torch.repeat_interleave(
203
+ sample['net_input']['patch_images'], seq_per_img, dim=0
204
+ )
205
+ sample_patch_masks = torch.repeat_interleave(
206
+ sample['net_input']['patch_masks'], seq_per_img, dim=0
207
+ )
208
+ gen_prev_output_tokens = torch.as_tensor(
209
+ merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
210
+ device=sample["target"].device, dtype=torch.int64
211
+ )
212
+ gen_target_tokens = torch.as_tensor(
213
+ merge(gen_target), device=sample["target"].device, dtype=torch.int64
214
+ )
215
+ net_output = model(
216
+ src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
217
+ patch_images=sample_patch_images, patch_masks=sample_patch_masks,
218
+ prev_output_tokens=gen_prev_output_tokens
219
+ )
220
+
221
+ return net_output, gen_target_tokens
222
+
223
+ def get_lprobs_and_target(self, model, net_output, gen_target):
224
+ if self.constraint_start is not None and self.constraint_end is not None:
225
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
226
+ net_output[0][:, :, self.constraint_end:] = -math.inf
227
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
228
+ if self.ignore_prefix_size > 0:
229
+ if getattr(lprobs, "batch_first", False):
230
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
231
+ gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
232
+ else:
233
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
234
+ gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
235
+ return lprobs, gen_target
236
+
237
+ def compute_loss(self, model, sample, reduce=True):
238
+ gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
239
+ reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
240
+ net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
241
+ gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
242
+ loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
243
+ nsentences = gen_target_tokens.size(0)
244
+
245
+ return loss, scores.sum(), ntokens, nsentences
246
+
247
+ @classmethod
248
+ def reduce_metrics(cls, logging_outputs) -> None:
249
+ """Aggregate logging outputs from data parallel training."""
250
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
251
+ score_sum = sum(log.get("score", 0) for log in logging_outputs)
252
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
253
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
254
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
255
+
256
+ metrics.log_scalar(
257
+ "loss", loss_sum / sample_size, sample_size, round=3
258
+ )
259
+ metrics.log_scalar(
260
+ "score", score_sum / nsentences, nsentences, round=3
261
+ )
262
+
263
+ metrics.log_scalar(
264
+ "ntokens", ntokens, 1, round=3
265
+ )
266
+ metrics.log_scalar(
267
+ "nsentences", nsentences, 1, round=3
268
+ )
269
+ metrics.log_scalar(
270
+ "sample_size", sample_size, 1, round=3
271
+ )
272
+
273
+ @staticmethod
274
+ def logging_outputs_can_be_summed() -> bool:
275
+ """
276
+ Whether the logging outputs returned by `forward` can be summed
277
+ across workers prior to calling `reduce_metrics`. Setting this
278
+ to True will improves distributed training speed.
279
+ """
280
+ return True
data/__init__.py ADDED
File without changes
data/data_utils.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import re
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from fairseq.file_io import PathManager
21
+ from fairseq import utils
22
+ import os
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def infer_language_pair(path):
28
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
29
+ src, dst = None, None
30
+ for filename in PathManager.ls(path):
31
+ parts = filename.split(".")
32
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
33
+ return parts[1].split("-")
34
+ return src, dst
35
+
36
+
37
+ def collate_tokens(
38
+ values,
39
+ pad_idx,
40
+ eos_idx=None,
41
+ left_pad=False,
42
+ move_eos_to_beginning=False,
43
+ pad_to_length=None,
44
+ pad_to_multiple=1,
45
+ pad_to_bsz=None,
46
+ ):
47
+ """Convert a list of 1d tensors into a padded 2d tensor."""
48
+ size = max(v.size(0) for v in values)
49
+ size = size if pad_to_length is None else max(size, pad_to_length)
50
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
51
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
52
+
53
+ def copy_tensor(src, dst):
54
+ assert dst.numel() == src.numel()
55
+ if move_eos_to_beginning:
56
+ if eos_idx is None:
57
+ # if no eos_idx is specified, then use the last token in src
58
+ dst[0] = src[-1]
59
+ else:
60
+ dst[0] = eos_idx
61
+ dst[1:] = src[:-1]
62
+ else:
63
+ dst.copy_(src)
64
+
65
+ if values[0].dim() == 1:
66
+ res = values[0].new(len(values), size).fill_(pad_idx)
67
+ elif values[0].dim() == 2:
68
+ assert move_eos_to_beginning is False
69
+ res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ for i, v in enumerate(values):
74
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
75
+ return res
76
+
77
+
78
+ def load_indexed_dataset(
79
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
80
+ ):
81
+ """A helper function for loading indexed datasets.
82
+
83
+ Args:
84
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
85
+ dictionary (~fairseq.data.Dictionary): data dictionary
86
+ dataset_impl (str, optional): which dataset implementation to use. If
87
+ not provided, it will be inferred automatically. For legacy indexed
88
+ data we use the 'cached' implementation by default.
89
+ combine (bool, optional): automatically load and combine multiple
90
+ datasets. For example, if *path* is 'data-bin/train', then we will
91
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
92
+ single ConcatDataset instance.
93
+ """
94
+ import fairseq.data.indexed_dataset as indexed_dataset
95
+ from fairseq.data.concat_dataset import ConcatDataset
96
+
97
+ datasets = []
98
+ for k in itertools.count():
99
+ path_k = path + (str(k) if k > 0 else "")
100
+ try:
101
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
102
+ except Exception as e:
103
+ if "StorageException: [404] Path not found" in str(e):
104
+ logger.warning(f"path_k: {e} not found")
105
+ else:
106
+ raise e
107
+
108
+ dataset_impl_k = dataset_impl
109
+ if dataset_impl_k is None:
110
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
111
+ dataset = indexed_dataset.make_dataset(
112
+ path_k,
113
+ impl=dataset_impl_k or default,
114
+ fix_lua_indexing=True,
115
+ dictionary=dictionary,
116
+ )
117
+ if dataset is None:
118
+ break
119
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
120
+ datasets.append(dataset)
121
+ if not combine:
122
+ break
123
+ if len(datasets) == 0:
124
+ return None
125
+ elif len(datasets) == 1:
126
+ return datasets[0]
127
+ else:
128
+ return ConcatDataset(datasets)
129
+
130
+
131
+ @contextlib.contextmanager
132
+ def numpy_seed(seed, *addl_seeds):
133
+ """Context manager which seeds the NumPy PRNG with the specified seed and
134
+ restores the state afterward"""
135
+ if seed is None:
136
+ yield
137
+ return
138
+ if len(addl_seeds) > 0:
139
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
140
+ state = np.random.get_state()
141
+ np.random.seed(seed)
142
+ try:
143
+ yield
144
+ finally:
145
+ np.random.set_state(state)
146
+
147
+
148
+ def collect_filtered(function, iterable, filtered):
149
+ """
150
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
151
+
152
+ Args:
153
+ function (callable): function that returns ``False`` for elements that
154
+ should be filtered
155
+ iterable (iterable): iterable to filter
156
+ filtered (list): list to store filtered elements
157
+ """
158
+ for el in iterable:
159
+ if function(el):
160
+ yield el
161
+ else:
162
+ filtered.append(el)
163
+
164
+
165
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
166
+ def compare_leq(a, b):
167
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
168
+
169
+ def check_size(idx):
170
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
171
+ return size_fn(idx) <= max_positions
172
+ elif isinstance(max_positions, dict):
173
+ idx_size = size_fn(idx)
174
+ assert isinstance(idx_size, dict)
175
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
176
+ return all(
177
+ all(
178
+ a is None or b is None or a <= b
179
+ for a, b in zip(idx_size[key], max_positions[key])
180
+ )
181
+ for key in intersect_keys
182
+ )
183
+ else:
184
+ # For MultiCorpusSampledDataset, will generalize it later
185
+ if not isinstance(size_fn(idx), Iterable):
186
+ return all(size_fn(idx) <= b for b in max_positions)
187
+ return all(
188
+ a is None or b is None or a <= b
189
+ for a, b in zip(size_fn(idx), max_positions)
190
+ )
191
+
192
+ ignored = []
193
+ itr = collect_filtered(check_size, indices, ignored)
194
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
195
+ return indices, ignored
196
+
197
+
198
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
199
+ """
200
+ [deprecated] Filter indices based on their size.
201
+ Use `FairseqDataset::filter_indices_by_size` instead.
202
+
203
+ Args:
204
+ indices (List[int]): ordered list of dataset indices
205
+ dataset (FairseqDataset): fairseq dataset instance
206
+ max_positions (tuple): filter elements larger than this size.
207
+ Comparisons are done component-wise.
208
+ raise_exception (bool, optional): if ``True``, raise an exception if
209
+ any elements are filtered (default: False).
210
+ """
211
+ warnings.warn(
212
+ "data_utils.filter_by_size is deprecated. "
213
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
214
+ stacklevel=2,
215
+ )
216
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
217
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
218
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
219
+ indices = indices[dataset.sizes[indices] <= max_positions]
220
+ elif (
221
+ hasattr(dataset, "sizes")
222
+ and isinstance(dataset.sizes, list)
223
+ and len(dataset.sizes) == 1
224
+ ):
225
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
226
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
227
+ else:
228
+ indices, ignored = _filter_by_size_dynamic(
229
+ indices, dataset.size, max_positions
230
+ )
231
+ else:
232
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
233
+
234
+ if len(ignored) > 0 and raise_exception:
235
+ raise Exception(
236
+ (
237
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
238
+ "skip this example with --skip-invalid-size-inputs-valid-test"
239
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
240
+ )
241
+ if len(ignored) > 0:
242
+ logger.warning(
243
+ (
244
+ "{} samples have invalid sizes and will be skipped, "
245
+ "max_positions={}, first few sample ids={}"
246
+ ).format(len(ignored), max_positions, ignored[:10])
247
+ )
248
+ return indices
249
+
250
+
251
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
252
+ """Filter a list of sample indices. Remove those that are longer
253
+ than specified in max_sizes.
254
+
255
+ Args:
256
+ indices (np.array): original array of sample indices
257
+ max_sizes (int or list[int] or tuple[int]): max sample size,
258
+ can be defined separately for src and tgt (then list or tuple)
259
+
260
+ Returns:
261
+ np.array: filtered sample array
262
+ list: list of removed indices
263
+ """
264
+ if max_sizes is None:
265
+ return indices, []
266
+ if type(max_sizes) in (int, float):
267
+ max_src_size, max_tgt_size = max_sizes, max_sizes
268
+ else:
269
+ max_src_size, max_tgt_size = max_sizes
270
+ if tgt_sizes is None:
271
+ ignored = indices[src_sizes[indices] > max_src_size]
272
+ else:
273
+ ignored = indices[
274
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
275
+ ]
276
+ if len(ignored) > 0:
277
+ if tgt_sizes is None:
278
+ indices = indices[src_sizes[indices] <= max_src_size]
279
+ else:
280
+ indices = indices[
281
+ (src_sizes[indices] <= max_src_size)
282
+ & (tgt_sizes[indices] <= max_tgt_size)
283
+ ]
284
+ return indices, ignored.tolist()
285
+
286
+
287
+ def batch_by_size(
288
+ indices,
289
+ num_tokens_fn,
290
+ num_tokens_vec=None,
291
+ max_tokens=None,
292
+ max_sentences=None,
293
+ required_batch_size_multiple=1,
294
+ fixed_shapes=None,
295
+ ):
296
+ """
297
+ Yield mini-batches of indices bucketed by size. Batches may contain
298
+ sequences of different lengths.
299
+
300
+ Args:
301
+ indices (List[int]): ordered list of dataset indices
302
+ num_tokens_fn (callable): function that returns the number of tokens at
303
+ a given index
304
+ num_tokens_vec (List[int], optional): precomputed vector of the number
305
+ of tokens for each index in indices (to enable faster batch generation)
306
+ max_tokens (int, optional): max number of tokens in each batch
307
+ (default: None).
308
+ max_sentences (int, optional): max number of sentences in each
309
+ batch (default: None).
310
+ required_batch_size_multiple (int, optional): require batch size to
311
+ be less than N or a multiple of N (default: 1).
312
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
313
+ only be created with the given shapes. *max_sentences* and
314
+ *required_batch_size_multiple* will be ignored (default: None).
315
+ """
316
+ try:
317
+ from fairseq.data.data_utils_fast import (
318
+ batch_by_size_fn,
319
+ batch_by_size_vec,
320
+ batch_fixed_shapes_fast,
321
+ )
322
+ except ImportError:
323
+ raise ImportError(
324
+ "Please build Cython components with: "
325
+ "`python setup.py build_ext --inplace`"
326
+ )
327
+ except ValueError:
328
+ raise ValueError(
329
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
330
+ )
331
+
332
+ # added int() to avoid TypeError: an integer is required
333
+ max_tokens = (
334
+ int(max_tokens) if max_tokens is not None else -1
335
+ )
336
+ max_sentences = max_sentences if max_sentences is not None else -1
337
+ bsz_mult = required_batch_size_multiple
338
+
339
+ if not isinstance(indices, np.ndarray):
340
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
341
+
342
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
343
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
344
+
345
+ if fixed_shapes is None:
346
+ if num_tokens_vec is None:
347
+ return batch_by_size_fn(
348
+ indices,
349
+ num_tokens_fn,
350
+ max_tokens,
351
+ max_sentences,
352
+ bsz_mult,
353
+ )
354
+ else:
355
+ return batch_by_size_vec(
356
+ indices,
357
+ num_tokens_vec,
358
+ max_tokens,
359
+ max_sentences,
360
+ bsz_mult,
361
+ )
362
+
363
+ else:
364
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
365
+ sort_order = np.lexsort(
366
+ [
367
+ fixed_shapes[:, 1].argsort(), # length
368
+ fixed_shapes[:, 0].argsort(), # bsz
369
+ ]
370
+ )
371
+ fixed_shapes_sorted = fixed_shapes[sort_order]
372
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
373
+
374
+
375
+ def post_process(sentence: str, symbol: str):
376
+ if symbol == "sentencepiece":
377
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
378
+ elif symbol == "wordpiece":
379
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
380
+ elif symbol == "letter":
381
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
382
+ elif symbol == "silence":
383
+ import re
384
+ sentence = sentence.replace("<SIL>", "")
385
+ sentence = re.sub(' +', ' ', sentence).strip()
386
+ elif symbol == "_EOW":
387
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
+ if symbol == "subword_nmt":
390
+ symbol = "@@ "
391
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
392
+ elif symbol == "none":
393
+ pass
394
+ elif symbol is not None:
395
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
+ return sentence
397
+
398
+
399
+ def compute_mask_indices(
400
+ shape: Tuple[int, int],
401
+ padding_mask: Optional[torch.Tensor],
402
+ mask_prob: float,
403
+ mask_length: int,
404
+ mask_type: str = "static",
405
+ mask_other: float = 0.0,
406
+ min_masks: int = 0,
407
+ no_overlap: bool = False,
408
+ min_space: int = 0,
409
+ ) -> np.ndarray:
410
+ """
411
+ Computes random mask spans for a given shape
412
+
413
+ Args:
414
+ shape: the the shape for which to compute masks.
415
+ should be of size 2 where first element is batch size and 2nd is timesteps
416
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
417
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
418
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
419
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
420
+ mask_type: how to compute mask lengths
421
+ static = fixed size
422
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
423
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
424
+ poisson = sample from possion distribution with lambda = mask length
425
+ min_masks: minimum number of masked spans
426
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
427
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
428
+ """
429
+
430
+ bsz, all_sz = shape
431
+ mask = np.full((bsz, all_sz), False)
432
+
433
+ all_num_mask = int(
434
+ # add a random number for probabilistic rounding
435
+ mask_prob * all_sz / float(mask_length)
436
+ + np.random.rand()
437
+ )
438
+
439
+ all_num_mask = max(min_masks, all_num_mask)
440
+
441
+ mask_idcs = []
442
+ for i in range(bsz):
443
+ if padding_mask is not None:
444
+ sz = all_sz - padding_mask[i].long().sum().item()
445
+ num_mask = int(
446
+ # add a random number for probabilistic rounding
447
+ mask_prob * sz / float(mask_length)
448
+ + np.random.rand()
449
+ )
450
+ num_mask = max(min_masks, num_mask)
451
+ else:
452
+ sz = all_sz
453
+ num_mask = all_num_mask
454
+
455
+ if mask_type == "static":
456
+ lengths = np.full(num_mask, mask_length)
457
+ elif mask_type == "uniform":
458
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
459
+ elif mask_type == "normal":
460
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
461
+ lengths = [max(1, int(round(x))) for x in lengths]
462
+ elif mask_type == "poisson":
463
+ lengths = np.random.poisson(mask_length, size=num_mask)
464
+ lengths = [int(round(x)) for x in lengths]
465
+ else:
466
+ raise Exception("unknown mask selection " + mask_type)
467
+
468
+ if sum(lengths) == 0:
469
+ lengths[0] = min(mask_length, sz - 1)
470
+
471
+ if no_overlap:
472
+ mask_idc = []
473
+
474
+ def arrange(s, e, length, keep_length):
475
+ span_start = np.random.randint(s, e - length)
476
+ mask_idc.extend(span_start + i for i in range(length))
477
+
478
+ new_parts = []
479
+ if span_start - s - min_space >= keep_length:
480
+ new_parts.append((s, span_start - min_space + 1))
481
+ if e - span_start - keep_length - min_space > keep_length:
482
+ new_parts.append((span_start + length + min_space, e))
483
+ return new_parts
484
+
485
+ parts = [(0, sz)]
486
+ min_length = min(lengths)
487
+ for length in sorted(lengths, reverse=True):
488
+ lens = np.fromiter(
489
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
490
+ np.int,
491
+ )
492
+ l_sum = np.sum(lens)
493
+ if l_sum == 0:
494
+ break
495
+ probs = lens / np.sum(lens)
496
+ c = np.random.choice(len(parts), p=probs)
497
+ s, e = parts.pop(c)
498
+ parts.extend(arrange(s, e, length, min_length))
499
+ mask_idc = np.asarray(mask_idc)
500
+ else:
501
+ min_len = min(lengths)
502
+ if sz - min_len <= num_mask:
503
+ min_len = sz - num_mask - 1
504
+
505
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
506
+
507
+ mask_idc = np.asarray(
508
+ [
509
+ mask_idc[j] + offset
510
+ for j in range(len(mask_idc))
511
+ for offset in range(lengths[j])
512
+ ]
513
+ )
514
+
515
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
516
+
517
+ min_len = min([len(m) for m in mask_idcs])
518
+ for i, mask_idc in enumerate(mask_idcs):
519
+ if len(mask_idc) > min_len:
520
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
521
+ mask[i, mask_idc] = True
522
+
523
+ return mask
524
+
525
+
526
+ def get_mem_usage():
527
+ try:
528
+ import psutil
529
+
530
+ mb = 1024 * 1024
531
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
532
+ except ImportError:
533
+ return "N/A"
534
+
535
+
536
+ # lens: torch.LongTensor
537
+ # returns: torch.BoolTensor
538
+ def lengths_to_padding_mask(lens):
539
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
540
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
541
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
542
+ return mask
543
+
544
+
545
+ # lens: torch.LongTensor
546
+ # returns: torch.BoolTensor
547
+ def lengths_to_mask(lens):
548
+ return ~lengths_to_padding_mask(lens)
549
+
550
+
551
+ def get_buckets(sizes, num_buckets):
552
+ buckets = np.unique(
553
+ np.percentile(
554
+ sizes,
555
+ np.linspace(0, 100, num_buckets + 1),
556
+ interpolation='lower',
557
+ )[1:]
558
+ )
559
+ return buckets
560
+
561
+
562
+ def get_bucketed_sizes(orig_sizes, buckets):
563
+ sizes = np.copy(orig_sizes)
564
+ assert np.min(sizes) >= 0
565
+ start_val = -1
566
+ for end_val in buckets:
567
+ mask = (sizes > start_val) & (sizes <= end_val)
568
+ sizes[mask] = end_val
569
+ start_val = end_val
570
+ return sizes
571
+
572
+
573
+
574
+ def _find_extra_valid_paths(dataset_path: str) -> set:
575
+ paths = utils.split_paths(dataset_path)
576
+ all_valid_paths = set()
577
+ for sub_dir in paths:
578
+ contents = PathManager.ls(sub_dir)
579
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
580
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
581
+ # Remove .bin, .idx etc
582
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
583
+ return roots
584
+
585
+
586
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
587
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
588
+ if (
589
+ train_cfg.dataset.ignore_unused_valid_subsets
590
+ or train_cfg.dataset.combine_valid_subsets
591
+ or train_cfg.dataset.disable_validation
592
+ or not hasattr(train_cfg.task, "data")
593
+ ):
594
+ return
595
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
596
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
597
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
598
+ if ignored_paths:
599
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
600
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
601
+ raise ValueError(msg)
data/file_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pickle
4
+
5
+
6
+ class FileDataset:
7
+ def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
8
+ self.file_path = file_path
9
+ assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
10
+
11
+ self.separator = separator
12
+ if selected_col_ids is None:
13
+ # default to all fields
14
+ self.selected_col_ids = list(
15
+ range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
16
+ else:
17
+ self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
18
+ if dtypes is None:
19
+ # default to str
20
+ self.dtypes = [str for col_id in self.selected_col_ids]
21
+ else:
22
+ self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
23
+ assert len(self.dtypes) == len(self.selected_col_ids)
24
+
25
+ self.data_cnt = 0
26
+ try:
27
+ self.slice_id = torch.distributed.get_rank()
28
+ self.slice_count = torch.distributed.get_world_size()
29
+ except Exception:
30
+ self.slice_id = 0
31
+ self.slice_count = 1
32
+ self.cached_index = cached_index
33
+ self._init_seek_index()
34
+ self._reader = self._get_reader()
35
+ print("file {} slice_id {} row count {} total row count {}".format(
36
+ self.file_path, self.slice_id, self.row_count, self.total_row_count)
37
+ )
38
+
39
+ def _init_seek_index(self):
40
+ if self.cached_index:
41
+ cache_path = "{}.index".format(self.file_path)
42
+ assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
43
+ self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
44
+ print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
45
+ self.file_path, self.slice_id))
46
+ else:
47
+ # make an iteration over the file to get row_count and line_idx-to-offset mapping
48
+ fp = open(self.file_path, "r")
49
+ print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
50
+ self.file_path, self.slice_id))
51
+ self.total_row_count = 0
52
+ offset = 0
53
+ self.lineid_to_offset = []
54
+ for line in fp:
55
+ self.lineid_to_offset.append(offset)
56
+ self.total_row_count += 1
57
+ offset += len(line.encode('utf-8'))
58
+ self._compute_start_pos_and_row_count()
59
+ print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
60
+ self.file_path, self.slice_id))
61
+
62
+ def _compute_start_pos_and_row_count(self):
63
+ self.row_count = self.total_row_count // self.slice_count
64
+ if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
65
+ self.row_count += 1
66
+ self.start_pos = self.row_count * self.slice_id
67
+ else:
68
+ self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
69
+
70
+ def _get_reader(self):
71
+ fp = open(self.file_path, "r")
72
+ fp.seek(self.lineid_to_offset[self.start_pos])
73
+ return fp
74
+
75
+ def _seek(self, offset=0):
76
+ try:
77
+ print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
78
+ self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
79
+ self.data_cnt = offset
80
+ except Exception:
81
+ print("slice_id {} seek offset {}".format(self.slice_id, offset))
82
+ self._reader.seek(self.lineid_to_offset[offset])
83
+ self.data_cnt = offset
84
+
85
+ def __del__(self):
86
+ self._reader.close()
87
+
88
+ def __len__(self):
89
+ return self.row_count
90
+
91
+ def get_total_row_count(self):
92
+ return self.total_row_count
93
+
94
+ def __getitem__(self, index):
95
+ if self.data_cnt == self.row_count:
96
+ print("reach the end of datafile, start a new reader")
97
+ self.data_cnt = 0
98
+ self._reader = self._get_reader()
99
+ column_l = self._reader.readline().rstrip("\n").split(self.separator)
100
+ self.data_cnt += 1
101
+ column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
102
+ return column_l
data/mm_data/__init__.py ADDED
File without changes
data/mm_data/caption_dataset.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+ import string
10
+
11
+ import numpy as np
12
+ import torch
13
+ import base64
14
+ from torchvision import transforms
15
+
16
+ from PIL import Image, ImageFile
17
+
18
+ from data import data_utils
19
+ from data.ofa_dataset import OFADataset
20
+
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+ ImageFile.MAX_IMAGE_PIXELS = None
23
+ Image.MAX_IMAGE_PIXELS = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
+
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+
31
+
32
+ def collate(samples, pad_idx, eos_idx):
33
+ if len(samples) == 0:
34
+ return {}
35
+
36
+ def merge(key):
37
+ return data_utils.collate_tokens(
38
+ [s[key] for s in samples],
39
+ pad_idx,
40
+ eos_idx=eos_idx,
41
+ )
42
+
43
+ id = np.array([s["id"] for s in samples])
44
+ src_tokens = merge("source")
45
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
+
47
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
+
50
+ prev_output_tokens = None
51
+ target = None
52
+ if samples[0].get("target", None) is not None:
53
+ target = merge("target")
54
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
55
+ ntokens = tgt_lengths.sum().item()
56
+
57
+ if samples[0].get("prev_output_tokens", None) is not None:
58
+ prev_output_tokens = merge("prev_output_tokens")
59
+ else:
60
+ ntokens = src_lengths.sum().item()
61
+
62
+ batch = {
63
+ "id": id,
64
+ "nsentences": len(samples),
65
+ "ntokens": ntokens,
66
+ "net_input": {
67
+ "src_tokens": src_tokens,
68
+ "src_lengths": src_lengths,
69
+ "patch_images": patch_images,
70
+ "patch_masks": patch_masks,
71
+ "prev_output_tokens": prev_output_tokens
72
+ },
73
+ "target": target,
74
+ }
75
+
76
+ return batch
77
+
78
+
79
+ class CaptionDataset(OFADataset):
80
+ def __init__(
81
+ self,
82
+ split,
83
+ dataset,
84
+ bpe,
85
+ src_dict,
86
+ tgt_dict=None,
87
+ max_src_length=128,
88
+ max_tgt_length=30,
89
+ patch_image_size=224,
90
+ imagenet_default_mean_and_std=False,
91
+ scst=False
92
+ ):
93
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
94
+ self.max_src_length = max_src_length
95
+ self.max_tgt_length = max_tgt_length
96
+ self.patch_image_size = patch_image_size
97
+ self.scst = scst
98
+
99
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
100
+
101
+ if imagenet_default_mean_and_std:
102
+ mean = IMAGENET_DEFAULT_MEAN
103
+ std = IMAGENET_DEFAULT_STD
104
+ else:
105
+ mean = [0.5, 0.5, 0.5]
106
+ std = [0.5, 0.5, 0.5]
107
+
108
+ self.patch_resize_transform = transforms.Compose([
109
+ lambda image: image.convert("RGB"),
110
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(mean=mean, std=std),
113
+ ])
114
+
115
+ def __getitem__(self, index):
116
+ uniq_id, image, caption = self.dataset[index]
117
+
118
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
119
+ patch_image = self.patch_resize_transform(image)
120
+ patch_mask = torch.tensor([True])
121
+
122
+ if self.split == 'train' and not self.scst:
123
+ caption = caption.translate(self.transtab).strip()
124
+ caption_token_list = caption.strip().split()
125
+ tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
126
+ else:
127
+ caption = ' '.join(caption.strip().split())
128
+ caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
129
+ tgt_caption = '&&'.join(caption_list)
130
+ src_item = self.encode_text(" what does the image describe?")
131
+ tgt_item = self.encode_text(" {}".format(tgt_caption))
132
+
133
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
134
+ target_item = torch.cat([tgt_item, self.eos_item])
135
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
136
+
137
+ example = {
138
+ "id": uniq_id,
139
+ "source": src_item,
140
+ "patch_image": patch_image,
141
+ "patch_mask": patch_mask,
142
+ "target": target_item,
143
+ "prev_output_tokens": prev_output_item
144
+ }
145
+ return example
146
+
147
+ def collater(self, samples, pad_to_length=None):
148
+ """Merge a list of samples to form a mini-batch.
149
+ Args:
150
+ samples (List[dict]): samples to collate
151
+ Returns:
152
+ dict: a mini-batch with the following keys:
153
+ """
154
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/mm_data/refcoco_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ import base64
13
+ import utils.transforms as T
14
+
15
+ from PIL import Image, ImageFile
16
+
17
+ from data import data_utils
18
+ from data.ofa_dataset import OFADataset
19
+
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ ImageFile.MAX_IMAGE_PIXELS = None
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ logger = logging.getLogger(__name__)
25
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
26
+
27
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
28
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
29
+
30
+
31
+ def collate(samples, pad_idx, eos_idx):
32
+ if len(samples) == 0:
33
+ return {}
34
+
35
+ def merge(key):
36
+ return data_utils.collate_tokens(
37
+ [s[key] for s in samples],
38
+ pad_idx,
39
+ eos_idx=eos_idx,
40
+ )
41
+
42
+ id = np.array([s["id"] for s in samples])
43
+ src_tokens = merge("source")
44
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
45
+
46
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
47
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
48
+
49
+ w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
50
+ h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
51
+ region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
52
+
53
+ prev_output_tokens = None
54
+ target = None
55
+ if samples[0].get("target", None) is not None:
56
+ target = merge("target")
57
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
58
+ ntokens = tgt_lengths.sum().item()
59
+
60
+ if samples[0].get("prev_output_tokens", None) is not None:
61
+ prev_output_tokens = merge("prev_output_tokens")
62
+ else:
63
+ ntokens = src_lengths.sum().item()
64
+
65
+ batch = {
66
+ "id": id,
67
+ "nsentences": len(samples),
68
+ "ntokens": ntokens,
69
+ "net_input": {
70
+ "src_tokens": src_tokens,
71
+ "src_lengths": src_lengths,
72
+ "patch_images": patch_images,
73
+ "patch_masks": patch_masks,
74
+ "prev_output_tokens": prev_output_tokens
75
+ },
76
+ "target": target,
77
+ "w_resize_ratios": w_resize_ratios,
78
+ "h_resize_ratios": h_resize_ratios,
79
+ "region_coords": region_coords
80
+ }
81
+
82
+ return batch
83
+
84
+
85
+ class RefcocoDataset(OFADataset):
86
+ def __init__(
87
+ self,
88
+ split,
89
+ dataset,
90
+ bpe,
91
+ src_dict,
92
+ tgt_dict=None,
93
+ max_src_length=80,
94
+ max_tgt_length=30,
95
+ patch_image_size=512,
96
+ imagenet_default_mean_and_std=False,
97
+ num_bins=1000,
98
+ max_image_size=512
99
+ ):
100
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
101
+ self.max_src_length = max_src_length
102
+ self.max_tgt_length = max_tgt_length
103
+ self.patch_image_size = patch_image_size
104
+ self.num_bins = num_bins
105
+
106
+ if imagenet_default_mean_and_std:
107
+ mean = IMAGENET_DEFAULT_MEAN
108
+ std = IMAGENET_DEFAULT_STD
109
+ else:
110
+ mean = [0.5, 0.5, 0.5]
111
+ std = [0.5, 0.5, 0.5]
112
+
113
+ # for positioning
114
+ self.positioning_transform = T.Compose([
115
+ T.RandomResize([patch_image_size], max_size=patch_image_size),
116
+ T.ToTensor(),
117
+ T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
118
+ ])
119
+
120
+ def __getitem__(self, index):
121
+ uniq_id, base64_str, text, region_coord = self.dataset[index]
122
+
123
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
124
+ w, h = image.size
125
+ boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
126
+ x0, y0, x1, y1 = region_coord.strip().split(',')
127
+ region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
128
+ boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
129
+ boxes_target["labels"] = np.array([0])
130
+ boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
131
+
132
+ patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
133
+ resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
134
+ patch_mask = torch.tensor([True])
135
+ quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
136
+ quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
137
+ quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
138
+ quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
139
+ region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
140
+ src_caption = self.pre_caption(text, self.max_src_length)
141
+ src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
142
+ tgt_item = self.encode_text(region_coord, use_bpe=False)
143
+
144
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
145
+ target_item = torch.cat([tgt_item, self.eos_item])
146
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
147
+
148
+ example = {
149
+ "id": uniq_id,
150
+ "source": src_item,
151
+ "patch_image": patch_image,
152
+ "patch_mask": patch_mask,
153
+ "target": target_item,
154
+ "prev_output_tokens": prev_output_item,
155
+ "w_resize_ratio": resize_w / w,
156
+ "h_resize_ratio": resize_h / h,
157
+ "region_coord": region
158
+ }
159
+ return example
160
+
161
+ def collater(self, samples, pad_to_length=None):
162
+ """Merge a list of samples to form a mini-batch.
163
+ Args:
164
+ samples (List[dict]): samples to collate
165
+ Returns:
166
+ dict: a mini-batch with the following keys:
167
+ """
168
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/ofa_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import torch.utils.data
4
+ from fairseq.data import FairseqDataset
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class OFADataset(FairseqDataset):
10
+ def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
11
+ self.split = split
12
+ self.dataset = dataset
13
+ self.bpe = bpe
14
+ self.src_dict = src_dict
15
+ self.tgt_dict = tgt_dict
16
+
17
+ self.bos = src_dict.bos()
18
+ self.eos = src_dict.eos()
19
+ self.pad = src_dict.pad()
20
+ self.bos_item = torch.LongTensor([self.bos])
21
+ self.eos_item = torch.LongTensor([self.eos])
22
+
23
+ def __len__(self):
24
+ return len(self.dataset)
25
+
26
+ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
27
+ s = self.tgt_dict.encode_line(
28
+ line=self.bpe.encode(text) if use_bpe else text,
29
+ add_if_not_exist=False,
30
+ append_eos=False
31
+ ).long()
32
+ if length is not None:
33
+ s = s[:length]
34
+ if append_bos:
35
+ s = torch.cat([self.bos_item, s])
36
+ if append_eos:
37
+ s = torch.cat([s, self.eos_item])
38
+ return s
39
+
40
+ def pre_question(self, question, max_ques_words):
41
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
42
+
43
+ question = re.sub(
44
+ r"\s{2,}",
45
+ ' ',
46
+ question,
47
+ )
48
+ question = question.rstrip('\n')
49
+ question = question.strip(' ')
50
+
51
+ # truncate question
52
+ question_words = question.split(' ')
53
+ if len(question_words) > max_ques_words:
54
+ question = ' '.join(question_words[:max_ques_words])
55
+
56
+ return question
57
+
58
+ def pre_caption(self, caption, max_words):
59
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
60
+
61
+ caption = re.sub(
62
+ r"\s{2,}",
63
+ ' ',
64
+ caption,
65
+ )
66
+ caption = caption.rstrip('\n')
67
+ caption = caption.strip(' ')
68
+
69
+ # truncate caption
70
+ caption_words = caption.split(' ')
71
+ if len(caption_words) > max_words:
72
+ caption = ' '.join(caption_words[:max_words])
73
+
74
+ return caption
datasets.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ # Datasets
2
+
3
+ We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
4
+
5
+ ## Finetuning
6
+
7
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
evaluate.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import json
11
+ from itertools import chain
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ from fairseq import distributed_utils, options, tasks, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.logging import progress_bar
19
+ from fairseq.utils import reset_logging
20
+ from omegaconf import DictConfig
21
+
22
+ from utils import checkpoint_utils
23
+ from utils.eval_utils import eval_step
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
29
+ stream=sys.stdout,
30
+ )
31
+ logger = logging.getLogger("ofa.evaluate")
32
+
33
+
34
+ def apply_half(t):
35
+ if t.dtype is torch.float32:
36
+ return t.to(dtype=torch.half)
37
+ return t
38
+
39
+
40
+ def main(cfg: DictConfig):
41
+ utils.import_user_module(cfg.common)
42
+
43
+ reset_logging()
44
+ logger.info(cfg)
45
+
46
+ assert (
47
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
48
+ ), "Must specify batch size either with --max-tokens or --batch-size"
49
+
50
+ # Fix seed for stochastic decoding
51
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
52
+ np.random.seed(cfg.common.seed)
53
+ utils.set_torch_seed(cfg.common.seed)
54
+
55
+ use_fp16 = cfg.common.fp16
56
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
57
+
58
+ if use_cuda:
59
+ torch.cuda.set_device(cfg.distributed_training.device_id)
60
+
61
+ # Load ensemble
62
+ overrides = eval(cfg.common_eval.model_overrides)
63
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
64
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
65
+ utils.split_paths(cfg.common_eval.path),
66
+ arg_overrides=overrides,
67
+ suffix=cfg.checkpoint.checkpoint_suffix,
68
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
69
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
70
+ )
71
+
72
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
73
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
74
+
75
+ # Move models to GPU
76
+ for model in models:
77
+ model.eval()
78
+ if use_fp16:
79
+ model.half()
80
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
81
+ model.cuda()
82
+ model.prepare_for_inference_(cfg)
83
+
84
+ # Load dataset (possibly sharded)
85
+ itr = task.get_batch_iterator(
86
+ dataset=task.dataset(cfg.dataset.gen_subset),
87
+ max_tokens=cfg.dataset.max_tokens,
88
+ max_sentences=cfg.dataset.batch_size,
89
+ max_positions=utils.resolve_max_positions(
90
+ task.max_positions(), *[m.max_positions() for m in models]
91
+ ),
92
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
93
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
94
+ seed=cfg.common.seed,
95
+ num_shards=cfg.distributed_training.distributed_world_size,
96
+ shard_id=cfg.distributed_training.distributed_rank,
97
+ num_workers=cfg.dataset.num_workers,
98
+ data_buffer_size=cfg.dataset.data_buffer_size,
99
+ ).next_epoch_itr(shuffle=False)
100
+ progress = progress_bar.progress_bar(
101
+ itr,
102
+ log_format=cfg.common.log_format,
103
+ log_interval=cfg.common.log_interval,
104
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
105
+ )
106
+
107
+ # Initialize generator
108
+ generator = task.build_generator(models, cfg.generation)
109
+
110
+ results = []
111
+ score_sum = torch.FloatTensor([0]).cuda()
112
+ score_cnt = torch.FloatTensor([0]).cuda()
113
+ for sample in progress:
114
+ if "net_input" not in sample:
115
+ continue
116
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
117
+ sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
118
+ with torch.no_grad():
119
+ result, scores = eval_step(task, generator, models, sample)
120
+ results += result
121
+ score_sum += sum(scores) if scores is not None else 0
122
+ score_cnt += len(scores) if scores is not None else 0
123
+ progress.log({"sentences": sample["nsentences"]})
124
+
125
+ gather_results = None
126
+ if cfg.distributed_training.distributed_world_size > 1:
127
+ gather_results = [None for _ in range(dist.get_world_size())]
128
+ dist.all_gather_object(gather_results, results)
129
+ dist.all_reduce(score_sum.data)
130
+ dist.all_reduce(score_cnt.data)
131
+ if score_cnt.item() > 0:
132
+ logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
133
+ score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
134
+ ))
135
+
136
+ if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
137
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
138
+ output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
139
+ gather_results = list(chain(*gather_results)) if gather_results is not None else results
140
+ with open(output_path, 'w') as fw:
141
+ json.dump(gather_results, fw)
142
+
143
+
144
+ def cli_main():
145
+ parser = options.get_generation_parser()
146
+ args = options.parse_args_and_arch(parser)
147
+ cfg = convert_namespace_to_omegaconf(args)
148
+ distributed_utils.call_main(cfg, main)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ cli_main()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .ofa import OFAModel, ofa_base_architecture, ofa_large_architecture, ofa_huge_architecture
models/ofa/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .ofa import OFAModel, ofa_base_architecture, ofa_large_architecture, ofa_huge_architecture
models/ofa/ofa.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ OFA
7
+ """
8
+ from typing import Optional
9
+
10
+ import logging
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from fairseq import utils
16
+ from fairseq.models import register_model, register_model_architecture
17
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
18
+
19
+ from .unify_transformer import TransformerModel
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @register_model("ofa")
25
+ class OFAModel(TransformerModel):
26
+ __jit_unused_properties__ = ["supported_targets"]
27
+
28
+ def __init__(self, args, encoder, decoder):
29
+ super().__init__(args, encoder, decoder)
30
+
31
+ # We follow BERT's random weight initialization
32
+ self.apply(init_bert_params)
33
+
34
+ self.classification_heads = nn.ModuleDict()
35
+ if hasattr(self.encoder, "dictionary"):
36
+ self.eos: int = self.encoder.dictionary.eos()
37
+
38
+ @staticmethod
39
+ def add_args(parser):
40
+ super(OFAModel, OFAModel).add_args(parser)
41
+ parser.add_argument(
42
+ "--pooler-dropout",
43
+ type=float,
44
+ metavar="D",
45
+ help="dropout probability in the masked_lm pooler layers",
46
+ )
47
+ parser.add_argument(
48
+ "--pooler-classifier",
49
+ type=str,
50
+ choices=['mlp', 'linear'],
51
+ help="type of pooler classifier",
52
+ )
53
+ parser.add_argument(
54
+ "--pooler-activation-fn",
55
+ choices=utils.get_available_activation_fns(),
56
+ help="activation function to use for pooler layer",
57
+ )
58
+ parser.add_argument(
59
+ "--spectral-norm-classification-head",
60
+ action="store_true",
61
+ help="Apply spectral normalization on the classification head",
62
+ )
63
+
64
+ @property
65
+ def supported_targets(self):
66
+ return {"self"}
67
+
68
+ def forward(
69
+ self,
70
+ src_tokens,
71
+ src_lengths,
72
+ prev_output_tokens,
73
+ patch_images: Optional[torch.Tensor] = None,
74
+ patch_images_2: Optional[torch.Tensor] = None,
75
+ patch_masks: Optional[torch.Tensor] = None,
76
+ code_masks: Optional[torch.Tensor] = None,
77
+ sample_patch_num: Optional[int] = None,
78
+ features_only: bool = False,
79
+ classification_head_name: Optional[str] = None,
80
+ token_embeddings: Optional[torch.Tensor] = None,
81
+ return_all_hiddens: bool = False,
82
+ alignment_layer: Optional[int] = None,
83
+ alignment_heads: Optional[int] = None,
84
+ ):
85
+ if classification_head_name is not None:
86
+ features_only = True
87
+
88
+ encoder_out = self.encoder(
89
+ src_tokens,
90
+ src_lengths=src_lengths,
91
+ patch_images=patch_images,
92
+ patch_masks=patch_masks,
93
+ patch_images_2=patch_images_2,
94
+ token_embeddings=token_embeddings,
95
+ return_all_hiddens=return_all_hiddens,
96
+ sample_patch_num=sample_patch_num
97
+ )
98
+ x, extra = self.decoder(
99
+ prev_output_tokens,
100
+ code_masks=code_masks,
101
+ encoder_out=encoder_out,
102
+ features_only=features_only,
103
+ alignment_layer=alignment_layer,
104
+ alignment_heads=alignment_heads,
105
+ src_lengths=src_lengths,
106
+ return_all_hiddens=return_all_hiddens,
107
+ )
108
+
109
+ pad = self.encoder.padding_idx
110
+ if classification_head_name is not None:
111
+ prev_lengths = prev_output_tokens.ne(pad).sum(1)
112
+ gather_index = prev_lengths[:, None, None].expand(x.size(0), 1, x.size(2)) - 1
113
+ sentence_representation = x.gather(1, gather_index).squeeze()
114
+ if self.classification_heads[classification_head_name].use_two_images:
115
+ hidden_size = sentence_representation.size(1)
116
+ sentence_representation = sentence_representation.view(-1, hidden_size * 2)
117
+ for k, head in self.classification_heads.items():
118
+ # for torch script only supports iteration
119
+ if k == classification_head_name:
120
+ x = head(sentence_representation)
121
+ break
122
+
123
+ return x, extra
124
+
125
+ def register_embedding_tokens(self, ans2label_dict, src_dict, bpe):
126
+ """Register embedding tokens"""
127
+ logger.info("Registering embedding tokens")
128
+ self.ans_tensor_list = []
129
+ for i in range(len(ans2label_dict)):
130
+ ans = src_dict[-len(ans2label_dict)+i]
131
+ ans = ans[5:-1].replace('_', ' ')
132
+ ans_tensor = src_dict.encode_line(
133
+ line=bpe.encode(' {}'.format(ans.lower())),
134
+ add_if_not_exist=False,
135
+ append_eos=False
136
+ ).long()
137
+ self.ans_tensor_list.append(ans_tensor)
138
+
139
+ def register_classification_head(
140
+ self, name, num_classes=None, inner_dim=None, use_two_images=False, **kwargs
141
+ ):
142
+ """Register a classification head."""
143
+ logger.info("Registering classification head: {0}".format(name))
144
+ if name in self.classification_heads:
145
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
146
+ prev_inner_dim = self.classification_heads[name].dense.out_features
147
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
148
+ logger.warning(
149
+ 're-registering head "{}" with num_classes {} (prev: {}) '
150
+ "and inner_dim {} (prev: {})".format(
151
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
152
+ )
153
+ )
154
+ self.classification_heads[name] = OFAClassificationHead(
155
+ input_dim=self.args.encoder_embed_dim,
156
+ inner_dim=inner_dim or self.args.encoder_embed_dim,
157
+ num_classes=num_classes,
158
+ activation_fn=self.args.pooler_activation_fn,
159
+ pooler_dropout=self.args.pooler_dropout,
160
+ pooler_classifier=self.args.pooler_classifier,
161
+ use_two_images=use_two_images,
162
+ do_spectral_norm=getattr(
163
+ self.args, "spectral_norm_classification_head", False
164
+ ),
165
+ )
166
+
167
+ def upgrade_state_dict_named(self, state_dict, name):
168
+ super().upgrade_state_dict_named(state_dict, name)
169
+
170
+ prefix = name + "." if name != "" else ""
171
+ current_head_names = (
172
+ []
173
+ if not hasattr(self, "classification_heads")
174
+ else self.classification_heads.keys()
175
+ )
176
+
177
+ # Handle new classification heads present in the state dict.
178
+ keys_to_delete = []
179
+ for k in state_dict.keys():
180
+ if not k.startswith(prefix + "classification_heads."):
181
+ continue
182
+
183
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
184
+ num_classes = state_dict[
185
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
186
+ ].size(0)
187
+ inner_dim = state_dict[
188
+ prefix + "classification_heads." + head_name + ".dense.weight"
189
+ ].size(0)
190
+
191
+ if getattr(self.args, "load_checkpoint_heads", False):
192
+ if head_name not in current_head_names:
193
+ self.register_classification_head(head_name, num_classes, inner_dim)
194
+ else:
195
+ if head_name not in current_head_names:
196
+ logger.warning(
197
+ "deleting classification head ({}) from checkpoint "
198
+ "not present in current model: {}".format(head_name, k)
199
+ )
200
+ keys_to_delete.append(k)
201
+ elif (
202
+ num_classes
203
+ != self.classification_heads[head_name].out_proj.out_features
204
+ or inner_dim
205
+ != self.classification_heads[head_name].dense.out_features
206
+ ):
207
+ logger.warning(
208
+ "deleting classification head ({}) from checkpoint "
209
+ "with different dimensions than current model: {}".format(
210
+ head_name, k
211
+ )
212
+ )
213
+ keys_to_delete.append(k)
214
+ for k in keys_to_delete:
215
+ del state_dict[k]
216
+
217
+ def truncate_emb(key):
218
+ if key in state_dict:
219
+ state_dict[key] = state_dict[key][:-1, :]
220
+
221
+ # When finetuning on translation task, remove last row of
222
+ # embedding matrix that corresponds to mask_idx token.
223
+ loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0)
224
+ if (
225
+ loaded_dict_size == len(self.encoder.dictionary) + 1
226
+ and "<mask>" not in self.encoder.dictionary
227
+ ):
228
+ truncate_emb("encoder.embed_tokens.weight")
229
+ truncate_emb("decoder.embed_tokens.weight")
230
+ truncate_emb("encoder.output_projection.weight")
231
+ truncate_emb("decoder.output_projection.weight")
232
+
233
+ if loaded_dict_size < len(self.encoder.dictionary):
234
+ num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
235
+ embed_dim = state_dict["encoder.embed_tokens.weight"].size(1)
236
+
237
+ new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
238
+ if getattr(self, "ans_tensor_list", None):
239
+ assert len(new_lang_embed_to_add) == len(self.ans_tensor_list)
240
+ for i, ans_tensor in enumerate(self.ans_tensor_list):
241
+ ans_embed = F.embedding(ans_tensor, state_dict["encoder.embed_tokens.weight"])
242
+ ans_embed = ans_embed.sum(0) / ans_embed.size(0)
243
+ new_lang_embed_to_add[i] = ans_embed
244
+ else:
245
+ nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5)
246
+ new_lang_embed_to_add = new_lang_embed_to_add.to(
247
+ dtype=state_dict["encoder.embed_tokens.weight"].dtype,
248
+ )
249
+
250
+ state_dict["encoder.embed_tokens.weight"] = torch.cat(
251
+ [state_dict["encoder.embed_tokens.weight"], new_lang_embed_to_add]
252
+ )
253
+ state_dict["decoder.embed_tokens.weight"] = torch.cat(
254
+ [state_dict["decoder.embed_tokens.weight"], new_lang_embed_to_add]
255
+ )
256
+ state_dict["decoder.output_projection.weight"] = torch.cat(
257
+ [state_dict["decoder.output_projection.weight"], new_lang_embed_to_add]
258
+ )
259
+
260
+ # Copy any newly-added classification heads into the state dict
261
+ # with their current weights.
262
+ if hasattr(self, "classification_heads"):
263
+ cur_state = self.classification_heads.state_dict()
264
+ for k, v in cur_state.items():
265
+ if prefix + "classification_heads." + k not in state_dict:
266
+ logger.info("Overwriting " + prefix + "classification_heads." + k)
267
+ state_dict[prefix + "classification_heads." + k] = v
268
+
269
+
270
+ class OFAClassificationHead(nn.Module):
271
+ """Head for sentence-level classification tasks."""
272
+
273
+ def __init__(
274
+ self,
275
+ input_dim,
276
+ inner_dim,
277
+ num_classes,
278
+ activation_fn,
279
+ pooler_dropout,
280
+ pooler_classifier,
281
+ use_two_images=False,
282
+ do_spectral_norm=False,
283
+ ):
284
+ super().__init__()
285
+ self.pooler_classifier = pooler_classifier
286
+ self.use_two_images = use_two_images
287
+ input_dim = input_dim * 2 if use_two_images else input_dim
288
+ if pooler_classifier == "mlp":
289
+ self.dense = nn.Linear(input_dim, inner_dim)
290
+ self.activation_fn = utils.get_activation_fn(activation_fn)
291
+ self.dropout = nn.Dropout(p=pooler_dropout)
292
+ self.out_proj = nn.Linear(inner_dim, num_classes)
293
+ elif pooler_classifier == "linear":
294
+ self.dropout = nn.Dropout(p=pooler_dropout)
295
+ self.out_proj = nn.Linear(input_dim, num_classes)
296
+ else:
297
+ raise NotImplementedError
298
+
299
+ if do_spectral_norm:
300
+ self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
301
+
302
+ def forward(self, features, **kwargs):
303
+ if self.pooler_classifier == 'mlp':
304
+ x = features
305
+ x = self.dropout(x)
306
+ x = self.dense(x)
307
+ x = self.activation_fn(x)
308
+ x = self.dropout(x)
309
+ x = self.out_proj(x)
310
+ elif self.pooler_classifier == 'linear':
311
+ x = features
312
+ x = self.dropout(x)
313
+ x = self.out_proj(x)
314
+ else:
315
+ raise NotImplementedError
316
+ return x
317
+
318
+
319
+ @register_model_architecture("ofa", "ofa_large")
320
+ def ofa_large_architecture(args):
321
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
322
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
323
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024)
324
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
325
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
326
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
327
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
328
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
329
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
330
+ args.decoder_ffn_embed_dim = getattr(
331
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
332
+ )
333
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
334
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
335
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
336
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
337
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
338
+ args.relu_dropout = getattr(args, "relu_dropout", 0.0)
339
+ args.dropout = getattr(args, "dropout", 0.0)
340
+ args.max_target_positions = getattr(args, "max_target_positions", 1024)
341
+ args.max_source_positions = getattr(args, "max_source_positions", 1024)
342
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
343
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
344
+ args.share_decoder_input_output_embed = getattr(
345
+ args, "share_decoder_input_output_embed", True
346
+ )
347
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", True)
348
+
349
+ args.decoder_output_dim = getattr(
350
+ args, "decoder_output_dim", args.decoder_embed_dim
351
+ )
352
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
353
+
354
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
355
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
356
+
357
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
358
+ args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
359
+ args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
360
+ args.pooler_classifier = getattr(args, "pooler_classifier", "mlp")
361
+
362
+ args.resnet_drop_path_rate = getattr(args, "resnet_drop_path_rate", 0.0)
363
+ args.encoder_drop_path_rate = getattr(args, "encoder_drop_path_rate", 0.0)
364
+ args.decoder_drop_path_rate = getattr(args, "decoder_drop_path_rate", 0.0)
365
+
366
+ args.resnet_type = getattr(args, "resnet_type", "resnet152")
367
+ args.token_bucket_size = getattr(args, "token_bucket_size", 256)
368
+ args.image_bucket_size = getattr(args, "image_bucket_size", 42)
369
+
370
+ args.freeze_encoder_embedding = getattr(args, "freeze_encoder_embedding", False)
371
+ args.freeze_decoder_embedding = getattr(args, "freeze_decoder_embedding", False)
372
+ args.add_type_embedding = getattr(args, "add_type_embedding", True)
373
+ args.attn_scale_factor = getattr(args, "attn_scale_factor", 2)
374
+
375
+ args.code_image_size = getattr(args, "code_image_size", 128)
376
+ args.patch_layernorm_embedding = getattr(args, "patch_layernorm_embedding", True)
377
+ args.code_layernorm_embedding = getattr(args, "code_layernorm_embedding", True)
378
+ args.entangle_position_embedding = getattr(args, "entangle_position_embedding", False)
379
+ args.disable_entangle = getattr(args, "disable_entangle", False)
380
+ args.sync_bn = getattr(args, "sync_bn", False)
381
+
382
+ args.scale_attn = getattr(args, "scale_attn", False)
383
+ args.scale_fc = getattr(args, "scale_fc", False)
384
+ args.scale_heads = getattr(args, "scale_heads", False)
385
+ args.scale_resids = getattr(args, "scale_resids", False)
386
+
387
+
388
+ @register_model_architecture("ofa", "ofa_base")
389
+ def ofa_base_architecture(args):
390
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
391
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
392
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
393
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
394
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
395
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
396
+ args.resnet_type = getattr(args, "resnet_type", "resnet101")
397
+ ofa_large_architecture(args)
398
+
399
+
400
+ @register_model_architecture("ofa", "ofa_huge")
401
+ def ofa_huge_architecture(args):
402
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280)
403
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1280)
404
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
405
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
406
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
407
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
408
+ args.resnet_type = getattr(args, "resnet_type", "resnet152")
409
+ ofa_large_architecture(args)
410
+
models/ofa/resnet.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
6
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
7
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
8
+ the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper...
9
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
10
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use
11
+ 'survival rate' as the argument.
12
+ """
13
+ if drop_prob == 0. or not training:
14
+ return x
15
+ keep_prob = 1 - drop_prob
16
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
17
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
18
+ random_tensor.floor_() # binarize
19
+ output = x.div(keep_prob) * random_tensor
20
+ return output
21
+
22
+
23
+ class DropPath(nn.Module):
24
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
25
+ """
26
+ def __init__(self, drop_prob=None):
27
+ super(DropPath, self).__init__()
28
+ self.drop_prob = drop_prob
29
+
30
+ def forward(self, x):
31
+ return drop_path(x, self.drop_prob, self.training)
32
+
33
+
34
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
35
+ """3x3 convolution with padding"""
36
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
37
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
38
+
39
+
40
+ def conv1x1(in_planes, out_planes, stride=1):
41
+ """1x1 convolution"""
42
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
43
+
44
+
45
+ class BasicBlock(nn.Module):
46
+ expansion = 1
47
+
48
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
49
+ base_width=64, dilation=1, norm_layer=None):
50
+ super(BasicBlock, self).__init__()
51
+ if norm_layer is None:
52
+ norm_layer = nn.BatchNorm2d
53
+ if groups != 1 or base_width != 64:
54
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
55
+ if dilation > 1:
56
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
57
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
58
+ self.conv1 = conv3x3(inplanes, planes, stride)
59
+ self.bn1 = norm_layer(planes)
60
+ self.relu = nn.ReLU(inplace=True)
61
+ self.conv2 = conv3x3(planes, planes)
62
+ self.bn2 = norm_layer(planes)
63
+ self.downsample = downsample
64
+ self.stride = stride
65
+
66
+ def forward(self, x):
67
+ assert False
68
+ identity = x
69
+
70
+ out = self.conv1(x)
71
+ out = self.bn1(out)
72
+ out = self.relu(out)
73
+
74
+ out = self.conv2(out)
75
+ out = self.bn2(out)
76
+
77
+ if self.downsample is not None:
78
+ identity = self.downsample(x)
79
+
80
+ out += identity
81
+ out = self.relu(out)
82
+
83
+ return out
84
+
85
+
86
+ class Bottleneck(nn.Module):
87
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
88
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
89
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
90
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
91
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
92
+
93
+ expansion = 4
94
+
95
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
96
+ base_width=64, dilation=1, norm_layer=None, drop_path_rate=0.0):
97
+ super(Bottleneck, self).__init__()
98
+ if norm_layer is None:
99
+ norm_layer = nn.BatchNorm2d
100
+ width = int(planes * (base_width / 64.)) * groups
101
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
102
+ self.conv1 = conv1x1(inplanes, width)
103
+ self.bn1 = norm_layer(width)
104
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
105
+ self.bn2 = norm_layer(width)
106
+ self.conv3 = conv1x1(width, planes * self.expansion)
107
+ self.bn3 = norm_layer(planes * self.expansion)
108
+ self.relu = nn.ReLU(inplace=True)
109
+ self.downsample = downsample
110
+ self.stride = stride
111
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
112
+
113
+ def forward(self, x):
114
+ identity = x
115
+
116
+ out = self.conv1(x)
117
+ out = self.bn1(out)
118
+ out = self.relu(out)
119
+
120
+ out = self.conv2(out)
121
+ out = self.bn2(out)
122
+ out = self.relu(out)
123
+
124
+ out = self.conv3(out)
125
+ out = self.bn3(out)
126
+
127
+ if self.downsample is not None:
128
+ identity = self.downsample(x)
129
+
130
+ out = identity + self.drop_path(out)
131
+ out = self.relu(out)
132
+
133
+ return out
134
+
135
+
136
+ class ResNet(nn.Module):
137
+
138
+ def __init__(self, layers, zero_init_residual=False,
139
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
140
+ norm_layer=None, drop_path_rate=0.0):
141
+ super(ResNet, self).__init__()
142
+ if norm_layer is None:
143
+ norm_layer = nn.BatchNorm2d
144
+ self._norm_layer = norm_layer
145
+
146
+ self.inplanes = 64
147
+ self.dilation = 1
148
+ if replace_stride_with_dilation is None:
149
+ # each element in the tuple indicates if we should replace
150
+ # the 2x2 stride with a dilated convolution instead
151
+ replace_stride_with_dilation = [False, False, False]
152
+ if len(replace_stride_with_dilation) != 3:
153
+ raise ValueError("replace_stride_with_dilation should be None "
154
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
155
+ self.groups = groups
156
+ self.base_width = width_per_group
157
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
158
+ bias=False)
159
+ self.bn1 = norm_layer(self.inplanes)
160
+ self.relu = nn.ReLU(inplace=True)
161
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
162
+ self.layer1 = self._make_layer(Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate)
163
+ self.layer2 = self._make_layer(Bottleneck, 128, layers[1], stride=2,
164
+ dilate=replace_stride_with_dilation[0], drop_path_rate=drop_path_rate)
165
+ self.layer3 = self._make_layer(Bottleneck, 256, layers[2], stride=2,
166
+ dilate=replace_stride_with_dilation[1], drop_path_rate=drop_path_rate)
167
+
168
+ for m in self.modules():
169
+ if isinstance(m, nn.Conv2d):
170
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
171
+ elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)):
172
+ nn.init.constant_(m.weight, 1)
173
+ nn.init.constant_(m.bias, 0)
174
+
175
+ # Zero-initialize the last BN in each residual branch,
176
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
177
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
178
+ if zero_init_residual:
179
+ for m in self.modules():
180
+ if isinstance(m, Bottleneck):
181
+ nn.init.constant_(m.bn3.weight, 0)
182
+ elif isinstance(m, BasicBlock):
183
+ nn.init.constant_(m.bn2.weight, 0)
184
+
185
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, drop_path_rate=0.0):
186
+ norm_layer = self._norm_layer
187
+ downsample = None
188
+ previous_dilation = self.dilation
189
+ if dilate:
190
+ self.dilation *= stride
191
+ stride = 1
192
+ if stride != 1 or self.inplanes != planes * block.expansion:
193
+ downsample = nn.Sequential(
194
+ conv1x1(self.inplanes, planes * block.expansion, stride),
195
+ norm_layer(planes * block.expansion),
196
+ )
197
+
198
+ layers = []
199
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
200
+ self.base_width, previous_dilation, norm_layer))
201
+ self.inplanes = planes * block.expansion
202
+
203
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)]
204
+ for i in range(1, blocks):
205
+ layers.append(block(self.inplanes, planes, groups=self.groups,
206
+ base_width=self.base_width, dilation=self.dilation,
207
+ norm_layer=norm_layer, drop_path_rate=dpr[i]))
208
+
209
+ return nn.Sequential(*layers)
210
+
211
+ def _forward_impl(self, x):
212
+ # See note [TorchScript super()]
213
+ x = self.conv1(x)
214
+ x = self.bn1(x)
215
+ x = self.relu(x)
216
+ x = self.maxpool(x)
217
+
218
+ x = self.layer1(x)
219
+ x = self.layer2(x)
220
+ x = self.layer3(x)
221
+
222
+ return x
223
+
224
+ def forward(self, x):
225
+ return self._forward_impl(x)
models/ofa/unify_multihead_attention.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq import utils
12
+ from fairseq.incremental_decoding_utils import with_incremental_state
13
+ from fairseq.modules.fairseq_dropout import FairseqDropout
14
+ from fairseq.modules.quant_noise import quant_noise
15
+ from torch import Tensor, nn
16
+ from torch.nn import Parameter
17
+
18
+
19
+ @with_incremental_state
20
+ class MultiheadAttention(nn.Module):
21
+ """Multi-headed attention.
22
+
23
+ See "Attention Is All You Need" for more details.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ embed_dim,
29
+ num_heads,
30
+ kdim=None,
31
+ vdim=None,
32
+ dropout=0.0,
33
+ bias=True,
34
+ add_bias_kv=False,
35
+ add_zero_attn=False,
36
+ self_attention=False,
37
+ encoder_decoder_attention=False,
38
+ q_noise=0.0,
39
+ qn_block_size=8,
40
+ scale_factor=2,
41
+ scale_heads=False
42
+ ):
43
+ super().__init__()
44
+ self.embed_dim = embed_dim
45
+ self.kdim = kdim if kdim is not None else embed_dim
46
+ self.vdim = vdim if vdim is not None else embed_dim
47
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
48
+
49
+ self.num_heads = num_heads
50
+ self.dropout_module = FairseqDropout(
51
+ dropout, module_name=self.__class__.__name__
52
+ )
53
+
54
+ self.head_dim = embed_dim // num_heads
55
+ assert (
56
+ self.head_dim * num_heads == self.embed_dim
57
+ ), "embed_dim must be divisible by num_heads"
58
+ self.scaling = float(self.head_dim * scale_factor) ** -0.5
59
+
60
+ self.self_attention = self_attention
61
+ self.encoder_decoder_attention = encoder_decoder_attention
62
+ self.c_attn = nn.Parameter(torch.ones((self.num_heads,)), requires_grad=True) if scale_heads else None
63
+
64
+ assert not self.self_attention or self.qkv_same_dim, (
65
+ "Self-attention requires query, key and " "value to be of the same size"
66
+ )
67
+
68
+ self.k_proj = quant_noise(
69
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
70
+ )
71
+ self.v_proj = quant_noise(
72
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
73
+ )
74
+ self.q_proj = quant_noise(
75
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
76
+ )
77
+
78
+ self.out_proj = quant_noise(
79
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
80
+ )
81
+
82
+ if add_bias_kv:
83
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
84
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
85
+ else:
86
+ self.bias_k = self.bias_v = None
87
+
88
+ self.add_zero_attn = add_zero_attn
89
+
90
+ self.reset_parameters()
91
+
92
+ self.onnx_trace = False
93
+
94
+ def prepare_for_onnx_export_(self):
95
+ self.onnx_trace = True
96
+
97
+ def reset_parameters(self):
98
+ if self.qkv_same_dim:
99
+ # Empirically observed the convergence to be much better with
100
+ # the scaled initialization
101
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
102
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
103
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
104
+ else:
105
+ nn.init.xavier_uniform_(self.k_proj.weight)
106
+ nn.init.xavier_uniform_(self.v_proj.weight)
107
+ nn.init.xavier_uniform_(self.q_proj.weight)
108
+
109
+ nn.init.xavier_uniform_(self.out_proj.weight)
110
+ if self.out_proj.bias is not None:
111
+ nn.init.constant_(self.out_proj.bias, 0.0)
112
+ if self.bias_k is not None:
113
+ nn.init.xavier_normal_(self.bias_k)
114
+ if self.bias_v is not None:
115
+ nn.init.xavier_normal_(self.bias_v)
116
+
117
+ def forward(
118
+ self,
119
+ query,
120
+ key: Optional[Tensor],
121
+ value: Optional[Tensor],
122
+ key_padding_mask: Optional[Tensor] = None,
123
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
124
+ need_weights: bool = True,
125
+ static_kv: bool = False,
126
+ attn_mask: Optional[Tensor] = None,
127
+ self_attn_mask: Optional[Tensor] = None,
128
+ before_softmax: bool = False,
129
+ need_head_weights: bool = False,
130
+ attn_bias: Optional[Tensor] = None
131
+ ) -> Tuple[Tensor, Optional[Tensor]]:
132
+ """Input shape: Time x Batch x Channel
133
+
134
+ Args:
135
+ key_padding_mask (ByteTensor, optional): mask to exclude
136
+ keys that are pads, of shape `(batch, src_len)`, where
137
+ padding elements are indicated by 1s.
138
+ need_weights (bool, optional): return the attention weights,
139
+ averaged over heads (default: False).
140
+ attn_mask (ByteTensor, optional): typically used to
141
+ implement causal attention, where the mask prevents the
142
+ attention from looking forward in time (default: None).
143
+ before_softmax (bool, optional): return the raw attention
144
+ weights and values before the attention softmax.
145
+ need_head_weights (bool, optional): return the attention
146
+ weights for each head. Implies *need_weights*. Default:
147
+ return the average attention weights over all heads.
148
+ """
149
+ if need_head_weights:
150
+ need_weights = True
151
+
152
+ is_tpu = query.device.type == "xla"
153
+
154
+ tgt_len, bsz, embed_dim = query.size()
155
+ src_len = tgt_len
156
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
157
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
158
+ if key is not None:
159
+ src_len, key_bsz, _ = key.size()
160
+ if not torch.jit.is_scripting():
161
+ assert key_bsz == bsz
162
+ assert value is not None
163
+ assert src_len, bsz == value.shape[:2]
164
+
165
+ if (
166
+ not self.onnx_trace
167
+ and not is_tpu # don't use PyTorch version on TPUs
168
+ and incremental_state is None
169
+ and not static_kv
170
+ # A workaround for quantization to work. Otherwise JIT compilation
171
+ # treats bias in linear module as method.
172
+ and not torch.jit.is_scripting()
173
+ and self_attn_mask is None
174
+ and attn_bias is None
175
+ ):
176
+ assert key is not None and value is not None
177
+ return F.multi_head_attention_forward(
178
+ query,
179
+ key,
180
+ value,
181
+ self.embed_dim,
182
+ self.num_heads,
183
+ torch.empty([0]),
184
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
185
+ self.bias_k,
186
+ self.bias_v,
187
+ self.add_zero_attn,
188
+ self.dropout_module.p,
189
+ self.out_proj.weight,
190
+ self.out_proj.bias,
191
+ self.training or self.dropout_module.apply_during_inference,
192
+ key_padding_mask,
193
+ need_weights,
194
+ attn_mask,
195
+ use_separate_proj_weight=True,
196
+ q_proj_weight=self.q_proj.weight,
197
+ k_proj_weight=self.k_proj.weight,
198
+ v_proj_weight=self.v_proj.weight,
199
+ )
200
+
201
+ if incremental_state is not None:
202
+ saved_state = self._get_input_buffer(incremental_state)
203
+ if saved_state is not None and "prev_key" in saved_state:
204
+ # previous time steps are cached - no need to recompute
205
+ # key and value if they are static
206
+ if static_kv:
207
+ assert self.encoder_decoder_attention and not self.self_attention
208
+ key = value = None
209
+ else:
210
+ saved_state = None
211
+
212
+ if self.self_attention and self_attn_mask is None:
213
+ q = self.q_proj(query)
214
+ k = self.k_proj(query)
215
+ v = self.v_proj(query)
216
+ elif self.encoder_decoder_attention:
217
+ # encoder-decoder attention
218
+ q = self.q_proj(query)
219
+ if key is None:
220
+ assert value is None
221
+ k = v = None
222
+ else:
223
+ k = self.k_proj(key)
224
+ v = self.v_proj(key)
225
+
226
+ else:
227
+ assert key is not None and value is not None
228
+ q = self.q_proj(query)
229
+ k = self.k_proj(key)
230
+ v = self.v_proj(value)
231
+ q *= self.scaling
232
+
233
+ if self.bias_k is not None:
234
+ assert self.bias_v is not None
235
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
236
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
237
+ if attn_mask is not None:
238
+ attn_mask = torch.cat(
239
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
240
+ )
241
+ if key_padding_mask is not None:
242
+ key_padding_mask = torch.cat(
243
+ [
244
+ key_padding_mask,
245
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
246
+ ],
247
+ dim=1,
248
+ )
249
+
250
+ q = (
251
+ q.contiguous()
252
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
253
+ .transpose(0, 1)
254
+ )
255
+ if k is not None:
256
+ k = (
257
+ k.contiguous()
258
+ .view(-1, bsz * self.num_heads, self.head_dim)
259
+ .transpose(0, 1)
260
+ )
261
+ if v is not None:
262
+ v = (
263
+ v.contiguous()
264
+ .view(-1, bsz * self.num_heads, self.head_dim)
265
+ .transpose(0, 1)
266
+ )
267
+
268
+ if saved_state is not None:
269
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
270
+ if "prev_key" in saved_state:
271
+ _prev_key = saved_state["prev_key"]
272
+ assert _prev_key is not None
273
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
274
+ if static_kv:
275
+ k = prev_key
276
+ else:
277
+ assert k is not None
278
+ k = torch.cat([prev_key, k], dim=1)
279
+ src_len = k.size(1)
280
+ if "prev_value" in saved_state:
281
+ _prev_value = saved_state["prev_value"]
282
+ assert _prev_value is not None
283
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
284
+ if static_kv:
285
+ v = prev_value
286
+ else:
287
+ assert v is not None
288
+ v = torch.cat([prev_value, v], dim=1)
289
+ prev_key_padding_mask: Optional[Tensor] = None
290
+ if "prev_key_padding_mask" in saved_state:
291
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
292
+ assert k is not None and v is not None
293
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
294
+ key_padding_mask=key_padding_mask,
295
+ prev_key_padding_mask=prev_key_padding_mask,
296
+ batch_size=bsz,
297
+ src_len=k.size(1),
298
+ static_kv=static_kv,
299
+ )
300
+
301
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
302
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
303
+ saved_state["prev_key_padding_mask"] = key_padding_mask
304
+ # In this branch incremental_state is never None
305
+ assert incremental_state is not None
306
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
307
+ assert k is not None
308
+ assert k.size(1) == src_len
309
+
310
+ # This is part of a workaround to get around fork/join parallelism
311
+ # not supporting Optional types.
312
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
313
+ key_padding_mask = None
314
+
315
+ if key_padding_mask is not None:
316
+ assert key_padding_mask.size(0) == bsz
317
+ assert key_padding_mask.size(1) == src_len
318
+
319
+ if self.add_zero_attn:
320
+ assert v is not None
321
+ src_len += 1
322
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
323
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
324
+ if attn_mask is not None:
325
+ attn_mask = torch.cat(
326
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
327
+ )
328
+ if key_padding_mask is not None:
329
+ key_padding_mask = torch.cat(
330
+ [
331
+ key_padding_mask,
332
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
333
+ key_padding_mask
334
+ ),
335
+ ],
336
+ dim=1,
337
+ )
338
+
339
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
340
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
341
+
342
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
343
+
344
+ if attn_bias is not None:
345
+ attn_weights += attn_bias
346
+
347
+ if attn_mask is not None:
348
+ attn_mask = attn_mask.unsqueeze(0)
349
+ if self.onnx_trace:
350
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
351
+ attn_weights += attn_mask
352
+
353
+ if self_attn_mask is not None:
354
+ self_attn_mask = self_attn_mask.unsqueeze(1).expand(bsz, self.num_heads, tgt_len, src_len)
355
+ attn_weights += self_attn_mask.contiguous().view(bsz * self.num_heads, tgt_len, src_len)
356
+
357
+ if key_padding_mask is not None:
358
+ # don't attend to padding symbols
359
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
360
+ if not is_tpu:
361
+ attn_weights = attn_weights.masked_fill(
362
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
363
+ float("-inf"),
364
+ )
365
+ else:
366
+ attn_weights = attn_weights.transpose(0, 2)
367
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
368
+ attn_weights = attn_weights.transpose(0, 2)
369
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
370
+
371
+ if before_softmax:
372
+ return attn_weights, v
373
+
374
+ attn_weights_float = utils.softmax(
375
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
376
+ )
377
+ attn_weights = attn_weights_float.type_as(attn_weights)
378
+ attn_probs = self.dropout_module(attn_weights)
379
+
380
+ assert v is not None
381
+ attn = torch.bmm(attn_probs, v)
382
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
383
+ if self.onnx_trace and attn.size(1) == 1:
384
+ # when ONNX tracing a single decoder step (sequence length == 1)
385
+ # the transpose is a no-op copy before view, thus unnecessary
386
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
387
+ else:
388
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
389
+ if self.c_attn is not None:
390
+ attn = attn.view(tgt_len, bsz, self.num_heads, self.head_dim)
391
+ attn = torch.einsum('tbhd,h->tbhd', attn, self.c_attn)
392
+ attn = attn.reshape(tgt_len, bsz, self.embed_dim)
393
+ attn = self.out_proj(attn)
394
+ attn_weights: Optional[Tensor] = None
395
+ if need_weights:
396
+ attn_weights = attn_weights_float.view(
397
+ bsz, self.num_heads, tgt_len, src_len
398
+ ).transpose(1, 0)
399
+ if not need_head_weights:
400
+ # average attention weights over heads
401
+ attn_weights = attn_weights.mean(dim=0)
402
+
403
+ return attn, attn_weights
404
+
405
+ @staticmethod
406
+ def _append_prev_key_padding_mask(
407
+ key_padding_mask: Optional[Tensor],
408
+ prev_key_padding_mask: Optional[Tensor],
409
+ batch_size: int,
410
+ src_len: int,
411
+ static_kv: bool,
412
+ ) -> Optional[Tensor]:
413
+ # saved key padding masks have shape (bsz, seq_len)
414
+ if prev_key_padding_mask is not None and static_kv:
415
+ new_key_padding_mask = prev_key_padding_mask
416
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
417
+ new_key_padding_mask = torch.cat(
418
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
419
+ )
420
+ # During incremental decoding, as the padding token enters and
421
+ # leaves the frame, there will be a time when prev or current
422
+ # is None
423
+ elif prev_key_padding_mask is not None:
424
+ if src_len > prev_key_padding_mask.size(1):
425
+ filler = torch.zeros(
426
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
427
+ device=prev_key_padding_mask.device,
428
+ )
429
+ new_key_padding_mask = torch.cat(
430
+ [prev_key_padding_mask.float(), filler.float()], dim=1
431
+ )
432
+ else:
433
+ new_key_padding_mask = prev_key_padding_mask.float()
434
+ elif key_padding_mask is not None:
435
+ if src_len > key_padding_mask.size(1):
436
+ filler = torch.zeros(
437
+ (batch_size, src_len - key_padding_mask.size(1)),
438
+ device=key_padding_mask.device,
439
+ )
440
+ new_key_padding_mask = torch.cat(
441
+ [filler.float(), key_padding_mask.float()], dim=1
442
+ )
443
+ else:
444
+ new_key_padding_mask = key_padding_mask.float()
445
+ else:
446
+ new_key_padding_mask = prev_key_padding_mask
447
+ return new_key_padding_mask
448
+
449
+ @torch.jit.export
450
+ def reorder_incremental_state(
451
+ self,
452
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
453
+ new_order: Tensor,
454
+ ):
455
+ """Reorder buffered internal state (for incremental generation)."""
456
+ input_buffer = self._get_input_buffer(incremental_state)
457
+ if input_buffer is not None:
458
+ for k in input_buffer.keys():
459
+ input_buffer_k = input_buffer[k]
460
+ if input_buffer_k is not None:
461
+ if self.encoder_decoder_attention and input_buffer_k.size(
462
+ 0
463
+ ) == new_order.size(0):
464
+ break
465
+ input_buffer[k] = input_buffer_k.index_select(0, new_order)
466
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
467
+ return incremental_state
468
+
469
+ def _get_input_buffer(
470
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
471
+ ) -> Dict[str, Optional[Tensor]]:
472
+ result = self.get_incremental_state(incremental_state, "attn_state")
473
+ if result is not None:
474
+ return result
475
+ else:
476
+ empty_result: Dict[str, Optional[Tensor]] = {}
477
+ return empty_result
478
+
479
+ def _set_input_buffer(
480
+ self,
481
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
482
+ buffer: Dict[str, Optional[Tensor]],
483
+ ):
484
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
485
+
486
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
487
+ return attn_weights
488
+
489
+ def upgrade_state_dict_named(self, state_dict, name):
490
+ prefix = name + "." if name != "" else ""
491
+ items_to_add = {}
492
+ keys_to_remove = []
493
+ for k in state_dict.keys():
494
+ if k.endswith(prefix + "in_proj_weight"):
495
+ # in_proj_weight used to be q + k + v with same dimensions
496
+ dim = int(state_dict[k].shape[0] / 3)
497
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
498
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
499
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
500
+
501
+ keys_to_remove.append(k)
502
+
503
+ k_bias = prefix + "in_proj_bias"
504
+ if k_bias in state_dict.keys():
505
+ dim = int(state_dict[k].shape[0] / 3)
506
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
507
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
508
+ dim : 2 * dim
509
+ ]
510
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
511
+
512
+ keys_to_remove.append(prefix + "in_proj_bias")
513
+
514
+ for k in keys_to_remove:
515
+ del state_dict[k]
516
+
517
+ for key, value in items_to_add.items():
518
+ state_dict[key] = value
models/ofa/unify_transformer.py ADDED
@@ -0,0 +1,1510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from fairseq import utils
14
+ from fairseq.distributed import fsdp_wrap
15
+ from fairseq.models import (
16
+ FairseqEncoder,
17
+ FairseqEncoderDecoderModel,
18
+ FairseqIncrementalDecoder,
19
+ register_model,
20
+ register_model_architecture,
21
+ )
22
+ from fairseq.modules import (
23
+ AdaptiveSoftmax,
24
+ BaseLayer,
25
+ FairseqDropout,
26
+ LayerDropModuleList,
27
+ LayerNorm,
28
+ SinusoidalPositionalEmbedding,
29
+ GradMultiply
30
+ )
31
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
32
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
33
+ from torch import Tensor
34
+
35
+ from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer
36
+ from .resnet import ResNet
37
+
38
+
39
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
40
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
41
+
42
+
43
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
44
+
45
+
46
+ def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3):
47
+ return nn.SyncBatchNorm.convert_sync_batchnorm(
48
+ nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps)
49
+ )
50
+
51
+
52
+ def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
53
+ context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
54
+ memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
55
+ relative_pos = context_pos - memory_pos
56
+ sign = torch.sign(relative_pos)
57
+ mid = bucket_size // 2
58
+ abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos))
59
+ log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
60
+ log_pos = log_pos.int()
61
+ bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long()
62
+ return bucket_pos + bucket_size - 1
63
+
64
+
65
+ def make_image_bucket_position(bucket_size, num_relative_distance):
66
+ coords_h = torch.arange(bucket_size)
67
+ coords_w = torch.arange(bucket_size)
68
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
69
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
70
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
71
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
72
+ relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
73
+ relative_coords[:, :, 1] += bucket_size - 1
74
+ relative_coords[:, :, 0] *= 2 * bucket_size - 1
75
+ relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
76
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
77
+ relative_position_index[0, 0:] = num_relative_distance - 3
78
+ relative_position_index[0:, 0] = num_relative_distance - 2
79
+ relative_position_index[0, 0] = num_relative_distance - 1
80
+ return relative_position_index
81
+
82
+
83
+ @register_model("unify_transformer")
84
+ class TransformerModel(FairseqEncoderDecoderModel):
85
+ """
86
+ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
87
+ <https://arxiv.org/abs/1706.03762>`_.
88
+
89
+ Args:
90
+ encoder (TransformerEncoder): the encoder
91
+ decoder (TransformerDecoder): the decoder
92
+
93
+ The Transformer model provides the following named architectures and
94
+ command-line arguments:
95
+
96
+ .. argparse::
97
+ :ref: fairseq.models.transformer_parser
98
+ :prog:
99
+ """
100
+
101
+ def __init__(self, args, encoder, decoder):
102
+ super().__init__(encoder, decoder)
103
+ self.args = args
104
+ self.supports_align_args = True
105
+
106
+ @staticmethod
107
+ def add_args(parser):
108
+ """Add model-specific arguments to the parser."""
109
+ # fmt: off
110
+ parser.add_argument('--activation-fn',
111
+ choices=utils.get_available_activation_fns(),
112
+ help='activation function to use')
113
+ parser.add_argument('--dropout', type=float, metavar='D',
114
+ help='dropout probability')
115
+ parser.add_argument('--attention-dropout', type=float, metavar='D',
116
+ help='dropout probability for attention weights')
117
+ parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
118
+ help='dropout probability after activation in FFN.')
119
+ parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
120
+ help='path to pre-trained encoder embedding')
121
+ parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
122
+ help='encoder embedding dimension')
123
+ parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
124
+ help='encoder embedding dimension for FFN')
125
+ parser.add_argument('--encoder-layers', type=int, metavar='N',
126
+ help='num encoder layers')
127
+ parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
128
+ help='num encoder attention heads')
129
+ parser.add_argument('--encoder-normalize-before', action='store_true',
130
+ help='apply layernorm before each encoder block')
131
+ parser.add_argument('--encoder-learned-pos', action='store_true',
132
+ help='use learned positional embeddings in the encoder')
133
+ parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
134
+ help='path to pre-trained decoder embedding')
135
+ parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
136
+ help='decoder embedding dimension')
137
+ parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
138
+ help='decoder embedding dimension for FFN')
139
+ parser.add_argument('--decoder-layers', type=int, metavar='N',
140
+ help='num decoder layers')
141
+ parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
142
+ help='num decoder attention heads')
143
+ parser.add_argument('--decoder-learned-pos', action='store_true',
144
+ help='use learned positional embeddings in the decoder')
145
+ parser.add_argument('--decoder-normalize-before', action='store_true',
146
+ help='apply layernorm before each decoder block')
147
+ parser.add_argument('--decoder-output-dim', type=int, metavar='N',
148
+ help='decoder output dimension (extra linear layer '
149
+ 'if different from decoder embed dim')
150
+ parser.add_argument('--share-decoder-input-output-embed', action='store_true',
151
+ help='share decoder input and output embeddings')
152
+ parser.add_argument('--share-all-embeddings', action='store_true',
153
+ help='share encoder, decoder and output embeddings'
154
+ ' (requires shared dictionary and embed dim)')
155
+ parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
156
+ help='if set, disables positional embeddings (outside self attention)')
157
+ parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
158
+ help='comma separated list of adaptive softmax cutoff points. '
159
+ 'Must be used with adaptive_loss criterion'),
160
+ parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
161
+ help='sets adaptive softmax dropout for the tail projections')
162
+ parser.add_argument('--layernorm-embedding', action='store_true',
163
+ help='add layernorm to embedding')
164
+ parser.add_argument('--no-scale-embedding', action='store_true',
165
+ help='if True, dont scale embeddings')
166
+ parser.add_argument('--checkpoint-activations', action='store_true',
167
+ help='checkpoint activations at each layer, which saves GPU '
168
+ 'memory usage at the cost of some additional compute')
169
+ parser.add_argument('--offload-activations', action='store_true',
170
+ help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
171
+ # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
172
+ parser.add_argument('--no-cross-attention', default=False, action='store_true',
173
+ help='do not perform cross-attention')
174
+ parser.add_argument('--cross-self-attention', default=False, action='store_true',
175
+ help='perform cross+self-attention')
176
+ # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
177
+ parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
178
+ help='LayerDrop probability for encoder')
179
+ parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
180
+ help='LayerDrop probability for decoder')
181
+ parser.add_argument('--encoder-layers-to-keep', default=None,
182
+ help='which layers to *keep* when pruning as a comma-separated list')
183
+ parser.add_argument('--decoder-layers-to-keep', default=None,
184
+ help='which layers to *keep* when pruning as a comma-separated list')
185
+ # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
186
+ parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
187
+ help='iterative PQ quantization noise at training time')
188
+ parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
189
+ help='block size of quantization noise at training time')
190
+ parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
191
+ help='scalar quantization noise and scalar quantization at training time')
192
+ # args for Fully Sharded Data Parallel (FSDP) training
193
+ parser.add_argument(
194
+ '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
195
+ help=(
196
+ 'minimum number of params for a layer to be wrapped with FSDP() when '
197
+ 'training with --ddp-backend=fully_sharded. Smaller values will '
198
+ 'improve memory efficiency, but may make torch.distributed '
199
+ 'communication less efficient due to smaller input sizes. This option '
200
+ 'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
201
+ '--offload-activations are passed.'
202
+ )
203
+ )
204
+
205
+ parser.add_argument('--resnet-drop-path-rate', type=float,
206
+ help='resnet drop path rate')
207
+ parser.add_argument('--encoder-drop-path-rate', type=float,
208
+ help='encoder drop path rate')
209
+ parser.add_argument('--decoder-drop-path-rate', type=float,
210
+ help='encoder drop path rate')
211
+
212
+ parser.add_argument('--token-bucket-size', type=int,
213
+ help='token bucket size')
214
+ parser.add_argument('--image-bucket-size', type=int,
215
+ help='image bucket size')
216
+
217
+ parser.add_argument('--attn-scale-factor', type=float,
218
+ help='attention scale factor')
219
+ parser.add_argument('--freeze-resnet', action='store_true',
220
+ help='freeze resnet')
221
+ parser.add_argument('--freeze-encoder-embedding', action='store_true',
222
+ help='freeze encoder token embedding')
223
+ parser.add_argument('--freeze-decoder-embedding', action='store_true',
224
+ help='freeze decoder token embedding')
225
+ parser.add_argument('--add-type-embedding', action='store_true',
226
+ help='add source/region/patch type embedding')
227
+
228
+ parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
229
+ help='resnet type')
230
+ parser.add_argument('--resnet-model-path', type=str, metavar='STR',
231
+ help='path to load resnet')
232
+ parser.add_argument('--code-image-size', type=int,
233
+ help='code image size')
234
+ parser.add_argument('--patch-layernorm-embedding', action='store_true',
235
+ help='add layernorm to patch embedding')
236
+ parser.add_argument('--code-layernorm-embedding', action='store_true',
237
+ help='add layernorm to code embedding')
238
+ parser.add_argument('--entangle-position-embedding', action='store_true',
239
+ help='entangle position embedding')
240
+ parser.add_argument('--disable-entangle', action='store_true',
241
+ help='disable entangle')
242
+ parser.add_argument('--sync-bn', action='store_true',
243
+ help='sync batchnorm')
244
+
245
+ parser.add_argument('--scale-attn', action='store_true',
246
+ help='scale attn')
247
+ parser.add_argument('--scale-fc', action='store_true',
248
+ help='scale fc')
249
+ parser.add_argument('--scale-heads', action='store_true',
250
+ help='scale heads')
251
+ parser.add_argument('--scale-resids', action='store_true',
252
+ help='scale resids')
253
+ # fmt: on
254
+
255
+ @classmethod
256
+ def build_model(cls, args, task):
257
+ """Build a new model instance."""
258
+
259
+ # make sure all arguments are present in older models
260
+ base_architecture(args)
261
+
262
+ if args.encoder_layers_to_keep:
263
+ args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
264
+ if args.decoder_layers_to_keep:
265
+ args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
266
+
267
+ if getattr(args, "max_source_positions", None) is None:
268
+ args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
269
+ if getattr(args, "max_target_positions", None) is None:
270
+ args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
271
+
272
+ src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
273
+
274
+ if args.share_all_embeddings:
275
+ if src_dict != tgt_dict:
276
+ raise ValueError("--share-all-embeddings requires a joined dictionary")
277
+ if args.encoder_embed_dim != args.decoder_embed_dim:
278
+ raise ValueError(
279
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
280
+ )
281
+ if args.decoder_embed_path and (
282
+ args.decoder_embed_path != args.encoder_embed_path
283
+ ):
284
+ raise ValueError(
285
+ "--share-all-embeddings not compatible with --decoder-embed-path"
286
+ )
287
+ encoder_embed_tokens = cls.build_embedding(
288
+ args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
289
+ )
290
+ decoder_embed_tokens = encoder_embed_tokens
291
+ args.share_decoder_input_output_embed = True
292
+ else:
293
+ encoder_embed_tokens = cls.build_embedding(
294
+ args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
295
+ )
296
+ decoder_embed_tokens = cls.build_embedding(
297
+ args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
298
+ )
299
+ if getattr(args, "freeze_encoder_embedding", False):
300
+ encoder_embed_tokens.weight.requires_grad = False
301
+ if getattr(args, "freeze_decoder_embedding", False):
302
+ decoder_embed_tokens.weight.requires_grad = False
303
+ if getattr(args, "offload_activations", False):
304
+ args.checkpoint_activations = True # offloading implies checkpointing
305
+ encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
306
+ decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
307
+ if not args.share_all_embeddings:
308
+ min_params_to_wrap = getattr(
309
+ args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
310
+ )
311
+ # fsdp_wrap is a no-op when --ddp-backend != fully_sharded
312
+ encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
313
+ decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
314
+ return cls(args, encoder, decoder)
315
+
316
+ @classmethod
317
+ def build_embedding(cls, args, dictionary, embed_dim, path=None):
318
+ num_embeddings = len(dictionary)
319
+ padding_idx = dictionary.pad()
320
+
321
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
322
+ # if provided, load from preloaded dictionaries
323
+ if path:
324
+ embed_dict = utils.parse_embedding(path)
325
+ utils.load_embedding(embed_dict, dictionary, emb)
326
+ return emb
327
+
328
+ @classmethod
329
+ def build_encoder(cls, args, src_dict, embed_tokens):
330
+ return TransformerEncoder(args, src_dict, embed_tokens)
331
+
332
+ @classmethod
333
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
334
+ return TransformerDecoder(
335
+ args,
336
+ tgt_dict,
337
+ embed_tokens,
338
+ no_encoder_attn=getattr(args, "no_cross_attention", False),
339
+ )
340
+
341
+ # TorchScript doesn't support optional arguments with variable length (**kwargs).
342
+ # Current workaround is to add union of all arguments in child classes.
343
+ def forward(
344
+ self,
345
+ src_tokens,
346
+ src_lengths,
347
+ prev_output_tokens,
348
+ return_all_hiddens: bool = True,
349
+ features_only: bool = False,
350
+ alignment_layer: Optional[int] = None,
351
+ alignment_heads: Optional[int] = None,
352
+ ):
353
+ """
354
+ Run the forward pass for an encoder-decoder model.
355
+
356
+ Copied from the base class, but without ``**kwargs``,
357
+ which are not supported by TorchScript.
358
+ """
359
+ encoder_out = self.encoder(
360
+ src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
361
+ )
362
+ decoder_out = self.decoder(
363
+ prev_output_tokens,
364
+ encoder_out=encoder_out,
365
+ features_only=features_only,
366
+ alignment_layer=alignment_layer,
367
+ alignment_heads=alignment_heads,
368
+ src_lengths=src_lengths,
369
+ return_all_hiddens=return_all_hiddens,
370
+ )
371
+ return decoder_out
372
+
373
+ # Since get_normalized_probs is in the Fairseq Model which is not scriptable,
374
+ # I rewrite the get_normalized_probs from Base Class to call the
375
+ # helper function in the Base Class.
376
+ @torch.jit.export
377
+ def get_normalized_probs(
378
+ self,
379
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
380
+ log_probs: bool,
381
+ sample: Optional[Dict[str, Tensor]] = None,
382
+ ):
383
+ """Get normalized probabilities (or log probs) from a net's output."""
384
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
385
+
386
+
387
+ class TransformerEncoder(FairseqEncoder):
388
+ """
389
+ Transformer encoder consisting of *args.encoder_layers* layers. Each layer
390
+ is a :class:`TransformerEncoderLayer`.
391
+
392
+ Args:
393
+ args (argparse.Namespace): parsed command-line arguments
394
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
395
+ embed_tokens (torch.nn.Embedding): input embedding
396
+ """
397
+
398
+ def __init__(self, args, dictionary, embed_tokens):
399
+ self.args = args
400
+ super().__init__(dictionary)
401
+ self.register_buffer("version", torch.Tensor([3]))
402
+
403
+ self.dropout_module = FairseqDropout(
404
+ args.dropout, module_name=self.__class__.__name__
405
+ )
406
+ self.encoder_layerdrop = args.encoder_layerdrop
407
+
408
+ embed_dim = embed_tokens.embedding_dim
409
+ self.padding_idx = embed_tokens.padding_idx
410
+ self.max_source_positions = args.max_source_positions
411
+ self.num_attention_heads = args.encoder_attention_heads
412
+
413
+ self.embed_tokens = embed_tokens
414
+
415
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
416
+
417
+ if getattr(args, "layernorm_embedding", False):
418
+ self.layernorm_embedding = LayerNorm(embed_dim)
419
+ else:
420
+ self.layernorm_embedding = None
421
+
422
+ if getattr(args, "add_type_embedding", False):
423
+ self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
424
+ else:
425
+ self.type_embedding = None
426
+
427
+ if getattr(args, "sync_bn", False):
428
+ norm_layer = BatchNorm2d
429
+ else:
430
+ norm_layer = None
431
+
432
+ if args.resnet_type == 'resnet101':
433
+ self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
434
+ elif args.resnet_type == 'resnet152':
435
+ self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
436
+ else:
437
+ raise NotImplementedError
438
+ self.image_proj = Linear(1024, embed_dim)
439
+ if getattr(args, "resnet_model_path", None):
440
+ print("load resnet {}".format(args.resnet_model_path))
441
+ resnet_state_dict = torch.load(self.args.resnet_model_path)
442
+ self.embed_images.load_state_dict(resnet_state_dict)
443
+ if getattr(args, "patch_layernorm_embedding", False):
444
+ self.patch_layernorm_embedding = LayerNorm(embed_dim)
445
+ else:
446
+ self.patch_layernorm_embedding = None
447
+
448
+ self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim)
449
+ self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
450
+ self.pos_ln = LayerNorm(embed_dim)
451
+ self.image_pos_ln = LayerNorm(embed_dim)
452
+ self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5
453
+ self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
454
+ self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
455
+
456
+ if not args.adaptive_input and args.quant_noise_pq > 0:
457
+ self.quant_noise = apply_quant_noise_(
458
+ nn.Linear(embed_dim, embed_dim, bias=False),
459
+ args.quant_noise_pq,
460
+ args.quant_noise_pq_block_size,
461
+ )
462
+ else:
463
+ self.quant_noise = None
464
+
465
+ if self.encoder_layerdrop > 0.0:
466
+ self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
467
+ else:
468
+ self.layers = nn.ModuleList([])
469
+
470
+ dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)]
471
+ self.layers.extend(
472
+ [self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)]
473
+ )
474
+ self.num_layers = len(self.layers)
475
+
476
+ if args.encoder_normalize_before:
477
+ self.layer_norm = LayerNorm(embed_dim)
478
+ else:
479
+ self.layer_norm = None
480
+
481
+ token_bucket_size = args.token_bucket_size
482
+ token_num_rel_dis = 2 * token_bucket_size - 1
483
+ token_rp_bucket = make_token_bucket_position(token_bucket_size)
484
+ self.token_rel_pos_table_list = nn.ModuleList(
485
+ [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
486
+ )
487
+
488
+ image_bucket_size = args.image_bucket_size
489
+ image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
490
+ image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
491
+ self.image_rel_pos_table_list = nn.ModuleList(
492
+ [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
493
+ )
494
+
495
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
496
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
497
+ self.entangle_position_embedding = args.entangle_position_embedding
498
+
499
+ def train(self, mode=True):
500
+ super(TransformerEncoder, self).train(mode)
501
+ if getattr(self.args, "freeze_resnet", False):
502
+ for m in self.embed_images.modules():
503
+ if isinstance(m, nn.BatchNorm2d):
504
+ m.eval()
505
+ m.weight.requires_grad = False
506
+ m.bias.requires_grad = False
507
+
508
+ def build_encoder_layer(self, args, drop_path_rate=0.0):
509
+ layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate)
510
+ checkpoint = getattr(args, "checkpoint_activations", False)
511
+ if checkpoint:
512
+ offload_to_cpu = getattr(args, "offload_activations", False)
513
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
514
+ # if we are checkpointing, enforce that FSDP always wraps the
515
+ # checkpointed layer, regardless of layer size
516
+ min_params_to_wrap = (
517
+ getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
518
+ if not checkpoint else 0
519
+ )
520
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
521
+ return layer
522
+
523
+ def get_rel_pos_bias(self, x, idx):
524
+ seq_len = x.size(1)
525
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
526
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
527
+ values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
528
+ values = values.permute([0, 3, 1, 2])
529
+ return values.contiguous()
530
+
531
+ def get_image_rel_pos_bias(self, image_position_ids, idx):
532
+ bsz, seq_len = image_position_ids.shape
533
+ rp_bucket_size = self.image_rp_bucket.size(1)
534
+
535
+ rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
536
+ bsz, rp_bucket_size, rp_bucket_size
537
+ ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
538
+ ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
539
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
540
+ values = values.permute(0, 3, 1, 2)
541
+ return values
542
+
543
+ def get_patch_images_info(self, patch_images, sample_patch_num, device):
544
+ image_embed = self.embed_images(patch_images)
545
+ h, w = image_embed.shape[-2:]
546
+ image_num_patches = h * w
547
+ image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
548
+ image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
549
+ torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1
550
+ image_position_idx = image_position_idx.view(-1).to(device)
551
+ image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
552
+
553
+ image_embed = image_embed.flatten(2).transpose(1, 2)
554
+ if sample_patch_num is not None:
555
+ patch_orders = [
556
+ random.sample(range(image_num_patches), k=sample_patch_num)
557
+ for _ in range(patch_images.size(0))
558
+ ]
559
+ patch_orders = torch.LongTensor(patch_orders).to(device)
560
+ image_embed = image_embed.gather(
561
+ 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
562
+ )
563
+ image_num_patches = sample_patch_num
564
+ image_padding_mask = image_padding_mask.gather(1, patch_orders)
565
+ image_position_ids = image_position_ids.gather(1, patch_orders)
566
+ image_pos_embed = self.embed_image_positions(image_position_ids)
567
+
568
+ return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
569
+
570
+ def forward_embedding(
571
+ self,
572
+ src_tokens,
573
+ image_embed: Optional[torch.Tensor] = None,
574
+ image_embed_2: Optional[torch.Tensor] = None,
575
+ token_embedding: Optional[torch.Tensor] = None,
576
+ pos_embed: Optional[torch.Tensor] = None,
577
+ image_pos_embed: Optional[torch.Tensor] = None,
578
+ image_pos_embed_2: Optional[torch.Tensor] = None
579
+ ):
580
+ # embed tokens and positions
581
+ if token_embedding is None:
582
+ token_embedding = self.embed_tokens(src_tokens)
583
+ x = embed = self.embed_scale * token_embedding
584
+ if self.entangle_position_embedding and pos_embed is not None:
585
+ x += pos_embed
586
+ if self.type_embedding is not None:
587
+ x += self.type_embedding(src_tokens.new_zeros(x.size()[:2]))
588
+ if self.layernorm_embedding is not None:
589
+ x = self.layernorm_embedding(x)
590
+ x = self.dropout_module(x)
591
+ if self.quant_noise is not None:
592
+ x = self.quant_noise(x)
593
+
594
+ # embed raw images
595
+ if image_embed is not None:
596
+ image_embed = self.image_proj(image_embed)
597
+ image_x = image_embed = self.embed_scale * image_embed
598
+ if self.entangle_position_embedding and image_pos_embed is not None:
599
+ image_x += image_pos_embed
600
+ if self.type_embedding is not None:
601
+ image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2]))
602
+ if self.patch_layernorm_embedding is not None:
603
+ image_x = self.patch_layernorm_embedding(image_x)
604
+ image_x = self.dropout_module(image_x)
605
+ if self.quant_noise is not None:
606
+ image_x = self.quant_noise(image_x)
607
+ x = torch.cat([image_x, x], dim=1)
608
+ embed = torch.cat([image_embed, embed], dim=1)
609
+
610
+ if image_embed_2 is not None:
611
+ assert self.type_embedding is not None
612
+ image_embed_2 = self.image_proj(image_embed_2)
613
+ image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
614
+ if self.entangle_position_embedding and image_pos_embed_2 is not None:
615
+ image_x_2 += image_pos_embed_2
616
+ if self.type_embedding is not None:
617
+ image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2))
618
+ if self.patch_layernorm_embedding is not None:
619
+ image_x_2 = self.patch_layernorm_embedding(image_x_2)
620
+ image_x_2 = self.dropout_module(image_x_2)
621
+ if self.quant_noise is not None:
622
+ image_x_2 = self.quant_noise(image_x_2)
623
+ x = torch.cat([image_x_2, x], dim=1)
624
+ embed = torch.cat([image_embed_2, embed], dim=1)
625
+
626
+ return x, embed
627
+
628
+ def forward(
629
+ self,
630
+ src_tokens,
631
+ src_lengths,
632
+ patch_images: Optional[torch.Tensor] = None,
633
+ patch_images_2: Optional[torch.Tensor] = None,
634
+ patch_masks: Optional[torch.Tensor] = None,
635
+ code_masks: Optional[torch.Tensor] = None,
636
+ return_all_hiddens: bool = False,
637
+ token_embeddings: Optional[torch.Tensor] = None,
638
+ sample_patch_num: Optional[int] = None
639
+ ):
640
+ """
641
+ Args:
642
+ src_tokens (LongTensor): tokens in the source language of shape
643
+ `(batch, src_len)`
644
+ src_lengths (torch.LongTensor): lengths of each source sentence of
645
+ shape `(batch)`
646
+ return_all_hiddens (bool, optional): also return all of the
647
+ intermediate hidden states (default: False).
648
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
649
+ default `None` will recompute embeddings
650
+
651
+ Returns:
652
+ dict:
653
+ - **encoder_out** (Tensor): the last encoder layer's output of
654
+ shape `(src_len, batch, embed_dim)`
655
+ - **encoder_padding_mask** (ByteTensor): the positions of
656
+ padding elements of shape `(batch, src_len)`
657
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
658
+ of shape `(batch, src_len, embed_dim)`
659
+ - **encoder_states** (List[Tensor]): all intermediate
660
+ hidden states of shape `(src_len, batch, embed_dim)`.
661
+ Only populated if *return_all_hiddens* is True.
662
+ """
663
+ return self.forward_scriptable(src_tokens,
664
+ src_lengths,
665
+ patch_images,
666
+ patch_images_2,
667
+ patch_masks,
668
+ return_all_hiddens,
669
+ token_embeddings,
670
+ sample_patch_num)
671
+
672
+ # TorchScript doesn't support super() method so that the scriptable Subclass
673
+ # can't access the base class model in Torchscript.
674
+ # Current workaround is to add a helper function with different name and
675
+ # call the helper function from scriptable Subclass.
676
+ def forward_scriptable(
677
+ self,
678
+ src_tokens,
679
+ src_lengths,
680
+ patch_images: Optional[torch.Tensor] = None,
681
+ patch_images_2: Optional[torch.Tensor] = None,
682
+ patch_masks: Optional[torch.Tensor] = None,
683
+ return_all_hiddens: bool = False,
684
+ token_embeddings: Optional[torch.Tensor] = None,
685
+ sample_patch_num: Optional[int] = None
686
+ ):
687
+ """
688
+ Args:
689
+ src_tokens (LongTensor): tokens in the source language of shape
690
+ `(batch, src_len)`
691
+ src_lengths (torch.LongTensor): lengths of each source sentence of
692
+ shape `(batch)`
693
+ return_all_hiddens (bool, optional): also return all of the
694
+ intermediate hidden states (default: False).
695
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
696
+ default `None` will recompute embeddings
697
+
698
+ Returns:
699
+ dict:
700
+ - **encoder_out** (Tensor): the last encoder layer's output of
701
+ shape `(src_len, batch, embed_dim)`
702
+ - **encoder_padding_mask** (ByteTensor): the positions of
703
+ padding elements of shape `(batch, src_len)`
704
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
705
+ of shape `(batch, src_len, embed_dim)`
706
+ - **encoder_states** (List[Tensor]): all intermediate
707
+ hidden states of shape `(src_len, batch, embed_dim)`.
708
+ Only populated if *return_all_hiddens* is True.
709
+ """
710
+ image_embed = None
711
+ image_embed_2 = None
712
+ image_pos_embed = None
713
+ image_pos_embed_2 = None
714
+ if patch_images is not None:
715
+ image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
716
+ self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
717
+ image_padding_mask[~patch_masks] = True
718
+ if patch_images_2 is not None:
719
+ image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
720
+ self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device)
721
+ image_padding_mask_2[~patch_masks] = True
722
+
723
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
724
+ if patch_images is not None:
725
+ encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
726
+ if patch_images_2 is not None:
727
+ encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
728
+ has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
729
+
730
+ pos_embed = self.embed_positions(utils.new_arange(src_tokens))
731
+ x, encoder_embedding = self.forward_embedding(
732
+ src_tokens, image_embed, image_embed_2, token_embeddings,
733
+ pos_embed, image_pos_embed, image_pos_embed_2
734
+ )
735
+
736
+ # account for padding while computing the representation
737
+ if has_pads:
738
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
739
+
740
+ # B x T x C -> T x B x C
741
+ x = x.transpose(0, 1)
742
+
743
+ pos_embed = self.pos_ln(pos_embed)
744
+ if patch_images is not None:
745
+ image_pos_embed = self.image_pos_ln(image_pos_embed)
746
+ pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
747
+ if patch_images_2 is not None:
748
+ image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
749
+ pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
750
+
751
+ pos_q = self.pos_q_linear(pos_embed).view(
752
+ x.size(1), x.size(0), self.num_attention_heads, -1
753
+ ).transpose(1, 2) * self.pos_scaling
754
+ pos_k = self.pos_k_linear(pos_embed).view(
755
+ x.size(1), x.size(0), self.num_attention_heads, -1
756
+ ).transpose(1, 2)
757
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
758
+
759
+ encoder_states = []
760
+
761
+ if return_all_hiddens:
762
+ encoder_states.append(x)
763
+
764
+ # encoder layers
765
+ for idx, layer in enumerate(self.layers):
766
+ self_attn_bias = abs_pos_bias.clone()
767
+ self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx)
768
+ if patch_images_2 is not None:
769
+ self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
770
+ self.get_image_rel_pos_bias(image_position_ids_2, idx)
771
+ self_attn_bias[:, :, image_num_patches_2:image_num_patches_2+image_num_patches, image_num_patches_2:image_num_patches_2+image_num_patches] += \
772
+ self.get_image_rel_pos_bias(image_position_ids, idx)
773
+ elif patch_images is not None:
774
+ self_attn_bias[:, :, :x.size(0) - src_tokens.size(1), :x.size(0) - src_tokens.size(1)] += \
775
+ self.get_image_rel_pos_bias(image_position_ids, idx)
776
+ self_attn_bias = self_attn_bias.reshape(-1, x.size(0), x.size(0))
777
+
778
+ x = layer(
779
+ x, encoder_padding_mask=encoder_padding_mask if has_pads else None, self_attn_bias=self_attn_bias
780
+ )
781
+ if return_all_hiddens:
782
+ assert encoder_states is not None
783
+ encoder_states.append(x)
784
+
785
+ if self.layer_norm is not None:
786
+ x = self.layer_norm(x)
787
+
788
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
789
+ # `forward` so we use a dictionary instead.
790
+ # TorchScript does not support mixed values so the values are all lists.
791
+ # The empty list is equivalent to None.
792
+ return {
793
+ "encoder_out": [x], # T x B x C
794
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
795
+ "encoder_embedding": [], # B x T x C
796
+ "encoder_states": encoder_states, # List[T x B x C]
797
+ "src_tokens": [],
798
+ "src_lengths": [],
799
+ "position_embeddings": [pos_embed], # B x T x C
800
+ }
801
+
802
+ @torch.jit.export
803
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
804
+ """
805
+ Reorder encoder output according to *new_order*.
806
+
807
+ Args:
808
+ encoder_out: output from the ``forward()`` method
809
+ new_order (LongTensor): desired order
810
+
811
+ Returns:
812
+ *encoder_out* rearranged according to *new_order*
813
+ """
814
+ if len(encoder_out["encoder_out"]) == 0:
815
+ new_encoder_out = []
816
+ else:
817
+ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
818
+ if len(encoder_out["encoder_padding_mask"]) == 0:
819
+ new_encoder_padding_mask = []
820
+ else:
821
+ new_encoder_padding_mask = [
822
+ encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
823
+ ]
824
+ if len(encoder_out["encoder_embedding"]) == 0:
825
+ new_encoder_embedding = []
826
+ else:
827
+ new_encoder_embedding = [
828
+ encoder_out["encoder_embedding"][0].index_select(0, new_order)
829
+ ]
830
+
831
+ if len(encoder_out["src_tokens"]) == 0:
832
+ new_src_tokens = []
833
+ else:
834
+ new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
835
+
836
+ if len(encoder_out["src_lengths"]) == 0:
837
+ new_src_lengths = []
838
+ else:
839
+ new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
840
+
841
+ if len(encoder_out["position_embeddings"]) == 0:
842
+ new_position_embeddings = []
843
+ else:
844
+ new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)]
845
+
846
+ encoder_states = encoder_out["encoder_states"]
847
+ if len(encoder_states) > 0:
848
+ for idx, state in enumerate(encoder_states):
849
+ encoder_states[idx] = state.index_select(1, new_order)
850
+
851
+ return {
852
+ "encoder_out": new_encoder_out, # T x B x C
853
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
854
+ "encoder_embedding": new_encoder_embedding, # B x T x C
855
+ "encoder_states": encoder_states, # List[T x B x C]
856
+ "src_tokens": new_src_tokens, # B x T
857
+ "src_lengths": new_src_lengths, # B x 1
858
+ "position_embeddings": new_position_embeddings, # B x T x C
859
+ }
860
+
861
+ def max_positions(self):
862
+ """Maximum input length supported by the encoder."""
863
+ if self.embed_positions is None:
864
+ return self.max_source_positions
865
+ return self.max_source_positions
866
+
867
+ def upgrade_state_dict_named(self, state_dict, name):
868
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
869
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
870
+ weights_key = "{}.embed_positions.weights".format(name)
871
+ if weights_key in state_dict:
872
+ print("deleting {0}".format(weights_key))
873
+ del state_dict[weights_key]
874
+ state_dict[
875
+ "{}.embed_positions._float_tensor".format(name)
876
+ ] = torch.FloatTensor(1)
877
+ for i in range(self.num_layers):
878
+ # update layer norms
879
+ self.layers[i].upgrade_state_dict_named(
880
+ state_dict, "{}.layers.{}".format(name, i)
881
+ )
882
+
883
+ # version_key = "{}.version".format(name)
884
+ # if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
885
+ # # earlier checkpoints did not normalize after the stack of layers
886
+ # self.layer_norm = None
887
+ # self.normalize = False
888
+ # state_dict[version_key] = torch.Tensor([1])
889
+
890
+ prefix = name + "." if name != "" else ""
891
+ for param_name, param_tensor in self.state_dict().items():
892
+ if (prefix + param_name) not in state_dict and param_name in self.state_dict():
893
+ state_dict[prefix + param_name] = self.state_dict()[param_name]
894
+
895
+ if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
896
+ num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"])
897
+ embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1)
898
+ new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
899
+ nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
900
+ new_pos_embed_to_add = new_pos_embed_to_add.to(
901
+ dtype=state_dict["encoder.embed_image_positions.weight"].dtype,
902
+ )
903
+ state_dict["encoder.embed_image_positions.weight"] = torch.cat(
904
+ [state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add]
905
+ )
906
+ return state_dict
907
+
908
+
909
+ class TransformerDecoder(FairseqIncrementalDecoder):
910
+ """
911
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
912
+ is a :class:`TransformerDecoderLayer`.
913
+
914
+ Args:
915
+ args (argparse.Namespace): parsed command-line arguments
916
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
917
+ embed_tokens (torch.nn.Embedding): output embedding
918
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
919
+ (default: False).
920
+ """
921
+
922
+ def __init__(
923
+ self,
924
+ args,
925
+ dictionary,
926
+ embed_tokens,
927
+ no_encoder_attn=False,
928
+ output_projection=None,
929
+ ):
930
+ self.args = args
931
+ super().__init__(dictionary)
932
+ self.register_buffer("version", torch.Tensor([3]))
933
+ self._future_mask = torch.empty(0)
934
+
935
+ self.dropout_module = FairseqDropout(
936
+ args.dropout, module_name=self.__class__.__name__
937
+ )
938
+ self.decoder_layerdrop = args.decoder_layerdrop
939
+ self.share_input_output_embed = args.share_decoder_input_output_embed
940
+ self.num_attention_heads = args.decoder_attention_heads
941
+
942
+ input_embed_dim = embed_tokens.embedding_dim
943
+ embed_dim = args.decoder_embed_dim
944
+ self.embed_dim = embed_dim
945
+ self.output_embed_dim = args.decoder_output_dim
946
+
947
+ self.padding_idx = embed_tokens.padding_idx
948
+ self.max_target_positions = args.max_target_positions
949
+
950
+ self.embed_tokens = embed_tokens
951
+
952
+ self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
953
+
954
+ if not args.adaptive_input and args.quant_noise_pq > 0:
955
+ self.quant_noise = apply_quant_noise_(
956
+ nn.Linear(embed_dim, embed_dim, bias=False),
957
+ args.quant_noise_pq,
958
+ args.quant_noise_pq_block_size,
959
+ )
960
+ else:
961
+ self.quant_noise = None
962
+
963
+ self.project_in_dim = (
964
+ Linear(input_embed_dim, embed_dim, bias=False)
965
+ if embed_dim != input_embed_dim
966
+ else None
967
+ )
968
+
969
+ if getattr(args, "layernorm_embedding", False):
970
+ self.layernorm_embedding = LayerNorm(embed_dim)
971
+ else:
972
+ self.layernorm_embedding = None
973
+
974
+ self.window_size = args.code_image_size // 8
975
+
976
+ self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim)
977
+ self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
978
+ self.pos_ln = LayerNorm(embed_dim)
979
+ self.image_pos_ln = LayerNorm(embed_dim)
980
+ self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5
981
+ self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim)
982
+ self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim)
983
+ self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim)
984
+ self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim)
985
+
986
+ if getattr(args, "code_layernorm_embedding", False):
987
+ self.code_layernorm_embedding = LayerNorm(embed_dim)
988
+ else:
989
+ self.code_layernorm_embedding = None
990
+
991
+ self.cross_self_attention = getattr(args, "cross_self_attention", False)
992
+
993
+ if self.decoder_layerdrop > 0.0:
994
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
995
+ else:
996
+ self.layers = nn.ModuleList([])
997
+
998
+ dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)]
999
+ self.layers.extend(
1000
+ [
1001
+ self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i])
1002
+ for i in range(args.decoder_layers)
1003
+ ]
1004
+ )
1005
+ self.num_layers = len(self.layers)
1006
+
1007
+ if args.decoder_normalize_before:
1008
+ self.layer_norm = LayerNorm(embed_dim)
1009
+ else:
1010
+ self.layer_norm = None
1011
+
1012
+ self.project_out_dim = (
1013
+ Linear(embed_dim, self.output_embed_dim, bias=False)
1014
+ if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
1015
+ else None
1016
+ )
1017
+
1018
+ self.adaptive_softmax = None
1019
+ self.output_projection = output_projection
1020
+ if self.output_projection is None:
1021
+ self.build_output_projection(args, dictionary, embed_tokens)
1022
+
1023
+ token_bucket_size = args.token_bucket_size
1024
+ token_num_rel_dis = 2 * token_bucket_size - 1
1025
+ token_rp_bucket = make_token_bucket_position(token_bucket_size)
1026
+ self.token_rel_pos_table_list = nn.ModuleList(
1027
+ [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
1028
+ )
1029
+
1030
+ image_bucket_size = args.image_bucket_size
1031
+ image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
1032
+ image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
1033
+ image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
1034
+ torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1
1035
+ image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
1036
+ image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)])
1037
+ self.image_rel_pos_table_list = nn.ModuleList(
1038
+ [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
1039
+ )
1040
+
1041
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
1042
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
1043
+ self.register_buffer("image_position_idx", image_position_idx)
1044
+ self.entangle_position_embedding = args.entangle_position_embedding
1045
+
1046
+ def build_output_projection(self, args, dictionary, embed_tokens):
1047
+ if args.adaptive_softmax_cutoff is not None:
1048
+ self.adaptive_softmax = AdaptiveSoftmax(
1049
+ len(dictionary),
1050
+ self.output_embed_dim,
1051
+ utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
1052
+ dropout=args.adaptive_softmax_dropout,
1053
+ adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
1054
+ factor=args.adaptive_softmax_factor,
1055
+ tie_proj=args.tie_adaptive_proj,
1056
+ )
1057
+ elif self.share_input_output_embed:
1058
+ self.output_projection = nn.Linear(
1059
+ self.embed_tokens.weight.shape[1],
1060
+ self.embed_tokens.weight.shape[0],
1061
+ bias=False,
1062
+ )
1063
+ self.output_projection.weight = self.embed_tokens.weight
1064
+ else:
1065
+ self.output_projection = nn.Linear(
1066
+ self.output_embed_dim, len(dictionary), bias=False
1067
+ )
1068
+ nn.init.normal_(
1069
+ self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
1070
+ )
1071
+ num_base_layers = getattr(args, "base_layers", 0)
1072
+ for i in range(num_base_layers):
1073
+ self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args))
1074
+
1075
+ def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0):
1076
+ layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate=drop_path_rate)
1077
+ checkpoint = getattr(args, "checkpoint_activations", False)
1078
+ if checkpoint:
1079
+ offload_to_cpu = getattr(args, "offload_activations", False)
1080
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
1081
+ # if we are checkpointing, enforce that FSDP always wraps the
1082
+ # checkpointed layer, regardless of layer size
1083
+ min_params_to_wrap = (
1084
+ getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
1085
+ if not checkpoint else 0
1086
+ )
1087
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
1088
+ return layer
1089
+
1090
+ def get_rel_pos_bias(self, x, idx):
1091
+ seq_len = x.size(1)
1092
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
1093
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
1094
+ values = values.permute([2, 0, 1])
1095
+ return values.contiguous()
1096
+
1097
+ def get_image_rel_pos_bias(self, x, idx):
1098
+ seq_len = x.size(1)
1099
+ image_position_idx = self.image_position_idx[:seq_len]
1100
+ rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
1101
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
1102
+ values = values.permute(2, 0, 1)
1103
+ return values
1104
+
1105
+ def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False):
1106
+ batch_size = tokens.size(0)
1107
+ tgt_len = tokens.size(1)
1108
+ tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
1109
+ if src_pos_embed is not None:
1110
+ src_len = src_pos_embed.size(1)
1111
+ pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
1112
+ batch_size, tgt_len, self.num_attention_heads, -1
1113
+ ).transpose(1, 2) * self.pos_scaling
1114
+ pos_k = self.cross_pos_k_linear(src_pos_embed).view(
1115
+ batch_size, src_len, self.num_attention_heads, -1
1116
+ ).transpose(1, 2)
1117
+ else:
1118
+ src_len = tgt_pos_embed.size(1)
1119
+ pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
1120
+ batch_size, tgt_len, self.num_attention_heads, -1
1121
+ ).transpose(1, 2) * self.pos_scaling
1122
+ pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
1123
+ batch_size, src_len, self.num_attention_heads, -1
1124
+ ).transpose(1, 2)
1125
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
1126
+ return abs_pos_bias
1127
+
1128
+ def forward(
1129
+ self,
1130
+ prev_output_tokens,
1131
+ code_masks: Optional[torch.Tensor] = None,
1132
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
1133
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
1134
+ features_only: bool = False,
1135
+ full_context_alignment: bool = False,
1136
+ alignment_layer: Optional[int] = None,
1137
+ alignment_heads: Optional[int] = None,
1138
+ src_lengths: Optional[Any] = None,
1139
+ return_all_hiddens: bool = False,
1140
+ ):
1141
+ """
1142
+ Args:
1143
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
1144
+ `(batch, tgt_len)`, for teacher forcing
1145
+ encoder_out (optional): output from the encoder, used for
1146
+ encoder-side attention, should be of size T x B x C
1147
+ incremental_state (dict): dictionary used for storing state during
1148
+ :ref:`Incremental decoding`
1149
+ features_only (bool, optional): only return features without
1150
+ applying output layer (default: False).
1151
+ full_context_alignment (bool, optional): don't apply
1152
+ auto-regressive mask to self-attention (default: False).
1153
+
1154
+ Returns:
1155
+ tuple:
1156
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
1157
+ - a dictionary with any model-specific outputs
1158
+ """
1159
+
1160
+ x, extra = self.extract_features(
1161
+ prev_output_tokens,
1162
+ code_masks=code_masks,
1163
+ encoder_out=encoder_out,
1164
+ incremental_state=incremental_state,
1165
+ full_context_alignment=full_context_alignment,
1166
+ alignment_layer=alignment_layer,
1167
+ alignment_heads=alignment_heads,
1168
+ )
1169
+
1170
+ if not features_only:
1171
+ x = self.output_layer(x)
1172
+ return x, extra
1173
+
1174
+ def extract_features(
1175
+ self,
1176
+ prev_output_tokens,
1177
+ code_masks: Optional[torch.Tensor],
1178
+ encoder_out: Optional[Dict[str, List[Tensor]]],
1179
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
1180
+ full_context_alignment: bool = False,
1181
+ alignment_layer: Optional[int] = None,
1182
+ alignment_heads: Optional[int] = None,
1183
+ ):
1184
+ return self.extract_features_scriptable(
1185
+ prev_output_tokens,
1186
+ code_masks,
1187
+ encoder_out,
1188
+ incremental_state,
1189
+ full_context_alignment,
1190
+ alignment_layer,
1191
+ alignment_heads,
1192
+ )
1193
+
1194
+ """
1195
+ A scriptable subclass of this class has an extract_features method and calls
1196
+ super().extract_features, but super() is not supported in torchscript. A copy of
1197
+ this function is made to be used in the subclass instead.
1198
+ """
1199
+
1200
+ def extract_features_scriptable(
1201
+ self,
1202
+ prev_output_tokens,
1203
+ code_masks: Optional[torch.Tensor],
1204
+ encoder_out: Optional[Dict[str, List[Tensor]]],
1205
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
1206
+ full_context_alignment: bool = False,
1207
+ alignment_layer: Optional[int] = None,
1208
+ alignment_heads: Optional[int] = None,
1209
+ ):
1210
+ """
1211
+ Similar to *forward* but only return features.
1212
+
1213
+ Includes several features from "Jointly Learning to Align and
1214
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
1215
+
1216
+ Args:
1217
+ full_context_alignment (bool, optional): don't apply
1218
+ auto-regressive mask to self-attention (default: False).
1219
+ alignment_layer (int, optional): return mean alignment over
1220
+ heads at this layer (default: last layer).
1221
+ alignment_heads (int, optional): only average alignment over
1222
+ this many heads (default: all heads).
1223
+
1224
+ Returns:
1225
+ tuple:
1226
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
1227
+ - a dictionary with any model-specific outputs
1228
+ """
1229
+ bs, slen = prev_output_tokens.size()
1230
+ if alignment_layer is None:
1231
+ alignment_layer = self.num_layers - 1
1232
+
1233
+ enc: Optional[Tensor] = None
1234
+ padding_mask: Optional[Tensor] = None
1235
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
1236
+ enc = encoder_out["encoder_out"][0]
1237
+ assert (
1238
+ enc.size()[1] == bs
1239
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
1240
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
1241
+ padding_mask = encoder_out["encoder_padding_mask"][0]
1242
+
1243
+ bsz, tgt_len = prev_output_tokens.shape
1244
+ token_position_idx = utils.new_arange(prev_output_tokens)
1245
+ tgt_pos_embed = self.embed_positions(token_position_idx)
1246
+ if code_masks is not None and torch.any(code_masks):
1247
+ image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len)
1248
+ tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
1249
+
1250
+ # self attn position bias
1251
+ self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False)
1252
+ if code_masks is not None and torch.any(code_masks):
1253
+ self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True)
1254
+ self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
1255
+ # cross attn position bias
1256
+ src_pos_embed = encoder_out['position_embeddings'][0]
1257
+ cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed)
1258
+ if code_masks is not None and torch.any(code_masks):
1259
+ cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
1260
+ cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
1261
+ cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
1262
+
1263
+ all_prev_output_tokens = prev_output_tokens.clone()
1264
+ if incremental_state is not None:
1265
+ prev_output_tokens = prev_output_tokens[:, -1:]
1266
+ cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
1267
+ tgt_pos_embed = tgt_pos_embed[:, -1:, :]
1268
+
1269
+ # embed tokens and positions
1270
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
1271
+
1272
+ if self.quant_noise is not None:
1273
+ x = self.quant_noise(x)
1274
+
1275
+ if self.project_in_dim is not None:
1276
+ x = self.project_in_dim(x)
1277
+
1278
+ if self.entangle_position_embedding is not None and not self.args.disable_entangle:
1279
+ x += tgt_pos_embed
1280
+
1281
+ if self.layernorm_embedding is not None:
1282
+ if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False):
1283
+ x = self.layernorm_embedding(x)
1284
+ elif code_masks is not None and code_masks.all():
1285
+ x = self.code_layernorm_embedding(x)
1286
+ else:
1287
+ x[~code_masks] = self.layernorm_embedding(x[~code_masks])
1288
+ x[code_masks] = self.code_layernorm_embedding(x[code_masks])
1289
+
1290
+ x = self.dropout_module(x)
1291
+
1292
+ # B x T x C -> T x B x C
1293
+ x = x.transpose(0, 1)
1294
+
1295
+ self_attn_padding_mask: Optional[Tensor] = None
1296
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
1297
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
1298
+
1299
+ # decoder layers
1300
+ attn: Optional[Tensor] = None
1301
+ inner_states: List[Optional[Tensor]] = [x]
1302
+ for idx, layer in enumerate(self.layers):
1303
+ if incremental_state is None and not full_context_alignment:
1304
+ self_attn_mask = self.buffered_future_mask(x)
1305
+ else:
1306
+ self_attn_mask = None
1307
+
1308
+ self_attn_bias = self_abs_pos_bias.clone()
1309
+ if code_masks is None or not code_masks.any():
1310
+ self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1311
+ elif code_masks is not None and code_masks.all():
1312
+ self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1313
+ else:
1314
+ self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1315
+ self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1316
+ self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
1317
+ if incremental_state is not None:
1318
+ self_attn_bias = self_attn_bias[:, -1:, :]
1319
+
1320
+ x, layer_attn, _ = layer(
1321
+ x,
1322
+ enc,
1323
+ padding_mask,
1324
+ incremental_state,
1325
+ self_attn_mask=self_attn_mask,
1326
+ self_attn_padding_mask=self_attn_padding_mask,
1327
+ need_attn=bool((idx == alignment_layer)),
1328
+ need_head_weights=bool((idx == alignment_layer)),
1329
+ self_attn_bias=self_attn_bias,
1330
+ cross_attn_bias=cross_abs_pos_bias
1331
+ )
1332
+ inner_states.append(x)
1333
+ if layer_attn is not None and idx == alignment_layer:
1334
+ attn = layer_attn.float().to(x)
1335
+
1336
+ if attn is not None:
1337
+ if alignment_heads is not None:
1338
+ attn = attn[:alignment_heads]
1339
+
1340
+ # average probabilities over heads
1341
+ attn = attn.mean(dim=0)
1342
+
1343
+ if self.layer_norm is not None:
1344
+ x = self.layer_norm(x)
1345
+
1346
+ # T x B x C -> B x T x C
1347
+ x = x.transpose(0, 1)
1348
+
1349
+ if self.project_out_dim is not None:
1350
+ x = self.project_out_dim(x)
1351
+
1352
+ return x, {"attn": [attn], "inner_states": inner_states}
1353
+
1354
+ def output_layer(self, features):
1355
+ """Project features to the vocabulary size."""
1356
+ if self.adaptive_softmax is None:
1357
+ # project back to size of vocabulary
1358
+ return self.output_projection(features)
1359
+ else:
1360
+ return features
1361
+
1362
+ def max_positions(self):
1363
+ """Maximum output length supported by the decoder."""
1364
+ if self.embed_positions is None:
1365
+ return self.max_target_positions
1366
+ return self.max_target_positions
1367
+
1368
+ def buffered_future_mask(self, tensor):
1369
+ dim = tensor.size(0)
1370
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
1371
+ if (
1372
+ self._future_mask.size(0) == 0
1373
+ or (not self._future_mask.device == tensor.device)
1374
+ or self._future_mask.size(0) < dim
1375
+ ):
1376
+ self._future_mask = torch.triu(
1377
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
1378
+ )
1379
+ self._future_mask = self._future_mask.to(tensor)
1380
+ return self._future_mask[:dim, :dim]
1381
+
1382
+ def upgrade_state_dict_named(self, state_dict, name):
1383
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
1384
+ if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
1385
+ weights_key = "{}.embed_positions.weights".format(name)
1386
+ if weights_key in state_dict:
1387
+ del state_dict[weights_key]
1388
+ state_dict[
1389
+ "{}.embed_positions._float_tensor".format(name)
1390
+ ] = torch.FloatTensor(1)
1391
+
1392
+ if f"{name}.output_projection.weight" not in state_dict:
1393
+ if self.share_input_output_embed:
1394
+ embed_out_key = f"{name}.embed_tokens.weight"
1395
+ else:
1396
+ embed_out_key = f"{name}.embed_out"
1397
+ if embed_out_key in state_dict:
1398
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
1399
+ embed_out_key
1400
+ ]
1401
+ if not self.share_input_output_embed:
1402
+ del state_dict[embed_out_key]
1403
+
1404
+ for i in range(self.num_layers):
1405
+ # update layer norms
1406
+ self.layers[i].upgrade_state_dict_named(
1407
+ state_dict, "{}.layers.{}".format(name, i)
1408
+ )
1409
+
1410
+ # version_key = "{}.version".format(name)
1411
+ # if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
1412
+ # # earlier checkpoints did not normalize after the stack of layers
1413
+ # self.layer_norm = None
1414
+ # self.normalize = False
1415
+ # state_dict[version_key] = torch.Tensor([1])
1416
+
1417
+ prefix = name + "." if name != "" else ""
1418
+ image_params = ["image_position_idx"]
1419
+ for image_param in image_params:
1420
+ state_dict[prefix + image_param] = self.state_dict()[image_param]
1421
+ for param_name, param_tensor in self.state_dict().items():
1422
+ if (prefix + param_name) not in state_dict and param_name in self.state_dict():
1423
+ state_dict[prefix + param_name] = self.state_dict()[param_name]
1424
+
1425
+ if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
1426
+ num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"])
1427
+ embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1)
1428
+ new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
1429
+ nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
1430
+ new_pos_embed_to_add = new_pos_embed_to_add.to(
1431
+ dtype=state_dict["decoder.embed_image_positions.weight"].dtype,
1432
+ )
1433
+ state_dict["decoder.embed_image_positions.weight"] = torch.cat(
1434
+ [state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add]
1435
+ )
1436
+ return state_dict
1437
+
1438
+
1439
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
1440
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
1441
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
1442
+ if padding_idx is not None:
1443
+ nn.init.constant_(m.weight[padding_idx], 0)
1444
+ if zero_init:
1445
+ nn.init.constant_(m.weight, 0)
1446
+ return m
1447
+
1448
+
1449
+ def Linear(in_features, out_features, bias=True):
1450
+ m = nn.Linear(in_features, out_features, bias)
1451
+ nn.init.xavier_uniform_(m.weight)
1452
+ if bias:
1453
+ nn.init.constant_(m.bias, 0.0)
1454
+ return m
1455
+
1456
+
1457
+ @register_model_architecture("unify_transformer", "unify_transformer")
1458
+ def base_architecture(args):
1459
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
1460
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
1461
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
1462
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
1463
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
1464
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1465
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
1466
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
1467
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
1468
+ args.decoder_ffn_embed_dim = getattr(
1469
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
1470
+ )
1471
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1472
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
1473
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1474
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
1475
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
1476
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
1477
+ args.activation_fn = getattr(args, "activation_fn", "relu")
1478
+ args.dropout = getattr(args, "dropout", 0.1)
1479
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
1480
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
1481
+ args.share_decoder_input_output_embed = getattr(
1482
+ args, "share_decoder_input_output_embed", False
1483
+ )
1484
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
1485
+ args.no_token_positional_embeddings = getattr(
1486
+ args, "no_token_positional_embeddings", False
1487
+ )
1488
+ args.adaptive_input = getattr(args, "adaptive_input", False)
1489
+ args.no_cross_attention = getattr(args, "no_cross_attention", False)
1490
+ args.cross_self_attention = getattr(args, "cross_self_attention", False)
1491
+
1492
+ args.decoder_output_dim = getattr(
1493
+ args, "decoder_output_dim", args.decoder_embed_dim
1494
+ )
1495
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
1496
+
1497
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
1498
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
1499
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
1500
+ args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
1501
+ args.offload_activations = getattr(args, "offload_activations", False)
1502
+ if args.offload_activations:
1503
+ args.checkpoint_activations = True
1504
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
1505
+ args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
1506
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
1507
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
1508
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
1509
+ args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
1510
+ args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
models/ofa/unify_transformer_layer.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from fairseq import utils
11
+ from fairseq.modules import LayerNorm
12
+ from fairseq.modules.fairseq_dropout import FairseqDropout
13
+ from fairseq.modules.quant_noise import quant_noise
14
+ from torch import Tensor
15
+
16
+ from .unify_multihead_attention import MultiheadAttention
17
+
18
+
19
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
20
+ """
21
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
22
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
23
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
24
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
25
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
26
+ argument.
27
+ """
28
+ if drop_prob == 0.0 or not training:
29
+ return x
30
+ keep_prob = 1 - drop_prob
31
+ shape = (1, x.shape[1], 1)
32
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33
+ random_tensor.floor_() # binarize
34
+ output = x.div(keep_prob) * random_tensor
35
+ return output
36
+
37
+
38
+ class DropPath(nn.Module):
39
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
40
+
41
+ def __init__(self, drop_prob=None):
42
+ super().__init__()
43
+ self.drop_prob = drop_prob
44
+
45
+ def forward(self, x):
46
+ return drop_path(x, self.drop_prob, self.training)
47
+
48
+ def extra_repr(self) -> str:
49
+ return "p={}".format(self.drop_prob)
50
+
51
+
52
+ class TransformerEncoderLayer(nn.Module):
53
+ """Encoder layer block.
54
+
55
+ In the original paper each operation (multi-head attention or FFN) is
56
+ postprocessed with: `dropout -> add residual -> layernorm`. In the
57
+ tensor2tensor code they suggest that learning is more robust when
58
+ preprocessing each layer with layernorm and postprocessing with:
59
+ `dropout -> add residual`. We default to the approach in the paper, but the
60
+ tensor2tensor approach can be enabled by setting
61
+ *args.encoder_normalize_before* to ``True``.
62
+
63
+ Args:
64
+ args (argparse.Namespace): parsed command-line arguments
65
+ """
66
+
67
+ def __init__(self, args, drop_path_rate=0.0):
68
+ super().__init__()
69
+ self.args = args
70
+ self.embed_dim = args.encoder_embed_dim
71
+ self.quant_noise = getattr(args, 'quant_noise_pq', 0)
72
+ self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
73
+ self.self_attn = self.build_self_attention(self.embed_dim, args)
74
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
75
+ self.dropout_module = FairseqDropout(
76
+ args.dropout, module_name=self.__class__.__name__
77
+ )
78
+ self.activation_fn = utils.get_activation_fn(
79
+ activation=getattr(args, 'activation_fn', 'relu') or "relu"
80
+ )
81
+ activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
82
+ if activation_dropout_p == 0:
83
+ # for backwards compatibility with models that use args.relu_dropout
84
+ activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
85
+ self.activation_dropout_module = FairseqDropout(
86
+ float(activation_dropout_p), module_name=self.__class__.__name__
87
+ )
88
+ self.normalize_before = args.encoder_normalize_before
89
+ self.fc1 = self.build_fc1(
90
+ self.embed_dim,
91
+ args.encoder_ffn_embed_dim,
92
+ self.quant_noise,
93
+ self.quant_noise_block_size,
94
+ )
95
+ self.fc2 = self.build_fc2(
96
+ args.encoder_ffn_embed_dim,
97
+ self.embed_dim,
98
+ self.quant_noise,
99
+ self.quant_noise_block_size,
100
+ )
101
+
102
+ self.attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
103
+ self.nh = self.self_attn.num_heads
104
+ self.head_dim = self.self_attn.head_dim
105
+
106
+ self.ffn_layernorm = LayerNorm(args.encoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None
107
+ self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None
108
+
109
+ self.final_layer_norm = LayerNorm(self.embed_dim)
110
+
111
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
112
+
113
+ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
114
+ return quant_noise(
115
+ nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
116
+ )
117
+
118
+ def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
119
+ return quant_noise(
120
+ nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
121
+ )
122
+
123
+ def build_self_attention(self, embed_dim, args):
124
+ return MultiheadAttention(
125
+ embed_dim,
126
+ args.encoder_attention_heads,
127
+ dropout=args.attention_dropout,
128
+ self_attention=True,
129
+ q_noise=self.quant_noise,
130
+ qn_block_size=self.quant_noise_block_size,
131
+ scale_factor=args.attn_scale_factor,
132
+ scale_heads=getattr(args, 'scale_heads', False)
133
+ )
134
+
135
+ def residual_connection(self, x, residual):
136
+ return residual + self.drop_path(x)
137
+
138
+ def upgrade_state_dict_named(self, state_dict, name):
139
+ """
140
+ Rename layer norm states from `...layer_norms.0.weight` to
141
+ `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
142
+ `...final_layer_norm.weight`
143
+ """
144
+ layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
145
+ for old, new in layer_norm_map.items():
146
+ for m in ("weight", "bias"):
147
+ k = "{}.layer_norms.{}.{}".format(name, old, m)
148
+ if k in state_dict:
149
+ state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
150
+ del state_dict[k]
151
+ if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict():
152
+ state_dict[
153
+ "{}.{}.{}".format(name, new, m)
154
+ ] = self.state_dict()["{}.{}".format(new, m)]
155
+
156
+ prefix = name + "." if name != "" else ""
157
+ for param_name, param_tensor in self.state_dict().items():
158
+ if (prefix + param_name) not in state_dict and param_name in self.state_dict():
159
+ state_dict[prefix + param_name] = self.state_dict()[param_name]
160
+
161
+ def forward(
162
+ self,
163
+ x,
164
+ encoder_padding_mask: Optional[Tensor],
165
+ attn_mask: Optional[Tensor] = None,
166
+ self_attn_bias: Optional[Tensor] = None
167
+ ):
168
+ """
169
+ Args:
170
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
171
+ encoder_padding_mask (ByteTensor): binary ByteTensor of shape
172
+ `(batch, seq_len)` where padding elements are indicated by ``1``.
173
+ attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
174
+ where `tgt_len` is the length of output and `src_len` is the
175
+ length of input, though here both are equal to `seq_len`.
176
+ `attn_mask[tgt_i, src_j] = 1` means that when calculating the
177
+ embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
178
+ useful for strided self-attention.
179
+
180
+ Returns:
181
+ encoded output of shape `(seq_len, batch, embed_dim)`
182
+ """
183
+ # anything in original attn_mask = 1, becomes -1e8
184
+ # anything in original attn_mask = 0, becomes 0
185
+ # Note that we cannot use -inf here, because at some edge cases,
186
+ # the attention weight (before softmax) for some padded element in query
187
+ # will become -inf, which results in NaN in model parameters
188
+ if attn_mask is not None:
189
+ attn_mask = attn_mask.masked_fill(
190
+ attn_mask.to(torch.bool),
191
+ -1e8 if x.dtype == torch.float32 else -1e4
192
+ )
193
+
194
+ residual = x
195
+ if self.normalize_before:
196
+ x = self.self_attn_layer_norm(x)
197
+ x, _ = self.self_attn(
198
+ query=x,
199
+ key=x,
200
+ value=x,
201
+ key_padding_mask=encoder_padding_mask,
202
+ need_weights=False,
203
+ attn_mask=attn_mask,
204
+ attn_bias=self_attn_bias
205
+ )
206
+ if self.attn_ln is not None:
207
+ x = self.attn_ln(x)
208
+ x = self.dropout_module(x)
209
+ x = self.residual_connection(x, residual)
210
+ if not self.normalize_before:
211
+ x = self.self_attn_layer_norm(x)
212
+
213
+ residual = x
214
+ if self.normalize_before:
215
+ x = self.final_layer_norm(x)
216
+ x = self.activation_fn(self.fc1(x))
217
+ x = self.activation_dropout_module(x)
218
+ if self.ffn_layernorm is not None:
219
+ x = self.ffn_layernorm(x)
220
+ x = self.fc2(x)
221
+ x = self.dropout_module(x)
222
+ if self.w_resid is not None:
223
+ residual = torch.mul(self.w_resid, residual)
224
+ x = self.residual_connection(x, residual)
225
+ if not self.normalize_before:
226
+ x = self.final_layer_norm(x)
227
+ return x
228
+
229
+
230
+ class TransformerDecoderLayer(nn.Module):
231
+ """Decoder layer block.
232
+
233
+ In the original paper each operation (multi-head attention, encoder
234
+ attention or FFN) is postprocessed with: `dropout -> add residual ->
235
+ layernorm`. In the tensor2tensor code they suggest that learning is more
236
+ robust when preprocessing each layer with layernorm and postprocessing with:
237
+ `dropout -> add residual`. We default to the approach in the paper, but the
238
+ tensor2tensor approach can be enabled by setting
239
+ *args.decoder_normalize_before* to ``True``.
240
+
241
+ Args:
242
+ args (argparse.Namespace): parsed command-line arguments
243
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
244
+ (default: False).
245
+ """
246
+
247
+ def __init__(
248
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, drop_path_rate=0.0
249
+ ):
250
+ super().__init__()
251
+ self.embed_dim = args.decoder_embed_dim
252
+ self.dropout_module = FairseqDropout(
253
+ args.dropout, module_name=self.__class__.__name__
254
+ )
255
+ self.quant_noise = getattr(args, "quant_noise_pq", 0)
256
+ self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
257
+
258
+ self.cross_self_attention = getattr(args, "cross_self_attention", False)
259
+
260
+ self.self_attn = self.build_self_attention(
261
+ self.embed_dim,
262
+ args,
263
+ add_bias_kv=add_bias_kv,
264
+ add_zero_attn=add_zero_attn,
265
+ )
266
+ self.self_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
267
+ self.cross_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None
268
+ self.nh = self.self_attn.num_heads
269
+ self.head_dim = self.self_attn.head_dim
270
+
271
+ self.activation_fn = utils.get_activation_fn(
272
+ activation=str(args.activation_fn)
273
+ if getattr(args, "activation_fn", None) is not None
274
+ else "relu"
275
+ )
276
+ activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
277
+ if activation_dropout_p == 0:
278
+ # for backwards compatibility with models that use args.relu_dropout
279
+ activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
280
+ self.activation_dropout_module = FairseqDropout(
281
+ float(activation_dropout_p), module_name=self.__class__.__name__
282
+ )
283
+ self.normalize_before = args.decoder_normalize_before
284
+
285
+ # use layerNorm rather than FusedLayerNorm for exporting.
286
+ # char_inputs can be used to determint this.
287
+ # TODO remove this once we update apex with the fix
288
+ export = getattr(args, "char_inputs", False)
289
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
290
+
291
+ if no_encoder_attn:
292
+ self.encoder_attn = None
293
+ self.encoder_attn_layer_norm = None
294
+ else:
295
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
296
+ self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
297
+
298
+ self.ffn_layernorm = LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None
299
+ self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None
300
+
301
+ self.fc1 = self.build_fc1(
302
+ self.embed_dim,
303
+ args.decoder_ffn_embed_dim,
304
+ self.quant_noise,
305
+ self.quant_noise_block_size,
306
+ )
307
+ self.fc2 = self.build_fc2(
308
+ args.decoder_ffn_embed_dim,
309
+ self.embed_dim,
310
+ self.quant_noise,
311
+ self.quant_noise_block_size,
312
+ )
313
+
314
+ self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
315
+ self.need_attn = True
316
+
317
+ self.onnx_trace = False
318
+
319
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
320
+
321
+ def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
322
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
323
+
324
+ def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
325
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
326
+
327
+ def build_self_attention(
328
+ self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
329
+ ):
330
+ return MultiheadAttention(
331
+ embed_dim,
332
+ args.decoder_attention_heads,
333
+ dropout=args.attention_dropout,
334
+ add_bias_kv=add_bias_kv,
335
+ add_zero_attn=add_zero_attn,
336
+ self_attention=not getattr(args, "cross_self_attention", False),
337
+ q_noise=self.quant_noise,
338
+ qn_block_size=self.quant_noise_block_size,
339
+ scale_factor=args.attn_scale_factor,
340
+ scale_heads=getattr(args, 'scale_heads', False)
341
+ )
342
+
343
+ def build_encoder_attention(self, embed_dim, args):
344
+ return MultiheadAttention(
345
+ embed_dim,
346
+ args.decoder_attention_heads,
347
+ kdim=getattr(args, "encoder_embed_dim", None),
348
+ vdim=getattr(args, "encoder_embed_dim", None),
349
+ dropout=args.attention_dropout,
350
+ encoder_decoder_attention=True,
351
+ q_noise=self.quant_noise,
352
+ qn_block_size=self.quant_noise_block_size,
353
+ scale_factor=args.attn_scale_factor,
354
+ scale_heads=getattr(args, 'scale_heads', False)
355
+ )
356
+
357
+ def prepare_for_onnx_export_(self):
358
+ self.onnx_trace = True
359
+
360
+ def residual_connection(self, x, residual):
361
+ return residual + self.drop_path(x)
362
+
363
+ def forward(
364
+ self,
365
+ x,
366
+ encoder_out: Optional[torch.Tensor] = None,
367
+ encoder_padding_mask: Optional[torch.Tensor] = None,
368
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
369
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
370
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
371
+ self_attn_mask: Optional[torch.Tensor] = None,
372
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
373
+ need_attn: bool = False,
374
+ need_head_weights: bool = False,
375
+ self_attn_bias: Optional[Tensor] = None,
376
+ cross_attn_bias: Optional[Tensor] = None
377
+ ):
378
+ """
379
+ Args:
380
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
381
+ encoder_padding_mask (ByteTensor, optional): binary
382
+ ByteTensor of shape `(batch, src_len)` where padding
383
+ elements are indicated by ``1``.
384
+ need_attn (bool, optional): return attention weights
385
+ need_head_weights (bool, optional): return attention weights
386
+ for each head (default: return average over heads).
387
+
388
+ Returns:
389
+ encoded output of shape `(seq_len, batch, embed_dim)`
390
+ """
391
+ if need_head_weights:
392
+ need_attn = True
393
+
394
+ residual = x
395
+ if self.normalize_before:
396
+ x = self.self_attn_layer_norm(x)
397
+ if prev_self_attn_state is not None:
398
+ prev_key, prev_value = prev_self_attn_state[:2]
399
+ saved_state: Dict[str, Optional[Tensor]] = {
400
+ "prev_key": prev_key,
401
+ "prev_value": prev_value,
402
+ }
403
+ if len(prev_self_attn_state) >= 3:
404
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
405
+ assert incremental_state is not None
406
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
407
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
408
+ if self.cross_self_attention and not (
409
+ incremental_state is not None
410
+ and _self_attn_input_buffer is not None
411
+ and "prev_key" in _self_attn_input_buffer
412
+ ):
413
+ if self_attn_mask is not None:
414
+ assert encoder_out is not None
415
+ self_attn_mask = torch.cat(
416
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
417
+ )
418
+ if self_attn_padding_mask is not None:
419
+ if encoder_padding_mask is None:
420
+ assert encoder_out is not None
421
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
422
+ encoder_out.size(1), encoder_out.size(0)
423
+ )
424
+ self_attn_padding_mask = torch.cat(
425
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
426
+ )
427
+ assert encoder_out is not None
428
+ y = torch.cat((encoder_out, x), dim=0)
429
+ else:
430
+ y = x
431
+
432
+ x, attn = self.self_attn(
433
+ query=x,
434
+ key=y,
435
+ value=y,
436
+ key_padding_mask=self_attn_padding_mask,
437
+ incremental_state=incremental_state,
438
+ need_weights=False,
439
+ attn_mask=self_attn_mask,
440
+ attn_bias=self_attn_bias
441
+ )
442
+ if self.self_attn_ln is not None:
443
+ x = self.self_attn_ln(x)
444
+ x = self.dropout_module(x)
445
+ x = self.residual_connection(x, residual)
446
+ if not self.normalize_before:
447
+ x = self.self_attn_layer_norm(x)
448
+
449
+ if self.encoder_attn is not None and encoder_out is not None:
450
+ residual = x
451
+ if self.normalize_before:
452
+ x = self.encoder_attn_layer_norm(x)
453
+ if prev_attn_state is not None:
454
+ prev_key, prev_value = prev_attn_state[:2]
455
+ saved_state: Dict[str, Optional[Tensor]] = {
456
+ "prev_key": prev_key,
457
+ "prev_value": prev_value,
458
+ }
459
+ if len(prev_attn_state) >= 3:
460
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
461
+ assert incremental_state is not None
462
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
463
+
464
+ x, attn = self.encoder_attn(
465
+ query=x,
466
+ key=encoder_out,
467
+ value=encoder_out,
468
+ key_padding_mask=encoder_padding_mask,
469
+ incremental_state=incremental_state,
470
+ static_kv=True,
471
+ need_weights=need_attn or (not self.training and self.need_attn),
472
+ need_head_weights=need_head_weights,
473
+ attn_bias=cross_attn_bias
474
+ )
475
+ if self.cross_attn_ln is not None:
476
+ x = self.cross_attn_ln(x)
477
+ x = self.dropout_module(x)
478
+ x = self.residual_connection(x, residual)
479
+ if not self.normalize_before:
480
+ x = self.encoder_attn_layer_norm(x)
481
+
482
+ residual = x
483
+ if self.normalize_before:
484
+ x = self.final_layer_norm(x)
485
+
486
+ x = self.activation_fn(self.fc1(x))
487
+ x = self.activation_dropout_module(x)
488
+ if self.ffn_layernorm is not None:
489
+ x = self.ffn_layernorm(x)
490
+ x = self.fc2(x)
491
+ x = self.dropout_module(x)
492
+ if self.w_resid is not None:
493
+ residual = torch.mul(self.w_resid, residual)
494
+ x = self.residual_connection(x, residual)
495
+ if not self.normalize_before:
496
+ x = self.final_layer_norm(x)
497
+ if self.onnx_trace and incremental_state is not None:
498
+ saved_state = self.self_attn._get_input_buffer(incremental_state)
499
+ assert saved_state is not None
500
+ if self_attn_padding_mask is not None:
501
+ self_attn_state = [
502
+ saved_state["prev_key"],
503
+ saved_state["prev_value"],
504
+ saved_state["prev_key_padding_mask"],
505
+ ]
506
+ else:
507
+ self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
508
+ return x, attn, self_attn_state
509
+ return x, attn, None
510
+
511
+ def make_generation_fast_(self, need_attn: bool = False, **kwargs):
512
+ self.need_attn = need_attn
513
+
514
+ def upgrade_state_dict_named(self, state_dict, name):
515
+ """
516
+ Rename layer norm states from `...layer_norms.0.weight` to
517
+ `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
518
+ `...final_layer_norm.weight`
519
+ """
520
+ # update layer norms
521
+ layer_norm_map = {
522
+ "0": "self_attn_layer_norm",
523
+ "1": "encoder_attn_layer_norm",
524
+ "2": "final_layer_norm",
525
+ }
526
+ for old, new in layer_norm_map.items():
527
+ for m in ("weight", "bias"):
528
+ k = "{}.layer_norms.{}.{}".format(name, old, m)
529
+ if k in state_dict:
530
+ state_dict[
531
+ "{}.{}.{}".format(name, new, m)
532
+ ] = state_dict[k]
533
+ del state_dict[k]
534
+ if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict():
535
+ state_dict[
536
+ "{}.{}.{}".format(name, new, m)
537
+ ] = self.state_dict()["{}.{}".format(new, m)]
538
+
539
+ prefix = name + "." if name != "" else ""
540
+ for param_name, param_tensor in self.state_dict().items():
541
+ if (prefix + param_name) not in state_dict and param_name in self.state_dict():
542
+ state_dict[prefix + param_name] = self.state_dict()[param_name]
models/search.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from fairseq.token_generation_constraints import (
12
+ ConstraintState,
13
+ OrderedConstraintState,
14
+ UnorderedConstraintState,
15
+ )
16
+ from torch import Tensor
17
+
18
+
19
+ class Search(nn.Module):
20
+ def __init__(self, tgt_dict):
21
+ super().__init__()
22
+ self.pad = tgt_dict.pad()
23
+ self.unk = tgt_dict.unk()
24
+ self.eos = tgt_dict.eos()
25
+ self.vocab_size = len(tgt_dict)
26
+ self.src_lengths = torch.tensor(-1)
27
+ self.supports_constraints = False
28
+ self.stop_on_max_len = False
29
+
30
+ def step(
31
+ self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None
32
+ ):
33
+ """Take a single search step.
34
+
35
+ Args:
36
+ step: the current search step, starting at 0
37
+ lprobs: (bsz x input_beam_size x vocab_size)
38
+ the model's log-probabilities over the vocabulary at the current step
39
+ scores: (bsz x input_beam_size x step)
40
+ the historical model scores of each hypothesis up to this point
41
+ prev_output_tokens: (bsz x step)
42
+ the previously generated oputput tokens
43
+ original_batch_idxs: (bsz)
44
+ the tensor with the batch indices, in the range [0, bsz)
45
+ this is useful in case there has been applied a re-ordering
46
+ and we need to know the orignal indices
47
+
48
+ Return: A tuple of (scores, indices, beams) where:
49
+ scores: (bsz x output_beam_size)
50
+ the scores of the chosen elements; output_beam_size can be
51
+ larger than input_beam_size, e.g., we may return
52
+ 2*input_beam_size to account for EOS
53
+ indices: (bsz x output_beam_size)
54
+ the indices of the chosen elements
55
+ beams: (bsz x output_beam_size)
56
+ the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
57
+ """
58
+ raise NotImplementedError
59
+
60
+ @torch.jit.export
61
+ def set_src_lengths(self, src_lengths):
62
+ self.src_lengths = src_lengths
63
+
64
+ @torch.jit.export
65
+ def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
66
+ """Initialize constraint states for constrained decoding (if supported).
67
+
68
+ Args:
69
+ batch_constraints: (torch.Tensor, optional)
70
+ the list of constraints, in packed form
71
+ beam_size: (int)
72
+ the beam size
73
+ Returns:
74
+ *encoder_out* rearranged according to *new_order*
75
+ """
76
+ pass
77
+
78
+ def prune_sentences(self, batch_idxs: Tensor):
79
+ """
80
+ Removes constraint states for completed sentences (if supported).
81
+ This is called from sequence_generator._generate() when sentences are
82
+ deleted from the batch.
83
+
84
+ Args:
85
+ batch_idxs: Indices of *sentences* whose constraint state should be *kept*.
86
+ """
87
+ pass
88
+
89
+ def update_constraints(self, active_hypos: Tensor):
90
+ """
91
+ Updates the constraint states by selecting the beam items that are retained.
92
+ This is called at each time step of sequence_generator._generate() when
93
+ the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size.
94
+
95
+ Args:
96
+ active_hypos: (batch size, beam size)
97
+ list of integers denoting, for each sentence, which beam candidate items
98
+ should be kept.
99
+ """
100
+ pass
101
+
102
+
103
+ class BeamSearch(Search):
104
+ def __init__(self, tgt_dict):
105
+ super().__init__(tgt_dict)
106
+ self.constraint_states = None
107
+
108
+ @torch.jit.export
109
+ def step(
110
+ self,
111
+ step: int,
112
+ lprobs,
113
+ scores: Optional[Tensor],
114
+ prev_output_tokens: Optional[Tensor] = None,
115
+ original_batch_idxs: Optional[Tensor] = None,
116
+ ):
117
+ bsz, beam_size, vocab_size = lprobs.size()
118
+
119
+ if step == 0:
120
+ # at the first step all hypotheses are equally likely, so use
121
+ # only the first beam
122
+ lprobs = lprobs[:, ::beam_size, :].contiguous()
123
+ else:
124
+ # make probs contain cumulative scores for each hypothesis
125
+ assert scores is not None
126
+ lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
127
+
128
+ top_prediction = torch.topk(
129
+ lprobs.view(bsz, -1),
130
+ k=min(
131
+ # Take the best 2 x beam_size predictions. We'll choose the first
132
+ # beam_size of these which don't predict eos to continue with.
133
+ beam_size * 2,
134
+ lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
135
+ ),
136
+ )
137
+ scores_buf = top_prediction[0]
138
+ indices_buf = top_prediction[1]
139
+ # Project back into relative indices and beams
140
+ beams_buf = indices_buf // vocab_size
141
+ indices_buf = indices_buf.fmod(vocab_size)
142
+
143
+ # At this point, beams_buf and indices_buf are single-dim and contain relative indices
144
+ return scores_buf, indices_buf, beams_buf
145
+
146
+
147
+ class PrefixConstrainedBeamSearch(Search):
148
+ def __init__(self, tgt_dict, prefix_allowed_tokens_fn):
149
+ super().__init__(tgt_dict)
150
+ self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
151
+ self.stop_on_max_len = True
152
+
153
+ @torch.jit.export
154
+ def apply_mask(self, x, prev_output_tokens, original_batch_idxs):
155
+ beam_size = x.shape[0] // original_batch_idxs.shape[0]
156
+ original_batch_idxs = (
157
+ original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist()
158
+ )
159
+
160
+ mask = torch.full_like(x, -math.inf)
161
+ for sent_i, (sent, batch_i) in enumerate(
162
+ zip(prev_output_tokens, original_batch_idxs)
163
+ ):
164
+ mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0
165
+
166
+ return mask
167
+
168
+ @torch.jit.export
169
+ def step(
170
+ self,
171
+ step: int,
172
+ lprobs: Tensor,
173
+ scores: Tensor,
174
+ prev_output_tokens: Tensor,
175
+ original_batch_idxs: Tensor,
176
+ ):
177
+ bsz, beam_size, vocab_size = lprobs.size()
178
+
179
+ lprobs += self.apply_mask(
180
+ lprobs.view(bsz * beam_size, 1, vocab_size),
181
+ prev_output_tokens,
182
+ original_batch_idxs,
183
+ ).view(bsz, beam_size, vocab_size)
184
+
185
+ if step == 0:
186
+ # at the first step all hypotheses are equally likely, so use
187
+ # only the first beam
188
+ lprobs = lprobs[:, ::beam_size, :].contiguous()
189
+ else:
190
+ # make probs contain cumulative scores for each hypothesis
191
+ assert scores is not None
192
+ lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
193
+
194
+ top_prediction = torch.topk(
195
+ lprobs.view(bsz, -1),
196
+ k=min(
197
+ # Take the best beam_size predictions. We'll choose the first
198
+ # beam_size of these which don't predict eos to continue with.
199
+ beam_size,
200
+ lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
201
+ ),
202
+ )
203
+ scores_buf = top_prediction[0]
204
+ indices_buf = top_prediction[1]
205
+ beams_buf = indices_buf // vocab_size
206
+ indices_buf = indices_buf.fmod(vocab_size)
207
+ return scores_buf, indices_buf, beams_buf
208
+
209
+
210
+ class LexicallyConstrainedBeamSearch(Search):
211
+ """Implements lexically constrained beam search as described in
212
+
213
+ Fast Lexically Constrained Decoding with Dynamic Beam
214
+ Allocation for Neural Machine Translation. Post & Vilar,
215
+ NAACL 2018. https://www.aclweb.org/anthology/N18-1119/
216
+
217
+ and
218
+
219
+ Improved Lexically Constrained Decoding for Translation and
220
+ Monolingual Rewriting. Hu et al, NAACL
221
+ 2019. https://www.aclweb.org/anthology/N19-1090/
222
+
223
+ This is accomplished by maintaining, for each beam hypothesis, a
224
+ ConstraintState object (see constraints.py) that tracks which
225
+ constraints have been generated and using this information to
226
+ shape the beam for each input sentence.
227
+ """
228
+
229
+ def __init__(self, tgt_dict, representation):
230
+ super().__init__(tgt_dict)
231
+ self.representation = representation
232
+ self.vocab_size = len(tgt_dict)
233
+ self.num_cands = 0
234
+ self.supports_constraints = True
235
+
236
+ @torch.jit.export
237
+ def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int):
238
+ self.constraint_states = []
239
+ for constraint_tensor in batch_constraints:
240
+ if self.representation == "ordered":
241
+ constraint_state = OrderedConstraintState.create(constraint_tensor)
242
+ elif self.representation == "unordered":
243
+ constraint_state = UnorderedConstraintState.create(constraint_tensor)
244
+
245
+ self.constraint_states.append([constraint_state for i in range(beam_size)])
246
+
247
+ @torch.jit.export
248
+ def prune_sentences(self, batch_idxs: Tensor):
249
+ self.constraint_states = [
250
+ self.constraint_states[i] for i in batch_idxs.tolist()
251
+ ]
252
+
253
+ @torch.jit.export
254
+ def update_constraints(self, active_hypos: Tensor):
255
+ if self.constraint_states:
256
+ batch_size = active_hypos.size(0)
257
+ for sentid in range(batch_size):
258
+ self.constraint_states[sentid] = [
259
+ self.constraint_states[sentid][i] for i in active_hypos[sentid]
260
+ ]
261
+
262
+ @torch.jit.export
263
+ def step(
264
+ self,
265
+ step: int,
266
+ lprobs: Tensor,
267
+ scores: Optional[Tensor],
268
+ prev_output_tokens: Optional[Tensor] = None,
269
+ original_batch_idxs: Optional[Tensor] = None,
270
+ ):
271
+ """
272
+ A constrained step builds a large candidates list from the following:
273
+ - the top 2 * {beam_size} items over the whole beam
274
+ - for each item in the beam
275
+ - the top {each_k} (default 1)
276
+ - all next constraints
277
+ We then compute the constrained state of each beam item, and assign
278
+ stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so
279
+ on. We then sort by (stripe, score), and truncate the list at
280
+ 2 * beam size.
281
+
282
+ Args:
283
+ step: the decoder step
284
+ lprobs: (batch size, beam size, target vocab)
285
+ the target-vocab distributions for each item in the beam.
286
+ Retrun: A tuple of (scores, indices, beams, constraints) where:
287
+ scores: (batch, output beam size)
288
+ the scores of the chosen elements
289
+ indices: (batch, output beam size)
290
+ the target vocab indices of the chosen elements
291
+ beams: (batch, output beam size)
292
+ the 0-indexed hypothesis ids of the chosen elements
293
+ constraints: (batch, output beam size)
294
+ the new constraint states
295
+ """
296
+ each_k = 1
297
+ device = lprobs.device
298
+
299
+ batch_size, beam_size, vocab_size = lprobs.size()
300
+
301
+ self.num_cands = min(
302
+ # Just take the k-best. We'll get another k from the 1-best from each
303
+ # row, plus more from the constraints
304
+ beam_size * 2,
305
+ lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad
306
+ )
307
+
308
+ # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items
309
+ constraint_states = self.constraint_states
310
+ if constraint_states and step > 0:
311
+ not_finished_indices = []
312
+ for sentno, sent_constraints in enumerate(constraint_states):
313
+ for beamno, state in enumerate(sent_constraints):
314
+ index = sentno * beam_size + beamno
315
+ if not state.finished:
316
+ not_finished_indices.append(index)
317
+ not_finished_indices = torch.tensor(not_finished_indices)
318
+ if not_finished_indices.numel() > 0:
319
+ lprobs.view(batch_size * beam_size, -1)[
320
+ not_finished_indices, self.eos
321
+ ] = -math.inf
322
+
323
+ if step == 0:
324
+ # at the first step all hypotheses are equally likely, so use
325
+ # only the first beam entry for each batch item
326
+ lprobs = lprobs[:, ::beam_size, :].contiguous()
327
+ else:
328
+ # make probs contain cumulative scores for each hypothesis
329
+ assert scores is not None
330
+ lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1)
331
+
332
+ top_prediction = torch.topk(
333
+ lprobs.view(batch_size, -1),
334
+ self.num_cands,
335
+ )
336
+ scores_buf, indices_buf = top_prediction
337
+ # Project back into relative indices and beams
338
+ beams_buf = indices_buf // vocab_size
339
+ indices_buf = indices_buf.fmod(vocab_size)
340
+
341
+ # Short circuit if there are no constraints in this batch
342
+ if not constraint_states:
343
+ return scores_buf, indices_buf, beams_buf
344
+
345
+ # STEP 1: get top-1 from each hypothesis across all sentences in the batch
346
+ if step > 0:
347
+ top_scores, top_indices = torch.topk(
348
+ lprobs.view(batch_size * beam_size, -1),
349
+ k=each_k,
350
+ dim=1,
351
+ )
352
+ top_scores = top_scores.view(batch_size, -1)
353
+ top_indices = top_indices.view(batch_size, -1)
354
+ scores_buf = torch.cat((scores_buf, top_scores), dim=1)
355
+ indices_buf = torch.cat((indices_buf, top_indices), dim=1)
356
+ new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1)
357
+ beams_buf = torch.cat((beams_buf, new_beams), dim=1)
358
+
359
+ # Now, process sentences in the batch one by one.
360
+ new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device)
361
+ new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
362
+ new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long()
363
+ for sentno, states in enumerate(constraint_states):
364
+ scores, indices, beams, new_states = self.step_sentence(
365
+ step,
366
+ sentno,
367
+ lprobs[sentno],
368
+ constraint_states[sentno],
369
+ beams_buf[sentno].clone(),
370
+ indices_buf[sentno].clone(),
371
+ scores_buf[sentno].clone(),
372
+ )
373
+ new_scores_buf[sentno] = scores
374
+ new_indices_buf[sentno] = indices
375
+ new_beams_buf[sentno] = beams
376
+ self.constraint_states[sentno] = new_states
377
+
378
+ return new_scores_buf, new_indices_buf, new_beams_buf
379
+
380
+ @torch.jit.export
381
+ def step_sentence(
382
+ self,
383
+ step: int,
384
+ sentno: int,
385
+ lprobs: Tensor,
386
+ constraint_states: List[List[ConstraintState]],
387
+ beams_buf: Tensor,
388
+ indices_buf: Tensor,
389
+ scores_buf: Tensor,
390
+ ):
391
+ """Does per-sentence processing. Adds all constraints for each
392
+ hypothesis to the list of candidates; then removes duplicates,
393
+ sorts, and dynamically stripes across the banks. All tensor inputs
394
+ are collapsed to those pertaining to a single input sentence.
395
+ """
396
+ device = lprobs.device
397
+
398
+ # STEP 2: Add all constraints for each beam item
399
+ for beamno, state in enumerate(constraint_states):
400
+ next_tokens = torch.tensor(list(state.next_tokens()), device=device).long()
401
+ if next_tokens.numel() != 0:
402
+ indices_buf = torch.cat((indices_buf, next_tokens))
403
+ next_beams = (
404
+ torch.tensor(beamno, device=device)
405
+ .repeat(next_tokens.size(0))
406
+ .long()
407
+ )
408
+ beams_buf = torch.cat((beams_buf, next_beams))
409
+ next_values = lprobs[beamno].take(next_tokens.view(-1))
410
+ scores_buf = torch.cat((scores_buf, next_values))
411
+
412
+ # At the 0th time step, there is just one beam item
413
+ if step == 0:
414
+ break
415
+
416
+ # STEP 3: Compute the "bank" for each candidate. This is the
417
+ # number of constraints it's generated. We need this so that
418
+ # we can do round-robin allocation of the beam across these
419
+ # banks. If C is the number of constraints, we select the best
420
+ # item in bank C, then the best in bank C-1, etc, followed by
421
+ # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
422
+ # on, until the maximum beam size. We accomplish this by
423
+ # creating a sort key and striping across the banks.
424
+
425
+ # Compute the new states for all candidates
426
+ cands_size = indices_buf.size(0)
427
+ constraint_states = [
428
+ constraint_states[beams_buf[i]].advance(indices_buf[i])
429
+ for i in range(cands_size)
430
+ ]
431
+
432
+ banks = torch.tensor([state.bank for state in constraint_states], device=device)
433
+
434
+ # STEP 4: Sort
435
+ num_constraint_tokens = len(state.tokens)
436
+
437
+ # Sort by keys (bank, score) (i.e., sort banks together, and scores
438
+ # within banks). AFAIK pytorch doesn't support either stable sort or
439
+ # multi-key sorting, so we have to hack this.
440
+ MAX_SCORE = -100
441
+ sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf
442
+ sort_values, sort_indices = sort_key.sort(dim=0, descending=True)
443
+ scores_buf = scores_buf[sort_indices]
444
+ indices_buf = indices_buf[sort_indices]
445
+ beams_buf = beams_buf[sort_indices]
446
+ banks = banks[sort_indices]
447
+
448
+ # Sort the constraints to follow suit
449
+ constraint_states = [constraint_states[i] for i in sort_indices]
450
+
451
+ # STEP 5: Remove duplicates. The topk calls (overall and
452
+ # per-row) plus the per-row generation of constraints will
453
+ # produce duplicates. Here we remove them.
454
+
455
+ def roll(t):
456
+ """Rolls a 1d tensor left by 1.
457
+
458
+ [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3]
459
+ """
460
+ return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0)
461
+
462
+ # We map candidates (beam, token_id) to a single dimension.
463
+ # This is then shifted by 1. We can then easily identify
464
+ # duplicates and create a mask that identifies unique
465
+ # extensions.
466
+ uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf
467
+ uniques_mask = roll(uniques_mask) != uniques_mask
468
+
469
+ # Use the mask to pare down the data structures
470
+ scores_buf = torch.masked_select(scores_buf, uniques_mask)
471
+ indices_buf = torch.masked_select(indices_buf, uniques_mask)
472
+ beams_buf = torch.masked_select(beams_buf, uniques_mask)
473
+ banks = torch.masked_select(banks, uniques_mask)
474
+ i = 1
475
+ for mask in uniques_mask[1:]:
476
+ if not mask:
477
+ constraint_states.pop(i)
478
+ i += mask
479
+
480
+ # STEP 6: Assign IDs round-robin across banks, sort, and
481
+ # truncate. Now that the candidates are sorted by (bank,
482
+ # score) and uniqed, we dynamically allocate the {beam_size}
483
+ # beam by striping across the candidates. These stripes will
484
+ # be used as sort keys to do round-robin selection. This is
485
+ # accomplished in a single pass with offsets. Sorting by
486
+ # highest-banks (furthest-along hypotheses) first ensures
487
+ # progress through the constraints.
488
+ #
489
+ # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
490
+ # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
491
+ # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
492
+ # = 0 5 10 1 6 11 13 2 7 12 3 8
493
+ #
494
+ # Sorting by this then gives the following banks:
495
+ #
496
+ # 3 2 1 0 3 2 1 0 3 2 1 2
497
+ #
498
+ # We'll take the top {beam_size} of these.
499
+ stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)]
500
+ stripes = torch.zeros_like(banks)
501
+ cur_bank_count = -1
502
+ cur_bank = banks[0]
503
+ for i, bank in enumerate(banks):
504
+ if bank != cur_bank:
505
+ cur_bank_count = 0
506
+ cur_bank = bank
507
+ else:
508
+ cur_bank_count += 1
509
+ stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count]
510
+
511
+ # STEP 7: Sort by the stripes values
512
+ sort_values, sort_indices = stripes.sort(dim=0)
513
+ scores_buf = scores_buf[sort_indices]
514
+ indices_buf = indices_buf[sort_indices]
515
+ beams_buf = beams_buf[sort_indices]
516
+ constraint_states = [constraint_states[i] for i in sort_indices]
517
+
518
+ # STEP 8: Truncate to the candidates size!
519
+ scores_buf = scores_buf[: self.num_cands]
520
+ indices_buf = indices_buf[: self.num_cands]
521
+ beams_buf = beams_buf[: self.num_cands]
522
+
523
+ return scores_buf, indices_buf, beams_buf, constraint_states
524
+
525
+
526
+ class LengthConstrainedBeamSearch(Search):
527
+ def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
528
+ super().__init__(tgt_dict)
529
+ self.min_len_a = min_len_a
530
+ self.min_len_b = min_len_b
531
+ self.max_len_a = max_len_a
532
+ self.max_len_b = max_len_b
533
+ self.beam = BeamSearch(tgt_dict)
534
+ self.needs_src_lengths = True
535
+
536
+ def step(
537
+ self,
538
+ step: int,
539
+ lprobs,
540
+ scores,
541
+ prev_output_tokens: Optional[Tensor] = None,
542
+ original_batch_idxs: Optional[Tensor] = None,
543
+ ):
544
+ min_lens = self.min_len_a * self.src_lengths + self.min_len_b
545
+ max_lens = self.max_len_a * self.src_lengths + self.max_len_b
546
+ lprobs[step < min_lens, :, self.eos] = -math.inf
547
+ lprobs[step >= max_lens, :, self.eos] = 0
548
+ return self.beam.step(step, lprobs, scores)
549
+
550
+
551
+ class DiverseBeamSearch(Search):
552
+ """Diverse Beam Search.
553
+
554
+ See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
555
+ Models" for details.
556
+
557
+ We only implement the Hamming Diversity penalty here, which performed best
558
+ in the original paper.
559
+ """
560
+
561
+ def __init__(self, tgt_dict, num_groups, diversity_strength):
562
+ super().__init__(tgt_dict)
563
+ self.num_groups = num_groups
564
+ self.diversity_strength = -diversity_strength
565
+ self.beam = BeamSearch(tgt_dict)
566
+
567
+ @torch.jit.export
568
+ def step(
569
+ self,
570
+ step: int,
571
+ lprobs,
572
+ scores,
573
+ prev_output_tokens: Optional[Tensor] = None,
574
+ original_batch_idxs: Optional[Tensor] = None,
575
+ ):
576
+ bsz, beam_size, vocab_size = lprobs.size()
577
+ if beam_size % self.num_groups != 0:
578
+ raise ValueError(
579
+ "DiverseBeamSearch requires --beam to be divisible by the number of groups"
580
+ )
581
+
582
+ # initialize diversity penalty
583
+ diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs)
584
+
585
+ scores_G, indices_G, beams_G = [], [], []
586
+ for g in range(self.num_groups):
587
+ lprobs_g = lprobs[:, g :: self.num_groups, :]
588
+ scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None
589
+
590
+ # apply diversity penalty
591
+ if g > 0:
592
+ lprobs_g = torch.add(
593
+ lprobs_g,
594
+ other=diversity_buf.unsqueeze(1),
595
+ alpha=self.diversity_strength,
596
+ )
597
+ else:
598
+ lprobs_g = lprobs_g.contiguous()
599
+
600
+ scores_buf, indices_buf, beams_buf = self.beam.step(
601
+ step, lprobs_g, scores_g
602
+ )
603
+ beams_buf.mul_(self.num_groups).add_(g)
604
+
605
+ scores_G.append(scores_buf.clone())
606
+ indices_G.append(indices_buf.clone())
607
+ beams_G.append(beams_buf.clone())
608
+
609
+ # update diversity penalty
610
+ diversity_buf.scatter_add_(
611
+ 1, indices_buf, torch.ones(indices_buf.size()).to(diversity_buf)
612
+ )
613
+
614
+ # interleave results from different groups
615
+ scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1)
616
+ indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1)
617
+ beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1)
618
+ return scores_buf, indices_buf, beams_buf
619
+
620
+
621
+ class Sampling(Search):
622
+ sampling_topk: int
623
+ sampling_topp: float
624
+
625
+ def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
626
+ super().__init__(tgt_dict)
627
+ self.sampling_topk = sampling_topk
628
+ self.sampling_topp = sampling_topp
629
+
630
+ def _sample_topp(self, lprobs):
631
+ """Sample among the smallest set of elements whose cumulative probability mass exceeds p.
632
+
633
+ See `"The Curious Case of Neural Text Degeneration"
634
+ (Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
635
+
636
+ Args:
637
+ lprobs: (bsz x input_beam_size x vocab_size)
638
+ the model's log-probabilities over the vocabulary at the current step
639
+
640
+ Return: A tuple of (trimed_probs, truncated_indices) where:
641
+ trimed_probs: (bsz x input_beam_size x ?)
642
+ the model's probabilities over the elements selected to sample from. The
643
+ width of the third dimension is determined by top-P.
644
+ truncated_indices: (bsz x input_beam_size x ?)
645
+ the indices of the chosen elements.
646
+ """
647
+ probs = lprobs.exp_()
648
+
649
+ # sort the last dimension (vocab dimension) in descending order
650
+ sorted_probs, sorted_indices = probs.sort(descending=True)
651
+
652
+ # compute a mask to indicate the words to be included in the top-P set.
653
+ cumsum_probs = sorted_probs.cumsum(dim=2)
654
+ mask = cumsum_probs.lt(self.sampling_topp)
655
+
656
+ # note that mask was computed by 'lt'. One more word needs to be included
657
+ # so that the cumulative probability mass can exceed p.
658
+ cumsum_mask = mask.cumsum(dim=2)
659
+ last_included = cumsum_mask[:, :, -1:]
660
+ last_included.clamp_(0, mask.size()[2] - 1)
661
+ mask = mask.scatter_(2, last_included, 1)
662
+
663
+ # truncate unnecessary dims.
664
+ max_dim = last_included.max()
665
+ truncated_mask = mask[:, :, : max_dim + 1]
666
+ truncated_probs = sorted_probs[:, :, : max_dim + 1]
667
+ truncated_indices = sorted_indices[:, :, : max_dim + 1]
668
+
669
+ # trim the words that are not in top-P by setting their probabilities
670
+ # to 0, so that they would not be sampled later.
671
+ trim_mask = ~truncated_mask
672
+ trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
673
+ return trimed_probs, truncated_indices
674
+
675
+ @torch.jit.export
676
+ def step(
677
+ self,
678
+ step: int,
679
+ lprobs,
680
+ scores,
681
+ prev_output_tokens: Optional[Tensor] = None,
682
+ original_batch_idxs: Optional[Tensor] = None,
683
+ ):
684
+ bsz, beam_size, vocab_size = lprobs.size()
685
+
686
+ if step == 0:
687
+ # at the first step all hypotheses are equally likely, so use
688
+ # only the first beam
689
+ lprobs = lprobs[:, ::beam_size, :].contiguous()
690
+
691
+ if self.sampling_topp > 0:
692
+ # only sample from the smallest set of words whose cumulative probability mass exceeds p
693
+ probs, top_indices = self._sample_topp(lprobs)
694
+ elif self.sampling_topk > 0:
695
+ # only sample from top-k candidates
696
+ lprobs, top_indices = lprobs.topk(self.sampling_topk)
697
+ probs = lprobs.exp_()
698
+ else:
699
+ probs = lprobs.exp_()
700
+
701
+ # dummy data to be consistent with true branch for type check
702
+ top_indices = torch.empty(0).to(probs)
703
+ # sample
704
+ if step == 0:
705
+ indices_buf = torch.multinomial(
706
+ probs.view(bsz, -1),
707
+ beam_size,
708
+ replacement=True,
709
+ ).view(bsz, beam_size)
710
+ else:
711
+ indices_buf = torch.multinomial(
712
+ probs.view(bsz * beam_size, -1),
713
+ 1,
714
+ replacement=True,
715
+ ).view(bsz, beam_size)
716
+
717
+ if step == 0:
718
+ # expand to beam size
719
+ probs = probs.expand(bsz, beam_size, -1)
720
+
721
+ # gather scores
722
+ scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1))
723
+ scores_buf = scores_buf.log_().view(bsz, -1)
724
+
725
+ # remap indices if using top-k or top-P sampling
726
+ if self.sampling_topk > 0 or self.sampling_topp > 0:
727
+ indices_buf = torch.gather(
728
+ top_indices.expand(bsz, beam_size, -1),
729
+ dim=2,
730
+ index=indices_buf.unsqueeze(-1),
731
+ ).squeeze(2)
732
+
733
+ if step == 0:
734
+ beams_buf = indices_buf.new_zeros(bsz, beam_size)
735
+ else:
736
+ beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1)
737
+ # make scores cumulative
738
+ scores_buf.add_(
739
+ torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf)
740
+ )
741
+
742
+ return scores_buf, indices_buf, beams_buf
743
+
744
+
745
+ class DiverseSiblingsSearch(Search):
746
+ """
747
+ Beam search with diverse siblings.
748
+
749
+ See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details.
750
+ https://arxiv.org/abs/1611.08562
751
+
752
+ 1/ Calculate hypotheses for each beam
753
+ 2/ Intra-sibling ordering
754
+ 3/ Rewrite scores
755
+ 4/ Choose top K hypotheses
756
+
757
+ if diversity_rate == 0 is equivalent to BeamSearch
758
+ """
759
+
760
+ def __init__(self, tgt_dict, diversity_rate):
761
+ super().__init__(tgt_dict)
762
+ self.diversity_rate = diversity_rate
763
+ self.beam = BeamSearch(tgt_dict)
764
+
765
+ def step(
766
+ self,
767
+ step: int,
768
+ lprobs,
769
+ scores,
770
+ prev_output_tokens: Optional[Tensor] = None,
771
+ original_batch_idxs: Optional[Tensor] = None,
772
+ ):
773
+ bsz, beam_size, vocab_size = lprobs.size()
774
+ k = min(
775
+ # Take the best 2 x beam_size predictions. We'll choose the first
776
+ # beam_size of these which don't predict eos to continue with.
777
+ beam_size * 2,
778
+ lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad
779
+ )
780
+ s_list: List[Tensor]
781
+ i_list: List[Tensor]
782
+ s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)]
783
+ i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)]
784
+ sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate
785
+
786
+ if step == 0:
787
+ return self.beam.step(step, lprobs, scores)
788
+ lprobs.add_(scores[:, :, step - 1].unsqueeze(-1))
789
+
790
+ # 1/ Calculate hypotheses for each beam
791
+ for i in range(beam_size):
792
+ torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i]))
793
+ i_list[i].fmod_(vocab_size)
794
+
795
+ # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores
796
+ s_list[i].sub_(sibling_score)
797
+
798
+ # 4/ Choose top K hypotheses
799
+ indices = torch.stack(i_list, dim=1).view(bsz, -1)
800
+
801
+ final_scores = torch.empty(0).to(lprobs)
802
+ final_indices = torch.LongTensor().to(device=lprobs.device)
803
+ final_beams = torch.LongTensor().to(device=lprobs.device)
804
+ (final_scores, final_indices) = torch.topk(
805
+ torch.stack(s_list, dim=1).view(bsz, -1),
806
+ k,
807
+ )
808
+
809
+ final_beams = final_indices // k
810
+
811
+ for i in range(bsz):
812
+ final_indices[i] = indices[i][final_indices[i]]
813
+
814
+ return final_scores, final_indices, final_beams
models/sequence_generator.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, List, Optional
8
+ import sys
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from fairseq import search, utils
13
+ from fairseq.models import FairseqIncrementalDecoder
14
+ from torch import Tensor
15
+ from fairseq.ngram_repeat_block import NGramRepeatBlock
16
+
17
+ from data import data_utils
18
+
19
+ class SequenceGenerator(nn.Module):
20
+ def __init__(
21
+ self,
22
+ models,
23
+ tgt_dict,
24
+ beam_size=1,
25
+ max_len_a=0,
26
+ max_len_b=200,
27
+ max_len=0,
28
+ min_len=1,
29
+ normalize_scores=True,
30
+ len_penalty=1.0,
31
+ unk_penalty=0.0,
32
+ temperature=1.0,
33
+ match_source_len=False,
34
+ no_repeat_ngram_size=0,
35
+ search_strategy=None,
36
+ eos=None,
37
+ symbols_to_strip_from_output=None,
38
+ lm_model=None,
39
+ lm_weight=1.0,
40
+ constraint_trie=None,
41
+ constraint_range=None,
42
+ gen_code=False,
43
+ gen_box=False,
44
+ ignore_eos=False,
45
+ zero_shot=False
46
+ ):
47
+ """Generates translations of a given source sentence.
48
+
49
+ Args:
50
+ models (List[~fairseq.models.FairseqModel]): ensemble of models,
51
+ currently support fairseq.models.TransformerModel for scripting
52
+ beam_size (int, optional): beam width (default: 1)
53
+ max_len_a/b (int, optional): generate sequences of maximum length
54
+ ax + b, where x is the source length
55
+ max_len (int, optional): the maximum length of the generated output
56
+ (not including end-of-sentence)
57
+ min_len (int, optional): the minimum length of the generated output
58
+ (not including end-of-sentence)
59
+ normalize_scores (bool, optional): normalize scores by the length
60
+ of the output (default: True)
61
+ len_penalty (float, optional): length penalty, where <1.0 favors
62
+ shorter, >1.0 favors longer sentences (default: 1.0)
63
+ unk_penalty (float, optional): unknown word penalty, where <0
64
+ produces more unks, >0 produces fewer (default: 0.0)
65
+ temperature (float, optional): temperature, where values
66
+ >1.0 produce more uniform samples and values <1.0 produce
67
+ sharper samples (default: 1.0)
68
+ match_source_len (bool, optional): outputs should match the source
69
+ length (default: False)
70
+ """
71
+ super().__init__()
72
+ if isinstance(models, EnsembleModel):
73
+ self.model = models
74
+ else:
75
+ self.model = EnsembleModel(models)
76
+ self.gen_code = gen_code
77
+ self.gen_box = gen_box
78
+ self.ignore_eos = ignore_eos
79
+ self.tgt_dict = tgt_dict
80
+ self.pad = tgt_dict.pad()
81
+ self.unk = tgt_dict.unk()
82
+ self.bos = tgt_dict.bos()
83
+ self.eos = tgt_dict.eos() if eos is None else eos
84
+ self.symbols_to_strip_from_output = (
85
+ symbols_to_strip_from_output.union({self.eos})
86
+ if symbols_to_strip_from_output is not None
87
+ else {self.bos, self.eos}
88
+ )
89
+ self.vocab_size = len(tgt_dict)
90
+ self.beam_size = beam_size
91
+ # the max beam size is the dictionary size - 1, since we never select pad
92
+ self.beam_size = min(beam_size, self.vocab_size - 1)
93
+ self.max_len_a = max_len_a
94
+ self.max_len_b = max_len_b
95
+ self.min_len = min_len
96
+ self.max_len = max_len or self.model.max_decoder_positions()
97
+
98
+ self.normalize_scores = normalize_scores
99
+ self.len_penalty = len_penalty
100
+ self.unk_penalty = unk_penalty
101
+ self.temperature = temperature
102
+ self.match_source_len = match_source_len
103
+ self.zero_shot = zero_shot
104
+
105
+ if no_repeat_ngram_size > 0:
106
+ self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
107
+ else:
108
+ self.repeat_ngram_blocker = None
109
+
110
+ assert temperature > 0, "--temperature must be greater than 0"
111
+
112
+ self.search = (
113
+ search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
114
+ )
115
+ # We only need to set src_lengths in LengthConstrainedBeamSearch.
116
+ # As a module attribute, setting it would break in multithread
117
+ # settings when the model is shared.
118
+ self.should_set_src_lengths = (
119
+ hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
120
+ )
121
+
122
+ self.model.eval()
123
+
124
+ self.lm_model = lm_model
125
+ self.lm_weight = lm_weight
126
+ if self.lm_model is not None:
127
+ self.lm_model.eval()
128
+
129
+ self.constraint_trie = constraint_trie
130
+
131
+ self.constraint_start = None
132
+ self.constraint_end = None
133
+ if constraint_range is not None:
134
+ constraint_start, constraint_end = constraint_range.split(',')
135
+ self.constraint_start = int(constraint_start)
136
+ self.constraint_end = int(constraint_end)
137
+
138
+ def cuda(self):
139
+ self.model.cuda()
140
+ return self
141
+
142
+ @torch.no_grad()
143
+ def forward(
144
+ self,
145
+ sample: Dict[str, Dict[str, Tensor]],
146
+ prefix_tokens: Optional[Tensor] = None,
147
+ bos_token: Optional[int] = None,
148
+ ):
149
+ """Generate a batch of translations.
150
+
151
+ Args:
152
+ sample (dict): batch
153
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
154
+ with these tokens
155
+ bos_token (int, optional): beginning of sentence token
156
+ (default: self.eos)
157
+ """
158
+ return self._generate(sample, prefix_tokens, bos_token=bos_token)
159
+
160
+ # TODO(myleott): unused, deprecate after pytorch-translate migration
161
+ def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
162
+ """Iterate over a batched dataset and yield individual translations.
163
+ Args:
164
+ cuda (bool, optional): use GPU for generation
165
+ timer (StopwatchMeter, optional): time generations
166
+ """
167
+ for sample in data_itr:
168
+ s = utils.move_to_cuda(sample) if cuda else sample
169
+ if "net_input" not in s:
170
+ continue
171
+ input = s["net_input"]
172
+ # model.forward normally channels prev_output_tokens into the decoder
173
+ # separately, but SequenceGenerator directly calls model.encoder
174
+ encoder_input = {
175
+ k: v for k, v in input.items() if k != "prev_output_tokens"
176
+ }
177
+ if timer is not None:
178
+ timer.start()
179
+ with torch.no_grad():
180
+ hypos = self.generate(encoder_input)
181
+ if timer is not None:
182
+ timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
183
+ for i, id in enumerate(s["id"].data):
184
+ # remove padding
185
+ src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
186
+ ref = (
187
+ utils.strip_pad(s["target"].data[i, :], self.pad)
188
+ if s["target"] is not None
189
+ else None
190
+ )
191
+ yield id, src, ref, hypos[i]
192
+
193
+ @torch.no_grad()
194
+ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
195
+ """Generate translations. Match the api of other fairseq generators.
196
+
197
+ Args:
198
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
199
+ sample (dict): batch
200
+ prefix_tokens (torch.LongTensor, optional): force decoder to begin
201
+ with these tokens
202
+ constraints (torch.LongTensor, optional): force decoder to include
203
+ the list of constraints
204
+ bos_token (int, optional): beginning of sentence token
205
+ (default: self.eos)
206
+ """
207
+ return self._generate(models, sample, **kwargs)
208
+
209
+ def _generate(
210
+ self,
211
+ models,
212
+ sample: Dict[str, Dict[str, Tensor]],
213
+ prefix_tokens: Optional[Tensor] = None,
214
+ constraints: Optional[Tensor] = None,
215
+ bos_token: Optional[int] = None,
216
+ ):
217
+ model = EnsembleModel(models)
218
+ incremental_states = torch.jit.annotate(
219
+ List[Dict[str, Dict[str, Optional[Tensor]]]],
220
+ [
221
+ torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
222
+ for i in range(model.models_size)
223
+ ],
224
+ )
225
+ net_input = sample["net_input"]
226
+
227
+ if "src_tokens" in net_input:
228
+ src_tokens = net_input["src_tokens"]
229
+ # length of the source text being the character length except EndOfSentence and pad
230
+ src_lengths = (
231
+ (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
232
+ )
233
+ elif "source" in net_input:
234
+ src_tokens = net_input["source"]
235
+ src_lengths = (
236
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
237
+ if net_input["padding_mask"] is not None
238
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
239
+ )
240
+ elif "features" in net_input:
241
+ src_tokens = net_input["features"]
242
+ src_lengths = (
243
+ net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
244
+ if net_input["padding_mask"] is not None
245
+ else torch.tensor(src_tokens.size(-1)).to(src_tokens)
246
+ )
247
+ else:
248
+ raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
249
+
250
+ # bsz: total number of sentences in beam
251
+ # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
252
+ bsz, src_len = src_tokens.size()[:2]
253
+ beam_size = self.beam_size
254
+
255
+ if constraints is not None and not self.search.supports_constraints:
256
+ raise NotImplementedError(
257
+ "Target-side constraints were provided, but search method doesn't support them"
258
+ )
259
+
260
+ # Initialize constraints, when active
261
+ self.search.init_constraints(constraints, beam_size)
262
+
263
+ max_len: int = -1
264
+ if self.match_source_len:
265
+ max_len = src_lengths.max().item()
266
+ else:
267
+ max_len = int(self.max_len_a * src_len + self.max_len_b)
268
+ assert (
269
+ self.min_len <= max_len
270
+ ), "min_len cannot be larger than max_len, please adjust these!"
271
+ # compute the encoder output for each beam
272
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
273
+ encoder_outs = model.forward_encoder(net_input)
274
+
275
+ # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
276
+ new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
277
+ new_order = new_order.to(src_tokens.device).long()
278
+ encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)
279
+ # ensure encoder_outs is a List.
280
+ assert encoder_outs is not None
281
+
282
+ # initialize buffers
283
+ scores = (
284
+ torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
285
+ ) # +1 for eos; pad is never chosen for scoring
286
+ tokens = (
287
+ torch.zeros(bsz * beam_size, max_len + 2)
288
+ .to(src_tokens)
289
+ .long()
290
+ .fill_(self.pad)
291
+ ) # +2 for eos and pad
292
+ # tokens[:, 0] = self.eos if bos_token is None else bos_token
293
+ tokens[:, 0] = self.bos
294
+ attn: Optional[Tensor] = None
295
+
296
+ # A list that indicates candidates that should be ignored.
297
+ # For example, suppose we're sampling and have already finalized 2/5
298
+ # samples. Then cands_to_ignore would mark 2 positions as being ignored,
299
+ # so that we only finalize the remaining 3 samples.
300
+ cands_to_ignore = (
301
+ torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
302
+ ) # forward and backward-compatible False mask
303
+
304
+ # list of completed sentences
305
+ finalized = torch.jit.annotate(
306
+ List[List[Dict[str, Tensor]]],
307
+ [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
308
+ ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
309
+
310
+ # a boolean array indicating if the sentence at the index is finished or not
311
+ finished = [False for i in range(bsz)]
312
+ num_remaining_sent = bsz # number of sentences remaining
313
+
314
+ # number of candidate hypos per step
315
+ cand_size = 2 * beam_size # 2 x beam size in case half are EOS
316
+
317
+ # offset arrays for converting between different indexing schemes
318
+ bbsz_offsets = (
319
+ (torch.arange(0, bsz) * beam_size)
320
+ .unsqueeze(1)
321
+ .type_as(tokens)
322
+ .to(src_tokens.device)
323
+ )
324
+ cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
325
+
326
+ reorder_state: Optional[Tensor] = None
327
+ batch_idxs: Optional[Tensor] = None
328
+
329
+ original_batch_idxs: Optional[Tensor] = None
330
+ if "id" in sample and isinstance(sample["id"], Tensor):
331
+ original_batch_idxs = sample["id"]
332
+ else:
333
+ original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
334
+
335
+ for step in range(max_len + 1): # one extra step for EOS marker
336
+ # reorder decoder internal states based on the prev choice of beams
337
+ if reorder_state is not None:
338
+ if batch_idxs is not None:
339
+ # update beam indices to take into account removed sentences
340
+ corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
341
+ batch_idxs
342
+ )
343
+ reorder_state.view(-1, beam_size).add_(
344
+ corr.unsqueeze(-1) * beam_size
345
+ )
346
+ original_batch_idxs = original_batch_idxs[batch_idxs]
347
+ model.reorder_incremental_state(incremental_states, reorder_state)
348
+ encoder_outs = model.reorder_encoder_out(
349
+ encoder_outs, reorder_state
350
+ )
351
+ with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
352
+ lprobs, avg_attn_scores = model.forward_decoder(
353
+ tokens[:, : step + 1],
354
+ encoder_outs,
355
+ incremental_states,
356
+ self.temperature,
357
+ constraint_trie=self.constraint_trie,
358
+ constraint_start=self.constraint_start,
359
+ constraint_end=self.constraint_end,
360
+ gen_code=self.gen_code,
361
+ zero_shot=self.zero_shot,
362
+ prefix_tokens=prefix_tokens
363
+ )
364
+
365
+ if self.lm_model is not None:
366
+ lm_out = self.lm_model(tokens[:, : step + 1])
367
+ probs = self.lm_model.get_normalized_probs(
368
+ lm_out, log_probs=True, sample=None
369
+ )
370
+ probs = probs[:, -1, :] * self.lm_weight
371
+ lprobs += probs
372
+ # handle prefix tokens (possibly with different lengths)
373
+ if (
374
+ prefix_tokens is not None
375
+ and step < prefix_tokens.size(1)
376
+ and step < max_len
377
+ ):
378
+ lprobs, tokens, scores = self._prefix_tokens(
379
+ step, lprobs, scores, tokens, prefix_tokens, beam_size
380
+ )
381
+ elif step < self.min_len:
382
+ # minimum length constraint (does not apply if using prefix_tokens)
383
+ lprobs[:, self.eos] = -math.inf
384
+
385
+ lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
386
+
387
+ lprobs[:, self.pad] = -math.inf # never select pad
388
+ lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
389
+
390
+ if (self.gen_code or self.gen_box) and step < max_len:
391
+ lprobs[:, :4] = -math.inf
392
+ if self.gen_box:
393
+ lprobs[:, -1] = -math.inf
394
+ if (step + 1) % 5 == 0:
395
+ lprobs[:, self.constraint_start:59457] = -math.inf
396
+ else:
397
+ lprobs[:, 59457:] = -math.inf
398
+
399
+ # handle max length constraint
400
+ if step >= max_len:
401
+ lprobs[:, : self.eos] = -math.inf
402
+ lprobs[:, self.eos + 1 :] = -math.inf
403
+ if self.ignore_eos:
404
+ lprobs[:, self.eos] = 1
405
+
406
+ # Record attention scores, only support avg_attn_scores is a Tensor
407
+ if avg_attn_scores is not None:
408
+ if attn is None:
409
+ attn = torch.empty(
410
+ bsz * beam_size, avg_attn_scores.size(1), max_len + 2
411
+ ).to(scores)
412
+ attn[:, :, step + 1].copy_(avg_attn_scores)
413
+
414
+ scores = scores.type_as(lprobs)
415
+ eos_bbsz_idx = torch.empty(0).to(
416
+ tokens
417
+ ) # indices of hypothesis ending with eos (finished sentences)
418
+ eos_scores = torch.empty(0).to(
419
+ scores
420
+ ) # scores of hypothesis ending with eos (finished sentences)
421
+
422
+ if self.should_set_src_lengths:
423
+ self.search.set_src_lengths(src_lengths)
424
+
425
+ if self.repeat_ngram_blocker is not None:
426
+ lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
427
+
428
+ # Shape: (batch, cand_size)
429
+ cand_scores, cand_indices, cand_beams = self.search.step(
430
+ step,
431
+ lprobs.view(bsz, -1, self.vocab_size),
432
+ scores.view(bsz, beam_size, -1)[:, :, :step],
433
+ tokens[:, : step + 1],
434
+ original_batch_idxs,
435
+ )
436
+
437
+ # cand_bbsz_idx contains beam indices for the top candidate
438
+ # hypotheses, with a range of values: [0, bsz*beam_size),
439
+ # and dimensions: [bsz, cand_size]
440
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
441
+
442
+ # finalize hypotheses that end in eos
443
+ # Shape of eos_mask: (batch size, beam size)
444
+ eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
445
+ eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
446
+
447
+ # only consider eos when it's among the top beam_size indices
448
+ # Now we know what beam item(s) to finish
449
+ # Shape: 1d list of absolute-numbered
450
+ eos_bbsz_idx = torch.masked_select(
451
+ cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
452
+ )
453
+
454
+ finalized_sents: List[int] = []
455
+ if eos_bbsz_idx.numel() > 0:
456
+ eos_scores = torch.masked_select(
457
+ cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
458
+ )
459
+
460
+ finalized_sents = self.finalize_hypos(
461
+ step,
462
+ eos_bbsz_idx,
463
+ eos_scores,
464
+ tokens,
465
+ scores,
466
+ finalized,
467
+ finished,
468
+ beam_size,
469
+ attn,
470
+ src_lengths,
471
+ max_len,
472
+ )
473
+ num_remaining_sent -= len(finalized_sents)
474
+
475
+ assert num_remaining_sent >= 0
476
+ if num_remaining_sent == 0:
477
+ break
478
+ if self.search.stop_on_max_len and step >= max_len:
479
+ break
480
+ assert step < max_len, f"{step} < {max_len}"
481
+
482
+ # Remove finalized sentences (ones for which {beam_size}
483
+ # finished hypotheses have been generated) from the batch.
484
+ if len(finalized_sents) > 0:
485
+ new_bsz = bsz - len(finalized_sents)
486
+
487
+ # construct batch_idxs which holds indices of batches to keep for the next pass
488
+ batch_mask = torch.ones(
489
+ bsz, dtype=torch.bool, device=cand_indices.device
490
+ )
491
+ batch_mask[finalized_sents] = False
492
+ # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
493
+ batch_idxs = torch.arange(
494
+ bsz, device=cand_indices.device
495
+ ).masked_select(batch_mask)
496
+
497
+ # Choose the subset of the hypothesized constraints that will continue
498
+ self.search.prune_sentences(batch_idxs)
499
+
500
+ eos_mask = eos_mask[batch_idxs]
501
+ cand_beams = cand_beams[batch_idxs]
502
+ bbsz_offsets.resize_(new_bsz, 1)
503
+ cand_bbsz_idx = cand_beams.add(bbsz_offsets)
504
+ cand_scores = cand_scores[batch_idxs]
505
+ cand_indices = cand_indices[batch_idxs]
506
+
507
+ if prefix_tokens is not None:
508
+ prefix_tokens = prefix_tokens[batch_idxs]
509
+ src_lengths = src_lengths[batch_idxs]
510
+ cands_to_ignore = cands_to_ignore[batch_idxs]
511
+
512
+ scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
513
+ tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
514
+ if attn is not None:
515
+ attn = attn.view(bsz, -1)[batch_idxs].view(
516
+ new_bsz * beam_size, attn.size(1), -1
517
+ )
518
+ bsz = new_bsz
519
+ else:
520
+ batch_idxs = None
521
+
522
+ # Set active_mask so that values > cand_size indicate eos hypos
523
+ # and values < cand_size indicate candidate active hypos.
524
+ # After, the min values per row are the top candidate active hypos
525
+
526
+ # Rewrite the operator since the element wise or is not supported in torchscript.
527
+
528
+ eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
529
+ active_mask = torch.add(
530
+ eos_mask.type_as(cand_offsets) * cand_size,
531
+ cand_offsets[: eos_mask.size(1)],
532
+ )
533
+
534
+ # get the top beam_size active hypotheses, which are just
535
+ # the hypos with the smallest values in active_mask.
536
+ # {active_hypos} indicates which {beam_size} hypotheses
537
+ # from the list of {2 * beam_size} candidates were
538
+ # selected. Shapes: (batch size, beam size)
539
+ new_cands_to_ignore, active_hypos = torch.topk(
540
+ active_mask, k=beam_size, dim=1, largest=False
541
+ )
542
+
543
+ # update cands_to_ignore to ignore any finalized hypos.
544
+ cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
545
+ # Make sure there is at least one active item for each sentence in the batch.
546
+ assert (~cands_to_ignore).any(dim=1).all()
547
+
548
+ # update cands_to_ignore to ignore any finalized hypos
549
+
550
+ # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
551
+ # can be selected more than once).
552
+ active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
553
+ active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
554
+
555
+ active_bbsz_idx = active_bbsz_idx.view(-1)
556
+ active_scores = active_scores.view(-1)
557
+
558
+ # copy tokens and scores for active hypotheses
559
+
560
+ # Set the tokens for each beam (can select the same row more than once)
561
+ tokens[:, : step + 1] = torch.index_select(
562
+ tokens[:, : step + 1], dim=0, index=active_bbsz_idx
563
+ )
564
+ # Select the next token for each of them
565
+ tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
566
+ cand_indices, dim=1, index=active_hypos
567
+ )
568
+ if step > 0:
569
+ scores[:, :step] = torch.index_select(
570
+ scores[:, :step], dim=0, index=active_bbsz_idx
571
+ )
572
+ scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
573
+ cand_scores, dim=1, index=active_hypos
574
+ )
575
+
576
+ # Update constraints based on which candidates were selected for the next beam
577
+ self.search.update_constraints(active_hypos)
578
+
579
+ # copy attention for active hypotheses
580
+ if attn is not None:
581
+ attn[:, :, : step + 2] = torch.index_select(
582
+ attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
583
+ )
584
+
585
+ # reorder incremental state in decoder
586
+ reorder_state = active_bbsz_idx
587
+
588
+ # sort by score descending
589
+ for sent in range(len(finalized)):
590
+ scores = torch.tensor(
591
+ [float(elem["score"].item()) for elem in finalized[sent]]
592
+ )
593
+ _, sorted_scores_indices = torch.sort(scores, descending=True)
594
+ finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
595
+ finalized[sent] = torch.jit.annotate(
596
+ List[Dict[str, Tensor]], finalized[sent]
597
+ )
598
+ return finalized
599
+
600
+ def _prefix_tokens(
601
+ self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
602
+ ):
603
+ """Handle prefix tokens"""
604
+ prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
605
+ prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
606
+ prefix_mask = prefix_toks.ne(self.pad)
607
+ if self.constraint_trie is None:
608
+ lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
609
+ else:
610
+ lprobs[prefix_mask] = -math.inf
611
+ lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
612
+ -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
613
+ )
614
+ # if prefix includes eos, then we should make sure tokens and
615
+ # scores are the same across all beams
616
+ eos_mask = prefix_toks.eq(self.eos)
617
+ if eos_mask.any():
618
+ # validate that the first beam matches the prefix
619
+ first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
620
+ :, 0, 1 : step + 1
621
+ ]
622
+ eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
623
+ target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
624
+ assert (first_beam == target_prefix).all()
625
+
626
+ # copy tokens, scores and lprobs from the first beam to all beams
627
+ tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
628
+ scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
629
+ lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
630
+ return lprobs, tokens, scores
631
+
632
+ def replicate_first_beam(self, tensor, mask, beam_size: int):
633
+ tensor = tensor.view(-1, beam_size, tensor.size(-1))
634
+ tensor[mask] = tensor[mask][:, :1, :]
635
+ return tensor.view(-1, tensor.size(-1))
636
+
637
+ def finalize_hypos(
638
+ self,
639
+ step: int,
640
+ bbsz_idx,
641
+ eos_scores,
642
+ tokens,
643
+ scores,
644
+ finalized: List[List[Dict[str, Tensor]]],
645
+ finished: List[bool],
646
+ beam_size: int,
647
+ attn: Optional[Tensor],
648
+ src_lengths,
649
+ max_len: int,
650
+ ):
651
+ """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
652
+ A sentence is finalized when {beam_size} finished items have been collected for it.
653
+
654
+ Returns number of sentences (not beam items) being finalized.
655
+ These will be removed from the batch and not processed further.
656
+ Args:
657
+ bbsz_idx (Tensor):
658
+ """
659
+ assert bbsz_idx.numel() == eos_scores.numel()
660
+
661
+ # clone relevant token and attention tensors.
662
+ # tokens is (batch * beam, max_len). So the index_select
663
+ # gets the newly EOS rows, then selects cols 1..{step + 2}
664
+ tokens_clone = tokens.index_select(0, bbsz_idx)[
665
+ :, 1 : step + 2
666
+ ] # skip the first index, which is EOS
667
+
668
+ tokens_clone[:, step] = self.eos
669
+ attn_clone = (
670
+ attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
671
+ if attn is not None
672
+ else None
673
+ )
674
+
675
+ # compute scores per token position
676
+ pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
677
+ pos_scores[:, step] = eos_scores
678
+ # convert from cumulative to per-position scores
679
+ pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
680
+
681
+ # normalize sentence-level scores
682
+ if self.normalize_scores:
683
+ eos_scores /= (step + 1) ** self.len_penalty
684
+
685
+ # cum_unfin records which sentences in the batch are finished.
686
+ # It helps match indexing between (a) the original sentences
687
+ # in the batch and (b) the current, possibly-reduced set of
688
+ # sentences.
689
+ cum_unfin: List[int] = []
690
+ prev = 0
691
+ for f in finished:
692
+ if f:
693
+ prev += 1
694
+ else:
695
+ cum_unfin.append(prev)
696
+ cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx)
697
+
698
+ unfin_idx = bbsz_idx // beam_size
699
+ sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx)
700
+
701
+ # Create a set of "{sent}{unfin_idx}", where
702
+ # "unfin_idx" is the index in the current (possibly reduced)
703
+ # list of sentences, and "sent" is the index in the original,
704
+ # unreduced batch
705
+ # For every finished beam item
706
+ # sentence index in the current (possibly reduced) batch
707
+ seen = (sent << 32) + unfin_idx
708
+ unique_seen: List[int] = torch.unique(seen).tolist()
709
+
710
+ if self.match_source_len:
711
+ condition = step > torch.index_select(src_lengths, 0, unfin_idx)
712
+ eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores)
713
+ sent_list: List[int] = sent.tolist()
714
+ for i in range(bbsz_idx.size()[0]):
715
+ # An input sentence (among those in a batch) is finished when
716
+ # beam_size hypotheses have been collected for it
717
+ if len(finalized[sent_list[i]]) < beam_size:
718
+ if attn_clone is not None:
719
+ # remove padding tokens from attn scores
720
+ hypo_attn = attn_clone[i]
721
+ else:
722
+ hypo_attn = torch.empty(0)
723
+
724
+ finalized[sent_list[i]].append(
725
+ {
726
+ "tokens": tokens_clone[i],
727
+ "score": eos_scores[i],
728
+ "attention": hypo_attn, # src_len x tgt_len
729
+ "alignment": torch.empty(0),
730
+ "positional_scores": pos_scores[i],
731
+ }
732
+ )
733
+
734
+ newly_finished: List[int] = []
735
+ for unique_s in unique_seen:
736
+ # check termination conditions for this sentence
737
+ unique_sent: int = unique_s >> 32
738
+ unique_unfin_idx: int = unique_s - (unique_sent << 32)
739
+
740
+ if not finished[unique_sent] and self.is_finished(
741
+ step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size
742
+ ):
743
+ finished[unique_sent] = True
744
+ newly_finished.append(unique_unfin_idx)
745
+
746
+ return newly_finished
747
+
748
+ def is_finished(
749
+ self,
750
+ step: int,
751
+ unfin_idx: int,
752
+ max_len: int,
753
+ finalized_sent_len: int,
754
+ beam_size: int,
755
+ ):
756
+ """
757
+ Check whether decoding for a sentence is finished, which
758
+ occurs when the list of finalized sentences has reached the
759
+ beam size, or when we reach the maximum length.
760
+ """
761
+ assert finalized_sent_len <= beam_size
762
+ if finalized_sent_len == beam_size or step == max_len:
763
+ return True
764
+ return False
765
+
766
+
767
+ class EnsembleModel(nn.Module):
768
+ """A wrapper around an ensemble of models."""
769
+
770
+ def __init__(self, models):
771
+ super().__init__()
772
+ self.models_size = len(models)
773
+ # method '__len__' is not supported in ModuleList for torch script
774
+ self.single_model = models[0]
775
+ self.models = nn.ModuleList(models)
776
+
777
+ self.has_incremental: bool = False
778
+ if all(
779
+ hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
780
+ for m in models
781
+ ):
782
+ self.has_incremental = True
783
+
784
+ def forward(self):
785
+ pass
786
+
787
+ def has_encoder(self):
788
+ return hasattr(self.single_model, "encoder")
789
+
790
+ def has_incremental_states(self):
791
+ return self.has_incremental
792
+
793
+ def max_decoder_positions(self):
794
+ return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
795
+
796
+ @torch.jit.export
797
+ def forward_encoder(self, net_input: Dict[str, Tensor]):
798
+ if not self.has_encoder():
799
+ return None
800
+ return [model.encoder.forward_torchscript(net_input) for model in self.models]
801
+
802
+ @torch.jit.export
803
+ def forward_decoder(
804
+ self,
805
+ tokens,
806
+ encoder_outs: List[Dict[str, List[Tensor]]],
807
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
808
+ temperature: float = 1.0,
809
+ constraint_trie=None,
810
+ constraint_start=None,
811
+ constraint_end=None,
812
+ gen_code=False,
813
+ zero_shot=False,
814
+ prefix_tokens=None
815
+ ):
816
+ log_probs = []
817
+ avg_attn: Optional[Tensor] = None
818
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None
819
+ code_mask = (tokens.new_ones(tokens.size(0))*gen_code).bool()
820
+ for i, model in enumerate(self.models):
821
+ if self.has_encoder():
822
+ encoder_out = encoder_outs[i]
823
+ # decode each model
824
+ if self.has_incremental_states():
825
+ decoder_out = model.decoder.forward(
826
+ tokens,
827
+ code_masks=code_mask,
828
+ encoder_out=encoder_out,
829
+ incremental_state=incremental_states[i],
830
+ )
831
+ else:
832
+ if hasattr(model, "decoder"):
833
+ decoder_out = model.decoder.forward(tokens, code_masks=code_mask, encoder_out=encoder_out)
834
+ else:
835
+ decoder_out = model.forward(tokens)
836
+
837
+ attn: Optional[Tensor] = None
838
+ decoder_len = len(decoder_out)
839
+ if decoder_len > 1 and decoder_out[1] is not None:
840
+ if isinstance(decoder_out[1], Tensor):
841
+ attn = decoder_out[1]
842
+ else:
843
+ attn_holder = decoder_out[1]["attn"]
844
+ if isinstance(attn_holder, Tensor):
845
+ attn = attn_holder
846
+ elif attn_holder is not None:
847
+ attn = attn_holder[0]
848
+ if attn is not None:
849
+ attn = attn[:, -1, :]
850
+
851
+ decoder_out_tuple = (
852
+ decoder_out[0][:, -1:, :].div_(temperature),
853
+ None if decoder_len <= 1 else decoder_out[1],
854
+ )
855
+
856
+ beam_size = decoder_out_tuple[0].size(0) // prefix_tokens.size(0) if prefix_tokens is not None else 0
857
+ if constraint_trie is not None and not zero_shot:
858
+ assert constraint_start is None and constraint_end is None
859
+ constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool()
860
+ constraint_prefix_tokens = tokens.tolist()
861
+ for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens):
862
+ prefix_len = prefix_tokens[token_index // beam_size].ne(1).sum().item() if prefix_tokens is not None else 0
863
+ if len(constraint_prefix_token) > prefix_len:
864
+ constraint_prefix_token = [0] + constraint_prefix_token[prefix_len+1:]
865
+ constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token)
866
+ constraint_masks[token_index][:, constraint_nodes] = True
867
+ else:
868
+ constraint_masks[token_index] = True
869
+ decoder_out_tuple[0].masked_fill_(~constraint_masks, -math.inf)
870
+ if constraint_start is not None and constraint_end is not None and not zero_shot:
871
+ assert constraint_trie is None
872
+ decoder_out_tuple[0][:, :, 4:constraint_start] = -math.inf
873
+ decoder_out_tuple[0][:, :, constraint_end:] = -math.inf
874
+
875
+ probs = model.get_normalized_probs(
876
+ decoder_out_tuple, log_probs=True, sample=None
877
+ )
878
+ if constraint_trie is not None and zero_shot:
879
+ assert constraint_start is None and constraint_end is None
880
+ constraint_masks = decoder_out_tuple[0].new_zeros(decoder_out_tuple[0].size()).bool()
881
+ constraint_prefix_tokens = tokens.tolist()
882
+ for token_index, constraint_prefix_token in enumerate(constraint_prefix_tokens):
883
+ constraint_nodes = constraint_trie.get_next_layer(constraint_prefix_token)
884
+ constraint_masks[token_index][:, constraint_nodes] = True
885
+ probs.masked_fill_(~constraint_masks, -math.inf)
886
+ if constraint_start is not None and constraint_end is not None and zero_shot:
887
+ assert constraint_trie is None
888
+ probs[:, :, 4:constraint_start] = -math.inf
889
+ probs[:, :, constraint_end:] = -math.inf
890
+ probs = probs[:, -1, :]
891
+ if self.models_size == 1:
892
+ return probs, attn
893
+
894
+ log_probs.append(probs)
895
+ if attn is not None:
896
+ if avg_attn is None:
897
+ avg_attn = attn
898
+ else:
899
+ avg_attn.add_(attn)
900
+
901
+ avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
902
+ self.models_size
903
+ )
904
+
905
+ if avg_attn is not None:
906
+ avg_attn.div_(self.models_size)
907
+ return avg_probs, avg_attn
908
+
909
+ @torch.jit.export
910
+ def reorder_encoder_out(
911
+ self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
912
+ ):
913
+ """
914
+ Reorder encoder output according to *new_order*.
915
+
916
+ Args:
917
+ encoder_out: output from the ``forward()`` method
918
+ new_order (LongTensor): desired order
919
+
920
+ Returns:
921
+ *encoder_out* rearranged according to *new_order*
922
+ """
923
+ new_outs: List[Dict[str, List[Tensor]]] = []
924
+ if not self.has_encoder():
925
+ return new_outs
926
+ for i, model in enumerate(self.models):
927
+ assert encoder_outs is not None
928
+ new_outs.append(
929
+ model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
930
+ )
931
+ return new_outs
932
+
933
+ @torch.jit.export
934
+ def reorder_incremental_state(
935
+ self,
936
+ incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
937
+ new_order,
938
+ ):
939
+ if not self.has_incremental_states():
940
+ return
941
+ for i, model in enumerate(self.models):
942
+ model.decoder.reorder_incremental_state_scripting(
943
+ incremental_states[i], new_order
944
+ )
945
+
946
+
947
+ class SequenceGeneratorWithAlignment(SequenceGenerator):
948
+ def __init__(
949
+ self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
950
+ ):
951
+ """Generates translations of a given source sentence.
952
+
953
+ Produces alignments following "Jointly Learning to Align and
954
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
955
+
956
+ Args:
957
+ left_pad_target (bool, optional): Whether or not the
958
+ hypothesis should be left padded or not when they are
959
+ teacher forced for generating alignments.
960
+ """
961
+ super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
962
+ self.left_pad_target = left_pad_target
963
+
964
+ if print_alignment == "hard":
965
+ self.extract_alignment = utils.extract_hard_alignment
966
+ elif print_alignment == "soft":
967
+ self.extract_alignment = utils.extract_soft_alignment
968
+
969
+ @torch.no_grad()
970
+ def generate(self, models, sample, **kwargs):
971
+ finalized = super()._generate(sample, **kwargs)
972
+
973
+ src_tokens = sample["net_input"]["src_tokens"]
974
+ bsz = src_tokens.shape[0]
975
+ beam_size = self.beam_size
976
+ (
977
+ src_tokens,
978
+ src_lengths,
979
+ prev_output_tokens,
980
+ tgt_tokens,
981
+ ) = self._prepare_batch_for_alignment(sample, finalized)
982
+ if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
983
+ attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
984
+ else:
985
+ attn = [
986
+ finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
987
+ for i in range(bsz * beam_size)
988
+ ]
989
+
990
+ if src_tokens.device != "cpu":
991
+ src_tokens = src_tokens.to("cpu")
992
+ tgt_tokens = tgt_tokens.to("cpu")
993
+ attn = [i.to("cpu") for i in attn]
994
+
995
+ # Process the attn matrix to extract hard alignments.
996
+ for i in range(bsz * beam_size):
997
+ alignment = self.extract_alignment(
998
+ attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
999
+ )
1000
+ finalized[i // beam_size][i % beam_size]["alignment"] = alignment
1001
+ return finalized
1002
+
1003
+ def _prepare_batch_for_alignment(self, sample, hypothesis):
1004
+ src_tokens = sample["net_input"]["src_tokens"]
1005
+ bsz = src_tokens.shape[0]
1006
+ src_tokens = (
1007
+ src_tokens[:, None, :]
1008
+ .expand(-1, self.beam_size, -1)
1009
+ .contiguous()
1010
+ .view(bsz * self.beam_size, -1)
1011
+ )
1012
+ src_lengths = sample["net_input"]["src_lengths"]
1013
+ src_lengths = (
1014
+ src_lengths[:, None]
1015
+ .expand(-1, self.beam_size)
1016
+ .contiguous()
1017
+ .view(bsz * self.beam_size)
1018
+ )
1019
+ prev_output_tokens = data_utils.collate_tokens(
1020
+ [beam["tokens"] for example in hypothesis for beam in example],
1021
+ self.pad,
1022
+ self.eos,
1023
+ self.left_pad_target,
1024
+ move_eos_to_beginning=True,
1025
+ )
1026
+ tgt_tokens = data_utils.collate_tokens(
1027
+ [beam["tokens"] for example in hypothesis for beam in example],
1028
+ self.pad,
1029
+ self.eos,
1030
+ self.left_pad_target,
1031
+ move_eos_to_beginning=False,
1032
+ )
1033
+ return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
1034
+
1035
+
1036
+ class EnsembleModelWithAlignment(EnsembleModel):
1037
+ """A wrapper around an ensemble of models."""
1038
+
1039
+ def __init__(self, models):
1040
+ super().__init__(models)
1041
+
1042
+ def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
1043
+ avg_attn = None
1044
+ for model in self.models:
1045
+ decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
1046
+ attn = decoder_out[1]["attn"][0]
1047
+ if avg_attn is None:
1048
+ avg_attn = attn
1049
+ else:
1050
+ avg_attn.add_(attn)
1051
+ if len(self.models) > 1:
1052
+ avg_attn.div_(len(self.models))
1053
+ return avg_attn
ofa_module/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ import data
2
+ import models
3
+ import tasks
4
+ import criterions
5
+ import utils
pokemons.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ -e ./fairseq/
2
+ ftfy==6.0.3
3
+ tensorboardX==2.4.1
4
+ pycocotools==2.0.4
5
+ pycocoevalcap==1.2
run_scripts/caption/coco_eval.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os.path as op
4
+
5
+ from pycocotools.coco import COCO
6
+ from pycocoevalcap.eval import COCOEvalCap
7
+
8
+
9
+ def evaluate_on_coco_caption(res_file, label_file, outfile=None):
10
+ """
11
+ res_file: txt file, each row is [image_key, json format list of captions].
12
+ Each caption is a dict, with fields "caption", "conf".
13
+ label_file: JSON file of ground truth captions in COCO format.
14
+ """
15
+ coco = COCO(label_file)
16
+ cocoRes = coco.loadRes(res_file)
17
+ cocoEval = COCOEvalCap(coco, cocoRes)
18
+
19
+ # evaluate on a subset of images by setting
20
+ # cocoEval.params['image_id'] = cocoRes.getImgIds()
21
+ # please remove this line when evaluating the full validation set
22
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
23
+
24
+ # evaluate results
25
+ # SPICE will take a few minutes the first time, but speeds up due to caching
26
+ cocoEval.evaluate()
27
+ result = cocoEval.eval
28
+ if not outfile:
29
+ print(result)
30
+ else:
31
+ with open(outfile, 'w') as fp:
32
+ json.dump(result, fp, indent=4)
33
+ return result
34
+
35
+
36
+ if __name__ == "__main__":
37
+ if len(sys.argv) == 3:
38
+ evaluate_on_coco_caption(sys.argv[1], sys.argv[2])
39
+ elif len(sys.argv) == 4:
40
+ evaluate_on_coco_caption(sys.argv[1], sys.argv[2], sys.argv[3])
41
+ else:
42
+ raise NotImplementedError
run_scripts/caption/evaluate_caption.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ user_dir=../../ofa_module
4
+ bpe_dir=../../utils/BPE
5
+
6
+ data=../../dataset/caption_data/caption_test.tsv
7
+ path=../../checkpoints/caption_large_best_clean.pt
8
+ result_path=../../results/caption
9
+ selected_cols=1,4,2
10
+ split='test'
11
+
12
+ CUDA_VISIBLE_DEVICES=4,5,6,7 python3 ../../evaluate.py \
13
+ ${data} \
14
+ --path=${path} \
15
+ --user-dir=${user_dir} \
16
+ --task=caption \
17
+ --batch-size=16 \
18
+ --log-format=simple --log-interval=10 \
19
+ --seed=7 \
20
+ --gen-subset=${split} \
21
+ --results-path=${result_path} \
22
+ --beam=5 \
23
+ --max-len-b=16 \
24
+ --no-repeat-ngram-size=3 \
25
+ --fp16 \
26
+ --num-workers=0 \
27
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"eval_cider\":False,\"selected_cols\":\"${selected_cols}\"}"
28
+
29
+ python coco_eval.py ../../results/caption/test_predict.json ../../dataset/caption_data/test_caption_coco_format.json
run_scripts/caption/train_caption_stage1.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+
3
+ log_dir=./stage1_logs
4
+ save_dir=./stage1_checkpoints
5
+ mkdir -p $log_dir $save_dir
6
+
7
+ bpe_dir=../../utils/BPE
8
+ user_dir=../../ofa_module
9
+
10
+ data_dir=../../dataset/caption_data
11
+ data=${data_dir}/caption_stage1_train.tsv,${data_dir}/caption_val.tsv
12
+ restore_file=../../checkpoints/ofa_large.pt
13
+ selected_cols=0,4,2
14
+
15
+ task=caption
16
+ arch=ofa_large
17
+ criterion=ajust_label_smoothed_cross_entropy
18
+ label_smoothing=0.1
19
+ lr=1e-5
20
+ max_epoch=5
21
+ warmup_ratio=0.06
22
+ batch_size=8
23
+ update_freq=4
24
+ resnet_drop_path_rate=0.0
25
+ encoder_drop_path_rate=0.1
26
+ decoder_drop_path_rate=0.1
27
+ dropout=0.1
28
+ attention_dropout=0.0
29
+ max_src_length=80
30
+ max_tgt_length=20
31
+ num_bins=1000
32
+ patch_image_size=480
33
+ eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p
34
+ drop_worst_ratio=0.2
35
+
36
+ for max_epoch in {2,}; do
37
+ echo "max_epoch "${max_epoch}
38
+ for warmup_ratio in {0.06,}; do
39
+ echo "warmup_ratio "${warmup_ratio}
40
+ for drop_worst_after in {2500,}; do
41
+ echo "drop_worst_after "${drop_worst_after}
42
+
43
+ log_file=${log_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}".log"
44
+ save_path=${save_dir}/${max_epoch}"_"${warmup_ratio}"_"${drop_worst_after}
45
+ mkdir -p $save_path
46
+
47
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../train.py \
48
+ $data \
49
+ --selected-cols=${selected_cols} \
50
+ --bpe-dir=${bpe_dir} \
51
+ --user-dir=${user_dir} \
52
+ --restore-file=${restore_file} \
53
+ --reset-optimizer --reset-dataloader --reset-meters \
54
+ --save-dir=${save_path} \
55
+ --task=${task} \
56
+ --arch=${arch} \
57
+ --criterion=${criterion} \
58
+ --label-smoothing=${label_smoothing} \
59
+ --batch-size=${batch_size} \
60
+ --update-freq=${update_freq} \
61
+ --encoder-normalize-before \
62
+ --decoder-normalize-before \
63
+ --share-decoder-input-output-embed \
64
+ --share-all-embeddings \
65
+ --layernorm-embedding \
66
+ --patch-layernorm-embedding \
67
+ --code-layernorm-embedding \
68
+ --resnet-drop-path-rate=${resnet_drop_path_rate} \
69
+ --encoder-drop-path-rate=${encoder_drop_path_rate} \
70
+ --decoder-drop-path-rate=${decoder_drop_path_rate} \
71
+ --dropout=${dropout} \
72
+ --attention-dropout=${attention_dropout} \
73
+ --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
74
+ --lr-scheduler=polynomial_decay --lr=${lr} \
75
+ --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
76
+ --log-format=simple --log-interval=10 \
77
+ --fixed-validation-seed=7 \
78
+ --no-epoch-checkpoints --keep-best-checkpoints=1 \
79
+ --save-interval=1 --validate-interval=1 \
80
+ --save-interval-updates=500 --validate-interval-updates=500 \
81
+ --eval-cider \
82
+ --eval-cider-cached-tokens=${eval_cider_cached} \
83
+ --eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
84
+ --best-checkpoint-metric=cider --maximize-best-checkpoint-metric \
85
+ --max-src-length=${max_src_length} \
86
+ --max-tgt-length=${max_tgt_length} \
87
+ --find-unused-parameters \
88
+ --freeze-encoder-embedding \
89
+ --freeze-decoder-embedding \
90
+ --add-type-embedding \
91
+ --scale-attn \
92
+ --scale-fc \
93
+ --scale-heads \
94
+ --disable-entangle \
95
+ --num-bins=${num_bins} \
96
+ --patch-image-size=${patch_image_size} \
97
+ --drop-worst-ratio=${drop_worst_ratio} \
98
+ --drop-worst-after=${drop_worst_after} \
99
+ --fp16 \
100
+ --fp16-scale-window=512 \
101
+ --num-workers=0 >> ${log_file} 2>&1
102
+ done
103
+ done
104
+ done
run_scripts/caption/train_caption_stage2.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+
3
+ log_dir=./stage2_logs
4
+ save_dir=./stage2_checkpoints
5
+ mkdir -p $log_dir $save_dir
6
+
7
+ bpe_dir=../../utils/BPE
8
+ user_dir=../../ofa_module
9
+
10
+ data_dir=../../dataset/caption_data
11
+ data=${data_dir}/caption_stage2_train.tsv,${data_dir}/caption_val.tsv
12
+ restore_file=../../checkpoints/caption_stage1_best.pt
13
+ selected_cols=1,4,2
14
+
15
+ task=caption
16
+ arch=ofa_large
17
+ criterion=scst_reward_criterion
18
+ label_smoothing=0.1
19
+ lr=1e-5
20
+ max_epoch=5
21
+ warmup_ratio=0.06
22
+ batch_size=2
23
+ update_freq=4
24
+ resnet_drop_path_rate=0.0
25
+ encoder_drop_path_rate=0.0
26
+ decoder_drop_path_rate=0.0
27
+ dropout=0.0
28
+ attention_dropout=0.0
29
+ max_src_length=80
30
+ max_tgt_length=20
31
+ num_bins=1000
32
+ patch_image_size=480
33
+ eval_cider_cached=${data_dir}/cider_cached_tokens/coco-valid-words.p
34
+ scst_cider_cached=${data_dir}/cider_cached_tokens/coco-train-words.p
35
+
36
+ for lr in {1e-5,}; do
37
+ echo "lr "${lr}
38
+ for max_epoch in {4,}; do
39
+ echo "max_epoch "${max_epoch}
40
+
41
+ log_file=${log_dir}/${lr}"_"${max_epoch}".log"
42
+ save_path=${save_dir}/${lr}"_"${max_epoch}
43
+ mkdir -p $save_path
44
+
45
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 ../../train.py \
46
+ $data \
47
+ --selected-cols=${selected_cols} \
48
+ --bpe-dir=${bpe_dir} \
49
+ --user-dir=${user_dir} \
50
+ --restore-file=${restore_file} \
51
+ --reset-optimizer --reset-dataloader --reset-meters \
52
+ --save-dir=${save_path} \
53
+ --task=${task} \
54
+ --arch=${arch} \
55
+ --criterion=${criterion} \
56
+ --batch-size=${batch_size} \
57
+ --update-freq=${update_freq} \
58
+ --encoder-normalize-before \
59
+ --decoder-normalize-before \
60
+ --share-decoder-input-output-embed \
61
+ --share-all-embeddings \
62
+ --layernorm-embedding \
63
+ --patch-layernorm-embedding \
64
+ --code-layernorm-embedding \
65
+ --resnet-drop-path-rate=${resnet_drop_path_rate} \
66
+ --encoder-drop-path-rate=${encoder_drop_path_rate} \
67
+ --decoder-drop-path-rate=${decoder_drop_path_rate} \
68
+ --dropout=${dropout} \
69
+ --attention-dropout=${attention_dropout} \
70
+ --weight-decay=0.01 --optimizer=adam --adam-betas="(0.9,0.999)" --adam-eps=1e-08 --clip-norm=1.0 \
71
+ --lr-scheduler=polynomial_decay --lr=${lr} \
72
+ --max-epoch=${max_epoch} --warmup-ratio=${warmup_ratio} \
73
+ --log-format=simple --log-interval=10 \
74
+ --fixed-validation-seed=7 \
75
+ --no-epoch-checkpoints --keep-best-checkpoints=1 \
76
+ --save-interval=1 --validate-interval=1 \
77
+ --save-interval-updates=500 --validate-interval-updates=500 \
78
+ --eval-cider \
79
+ --eval-cider-cached-tokens=${eval_cider_cached} \
80
+ --eval-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
81
+ --best-checkpoint-metric=cider --maximize-best-checkpoint-metric \
82
+ --max-src-length=${max_src_length} \
83
+ --max-tgt-length=${max_tgt_length} \
84
+ --find-unused-parameters \
85
+ --freeze-encoder-embedding \
86
+ --freeze-decoder-embedding \
87
+ --add-type-embedding \
88
+ --scale-attn \
89
+ --scale-fc \
90
+ --scale-heads \
91
+ --disable-entangle \
92
+ --num-bins=${num_bins} \
93
+ --patch-image-size=${patch_image_size} \
94
+ --scst \
95
+ --scst-cider-cached-tokens=${scst_cider_cached} \
96
+ --scst-args='{"beam":5,"max_len_b":16,"no_repeat_ngram_size":3}' \
97
+ --memory-efficient-fp16 \
98
+ --fp16-scale-window=512 \
99
+ --num-workers=0 >> ${log_file} 2>&1
100
+ done
101
+ done
run_scripts/refcoco/evaluate_refcoco.sh ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+
4
+ ########################## Evaluate Refcoco ##########################
5
+ user_dir=../../ofa_module
6
+ bpe_dir=../../utils/BPE
7
+ selected_cols=0,4,2,3
8
+
9
+ data=../../dataset/refcoco_data/refcoco_val.tsv
10
+ path=../../checkpoints/refcoco_large_best.pt
11
+ result_path=../../results/refcoco
12
+ split='refcoco_val'
13
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
14
+ ${data} \
15
+ --path=${path} \
16
+ --user-dir=${user_dir} \
17
+ --task=refcoco \
18
+ --batch-size=16 \
19
+ --log-format=simple --log-interval=10 \
20
+ --seed=7 \
21
+ --gen-subset=${split} \
22
+ --results-path=${result_path} \
23
+ --beam=5 \
24
+ --min-len=4 \
25
+ --max-len-a=0 \
26
+ --max-len-b=4 \
27
+ --no-repeat-ngram-size=3 \
28
+ --fp16 \
29
+ --num-workers=0 \
30
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
31
+
32
+ data=../../dataset/refcoco_data/refcoco_testA.tsv
33
+ path=../../checkpoints/refcoco_large_best.pt
34
+ result_path=../../results/refcoco
35
+ split='refcoco_testA'
36
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
37
+ ${data} \
38
+ --path=${path} \
39
+ --user-dir=${user_dir} \
40
+ --task=refcoco \
41
+ --batch-size=16 \
42
+ --log-format=simple --log-interval=10 \
43
+ --seed=7 \
44
+ --gen-subset=${split} \
45
+ --results-path=${result_path} \
46
+ --beam=5 \
47
+ --min-len=4 \
48
+ --max-len-a=0 \
49
+ --max-len-b=4 \
50
+ --no-repeat-ngram-size=3 \
51
+ --fp16 \
52
+ --num-workers=0 \
53
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
54
+
55
+ data=../../dataset/refcoco_data/refcoco_testB.tsv
56
+ path=../../checkpoints/refcoco_large_best.pt
57
+ result_path=../../results/refcoco
58
+ split='refcoco_testB'
59
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
60
+ ${data} \
61
+ --path=${path} \
62
+ --user-dir=${user_dir} \
63
+ --task=refcoco \
64
+ --batch-size=16 \
65
+ --log-format=simple --log-interval=10 \
66
+ --seed=7 \
67
+ --gen-subset=${split} \
68
+ --results-path=${result_path} \
69
+ --beam=5 \
70
+ --min-len=4 \
71
+ --max-len-a=0 \
72
+ --max-len-b=4 \
73
+ --no-repeat-ngram-size=3 \
74
+ --fp16 \
75
+ --num-workers=0 \
76
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
77
+
78
+
79
+
80
+ ######################### Evaluate Refcocoplus ##########################
81
+ data=../../dataset/refcocoplus_data/refcocoplus_val.tsv
82
+ path=../../checkpoints/refcocoplus_large_best.pt
83
+ result_path=../../results/refcocoplus
84
+ split='refcocoplus_val'
85
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
86
+ ${data} \
87
+ --path=${path} \
88
+ --user-dir=${user_dir} \
89
+ --task=refcoco \
90
+ --batch-size=16 \
91
+ --log-format=simple --log-interval=10 \
92
+ --seed=7 \
93
+ --gen-subset=${split} \
94
+ --results-path=${result_path} \
95
+ --beam=5 \
96
+ --min-len=4 \
97
+ --max-len-a=0 \
98
+ --max-len-b=4 \
99
+ --no-repeat-ngram-size=3 \
100
+ --fp16 \
101
+ --num-workers=0 \
102
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
103
+
104
+ data=../../dataset/refcocoplus_data/refcocoplus_testA.tsv
105
+ path=../../checkpoints/refcocoplus_large_best.pt
106
+ result_path=../../results/refcocoplus
107
+ split='refcocoplus_testA'
108
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
109
+ ${data} \
110
+ --path=${path} \
111
+ --user-dir=${user_dir} \
112
+ --task=refcoco \
113
+ --batch-size=16 \
114
+ --log-format=simple --log-interval=10 \
115
+ --seed=7 \
116
+ --gen-subset=${split} \
117
+ --results-path=${result_path} \
118
+ --beam=5 \
119
+ --min-len=4 \
120
+ --max-len-a=0 \
121
+ --max-len-b=4 \
122
+ --no-repeat-ngram-size=3 \
123
+ --fp16 \
124
+ --num-workers=0 \
125
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
126
+
127
+ data=../../dataset/refcocoplus_data/refcocoplus_testB.tsv
128
+ path=../../checkpoints/refcocoplus_large_best.pt
129
+ result_path=../../results/refcocoplus
130
+ split='refcocoplus_testB'
131
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
132
+ ${data} \
133
+ --path=${path} \
134
+ --user-dir=${user_dir} \
135
+ --task=refcoco \
136
+ --batch-size=16 \
137
+ --log-format=simple --log-interval=10 \
138
+ --seed=7 \
139
+ --gen-subset=${split} \
140
+ --results-path=${result_path} \
141
+ --beam=5 \
142
+ --min-len=4 \
143
+ --max-len-a=0 \
144
+ --max-len-b=4 \
145
+ --no-repeat-ngram-size=3 \
146
+ --fp16 \
147
+ --num-workers=0 \
148
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
149
+
150
+
151
+
152
+ ########################## Evaluate Refcocog ##########################
153
+ data=../../dataset/refcocog_data/refcocog_val.tsv
154
+ path=../../checkpoints/refcocog_large_best.pt
155
+ result_path=../../results/refcocog
156
+ split='refcocog_val'
157
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
158
+ ${data} \
159
+ --path=${path} \
160
+ --user-dir=${user_dir} \
161
+ --task=refcoco \
162
+ --batch-size=16 \
163
+ --log-format=simple --log-interval=10 \
164
+ --seed=7 \
165
+ --gen-subset=${split} \
166
+ --results-path=${result_path} \
167
+ --beam=5 \
168
+ --min-len=4 \
169
+ --max-len-a=0 \
170
+ --max-len-b=4 \
171
+ --no-repeat-ngram-size=3 \
172
+ --fp16 \
173
+ --num-workers=0 \
174
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
175
+
176
+ data=../../dataset/refcocog_data/refcocog_test.tsv
177
+ path=../../checkpoints/refcocog_large_best.pt
178
+ result_path=../../results/refcocog
179
+ split='refcocog_test'
180
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python3 ../../evaluate.py \
181
+ ${data} \
182
+ --path=${path} \
183
+ --user-dir=${user_dir} \
184
+ --task=refcoco \
185
+ --batch-size=16 \
186
+ --log-format=simple --log-interval=10 \
187
+ --seed=7 \
188
+ --gen-subset=${split} \
189
+ --results-path=${result_path} \
190
+ --beam=5 \
191
+ --min-len=4 \
192
+ --max-len-a=0 \
193
+ --max-len-b=4 \
194
+ --no-repeat-ngram-size=3 \
195
+ --fp16 \
196
+ --num-workers=0 \
197
+ --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
spaces.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ # Spaces
2
+ To provide better experience, we plan to build demos for our OFA models on Huggingface Spaces. Below we provide links to the demos. Have fun!
3
+
4
+ * Image Captioning: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)
tasks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .mm_tasks import *
2
+ from .ofa_task import OFATask
tasks/mm_tasks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from .caption import CaptionTask
2
+ from .refcoco import RefcocoTask
tasks/mm_tasks/caption.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ import json
8
+ import logging
9
+ from typing import Optional
10
+ from argparse import Namespace
11
+ from itertools import zip_longest
12
+ from collections import OrderedDict
13
+
14
+ import numpy as np
15
+ import sacrebleu
16
+ import string
17
+ from fairseq import metrics, utils
18
+ from fairseq.tasks import register_task
19
+
20
+ from tasks.ofa_task import OFATask, OFAConfig
21
+ from data.mm_data.caption_dataset import CaptionDataset
22
+ from data.file_dataset import FileDataset
23
+ from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
24
+
25
+ EVAL_BLEU_ORDER = 4
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class CaptionConfig(OFAConfig):
32
+ eval_bleu: bool = field(
33
+ default=False, metadata={"help": "evaluation with BLEU scores"}
34
+ )
35
+ eval_cider: bool = field(
36
+ default=False, metadata={"help": "evaluation with CIDEr scores"}
37
+ )
38
+ eval_args: Optional[str] = field(
39
+ default='{}',
40
+ metadata={
41
+ "help": 'generation args for BLUE or CIDEr scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
42
+ },
43
+ )
44
+ eval_print_samples: bool = field(
45
+ default=False, metadata={"help": "print sample generations during validation"}
46
+ )
47
+ eval_cider_cached_tokens: Optional[str] = field(
48
+ default=None,
49
+ metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
50
+ )
51
+
52
+ scst: bool = field(
53
+ default=False, metadata={"help": "Self-critical sequence training"}
54
+ )
55
+ scst_args: str = field(
56
+ default='{}',
57
+ metadata={
58
+ "help": 'generation args for Self-critical sequence training, as JSON string'
59
+ },
60
+ )
61
+
62
+
63
+ @register_task("caption", dataclass=CaptionConfig)
64
+ class CaptionTask(OFATask):
65
+ def __init__(self, cfg: CaptionConfig, src_dict, tgt_dict):
66
+ super().__init__(cfg, src_dict, tgt_dict)
67
+
68
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
69
+ paths = self.cfg.data.split(',')
70
+ assert len(paths) > 0
71
+
72
+ if split == 'train':
73
+ file_path = paths[(epoch - 1) % (len(paths) - 1)]
74
+ else:
75
+ file_path = paths[-1]
76
+ dataset = FileDataset(file_path, self.cfg.selected_cols)
77
+
78
+ self.datasets[split] = CaptionDataset(
79
+ split,
80
+ dataset,
81
+ self.bpe,
82
+ self.src_dict,
83
+ self.tgt_dict,
84
+ max_src_length=self.cfg.max_src_length,
85
+ max_tgt_length=self.cfg.max_tgt_length,
86
+ patch_image_size=self.cfg.patch_image_size,
87
+ imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
88
+ scst=getattr(self.cfg, 'scst', False)
89
+ )
90
+
91
+ def build_model(self, cfg):
92
+ model = super().build_model(cfg)
93
+ if self.cfg.eval_bleu or self.cfg.eval_cider:
94
+ gen_args = json.loads(self.cfg.eval_args)
95
+ self.sequence_generator = self.build_generator(
96
+ [model], Namespace(**gen_args)
97
+ )
98
+ if self.cfg.eval_cider:
99
+ self.CiderD_scorer = CiderD(df=self.cfg.eval_cider_cached_tokens)
100
+ if self.cfg.scst:
101
+ scst_args = json.loads(self.cfg.scst_args)
102
+ self.scst_generator = self.build_generator(
103
+ [model], Namespace(**scst_args)
104
+ )
105
+
106
+ return model
107
+
108
+ def _calculate_cider_scores(self, gen_res, gt_res):
109
+ '''
110
+ gen_res: generated captions, list of str
111
+ gt_idx: list of int, of the same length as gen_res
112
+ gt_res: ground truth captions, list of list of str.
113
+ gen_res[i] corresponds to gt_res[gt_idx[i]]
114
+ Each image can have multiple ground truth captions
115
+ '''
116
+ gen_res_size = len(gen_res)
117
+
118
+ res = OrderedDict()
119
+ for i in range(gen_res_size):
120
+ res[i] = [gen_res[i].strip()]
121
+
122
+ gts = OrderedDict()
123
+ gt_res_ = [
124
+ [gt_res[i][j].strip() for j in range(len(gt_res[i]))]
125
+ for i in range(len(gt_res))
126
+ ]
127
+ for i in range(gen_res_size):
128
+ gts[i] = gt_res_[i]
129
+
130
+ res_ = [{'image_id': i, 'caption': res[i]} for i in range(len(res))]
131
+ _, scores = self.CiderD_scorer.compute_score(gts, res_)
132
+ return scores
133
+
134
+ def valid_step(self, sample, model, criterion):
135
+ loss, sample_size, logging_output = criterion(model, sample)
136
+
137
+ model.eval()
138
+ if self.cfg.eval_bleu or self.cfg.eval_cider:
139
+ hyps, refs = self._inference(self.sequence_generator, sample, model)
140
+ if self.cfg.eval_bleu:
141
+ if self.cfg.eval_tokenized_bleu:
142
+ bleu = sacrebleu.corpus_bleu(hyps, list(zip_longest(*refs)), tokenize="none")
143
+ else:
144
+ bleu = sacrebleu.corpus_bleu(hyps, list(zip_longest(*refs)))
145
+ logging_output["_bleu_sys_len"] = bleu.sys_len
146
+ logging_output["_bleu_ref_len"] = bleu.ref_len
147
+ # we split counts into separate entries so that they can be
148
+ # summed efficiently across workers using fast-stat-sync
149
+ assert len(bleu.counts) == EVAL_BLEU_ORDER
150
+ for i in range(EVAL_BLEU_ORDER):
151
+ logging_output["_bleu_counts_" + str(i)] = bleu.counts[i]
152
+ logging_output["_bleu_totals_" + str(i)] = bleu.totals[i]
153
+ if self.cfg.eval_cider:
154
+ scores = self._calculate_cider_scores(hyps, refs)
155
+ logging_output["_cider_score_sum"] = scores.sum()
156
+ logging_output["_cider_cnt"] = scores.size
157
+
158
+ return loss, sample_size, logging_output
159
+
160
+ def reduce_metrics(self, logging_outputs, criterion):
161
+ super().reduce_metrics(logging_outputs, criterion)
162
+
163
+ def sum_logs(key):
164
+ import torch
165
+ result = sum(log.get(key, 0) for log in logging_outputs)
166
+ if torch.is_tensor(result):
167
+ result = result.cpu()
168
+ return result
169
+
170
+ if self.cfg.eval_bleu:
171
+ counts, totals = [], []
172
+ for i in range(EVAL_BLEU_ORDER):
173
+ counts.append(sum_logs("_bleu_counts_" + str(i)))
174
+ totals.append(sum_logs("_bleu_totals_" + str(i)))
175
+
176
+ if max(totals) > 0:
177
+ # log counts as numpy arrays -- log_scalar will sum them correctly
178
+ metrics.log_scalar("_bleu_counts", np.array(counts))
179
+ metrics.log_scalar("_bleu_totals", np.array(totals))
180
+ metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
181
+ metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
182
+
183
+ def compute_bleu(meters):
184
+ import inspect
185
+ import sacrebleu
186
+
187
+ fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
188
+ if "smooth_method" in fn_sig:
189
+ smooth = {"smooth_method": "exp"}
190
+ else:
191
+ smooth = {"smooth": "exp"}
192
+ bleu = sacrebleu.compute_bleu(
193
+ correct=meters["_bleu_counts"].sum,
194
+ total=meters["_bleu_totals"].sum,
195
+ sys_len=meters["_bleu_sys_len"].sum,
196
+ ref_len=meters["_bleu_ref_len"].sum,
197
+ **smooth
198
+ )
199
+ return round(bleu.score, 2)
200
+
201
+ metrics.log_derived("bleu", compute_bleu)
202
+
203
+ if self.cfg.eval_cider:
204
+ def compute_cider(meters):
205
+ cider = meters["_cider_score_sum"].sum / meters["_cider_cnt"].sum
206
+ cider = cider if isinstance(cider, float) else cider.item()
207
+ return round(cider, 3)
208
+
209
+ if sum_logs("_cider_cnt") > 0:
210
+ metrics.log_scalar("_cider_score_sum", sum_logs("_cider_score_sum"))
211
+ metrics.log_scalar("_cider_cnt", sum_logs("_cider_cnt"))
212
+ metrics.log_derived("cider", compute_cider)
213
+
214
+ def _inference(self, generator, sample, model):
215
+
216
+ def decode(toks, escape_unk=False):
217
+ s = self.tgt_dict.string(
218
+ toks.int().cpu(),
219
+ # The default unknown string in fairseq is `<unk>`, but
220
+ # this is tokenized by sacrebleu as `< unk >`, inflating
221
+ # BLEU scores. Instead, we use a somewhat more verbose
222
+ # alternative that is unlikely to appear in the real
223
+ # reference, but doesn't get split into multiple tokens.
224
+ unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"),
225
+ )
226
+ if self.bpe:
227
+ s = self.bpe.decode(s)
228
+ return s
229
+
230
+ gen_out = self.inference_step(generator, [model], sample)
231
+ hyps, refs = [], []
232
+ transtab = str.maketrans({key: None for key in string.punctuation})
233
+ for i in range(len(gen_out)):
234
+ decode_tokens = decode(gen_out[i][0]["tokens"])
235
+ hyps.append(decode_tokens.translate(transtab).strip())
236
+ refs.append(
237
+ [
238
+ sent.translate(transtab).strip()
239
+ for sent in decode(
240
+ utils.strip_pad(sample["target"][i], self.tgt_dict.pad()),
241
+ escape_unk=True, # don't count <unk> as matches to the hypo
242
+ ).split('&&')
243
+ ]
244
+ )
245
+ if self.cfg.eval_print_samples:
246
+ logger.info("example hypothesis: " + hyps[0])
247
+ logger.info("example reference: " + ' && '.join(refs[0]))
248
+
249
+ return hyps, refs
tasks/mm_tasks/refcoco.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ import json
8
+ import logging
9
+ from typing import Optional
10
+ from argparse import Namespace
11
+
12
+ import torch
13
+ from fairseq import metrics
14
+ from fairseq.tasks import register_task
15
+
16
+ from tasks.ofa_task import OFATask, OFAConfig
17
+ from data.mm_data.refcoco_dataset import RefcocoDataset
18
+ from data.file_dataset import FileDataset
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class RefcocoConfig(OFAConfig):
25
+ # options for reporting BLEU during validation
26
+ eval_acc: bool = field(
27
+ default=False, metadata={"help": "evaluation with BLEU scores"}
28
+ )
29
+ eval_args: Optional[str] = field(
30
+ default='{}',
31
+ metadata={
32
+ "help": 'generation args for BLUE or CIDEr scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string'
33
+ },
34
+ )
35
+ eval_print_samples: bool = field(
36
+ default=False, metadata={"help": "print sample generations during validation"}
37
+ )
38
+
39
+ max_image_size: int = field(
40
+ default=512, metadata={"help": "max image size for normalization"}
41
+ )
42
+ scst: bool = field(
43
+ default=False, metadata={"help": "Self-critical sequence training"}
44
+ )
45
+ scst_args: str = field(
46
+ default='{}',
47
+ metadata={
48
+ "help": 'generation args for Self-critical sequence training, as JSON string'
49
+ },
50
+ )
51
+
52
+
53
+ @register_task("refcoco", dataclass=RefcocoConfig)
54
+ class RefcocoTask(OFATask):
55
+ def __init__(self, cfg: RefcocoConfig, src_dict, tgt_dict):
56
+ super().__init__(cfg, src_dict, tgt_dict)
57
+
58
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
59
+ paths = self.cfg.data.split(',')
60
+ assert len(paths) > 0
61
+
62
+ if split == 'train':
63
+ file_path = paths[(epoch - 1) % (len(paths) - 1)]
64
+ else:
65
+ file_path = paths[-1]
66
+ dataset = FileDataset(file_path, self.cfg.selected_cols)
67
+
68
+ self.datasets[split] = RefcocoDataset(
69
+ split,
70
+ dataset,
71
+ self.bpe,
72
+ self.src_dict,
73
+ self.tgt_dict,
74
+ max_src_length=self.cfg.max_src_length,
75
+ max_tgt_length=self.cfg.max_tgt_length,
76
+ patch_image_size=self.cfg.patch_image_size,
77
+ imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
78
+ num_bins=self.cfg.num_bins,
79
+ max_image_size=self.cfg.max_image_size
80
+ )
81
+
82
+ def build_model(self, cfg):
83
+ model = super().build_model(cfg)
84
+ if self.cfg.eval_acc:
85
+ gen_args = json.loads(self.cfg.eval_args)
86
+ self.sequence_generator = self.build_generator(
87
+ [model], Namespace(**gen_args)
88
+ )
89
+ if self.cfg.scst:
90
+ scst_args = json.loads(self.cfg.scst_args)
91
+ self.scst_generator = self.build_generator(
92
+ [model], Namespace(**scst_args)
93
+ )
94
+
95
+ return model
96
+
97
+ def _calculate_ap_score(self, hyps, refs, thresh=0.5):
98
+ interacts = torch.cat(
99
+ [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
100
+ torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
101
+ dim=1
102
+ )
103
+ area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
104
+ area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
105
+ interacts_w = interacts[:, 2] - interacts[:, 0]
106
+ interacts_h = interacts[:, 3] - interacts[:, 1]
107
+ area_interacts = interacts_w * interacts_h
108
+ ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
109
+ return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
110
+
111
+ def valid_step(self, sample, model, criterion):
112
+ loss, sample_size, logging_output = criterion(model, sample)
113
+
114
+ model.eval()
115
+ if self.cfg.eval_acc:
116
+ hyps, refs = self._inference(self.sequence_generator, sample, model)
117
+ hyps = hyps / (self.cfg.num_bins - 1) * self.cfg.max_image_size
118
+ refs = refs / (self.cfg.num_bins - 1) * self.cfg.max_image_size
119
+ hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
120
+ hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
121
+ refs[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
122
+ refs[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
123
+
124
+ # scores = self._calculate_ap_score(hyps, refs)
125
+ scores = self._calculate_ap_score(hyps, sample['region_coords'].float())
126
+ logging_output["_score_sum"] = scores.sum().item()
127
+ logging_output["_score_cnt"] = scores.size(0)
128
+
129
+ return loss, sample_size, logging_output
130
+
131
+ def reduce_metrics(self, logging_outputs, criterion):
132
+ super().reduce_metrics(logging_outputs, criterion)
133
+
134
+ def sum_logs(key):
135
+ import torch
136
+ result = sum(log.get(key, 0) for log in logging_outputs)
137
+ if torch.is_tensor(result):
138
+ result = result.cpu()
139
+ return result
140
+
141
+ def compute_score(meters):
142
+ score = meters["_score_sum"].sum / meters["_score_cnt"].sum
143
+ score = score if isinstance(score, float) else score.item()
144
+ return round(score, 4)
145
+
146
+ if sum_logs("_score_cnt") > 0:
147
+ metrics.log_scalar("_score_sum", sum_logs("_score_sum"))
148
+ metrics.log_scalar("_score_cnt", sum_logs("_score_cnt"))
149
+ metrics.log_derived("score", compute_score)
150
+
151
+ def _inference(self, generator, sample, model):
152
+ gen_out = self.inference_step(generator, [model], sample)
153
+ hyps, refs = [], []
154
+ for i in range(len(gen_out)):
155
+ hyps.append(gen_out[i][0]["tokens"][:-1] - len(self.src_dict) + self.cfg.num_bins)
156
+ refs.append(sample["target"][i][:-1] - len(self.src_dict) + self.cfg.num_bins)
157
+ if self.cfg.eval_print_samples:
158
+ logger.info("example hypothesis: ", hyps[0])
159
+ logger.info("example reference: ", refs[0])
160
+
161
+ return torch.stack(hyps, dim=0), torch.stack(refs, dim=0)
tasks/ofa_task.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ import logging
8
+ import os
9
+ import math
10
+ import torch
11
+ from typing import Dict, Optional
12
+
13
+ from fairseq import search
14
+ from fairseq.data import FairseqDataset, iterators
15
+ from fairseq.optim.amp_optimizer import AMPOptimizer
16
+ from fairseq.dataclass import FairseqDataclass
17
+ from fairseq.tasks import FairseqTask, register_task
18
+ from omegaconf import DictConfig
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class OFAConfig(FairseqDataclass):
26
+ data: Optional[str] = field(
27
+ default=None,
28
+ metadata={
29
+ "help": "colon separated path to data directories list, will be iterated upon during epochs "
30
+ "in round-robin manner; however, valid and test data are always in the first directory "
31
+ "to avoid the need for repeating them in all directories"
32
+ },
33
+ )
34
+ selected_cols: Optional[str] = field(
35
+ default=None,
36
+ metadata={"help": "selected cols"},
37
+ )
38
+ bpe_dir: Optional[str] = field(
39
+ default=None,
40
+ metadata={"help": "bpe dir"},
41
+ )
42
+ max_source_positions: int = field(
43
+ default=1024, metadata={"help": "max number of tokens in the source sequence"}
44
+ )
45
+ max_target_positions: int = field(
46
+ default=1024, metadata={"help": "max number of tokens in the target sequence"}
47
+ )
48
+ max_src_length: int = field(
49
+ default=128, metadata={"help": "the maximum src sequence length"}
50
+ )
51
+ max_tgt_length: int = field(
52
+ default=30, metadata={"help": "the maximum target sequence length"}
53
+ )
54
+
55
+ code_dict_size: int = field(
56
+ default=8192, metadata={"help": "code dict size"}
57
+ )
58
+ patch_image_size: int = field(
59
+ default=480, metadata={"help": "patch image size"}
60
+ )
61
+ num_bins: int = field(
62
+ default=1000, metadata={"help": "number of quantization bins"}
63
+ )
64
+
65
+ imagenet_default_mean_and_std: bool = field(
66
+ default=False,
67
+ metadata={"help": "imagenet normalize"},
68
+ )
69
+ constraint_range: Optional[str] = field(
70
+ default=None,
71
+ metadata={"help": "constraint range"}
72
+ )
73
+
74
+
75
+ @register_task("ofa", dataclass=OFAConfig)
76
+ class OFATask(FairseqTask):
77
+ def __init__(self, cfg: OFAConfig, src_dict, tgt_dict):
78
+ super().__init__(cfg)
79
+ self.src_dict = src_dict
80
+ self.tgt_dict = tgt_dict
81
+
82
+ @classmethod
83
+ def setup_task(cls, cfg: DictConfig, **kwargs):
84
+ """Setup the task."""
85
+
86
+ # load dictionaries
87
+ src_dict = cls.load_dictionary(
88
+ os.path.join(cfg.bpe_dir, "dict.txt")
89
+ )
90
+ tgt_dict = cls.load_dictionary(
91
+ os.path.join(cfg.bpe_dir, "dict.txt")
92
+ )
93
+ src_dict.add_symbol("<mask>")
94
+ tgt_dict.add_symbol("<mask>")
95
+ for i in range(cfg.code_dict_size):
96
+ src_dict.add_symbol("<code_{}>".format(i))
97
+ tgt_dict.add_symbol("<code_{}>".format(i))
98
+ # quantization
99
+ for i in range(cfg.num_bins):
100
+ src_dict.add_symbol("<bin_{}>".format(i))
101
+ tgt_dict.add_symbol("<bin_{}>".format(i))
102
+
103
+ logger.info("source dictionary: {} types".format(len(src_dict)))
104
+ logger.info("target dictionary: {} types".format(len(tgt_dict)))
105
+ return cls(cfg, src_dict, tgt_dict)
106
+
107
+ def get_batch_iterator(
108
+ self,
109
+ dataset,
110
+ max_tokens=None,
111
+ max_sentences=None,
112
+ max_positions=None,
113
+ ignore_invalid_inputs=False,
114
+ required_batch_size_multiple=1,
115
+ seed=1,
116
+ num_shards=1,
117
+ shard_id=0,
118
+ num_workers=0,
119
+ epoch=1,
120
+ data_buffer_size=0,
121
+ disable_iterator_cache=False,
122
+ ):
123
+ assert isinstance(dataset, FairseqDataset)
124
+
125
+ # initialize the dataset with the correct starting epoch
126
+ dataset.set_epoch(epoch)
127
+
128
+ # create mini-batches with given size constraints
129
+ batch_sampler = [
130
+ [j for j in range(i, min(i + max_sentences, len(dataset)))]
131
+ for i in range(0, len(dataset), max_sentences)
132
+ ]
133
+ total_row_count = dataset.dataset.get_total_row_count()
134
+ num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
135
+ if len(batch_sampler) < num_batches:
136
+ batch_sampler.append([])
137
+
138
+ # return a reusable, sharded iterator
139
+ epoch_iter = iterators.EpochBatchIterator(
140
+ dataset=dataset,
141
+ collate_fn=dataset.collater,
142
+ batch_sampler=batch_sampler,
143
+ seed=seed,
144
+ num_shards=1,
145
+ shard_id=0,
146
+ num_workers=num_workers,
147
+ epoch=epoch,
148
+ buffer_size=data_buffer_size
149
+ )
150
+
151
+ return epoch_iter
152
+
153
+ def build_model(self, cfg: FairseqDataclass):
154
+ model = super().build_model(cfg)
155
+ bpe_dict = {
156
+ "_name": "gpt2",
157
+ "gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
158
+ "gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
159
+ }
160
+ bpe_dict = DictConfig(bpe_dict)
161
+ self.bpe = self.build_bpe(bpe_dict)
162
+ return model
163
+
164
+ def build_generator(
165
+ self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
166
+ ):
167
+ """
168
+ Build a :class:`~fairseq.SequenceGenerator` instance for this
169
+ task.
170
+
171
+ Args:
172
+ models (List[~fairseq.models.FairseqModel]): ensemble of models
173
+ args (fairseq.dataclass.configs.GenerationConfig):
174
+ configuration object (dataclass) for generation
175
+ extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
176
+ through to SequenceGenerator
177
+ prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
178
+ If provided, this function constrains the beam search to
179
+ allowed tokens only at each step. The provided function
180
+ should take 2 arguments: the batch ID (`batch_id: int`)
181
+ and a unidimensional tensor of token ids (`inputs_ids:
182
+ torch.Tensor`). It has to return a `List[int]` with the
183
+ allowed tokens for the next generation step conditioned
184
+ on the previously generated tokens (`inputs_ids`) and
185
+ the batch ID (`batch_id`). This argument is useful for
186
+ constrained generation conditioned on the prefix, as
187
+ described in "Autoregressive Entity Retrieval"
188
+ (https://arxiv.org/abs/2010.00904) and
189
+ https://github.com/facebookresearch/GENRE.
190
+ """
191
+ if getattr(args, "score_reference", False):
192
+ from fairseq.sequence_scorer import SequenceScorer
193
+
194
+ return SequenceScorer(
195
+ self.target_dictionary,
196
+ compute_alignment=getattr(args, "print_alignment", False),
197
+ )
198
+
199
+ from fairseq.sequence_generator import (
200
+ # SequenceGenerator,
201
+ SequenceGeneratorWithAlignment,
202
+ )
203
+ from models.sequence_generator import SequenceGenerator
204
+
205
+ # Choose search strategy. Defaults to Beam Search.
206
+ sampling = getattr(args, "sampling", False)
207
+ sampling_topk = getattr(args, "sampling_topk", -1)
208
+ sampling_topp = getattr(args, "sampling_topp", -1.0)
209
+ diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
210
+ diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
211
+ match_source_len = getattr(args, "match_source_len", False)
212
+ diversity_rate = getattr(args, "diversity_rate", -1)
213
+ constrained = getattr(args, "constraints", False)
214
+ if prefix_allowed_tokens_fn is None:
215
+ prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
216
+ if (
217
+ sum(
218
+ int(cond)
219
+ for cond in [
220
+ sampling,
221
+ diverse_beam_groups > 0,
222
+ match_source_len,
223
+ diversity_rate > 0,
224
+ ]
225
+ )
226
+ > 1
227
+ ):
228
+ raise ValueError("Provided Search parameters are mutually exclusive.")
229
+ assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
230
+ assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
231
+
232
+ if sampling:
233
+ search_strategy = search.Sampling(
234
+ self.target_dictionary, sampling_topk, sampling_topp
235
+ )
236
+ elif diverse_beam_groups > 0:
237
+ search_strategy = search.DiverseBeamSearch(
238
+ self.target_dictionary, diverse_beam_groups, diverse_beam_strength
239
+ )
240
+ elif match_source_len:
241
+ # this is useful for tagging applications where the output
242
+ # length should match the input length, so we hardcode the
243
+ # length constraints for simplicity
244
+ search_strategy = search.LengthConstrainedBeamSearch(
245
+ self.target_dictionary,
246
+ min_len_a=1,
247
+ min_len_b=0,
248
+ max_len_a=1,
249
+ max_len_b=0,
250
+ )
251
+ elif diversity_rate > -1:
252
+ search_strategy = search.DiverseSiblingsSearch(
253
+ self.target_dictionary, diversity_rate
254
+ )
255
+ elif constrained:
256
+ search_strategy = search.LexicallyConstrainedBeamSearch(
257
+ self.target_dictionary, args.constraints
258
+ )
259
+ elif prefix_allowed_tokens_fn:
260
+ search_strategy = search.PrefixConstrainedBeamSearch(
261
+ self.target_dictionary, prefix_allowed_tokens_fn
262
+ )
263
+ else:
264
+ search_strategy = search.BeamSearch(self.target_dictionary)
265
+
266
+ extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
267
+ if seq_gen_cls is None:
268
+ if getattr(args, "print_alignment", False):
269
+ seq_gen_cls = SequenceGeneratorWithAlignment
270
+ extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
271
+ else:
272
+ seq_gen_cls = SequenceGenerator
273
+
274
+ return seq_gen_cls(
275
+ models,
276
+ self.target_dictionary,
277
+ beam_size=getattr(args, "beam", 5),
278
+ max_len_a=getattr(args, "max_len_a", 0),
279
+ max_len_b=getattr(args, "max_len_b", 200),
280
+ min_len=getattr(args, "min_len", 1),
281
+ normalize_scores=(not getattr(args, "unnormalized", False)),
282
+ len_penalty=getattr(args, "lenpen", 1),
283
+ unk_penalty=getattr(args, "unkpen", 0),
284
+ temperature=getattr(args, "temperature", 1.0),
285
+ match_source_len=getattr(args, "match_source_len", False),
286
+ no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
287
+ search_strategy=search_strategy,
288
+ constraint_range=self.cfg.constraint_range,
289
+ **extra_gen_cls_kwargs,
290
+ )
291
+
292
+ def train_step(
293
+ self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs
294
+ ):
295
+ """
296
+ Do forward and backward, and return the loss as computed by *criterion*
297
+ for the given *model* and *sample*.
298
+
299
+ Args:
300
+ sample (dict): the mini-batch. The format is defined by the
301
+ :class:`~fairseq.data.FairseqDataset`.
302
+ model (~fairseq.models.BaseFairseqModel): the model
303
+ criterion (~fairseq.criterions.FairseqCriterion): the criterion
304
+ optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
305
+ update_num (int): the current update
306
+ ignore_grad (bool): multiply loss by 0 if this is set to True
307
+
308
+ Returns:
309
+ tuple:
310
+ - the loss
311
+ - the sample size, which is used as the denominator for the
312
+ gradient
313
+ - logging outputs to display while training
314
+ """
315
+ model.train()
316
+ model.set_num_updates(update_num)
317
+ with torch.autograd.profiler.record_function("forward"):
318
+ with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
319
+ loss, sample_size, logging_output = criterion(model, sample, update_num=update_num)
320
+ if ignore_grad:
321
+ loss *= 0
322
+ with torch.autograd.profiler.record_function("backward"):
323
+ optimizer.backward(loss)
324
+ return loss, sample_size, logging_output
325
+
326
+ def max_positions(self):
327
+ """Return the max sentence length allowed by the task."""
328
+ return (self.cfg.max_source_positions, self.cfg.max_target_positions)
329
+
330
+ @property
331
+ def source_dictionary(self):
332
+ """Return the source :class:`~fairseq.data.Dictionary`."""
333
+ return self.src_dict
334
+
335
+ @property
336
+ def target_dictionary(self):
337
+ """Return the target :class:`~fairseq.data.Dictionary`."""
338
+ return self.tgt_dict
train.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Train a new model on one or across multiple GPUs.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from typing import Dict, Optional, Any, List, Tuple, Callable
16
+
17
+ # We need to setup root logger before importing any fairseq libraries.
18
+ logging.basicConfig(
19
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("fairseq_cli.train")
25
+
26
+ import numpy as np
27
+ import torch
28
+ from fairseq import (
29
+ # checkpoint_utils,
30
+ options,
31
+ quantization_utils,
32
+ tasks,
33
+ utils,
34
+ )
35
+ from fairseq.data import iterators
36
+ from fairseq.data.plasma_utils import PlasmaStore
37
+ from fairseq.dataclass.configs import FairseqConfig
38
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
39
+ from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
40
+ from fairseq.file_io import PathManager
41
+ from fairseq.logging import meters, metrics, progress_bar
42
+ from fairseq.model_parallel.megatron_trainer import MegatronTrainer
43
+ # from fairseq.trainer import Trainer
44
+ from omegaconf import DictConfig, OmegaConf
45
+
46
+ from utils import checkpoint_utils
47
+ from trainer import Trainer
48
+
49
+
50
+ def main(cfg: FairseqConfig) -> None:
51
+ if isinstance(cfg, argparse.Namespace):
52
+ cfg = convert_namespace_to_omegaconf(cfg)
53
+
54
+ utils.import_user_module(cfg.common)
55
+
56
+ if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
57
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
58
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
59
+
60
+ assert (
61
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
62
+ ), "Must specify batch size either with --max-tokens or --batch-size"
63
+ metrics.reset()
64
+
65
+ if cfg.common.log_file is not None:
66
+ handler = logging.FileHandler(filename=cfg.common.log_file)
67
+ logger.addHandler(handler)
68
+
69
+ np.random.seed(cfg.common.seed)
70
+ utils.set_torch_seed(cfg.common.seed)
71
+
72
+ if distributed_utils.is_master(cfg.distributed_training):
73
+ checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
74
+
75
+ # Print args
76
+ logger.info(cfg)
77
+
78
+ if cfg.checkpoint.write_checkpoints_asynchronously:
79
+ try:
80
+ import iopath # noqa: F401
81
+ except ImportError:
82
+ logging.exception(
83
+ "Asynchronous checkpoint writing is specified but iopath is "
84
+ "not installed: `pip install iopath`"
85
+ )
86
+ return
87
+
88
+ # Setup task, e.g., translation, language modeling, etc.
89
+ task = tasks.setup_task(cfg.task)
90
+
91
+ assert cfg.criterion, "Please specify criterion to train a model"
92
+
93
+ # Build model and criterion
94
+ if cfg.distributed_training.ddp_backend == "fully_sharded":
95
+ with fsdp_enable_wrap(cfg.distributed_training):
96
+ model = fsdp_wrap(task.build_model(cfg.model))
97
+ else:
98
+ model = task.build_model(cfg.model)
99
+ criterion = task.build_criterion(cfg.criterion)
100
+ logger.info(model)
101
+ logger.info("task: {}".format(task.__class__.__name__))
102
+ logger.info("model: {}".format(model.__class__.__name__))
103
+ logger.info("criterion: {}".format(criterion.__class__.__name__))
104
+ logger.info(
105
+ "num. shared model params: {:,} (num. trained: {:,})".format(
106
+ sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
107
+ sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
108
+ )
109
+ )
110
+
111
+ logger.info(
112
+ "num. expert model params: {} (num. trained: {})".format(
113
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
114
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
115
+ )
116
+ )
117
+
118
+ # Load valid dataset (we load training data below, based on the latest checkpoint)
119
+ # We load the valid dataset AFTER building the model
120
+ # data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
121
+ if cfg.dataset.combine_valid_subsets:
122
+ task.load_dataset("valid", combine=True, epoch=1)
123
+ else:
124
+ for valid_sub_split in cfg.dataset.valid_subset.split(","):
125
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
126
+
127
+ # (optionally) Configure quantization
128
+ if cfg.common.quantization_config_path is not None:
129
+ quantizer = quantization_utils.Quantizer(
130
+ config_path=cfg.common.quantization_config_path,
131
+ max_epoch=cfg.optimization.max_epoch,
132
+ max_update=cfg.optimization.max_update,
133
+ )
134
+ else:
135
+ quantizer = None
136
+
137
+ # Build trainer
138
+ if cfg.common.model_parallel_size == 1:
139
+ trainer = Trainer(cfg, task, model, criterion, quantizer)
140
+ else:
141
+ trainer = MegatronTrainer(cfg, task, model, criterion)
142
+ logger.info(
143
+ "training on {} devices (GPUs/TPUs)".format(
144
+ cfg.distributed_training.distributed_world_size
145
+ )
146
+ )
147
+ logger.info(
148
+ "max tokens per device = {} and max sentences per device = {}".format(
149
+ cfg.dataset.max_tokens,
150
+ cfg.dataset.batch_size,
151
+ )
152
+ )
153
+
154
+ # Load the latest checkpoint if one is available and restore the
155
+ # corresponding train iterator
156
+ extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
157
+ cfg.checkpoint,
158
+ trainer,
159
+ # don't cache epoch iterators for sharded datasets
160
+ disable_iterator_cache=task.has_sharded_data("train"),
161
+ )
162
+ if cfg.common.tpu:
163
+ import torch_xla.core.xla_model as xm
164
+ xm.rendezvous("load_checkpoint") # wait for all workers
165
+
166
+ max_epoch = cfg.optimization.max_epoch or math.inf
167
+ if max_epoch > 0:
168
+ num_iter_per_epoch = (len(epoch_itr) + cfg.distributed_training.distributed_world_size - 1) \
169
+ // cfg.distributed_training.distributed_world_size
170
+ trainer.lr_reinit(num_iter_per_epoch * max_epoch, trainer.get_num_updates())
171
+ lr = trainer.get_lr()
172
+
173
+ train_meter = meters.StopwatchMeter()
174
+ train_meter.start()
175
+ while epoch_itr.next_epoch_idx <= max_epoch:
176
+ if lr <= cfg.optimization.stop_min_lr:
177
+ logger.info(
178
+ f"stopping training because current learning rate ({lr}) is smaller "
179
+ "than or equal to minimum learning rate "
180
+ f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
181
+ )
182
+ break
183
+
184
+ # train for one epoch
185
+ valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
186
+ if should_stop:
187
+ break
188
+
189
+ # only use first validation loss to update the learning rate
190
+ lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
191
+
192
+ epoch_itr = trainer.get_train_iterator(
193
+ epoch_itr.next_epoch_idx,
194
+ # sharded data: get train iterator for next epoch
195
+ load_dataset=True,
196
+ # don't cache epoch iterators for sharded datasets
197
+ disable_iterator_cache=task.has_sharded_data("train"),
198
+ )
199
+ train_meter.stop()
200
+ logger.info("done training in {:.1f} seconds".format(train_meter.sum))
201
+
202
+ # ioPath implementation to wait for all asynchronous file writes to complete.
203
+ if cfg.checkpoint.write_checkpoints_asynchronously:
204
+ logger.info(
205
+ "ioPath PathManager waiting for all asynchronous checkpoint "
206
+ "writes to finish."
207
+ )
208
+ PathManager.async_close()
209
+ logger.info("ioPath PathManager finished waiting.")
210
+
211
+
212
+ def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
213
+ # skip check if no validation was done in the current epoch
214
+ if valid_loss is None:
215
+ return False
216
+ if cfg.checkpoint.patience <= 0:
217
+ return False
218
+
219
+ def is_better(a, b):
220
+ return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
221
+
222
+ prev_best = getattr(should_stop_early, "best", None)
223
+ if prev_best is None or is_better(valid_loss, prev_best):
224
+ should_stop_early.best = valid_loss
225
+ should_stop_early.num_runs = 0
226
+ return False
227
+ else:
228
+ should_stop_early.num_runs += 1
229
+ if should_stop_early.num_runs >= cfg.checkpoint.patience:
230
+ logger.info(
231
+ "early stop since valid performance hasn't improved for last {} runs".format(
232
+ cfg.checkpoint.patience
233
+ )
234
+ )
235
+ return True
236
+ else:
237
+ return False
238
+
239
+
240
+ @metrics.aggregate("train")
241
+ def train(
242
+ cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
243
+ ) -> Tuple[List[Optional[float]], bool]:
244
+ """Train the model for one epoch and return validation losses."""
245
+ # Initialize data iterator
246
+ itr = epoch_itr.next_epoch_itr(
247
+ fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
248
+ shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
249
+ )
250
+ update_freq = (
251
+ cfg.optimization.update_freq[epoch_itr.epoch - 1]
252
+ if epoch_itr.epoch <= len(cfg.optimization.update_freq)
253
+ else cfg.optimization.update_freq[-1]
254
+ )
255
+ itr = iterators.GroupedIterator(itr, update_freq)
256
+ if cfg.common.tpu:
257
+ itr = utils.tpu_data_loader(itr)
258
+ progress = progress_bar.progress_bar(
259
+ itr,
260
+ log_format=cfg.common.log_format,
261
+ log_file=cfg.common.log_file,
262
+ log_interval=cfg.common.log_interval,
263
+ epoch=epoch_itr.epoch,
264
+ tensorboard_logdir=(
265
+ cfg.common.tensorboard_logdir
266
+ if distributed_utils.is_master(cfg.distributed_training)
267
+ else None
268
+ ),
269
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
270
+ wandb_project=(
271
+ cfg.common.wandb_project
272
+ if distributed_utils.is_master(cfg.distributed_training)
273
+ else None
274
+ ),
275
+ wandb_run_name=os.environ.get(
276
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
277
+ ),
278
+ azureml_logging=(
279
+ cfg.common.azureml_logging
280
+ if distributed_utils.is_master(cfg.distributed_training)
281
+ else False
282
+ ),
283
+ )
284
+ progress.update_config(_flatten_config(cfg))
285
+
286
+ trainer.begin_epoch(epoch_itr.epoch)
287
+
288
+ valid_subsets = cfg.dataset.valid_subset.split(",")
289
+ should_stop = False
290
+ num_updates = trainer.get_num_updates()
291
+ logger.info("Start iterating over samples")
292
+ for i, samples in enumerate(progress):
293
+ with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
294
+ "train_step-%d" % i
295
+ ):
296
+ log_output = trainer.train_step(samples)
297
+
298
+ if log_output is not None: # not OOM, overflow, ...
299
+ # log mid-epoch stats
300
+ num_updates = trainer.get_num_updates()
301
+ if num_updates % cfg.common.log_interval == 0:
302
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
303
+ progress.log(stats, tag="train_inner", step=num_updates)
304
+
305
+ # reset mid-epoch stats after each log interval
306
+ # the end-of-epoch stats will still be preserved
307
+ metrics.reset_meters("train_inner")
308
+
309
+ end_of_epoch = not itr.has_next()
310
+ valid_losses, should_stop = validate_and_save(
311
+ cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
312
+ )
313
+
314
+ if should_stop:
315
+ break
316
+
317
+ # log end-of-epoch stats
318
+ logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
319
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
320
+ progress.print(stats, tag="train", step=num_updates)
321
+
322
+ # reset epoch-level meters
323
+ metrics.reset_meters("train")
324
+ return valid_losses, should_stop
325
+
326
+
327
+ def _flatten_config(cfg: DictConfig):
328
+ config = OmegaConf.to_container(cfg)
329
+ # remove any legacy Namespaces and replace with a single "args"
330
+ namespace = None
331
+ for k, v in list(config.items()):
332
+ if isinstance(v, argparse.Namespace):
333
+ namespace = v
334
+ del config[k]
335
+ if namespace is not None:
336
+ config["args"] = vars(namespace)
337
+ return config
338
+
339
+
340
+ def validate_and_save(
341
+ cfg: DictConfig,
342
+ trainer: Trainer,
343
+ task: tasks.FairseqTask,
344
+ epoch_itr,
345
+ valid_subsets: List[str],
346
+ end_of_epoch: bool,
347
+ ) -> Tuple[List[Optional[float]], bool]:
348
+ num_updates = trainer.get_num_updates()
349
+ max_update = cfg.optimization.max_update or math.inf
350
+
351
+ # Stopping conditions (and an additional one based on validation loss later
352
+ # on)
353
+ should_stop = False
354
+ if num_updates >= max_update:
355
+ should_stop = True
356
+ logger.info(
357
+ f"Stopping training due to "
358
+ f"num_updates: {num_updates} >= max_update: {max_update}"
359
+ )
360
+
361
+ training_time_hours = trainer.cumulative_training_time() / (60 * 60)
362
+ if (
363
+ cfg.optimization.stop_time_hours > 0
364
+ and training_time_hours > cfg.optimization.stop_time_hours
365
+ ):
366
+ should_stop = True
367
+ logger.info(
368
+ f"Stopping training due to "
369
+ f"cumulative_training_time: {training_time_hours} > "
370
+ f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
371
+ )
372
+
373
+ do_save = (
374
+ (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
375
+ or should_stop
376
+ or (
377
+ cfg.checkpoint.save_interval_updates > 0
378
+ and num_updates > 0
379
+ and num_updates % cfg.checkpoint.save_interval_updates == 0
380
+ and num_updates >= cfg.dataset.validate_after_updates
381
+ )
382
+ )
383
+ do_validate = (
384
+ (not end_of_epoch and do_save) # validate during mid-epoch saves
385
+ or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
386
+ or should_stop
387
+ or (
388
+ cfg.dataset.validate_interval_updates > 0
389
+ and num_updates > 0
390
+ and num_updates % cfg.dataset.validate_interval_updates == 0
391
+ )
392
+ ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates
393
+
394
+ # Validate
395
+ valid_losses = [None]
396
+ if do_validate:
397
+ valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
398
+
399
+ should_stop |= should_stop_early(cfg, valid_losses[0])
400
+
401
+ # Save checkpoint
402
+ if do_save or should_stop:
403
+ checkpoint_utils.save_checkpoint(
404
+ cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
405
+ )
406
+
407
+ return valid_losses, should_stop
408
+
409
+
410
+ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
411
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
412
+ return stats
413
+
414
+
415
+ def validate(
416
+ cfg: DictConfig,
417
+ trainer: Trainer,
418
+ task: tasks.FairseqTask,
419
+ epoch_itr,
420
+ subsets: List[str],
421
+ ) -> List[Optional[float]]:
422
+ """Evaluate the model on the validation set(s) and return the losses."""
423
+
424
+ if cfg.dataset.fixed_validation_seed is not None:
425
+ # set fixed seed for every validation
426
+ utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
427
+
428
+ trainer.begin_valid_epoch(epoch_itr.epoch)
429
+ valid_losses = []
430
+ for subset in subsets:
431
+ logger.info('begin validation on "{}" subset'.format(subset))
432
+
433
+ # Initialize data iterator
434
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(
435
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
436
+ )
437
+ if cfg.common.tpu:
438
+ itr = utils.tpu_data_loader(itr)
439
+ progress = progress_bar.progress_bar(
440
+ itr,
441
+ log_format=cfg.common.log_format,
442
+ log_interval=cfg.common.log_interval,
443
+ epoch=epoch_itr.epoch,
444
+ prefix=f"valid on '{subset}' subset",
445
+ tensorboard_logdir=(
446
+ cfg.common.tensorboard_logdir
447
+ if distributed_utils.is_master(cfg.distributed_training)
448
+ else None
449
+ ),
450
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
451
+ wandb_project=(
452
+ cfg.common.wandb_project
453
+ if distributed_utils.is_master(cfg.distributed_training)
454
+ else None
455
+ ),
456
+ wandb_run_name=os.environ.get(
457
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
458
+ ),
459
+ )
460
+
461
+ # create a new root metrics aggregator so validation metrics
462
+ # don't pollute other aggregators (e.g., train meters)
463
+ with metrics.aggregate(new_root=True) as agg:
464
+ for i, sample in enumerate(progress):
465
+ if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
466
+ break
467
+ trainer.valid_step(sample)
468
+
469
+ # log validation stats
470
+ if hasattr(task, 'get_valid_stats'):
471
+ stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values())
472
+ else:
473
+ stats = agg.get_smoothed_values()
474
+ stats = get_valid_stats(cfg, trainer, stats)
475
+
476
+ if hasattr(task, "post_validate"):
477
+ task.post_validate(trainer.get_model(), stats, agg)
478
+
479
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
480
+
481
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
482
+ return valid_losses
483
+
484
+
485
+ def get_valid_stats(
486
+ cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
487
+ ) -> Dict[str, Any]:
488
+ stats["num_updates"] = trainer.get_num_updates()
489
+ if hasattr(checkpoint_utils.save_checkpoint, "best"):
490
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
491
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
492
+ stats[key] = best_function(
493
+ checkpoint_utils.save_checkpoint.best,
494
+ stats[cfg.checkpoint.best_checkpoint_metric],
495
+ )
496
+ return stats
497
+
498
+
499
+ def cli_main(
500
+ modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
501
+ ) -> None:
502
+ parser = options.get_training_parser()
503
+ args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
504
+
505
+ cfg = convert_namespace_to_omegaconf(args)
506
+
507
+ if cfg.common.use_plasma_view:
508
+ server = PlasmaStore(path=cfg.common.plasma_path)
509
+ logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")
510
+
511
+ if args.profile:
512
+ with torch.cuda.profiler.profile():
513
+ with torch.autograd.profiler.emit_nvtx():
514
+ distributed_utils.call_main(cfg, main)
515
+ else:
516
+ distributed_utils.call_main(cfg, main)
517
+
518
+ # if cfg.common.use_plasma_view:
519
+ # server.server.kill()
520
+
521
+
522
+ if __name__ == "__main__":
523
+ cli_main()
trainer.py ADDED
@@ -0,0 +1,1531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Train a network across multiple GPUs.
8
+ """
9
+
10
+ import contextlib
11
+ import logging
12
+ import sys
13
+ import time
14
+ from argparse import Namespace
15
+ from itertools import chain
16
+ from typing import Any, Dict, List
17
+
18
+ import torch
19
+ from fairseq import models, optim, utils
20
+ from fairseq.dataclass.configs import FairseqConfig
21
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
22
+ from fairseq.distributed import utils as distributed_utils
23
+ from fairseq.file_io import PathManager
24
+ from fairseq.logging import meters, metrics
25
+ from fairseq.models.ema import build_ema
26
+ from fairseq.nan_detector import NanDetector
27
+ from fairseq.optim import lr_scheduler
28
+ from omegaconf import OmegaConf
29
+
30
+ from utils import checkpoint_utils
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Trainer(object):
36
+ """Main class for data parallel training.
37
+
38
+ This class supports synchronous distributed data parallel training,
39
+ where multiple workers each have a full model replica and gradients
40
+ are accumulated across workers before each update. We use
41
+ :class:`~torch.nn.parallel.DistributedDataParallel` to handle
42
+ communication of the gradients across workers.
43
+ """
44
+
45
+ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None):
46
+
47
+ if isinstance(cfg, Namespace):
48
+ logger.warning(
49
+ "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
50
+ )
51
+ cfg = convert_namespace_to_omegaconf(cfg)
52
+
53
+ self.cfg = cfg
54
+ self.task = task
55
+
56
+ # catalog shared parameters
57
+ shared_params = _catalog_shared_params(model)
58
+ self.tpu = cfg.common.tpu
59
+ self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
60
+ if self.cuda:
61
+ self.device = torch.device("cuda")
62
+ elif self.tpu:
63
+ self.device = utils.get_tpu_device()
64
+ else:
65
+ self.device = torch.device("cpu")
66
+
67
+ if self.is_fsdp:
68
+ import fairscale
69
+ if self.cfg.common.bf16:
70
+ raise ValueError(
71
+ "FullyShardedDataParallel is not compatible with --bf16 or "
72
+ "--memory-efficient-bf16"
73
+ )
74
+ if self.cfg.distributed_training.zero_sharding != "none":
75
+ raise ValueError(
76
+ "FullyShardedDataParallel is not compatible with --zero-sharding "
77
+ "option (it's already built in)"
78
+ )
79
+ if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0":
80
+ raise RuntimeError(
81
+ "Please update to fairscale 0.4.0 or newer when combining "
82
+ "--update-freq with FullyShardedDataParallel"
83
+ )
84
+ else:
85
+ if (
86
+ hasattr(self.cfg.distributed_training, "cpu_offload")
87
+ and self.cfg.distributed_training.cpu_offload
88
+ ):
89
+ raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded")
90
+
91
+ # copy model and criterion to current device/dtype
92
+ self._criterion = criterion
93
+ self._model = model
94
+ if not self.is_fsdp:
95
+ if cfg.common.fp16:
96
+ assert not cfg.common.amp, "Cannot use fp16 and AMP together"
97
+ self._criterion = self._criterion.half()
98
+ self._model = self._model.half()
99
+ elif cfg.common.bf16:
100
+ self._criterion = self._criterion.to(dtype=torch.bfloat16)
101
+ self._model = self._model.to(dtype=torch.bfloat16)
102
+ elif cfg.common.amp:
103
+ self._amp_retries = 0
104
+ if (
105
+ not cfg.distributed_training.pipeline_model_parallel
106
+ # the DistributedFairseqModel wrapper will handle moving to device,
107
+ # so only handle cases which don't use the wrapper
108
+ and not self.use_distributed_wrapper
109
+ ):
110
+ self._criterion = self._criterion.to(device=self.device)
111
+ self._model = self._model.to(device=self.device)
112
+ self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
113
+ self.last_device = None
114
+ if self.cuda and self.pipeline_model_parallel:
115
+ self.last_device = torch.device(
116
+ cfg.distributed_training.pipeline_devices[-1]
117
+ )
118
+
119
+ # check that shared parameters are preserved after device transfer
120
+ for shared_param in shared_params:
121
+ ref = _get_module_by_path(self._model, shared_param[0])
122
+ for path in shared_param[1:]:
123
+ logger.info(
124
+ "detected shared parameter: {} <- {}".format(shared_param[0], path)
125
+ )
126
+ _set_module_by_path(self._model, path, ref)
127
+
128
+ self._dummy_batch = None # indicates we don't have a dummy batch at first
129
+ self._lr_scheduler = None
130
+ self._num_updates = 0
131
+ self._num_xla_compiles = 0 # for TPUs
132
+ self._optim_history = None
133
+ self._optimizer = None
134
+ self._warn_once = set()
135
+ self._wrapped_criterion = None
136
+ self._wrapped_model = None
137
+ self._ema = None
138
+
139
+ # TODO(myleott): support tpu
140
+ if self.cuda and self.data_parallel_world_size > 1:
141
+ self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
142
+ else:
143
+ self._grad_norm_buf = None
144
+
145
+ self.quantizer = quantizer
146
+ if self.quantizer is not None:
147
+ self.quantizer.set_trainer(self)
148
+
149
+ # get detailed cuda environment
150
+ if self.cuda:
151
+ self.cuda_env = utils.CudaEnvironment()
152
+ if self.data_parallel_world_size > 1:
153
+ self.cuda_env_arr = distributed_utils.all_gather_list(
154
+ self.cuda_env, group=distributed_utils.get_global_group()
155
+ )
156
+ else:
157
+ self.cuda_env_arr = [self.cuda_env]
158
+ if self.data_parallel_rank == 0:
159
+ utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
160
+ else:
161
+ self.cuda_env = None
162
+ self.cuda_env_arr = None
163
+
164
+ metrics.log_start_time("wall", priority=790, round=0)
165
+
166
+ self._start_time = time.time()
167
+ self._previous_training_time = 0
168
+ self._cumulative_training_time = None
169
+
170
+ def reinitialize(self):
171
+ """Reinitialize the Trainer, typically after model params change."""
172
+ self._lr_scheduler = None
173
+ self._optimizer = None
174
+ self._wrapped_criterion = None
175
+ self._wrapped_model = None
176
+
177
+ @property
178
+ def data_parallel_world_size(self):
179
+ if self.cfg.distributed_training.distributed_world_size == 1:
180
+ return 1
181
+ return distributed_utils.get_data_parallel_world_size()
182
+
183
+ @property
184
+ def data_parallel_process_group(self):
185
+ return distributed_utils.get_data_parallel_group()
186
+
187
+ @property
188
+ def data_parallel_rank(self):
189
+ if self.cfg.distributed_training.distributed_world_size == 1:
190
+ return 0
191
+ return distributed_utils.get_data_parallel_rank()
192
+
193
+ @property
194
+ def is_data_parallel_master(self):
195
+ # NOTE: this returns true for all model parallel replicas with data
196
+ # parallel rank 0
197
+ return self.data_parallel_rank == 0
198
+
199
+ @property
200
+ def use_distributed_wrapper(self) -> bool:
201
+ return (
202
+ self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf
203
+ ) or (
204
+ self.is_fsdp and self.cfg.distributed_training.cpu_offload
205
+ )
206
+
207
+ @property
208
+ def should_save_checkpoint_on_current_rank(self) -> bool:
209
+ """Indicates whether to save checkpoints on the current DDP rank."""
210
+ if (
211
+ self.is_fsdp and self.cfg.distributed_training.use_sharded_state
212
+ ) or getattr(self.cfg.model, "base_layers", 0) > 0:
213
+ return True
214
+ else:
215
+ return self.is_data_parallel_master
216
+
217
+ @property
218
+ def always_call_state_dict_during_save_checkpoint(self) -> bool:
219
+ if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state:
220
+ # FSDP calls communication collective when consolidating checkpoints
221
+ return True
222
+ else:
223
+ return False
224
+
225
+ @property
226
+ def checkpoint_suffix(self) -> str:
227
+ """Suffix to add to the checkpoint file name."""
228
+ if self.is_fsdp and self.cfg.distributed_training.use_sharded_state:
229
+ return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format(
230
+ self.data_parallel_rank
231
+ )
232
+ else:
233
+ return self.cfg.checkpoint.checkpoint_suffix or ""
234
+
235
+ @property
236
+ def criterion(self):
237
+ if self._wrapped_criterion is None:
238
+ if utils.has_parameters(self._criterion) and self.use_distributed_wrapper:
239
+ self._wrapped_criterion = models.DistributedFairseqModel(
240
+ self.cfg.distributed_training,
241
+ self._criterion,
242
+ process_group=self.data_parallel_process_group,
243
+ device=self.device,
244
+ )
245
+ else:
246
+ self._wrapped_criterion = self._criterion
247
+ return self._wrapped_criterion
248
+
249
+ @property
250
+ def model(self):
251
+ if self._wrapped_model is None:
252
+ if self.use_distributed_wrapper:
253
+ self._wrapped_model = models.DistributedFairseqModel(
254
+ self.cfg.distributed_training,
255
+ self._model,
256
+ process_group=self.data_parallel_process_group,
257
+ device=self.device,
258
+ )
259
+ else:
260
+ self._wrapped_model = self._model
261
+ return self._wrapped_model
262
+
263
+ @property
264
+ def ema(self):
265
+ if self._ema is None:
266
+ self._build_ema()
267
+ return self._ema
268
+
269
+ def _build_ema(self):
270
+ if self.cfg.ema.store_ema:
271
+ self._ema = build_ema(self._model, self.cfg.ema, self.device)
272
+ logger.info(
273
+ "Exponential Moving Average Shadow Model is initialized."
274
+ )
275
+
276
+ @property
277
+ def optimizer(self):
278
+ if self._optimizer is None:
279
+ self._build_optimizer()
280
+ return self._optimizer
281
+
282
+ @property
283
+ def lr_scheduler(self):
284
+ if self._lr_scheduler is None:
285
+ self._build_optimizer() # this will initialize self._lr_scheduler
286
+ return self._lr_scheduler
287
+
288
+ def _build_optimizer(self):
289
+ params = list(
290
+ filter(
291
+ lambda p: p.requires_grad,
292
+ chain(self.model.parameters(), self.criterion.parameters()),
293
+ )
294
+ )
295
+
296
+ if self.is_fsdp and self.cfg.common.fp16:
297
+ # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper,
298
+ # mostly for the grad scaling. But if we don't have the
299
+ # --memory-efficient-fp16 flag set, then we're effectively doing
300
+ # regular --fp16 and can allow the use of optimizers that would
301
+ # otherwise be unsupported by MemoryEfficientFP16Optimizer.
302
+ allow_unsupported = not self.cfg.common.memory_efficient_fp16
303
+ self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
304
+ self.cfg, params, allow_unsupported=allow_unsupported
305
+ )
306
+ elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp:
307
+ if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
308
+ logger.info(
309
+ "NOTE: your device does NOT support faster training with --fp16 or --amp, "
310
+ "please switch to FP32 which is likely to be faster"
311
+ )
312
+ if (
313
+ self.cfg.common.memory_efficient_fp16
314
+ or self.cfg.common.memory_efficient_bf16
315
+ ):
316
+ self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
317
+ self.cfg, params
318
+ )
319
+ elif self.cfg.common.amp:
320
+ self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params)
321
+ else:
322
+ self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
323
+ else:
324
+ if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
325
+ logger.info("NOTE: your device may support faster training with --fp16 or --amp")
326
+ self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
327
+
328
+ if self.is_fsdp:
329
+ assert (
330
+ not self.cfg.optimization.use_bmuf
331
+ ), "--ddp-backend=fully_sharded is not compatible with BMUF"
332
+ assert self._optimizer.supports_flat_params, (
333
+ "--ddp-backend=fully_sharded is only compatible with pointwise "
334
+ "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). "
335
+ "However, the sharding will result in slightly different results when "
336
+ "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)"
337
+ )
338
+
339
+ if self.cfg.optimization.use_bmuf:
340
+ self._optimizer = optim.FairseqBMUF(
341
+ self.cfg.bmuf,
342
+ self._optimizer,
343
+ )
344
+
345
+ if self.cfg.distributed_training.zero_sharding == "os":
346
+ if (
347
+ self.cfg.common.fp16
348
+ and not self.cfg.common.memory_efficient_fp16
349
+ and not self.cfg.common.memory_efficient_bf16
350
+ ) and not self.cfg.common.fp16_no_flatten_grads:
351
+ raise ValueError(
352
+ "ZeRO is incomptabile with fp16 and flattened grads. "
353
+ "Please use --fp16-no-flatten-grads"
354
+ )
355
+ else:
356
+ optim.shard_(self._optimizer, self.data_parallel_process_group)
357
+
358
+ # We should initialize the learning rate scheduler immediately after
359
+ # building the optimizer, so that the initial learning rate is set.
360
+ self._lr_scheduler = lr_scheduler.build_lr_scheduler(
361
+ self.cfg.lr_scheduler,
362
+ self.optimizer,
363
+ )
364
+ self._lr_scheduler.step_update(0)
365
+
366
+ @property
367
+ def is_fsdp(self):
368
+ return self.cfg.distributed_training.ddp_backend == "fully_sharded"
369
+
370
+ def consolidate_optimizer(self):
371
+ """For OSS, we need to consolidate the state dict."""
372
+ if self.cfg.checkpoint.no_save_optimizer_state:
373
+ return
374
+ self._gathered_optim_state = None
375
+ if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
376
+ self.optimizer.optimizer.consolidate_state_dict()
377
+ elif self.is_fsdp and not self.model.use_sharded_state:
378
+ st = self.model.gather_full_optim_state_dict(
379
+ self.optimizer
380
+ ) # only returns on rank 0
381
+ self._gathered_optim_state = st
382
+
383
+ def state_dict(self):
384
+ state_dict = {
385
+ "args": None, # legacy
386
+ "cfg": (
387
+ OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True)
388
+ if OmegaConf.is_config(self.cfg)
389
+ else self.cfg
390
+ ),
391
+ "model": self.model.state_dict(),
392
+ "criterion": (
393
+ self.criterion.state_dict()
394
+ if utils.has_parameters(self.criterion)
395
+ else None
396
+ ),
397
+ "optimizer_history": (self._optim_history or [])
398
+ + [
399
+ {
400
+ "criterion_name": self.get_criterion().__class__.__name__,
401
+ "optimizer_name": self.optimizer.__class__.__name__,
402
+ "lr_scheduler_state": self.lr_scheduler.state_dict(),
403
+ "num_updates": self.get_num_updates(),
404
+ }
405
+ ],
406
+ "task_state": self.task.state_dict() if self.task is not None else {},
407
+ "extra_state": {
408
+ "metrics": metrics.state_dict(),
409
+ "previous_training_time": self.cumulative_training_time(),
410
+ },
411
+ }
412
+ if self.cfg.ema.store_ema:
413
+ # Save EMA model state as extra state
414
+ state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict()
415
+ if self.cfg.ema.ema_fp32:
416
+ # Save EMA params in fp32
417
+ state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params
418
+ if not self.cfg.checkpoint.no_save_optimizer_state:
419
+ if self._gathered_optim_state is not None:
420
+ state_dict["last_optimizer_state"] = self._gathered_optim_state
421
+ self._gathered_optim_state = None
422
+ else:
423
+ state_dict["last_optimizer_state"] = self.optimizer.state_dict()
424
+ if self.is_fsdp:
425
+ # save meta data for recombining checkpoint upon loading
426
+ state_dict["fsdp_metadata"] = self.model.local_metadata_dict()
427
+ return state_dict
428
+
429
+ def save_checkpoint(self, filename, extra_state):
430
+ """Save all training state in a checkpoint file."""
431
+ logger.info(f"Saving checkpoint to {filename}")
432
+ # call state_dict on all ranks in case it needs internal communication
433
+ state_dict = utils.move_to_cpu(self.state_dict())
434
+ state_dict["extra_state"].update(extra_state)
435
+ if self.should_save_checkpoint_on_current_rank:
436
+ checkpoint_utils.torch_persistent_save(
437
+ state_dict,
438
+ filename,
439
+ async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
440
+ )
441
+ logger.info(f"Finished saving checkpoint to {filename}")
442
+
443
+ def load_checkpoint(
444
+ self,
445
+ filename,
446
+ reset_optimizer=False,
447
+ reset_lr_scheduler=False,
448
+ optimizer_overrides=None,
449
+ reset_meters=False,
450
+ ):
451
+ """
452
+ Load all training state from a checkpoint file.
453
+ rank = 0 will load the checkpoint, and then broadcast it to all
454
+ other ranks.
455
+ """
456
+ extra_state, self._optim_history, last_optim_state = None, [], None
457
+
458
+ logger.info(f"Preparing to load checkpoint {filename}")
459
+ is_distributed = self.data_parallel_world_size > 1
460
+ bexists = PathManager.isfile(filename)
461
+ if bexists:
462
+ load_on_all_ranks = (
463
+ self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks
464
+ # TPUs don't support broadcast yet, so load checkpoints
465
+ # on every worker for now
466
+ or self.tpu
467
+ # FSDP requires loading checkpoint shards on all ranks
468
+ or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state)
469
+ or getattr(self.cfg.model, "base_layers", 0) > 0
470
+ )
471
+
472
+ if load_on_all_ranks or self.data_parallel_rank == 0:
473
+ state = checkpoint_utils.load_checkpoint_to_cpu(
474
+ filename, load_on_all_ranks=load_on_all_ranks
475
+ )
476
+ last_optim_state = state.get("last_optimizer_state", None)
477
+
478
+ # If doing zero_sharding, do not broadcast global optimizer
479
+ # state. Later we will broadcast sharded states to each rank
480
+ # to avoid memory from exploding.
481
+ if (
482
+ not load_on_all_ranks
483
+ and self.cfg.distributed_training.zero_sharding == "os"
484
+ and "last_optimizer_state" in state
485
+ and is_distributed
486
+ ):
487
+ state["last_optimizer_state"] = "SHARDED"
488
+ else:
489
+ last_optim_state = None
490
+ state = None
491
+
492
+ if is_distributed and not load_on_all_ranks:
493
+ state = distributed_utils.broadcast_object(
494
+ state,
495
+ src_rank=0,
496
+ group=self.data_parallel_process_group,
497
+ dist_device=self.device,
498
+ )
499
+ if self.data_parallel_rank > 0:
500
+ last_optim_state = state.get("last_optimizer_state", None)
501
+
502
+ # load model parameters
503
+ try:
504
+ if self.cfg.checkpoint.use_ema_weights_to_init_param and "extra_state" in state and "ema" in state["extra_state"]:
505
+ logger.info("use_ema_weights_to_init_param = True, will use EMA weights in the ckpt to init the model param...")
506
+ ema_state_dict = state["extra_state"]["ema_fp32_params"] if "ema_fp32_params" in state["extra_state"] else state["extra_state"]["ema"]
507
+ self.model.load_state_dict(
508
+ ema_state_dict, strict=True, model_cfg=self.cfg.model
509
+ )
510
+ else:
511
+ self.model.load_state_dict(
512
+ state["model"], strict=True, model_cfg=self.cfg.model
513
+ )
514
+ # save memory for later steps
515
+ if not (self.cfg.ema.store_ema and (self.cfg.checkpoint.use_latest_weights_to_init_ema or not ("extra_state" in state and "ema" in state["extra_state"]))):
516
+ del state["model"]
517
+ if utils.has_parameters(self.get_criterion()):
518
+ self.get_criterion().load_state_dict(
519
+ state["criterion"], strict=True
520
+ )
521
+ del state["criterion"]
522
+
523
+ except Exception:
524
+ raise Exception(
525
+ "Cannot load model parameters from checkpoint {}; "
526
+ "please ensure that the architectures match.".format(filename)
527
+ )
528
+ extra_state = state["extra_state"]
529
+ self._optim_history = state["optimizer_history"]
530
+
531
+ if last_optim_state is not None and not reset_optimizer:
532
+ # rebuild optimizer after loading model, since params may have changed
533
+ self._build_optimizer()
534
+
535
+ # only reload optimizer and lr_scheduler if they match
536
+ last_optim = self._optim_history[-1]
537
+ assert (
538
+ last_optim["criterion_name"] == self.get_criterion().__class__.__name__
539
+ ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}"
540
+ assert (
541
+ last_optim["optimizer_name"] == self.optimizer.__class__.__name__
542
+ ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
543
+
544
+ if not reset_lr_scheduler:
545
+ self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
546
+
547
+ if self.is_fsdp and not self.model.use_sharded_state:
548
+ # if use_sharded_state, the last_optim_state is already sharded, skip this
549
+ last_optim_state = self.model.get_shard_from_optim_state_dict(
550
+ last_optim_state
551
+ )
552
+ elif not load_on_all_ranks and is_distributed:
553
+ last_optim_state = self.optimizer.broadcast_global_state_dict(
554
+ last_optim_state
555
+ )
556
+
557
+ self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
558
+
559
+ self.set_num_updates(last_optim["num_updates"])
560
+
561
+ if extra_state is not None:
562
+ itr_state = extra_state["train_iterator"]
563
+ epoch = itr_state["epoch"]
564
+
565
+ if "previous_training_time" in extra_state:
566
+ self._previous_training_time = extra_state["previous_training_time"]
567
+ self._start_time = time.time()
568
+
569
+ self.lr_step(epoch)
570
+
571
+ if (
572
+ itr_state.get("version", 1) >= 2
573
+ and itr_state["iterations_in_epoch"] == 0
574
+ ):
575
+ # reset meters at start of epoch
576
+ reset_meters = True
577
+
578
+ if "metrics" in extra_state and not reset_meters:
579
+ metrics.load_state_dict(extra_state["metrics"])
580
+
581
+ # reset TimeMeters, since their start times don't make sense anymore
582
+ for meter in metrics.get_meters("default"):
583
+ if isinstance(meter, meters.TimeMeter):
584
+ meter.reset()
585
+
586
+ if self.cfg.ema.store_ema:
587
+ if self.cfg.checkpoint.use_latest_weights_to_init_ema or "ema" not in extra_state:
588
+ if "ema" not in extra_state:
589
+ logger.warn(
590
+ "EMA not found in checkpoint. But store_ema is True. "
591
+ "EMA is re-initialized from checkpoint."
592
+ )
593
+ elif self.cfg.checkpoint.use_latest_weights_to_init_ema:
594
+ logger.info(
595
+ "use_latest_weights_to_init_ema = True. EMA is re-initialized from checkpoint."
596
+ )
597
+ self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32)
598
+ del state["model"]
599
+ else:
600
+ logger.info(
601
+ "Loading EMA from checkpoint"
602
+ )
603
+ self.ema.restore(extra_state["ema"], build_fp32_params=False)
604
+
605
+ if self.cfg.ema.ema_fp32:
606
+ if "ema_fp32_params" in extra_state:
607
+ logger.info(
608
+ "Loading EMA fp32 params from checkpoint"
609
+ )
610
+ self.ema.build_fp32_params(extra_state["ema_fp32_params"])
611
+ else:
612
+ logger.info(
613
+ "Building EMA fp32 params from EMA model in checkpoint"
614
+ )
615
+ self.ema.build_fp32_params()
616
+
617
+ logger.info(
618
+ "Loaded checkpoint {} (epoch {} @ {} updates)".format(
619
+ filename, epoch, self.get_num_updates()
620
+ )
621
+ )
622
+
623
+ else:
624
+ logger.info("No existing checkpoint found {}".format(filename))
625
+
626
+ return extra_state
627
+
628
+ def get_train_iterator(
629
+ self,
630
+ epoch,
631
+ combine=True,
632
+ load_dataset=True,
633
+ data_selector=None,
634
+ shard_batch_itr=True,
635
+ disable_iterator_cache=False,
636
+ ):
637
+ """Return an EpochBatchIterator over the training set for a given epoch."""
638
+ if load_dataset:
639
+ logger.info("loading train data for epoch {}".format(epoch))
640
+ self.task.load_dataset(
641
+ self.cfg.dataset.train_subset,
642
+ epoch=epoch,
643
+ combine=combine,
644
+ data_selector=data_selector,
645
+ tpu=self.tpu,
646
+ )
647
+ batch_iterator = self.task.get_batch_iterator(
648
+ dataset=self.task.dataset(self.cfg.dataset.train_subset),
649
+ max_tokens=self.cfg.dataset.max_tokens,
650
+ max_sentences=self.cfg.dataset.batch_size,
651
+ max_positions=utils.resolve_max_positions(
652
+ self.task.max_positions(),
653
+ self.model.max_positions(),
654
+ self.cfg.dataset.max_tokens,
655
+ ),
656
+ ignore_invalid_inputs=True,
657
+ required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
658
+ seed=self.cfg.common.seed,
659
+ num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
660
+ shard_id=self.data_parallel_rank if shard_batch_itr else 0,
661
+ num_workers=self.cfg.dataset.num_workers,
662
+ epoch=epoch,
663
+ data_buffer_size=self.cfg.dataset.data_buffer_size,
664
+ disable_iterator_cache=disable_iterator_cache,
665
+ )
666
+ self.reset_dummy_batch(batch_iterator.first_batch)
667
+ batch_iterator.dataset.dataset._seek()
668
+ return batch_iterator
669
+
670
+ def get_valid_iterator(
671
+ self,
672
+ subset,
673
+ disable_iterator_cache=False,
674
+ ):
675
+ """Return an EpochBatchIterator over given validation subset for a given epoch."""
676
+ self.task.dataset(subset).dataset._seek()
677
+ batch_iterator = self.task.get_batch_iterator(
678
+ dataset=self.task.dataset(subset),
679
+ max_tokens=self.cfg.dataset.max_tokens_valid,
680
+ max_sentences=self.cfg.dataset.batch_size_valid,
681
+ max_positions=utils.resolve_max_positions(
682
+ self.task.max_positions(),
683
+ self.model.max_positions(),
684
+ ),
685
+ ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
686
+ required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
687
+ seed=self.cfg.common.seed,
688
+ num_shards=self.data_parallel_world_size,
689
+ shard_id=self.data_parallel_rank,
690
+ num_workers=self.cfg.dataset.num_workers,
691
+ # always pass a fixed "epoch" to keep validation data consistent
692
+ # across training epochs
693
+ epoch=1,
694
+ data_buffer_size=self.cfg.dataset.data_buffer_size,
695
+ disable_iterator_cache=disable_iterator_cache,
696
+ )
697
+ self.reset_dummy_batch(batch_iterator.first_batch)
698
+ batch_iterator.dataset.dataset._seek()
699
+ return batch_iterator
700
+
701
+ def begin_epoch(self, epoch):
702
+ """Called at the beginning of each epoch."""
703
+ logger.info("begin training epoch {}".format(epoch))
704
+
705
+ self.lr_step_begin_epoch(epoch)
706
+
707
+ if self.quantizer is not None:
708
+ self.quantizer.begin_epoch(epoch)
709
+
710
+ # task specific setup per epoch
711
+ self.task.begin_epoch(epoch, self.get_model())
712
+
713
+ if self.tpu:
714
+ import torch_xla.core.xla_model as xm
715
+
716
+ xm.rendezvous("begin_epoch") # wait for all workers
717
+ xm.mark_step()
718
+
719
+ def begin_valid_epoch(self, epoch):
720
+ """Called at the beginning of each validation epoch."""
721
+
722
+ # task specific setup per validation epoch
723
+ self.task.begin_valid_epoch(epoch, self.get_model())
724
+
725
+ def reset_dummy_batch(self, batch):
726
+ self._dummy_batch = batch
727
+
728
+ @metrics.aggregate("train")
729
+ def train_step(self, samples, raise_oom=False):
730
+ """Do forward, backward and parameter update."""
731
+ self._set_seed()
732
+ self.model.train()
733
+ self.criterion.train()
734
+ self.zero_grad()
735
+
736
+ metrics.log_start_time("train_wall", priority=800, round=0)
737
+
738
+ # If EMA is enabled through store_ema=True
739
+ # and task.uses_ema is True, pass the EMA model as a keyword
740
+ # argument to the task.
741
+ extra_kwargs = {}
742
+ if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
743
+ extra_kwargs["ema_model"] = self.ema.get_model()
744
+
745
+ # forward and backward pass
746
+ logging_outputs, sample_size, ooms = [], 0, 0
747
+ for i, sample in enumerate(samples): # delayed update loop
748
+ sample, is_dummy_batch = self._prepare_sample(sample)
749
+
750
+ def maybe_no_sync():
751
+ """
752
+ Whenever *samples* contains more than one mini-batch, we
753
+ want to accumulate gradients locally and only call
754
+ all-reduce in the last backwards pass.
755
+ """
756
+ if (
757
+ self.data_parallel_world_size > 1
758
+ and hasattr(self.model, "no_sync")
759
+ and i < len(samples) - 1
760
+ # The no_sync context manager results in increased memory
761
+ # usage with FSDP, since full-size gradients will be
762
+ # accumulated on each GPU. It's typically a better tradeoff
763
+ # to do the extra communication with FSDP.
764
+ and not self.is_fsdp
765
+ ):
766
+ return self.model.no_sync()
767
+ else:
768
+ return contextlib.ExitStack() # dummy contextmanager
769
+
770
+ try:
771
+ with maybe_no_sync():
772
+ # forward and backward
773
+ loss, sample_size_i, logging_output = self.task.train_step(
774
+ sample=sample,
775
+ model=self.model,
776
+ criterion=self.criterion,
777
+ optimizer=self.optimizer,
778
+ update_num=self.get_num_updates(),
779
+ ignore_grad=is_dummy_batch,
780
+ **extra_kwargs,
781
+ )
782
+ del loss
783
+
784
+ logging_outputs.append(logging_output)
785
+ sample_size += sample_size_i
786
+
787
+ # emptying the CUDA cache after the first step can
788
+ # reduce the chance of OOM
789
+ if self.cuda and self.get_num_updates() == 0:
790
+ torch.cuda.empty_cache()
791
+ except RuntimeError as e:
792
+ if "out of memory" in str(e):
793
+ self._log_oom(e)
794
+ if raise_oom:
795
+ raise e
796
+ logger.warning(
797
+ "attempting to recover from OOM in forward/backward pass"
798
+ )
799
+ ooms += 1
800
+ self.zero_grad()
801
+ if self.cuda:
802
+ torch.cuda.empty_cache()
803
+ if self.cfg.distributed_training.distributed_world_size == 1:
804
+ return None
805
+ else:
806
+ raise e
807
+
808
+ if self.tpu and i < len(samples) - 1:
809
+ # tpu-comment: every XLA operation before marking step is
810
+ # appended to the IR graph, and processing too many batches
811
+ # before marking step can lead to OOM errors.
812
+ # To handle gradient accumulation use case, we explicitly
813
+ # mark step here for every forward pass without a backward pass
814
+ self._xla_markstep_and_send_to_cpu()
815
+
816
+ if is_dummy_batch:
817
+ if torch.is_tensor(sample_size):
818
+ sample_size.zero_()
819
+ else:
820
+ sample_size *= 0.0
821
+
822
+ if torch.is_tensor(sample_size):
823
+ sample_size = sample_size.float()
824
+ else:
825
+ sample_size = float(sample_size)
826
+
827
+ # gather logging outputs from all replicas
828
+ if self._sync_stats():
829
+ train_time = self._local_cumulative_training_time()
830
+ logging_outputs, (
831
+ sample_size,
832
+ ooms,
833
+ total_train_time,
834
+ ) = self._aggregate_logging_outputs(
835
+ logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch
836
+ )
837
+ self._cumulative_training_time = (
838
+ total_train_time / self.data_parallel_world_size
839
+ )
840
+
841
+ overflow = False
842
+ try:
843
+ with torch.autograd.profiler.record_function("reduce-grads"):
844
+ # reduce gradients across workers
845
+ self.optimizer.all_reduce_grads(self.model)
846
+ if utils.has_parameters(self.criterion):
847
+ self.optimizer.all_reduce_grads(self.criterion)
848
+
849
+ with torch.autograd.profiler.record_function("multiply-grads"):
850
+ # multiply gradients by (data_parallel_size / sample_size) since
851
+ # DDP normalizes by the number of data parallel workers for
852
+ # improved fp16 precision.
853
+ # Thus we get (sum_of_gradients / sample_size) at the end.
854
+ # In case of fp16, this step also undoes loss scaling.
855
+ # (Debugging note: Some optimizers perform this scaling on the
856
+ # fly, so inspecting model.parameters() or optimizer.params may
857
+ # still show the original, unscaled gradients.)
858
+ numer = (
859
+ self.data_parallel_world_size
860
+ if not self.cfg.optimization.use_bmuf or self._sync_stats()
861
+ else 1
862
+ )
863
+ self.optimizer.multiply_grads(numer / (sample_size or 1.0))
864
+ # Note: (sample_size or 1.0) handles the case of a zero gradient, in a
865
+ # way that avoids CPU/device transfers in case sample_size is a GPU or
866
+ # TPU object. The assumption is that the gradient itself is also 0.
867
+
868
+ with torch.autograd.profiler.record_function("clip-grads"):
869
+ # clip grads
870
+ grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
871
+
872
+ # check that grad norms are consistent across workers
873
+ # on tpu check tensor is slow
874
+ if not self.tpu:
875
+ if (
876
+ not self.cfg.optimization.use_bmuf
877
+ and self.cfg.distributed_training.ddp_backend != "slow_mo"
878
+ ):
879
+ self._check_grad_norms(grad_norm)
880
+ if not torch.isfinite(grad_norm).all():
881
+ # in case of AMP, if gradients are Nan/Inf then
882
+ # optimizer step is still required
883
+ if self.cfg.common.amp:
884
+ overflow = True
885
+ else:
886
+ # check local gradnorm single GPU case, trigger NanDetector
887
+ raise FloatingPointError("gradients are Nan/Inf")
888
+
889
+ with torch.autograd.profiler.record_function("optimizer"):
890
+ # take an optimization step
891
+ self.task.optimizer_step(
892
+ self.optimizer, model=self.model, update_num=self.get_num_updates()
893
+ )
894
+ if self.cfg.common.amp and overflow:
895
+ if self._amp_retries == self.cfg.common.amp_batch_retries:
896
+ logger.info("AMP: skipping this batch.")
897
+ self._amp_retries = 0
898
+ else:
899
+ self._amp_retries += 1
900
+ return self.train_step(samples, raise_oom) # recursion to feed in same batch
901
+
902
+ except FloatingPointError:
903
+ # re-run the forward and backward pass with hooks attached to print
904
+ # out where it fails
905
+ self.zero_grad()
906
+ with NanDetector(self.get_model()):
907
+ for _, sample in enumerate(samples):
908
+ sample, _ = self._prepare_sample(sample)
909
+ self.task.train_step(
910
+ sample,
911
+ self.model,
912
+ self.criterion,
913
+ self.optimizer,
914
+ self.get_num_updates(),
915
+ ignore_grad=False,
916
+ **extra_kwargs,
917
+ )
918
+ raise
919
+ except OverflowError as e:
920
+ overflow = True
921
+ logger.info(
922
+ f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
923
+ )
924
+ grad_norm = torch.tensor(0.0).cuda()
925
+ self.zero_grad()
926
+ except RuntimeError as e:
927
+ if "out of memory" in str(e):
928
+ self._log_oom(e)
929
+ logger.error("OOM during optimization, irrecoverable")
930
+ raise e
931
+
932
+ # Some distributed wrappers (e.g., SlowMo) need access to the optimizer
933
+ # after the step
934
+ if hasattr(self.model, "perform_additional_optimizer_actions"):
935
+ if hasattr(self.optimizer, "fp32_params"):
936
+ self.model.perform_additional_optimizer_actions(
937
+ self.optimizer.optimizer, self.optimizer.fp32_params
938
+ )
939
+ else:
940
+ self.model.perform_additional_optimizer_actions(
941
+ self.optimizer.optimizer
942
+ )
943
+
944
+ logging_output = None
945
+ if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo":
946
+ self.set_num_updates(self.get_num_updates() + 1)
947
+
948
+ if self.cfg.ema.store_ema:
949
+ # Step EMA forward with new model.
950
+ self.ema.step(
951
+ self.get_model(),
952
+ self.get_num_updates(),
953
+ )
954
+ metrics.log_scalar(
955
+ "ema_decay",
956
+ self.ema.get_decay(),
957
+ priority=10000,
958
+ round=5,
959
+ weight=0,
960
+ )
961
+
962
+ if self.tpu:
963
+ import torch_xla.core.xla_model as xm
964
+
965
+ # mark step on TPUs
966
+ self._xla_markstep_and_send_to_cpu()
967
+
968
+ # only log stats every log_interval steps
969
+ # this causes wps to be misreported when log_interval > 1
970
+ logging_output = {}
971
+ if self.get_num_updates() % self.cfg.common.log_interval == 0:
972
+ # log memory usage
973
+ mem_info = xm.get_memory_info(self.device)
974
+ gb_free = mem_info["kb_free"] / 1024 / 1024
975
+ gb_total = mem_info["kb_total"] / 1024 / 1024
976
+ metrics.log_scalar(
977
+ "gb_free", gb_free, priority=1500, round=1, weight=0
978
+ )
979
+ metrics.log_scalar(
980
+ "gb_total", gb_total, priority=1600, round=1, weight=0
981
+ )
982
+ logging_outputs = self._xla_markstep_and_send_to_cpu(
983
+ logging_outputs
984
+ )
985
+ logging_output = self._reduce_and_log_stats(
986
+ logging_outputs, sample_size, grad_norm
987
+ )
988
+
989
+ # log whenever there's an XLA compilation, since these
990
+ # slow down training and may indicate opportunities for
991
+ # optimization
992
+ self._check_xla_compilation()
993
+ else:
994
+ if self.cuda and self.cuda_env is not None:
995
+ # log minimum free memory over the iteration
996
+ gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
997
+ torch.cuda.reset_peak_memory_stats()
998
+ gb_free = self.cuda_env.total_memory_in_GB - gb_used
999
+ metrics.log_scalar(
1000
+ "gb_free", gb_free, priority=1500, round=1, weight=0
1001
+ )
1002
+
1003
+ # log stats
1004
+ logging_output = self._reduce_and_log_stats(
1005
+ logging_outputs, sample_size, grad_norm
1006
+ )
1007
+
1008
+ # clear CUDA cache to reduce memory fragmentation
1009
+ if (
1010
+ self.cuda
1011
+ and self.cfg.common.empty_cache_freq > 0
1012
+ and (
1013
+ (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
1014
+ % self.cfg.common.empty_cache_freq
1015
+ )
1016
+ == 0
1017
+ ):
1018
+ torch.cuda.empty_cache()
1019
+
1020
+ if self.cfg.common.fp16 or self.cfg.common.amp:
1021
+ metrics.log_scalar(
1022
+ "loss_scale",
1023
+ (
1024
+ self.optimizer.scaler.loss_scale
1025
+ if self.cfg.common.fp16
1026
+ else self.optimizer.scaler.get_scale()
1027
+ ),
1028
+ priority=700,
1029
+ round=4,
1030
+ weight=0,
1031
+ )
1032
+
1033
+ metrics.log_stop_time("train_wall")
1034
+ return logging_output
1035
+
1036
+ @metrics.aggregate("valid")
1037
+ def valid_step(self, sample, raise_oom=False):
1038
+ """Do forward pass in evaluation mode."""
1039
+ if self.tpu:
1040
+ import torch_xla.core.xla_model as xm
1041
+
1042
+ xm.rendezvous("valid_step") # wait for all workers
1043
+
1044
+ # If EMA is enabled through store_ema=True
1045
+ # and task.uses_ema is True, pass the EMA model as a keyword
1046
+ # argument to the task.
1047
+ extra_kwargs = {}
1048
+ if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False):
1049
+ extra_kwargs["ema_model"] = self.ema.get_model()
1050
+
1051
+ with torch.no_grad():
1052
+ self.model.eval()
1053
+ self.criterion.eval()
1054
+
1055
+ sample, is_dummy_batch = self._prepare_sample(sample)
1056
+
1057
+ try:
1058
+ _loss, sample_size, logging_output = self.task.valid_step(
1059
+ sample, self.model, self.criterion, **extra_kwargs
1060
+ )
1061
+ except RuntimeError as e:
1062
+ if "out of memory" in str(e):
1063
+ self._log_oom(e)
1064
+ if not raise_oom:
1065
+ logger.warning(
1066
+ "ran out of memory in validation step, retrying batch"
1067
+ )
1068
+ for p in self.model.parameters():
1069
+ if p.grad is not None:
1070
+ p.grad = None # free some memory
1071
+ if self.cuda:
1072
+ torch.cuda.empty_cache()
1073
+ return self.valid_step(sample, raise_oom=True)
1074
+ raise e
1075
+
1076
+ logging_outputs = [logging_output]
1077
+ if is_dummy_batch:
1078
+ if torch.is_tensor(sample_size):
1079
+ sample_size.zero_()
1080
+ else:
1081
+ sample_size *= 0.0
1082
+
1083
+ # gather logging outputs from all replicas
1084
+ if self.data_parallel_world_size > 1:
1085
+ logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
1086
+ logging_outputs,
1087
+ sample_size,
1088
+ ignore=is_dummy_batch,
1089
+ )
1090
+
1091
+ # log validation stats
1092
+ if self.tpu:
1093
+ logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs)
1094
+ logging_output = self._reduce_and_log_stats(logging_outputs, sample_size)
1095
+
1096
+ return logging_output
1097
+
1098
+ def zero_grad(self):
1099
+ self.optimizer.zero_grad()
1100
+
1101
+ def lr_step_begin_epoch(self, epoch):
1102
+ """Adjust the learning rate at the beginning of the epoch."""
1103
+ self.lr_scheduler.step_begin_epoch(epoch)
1104
+ # prefer updating the LR based on the number of steps
1105
+ return self.lr_step_update()
1106
+
1107
+ def lr_reinit(self, total_updates, num_updates):
1108
+ self.lr_scheduler.reinit(total_updates, num_updates)
1109
+
1110
+ def lr_step(self, epoch, val_loss=None):
1111
+ """Adjust the learning rate at the end of the epoch."""
1112
+ self.lr_scheduler.step(epoch, val_loss)
1113
+ # prefer updating the LR based on the number of steps
1114
+ return self.lr_step_update()
1115
+
1116
+ def lr_step_update(self):
1117
+ """Update the learning rate after each update."""
1118
+ new_lr = self.lr_scheduler.step_update(self.get_num_updates())
1119
+ if isinstance(new_lr, dict):
1120
+ for k, v in new_lr.items():
1121
+ metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
1122
+ new_lr = new_lr.get("default", next(iter(new_lr.values())))
1123
+ else:
1124
+ metrics.log_scalar("lr", new_lr, weight=0, priority=300)
1125
+ return new_lr
1126
+
1127
+ def get_lr(self):
1128
+ """Get the current learning rate."""
1129
+ return self.optimizer.get_lr()
1130
+
1131
+ def get_model(self):
1132
+ """Get the (non-wrapped) model instance."""
1133
+ return self._model
1134
+
1135
+ def get_criterion(self):
1136
+ """Get the (non-wrapped) criterion instance."""
1137
+ return self._criterion
1138
+
1139
+ def get_meter(self, name):
1140
+ """[deprecated] Get a specific meter by name."""
1141
+ from fairseq import meters
1142
+
1143
+ if "get_meter" not in self._warn_once:
1144
+ self._warn_once.add("get_meter")
1145
+ utils.deprecation_warning(
1146
+ "Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
1147
+ )
1148
+
1149
+ train_meters = metrics.get_meters("train")
1150
+ if train_meters is None:
1151
+ train_meters = {}
1152
+
1153
+ if name == "train_loss" and "loss" in train_meters:
1154
+ return train_meters["loss"]
1155
+ elif name == "train_nll_loss":
1156
+ # support for legacy train.py, which assumed this meter is
1157
+ # always initialized
1158
+ m = train_meters.get("nll_loss", None)
1159
+ return m or meters.AverageMeter()
1160
+ elif name == "wall":
1161
+ # support for legacy train.py, which assumed this meter is
1162
+ # always initialized
1163
+ m = metrics.get_meter("default", "wall")
1164
+ return m or meters.TimeMeter()
1165
+ elif name == "wps":
1166
+ m = metrics.get_meter("train", "wps")
1167
+ return m or meters.TimeMeter()
1168
+ elif name in {"valid_loss", "valid_nll_loss"}:
1169
+ # support for legacy train.py, which assumed these meters
1170
+ # are always initialized
1171
+ k = name[len("valid_") :]
1172
+ m = metrics.get_meter("valid", k)
1173
+ return m or meters.AverageMeter()
1174
+ elif name == "oom":
1175
+ return meters.AverageMeter()
1176
+ elif name in train_meters:
1177
+ return train_meters[name]
1178
+ return None
1179
+
1180
+ def get_num_updates(self):
1181
+ """Get the number of parameters updates."""
1182
+ return self._num_updates
1183
+
1184
+ def set_num_updates(self, num_updates):
1185
+ """Set the number of parameters updates."""
1186
+ self._num_updates = num_updates
1187
+ self.lr_step_update()
1188
+ if self.quantizer:
1189
+ self.quantizer.step_update(self._num_updates)
1190
+ metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
1191
+
1192
+ def clip_grad_norm(self, clip_norm):
1193
+ def agg_norm_fn(total_norm):
1194
+ total_norm = total_norm.cuda().float() ** 2
1195
+ total_norm = distributed_utils.all_reduce(
1196
+ total_norm, group=self.data_parallel_process_group
1197
+ )
1198
+ return total_norm ** 0.5
1199
+
1200
+ should_agg_norm = (
1201
+ self.is_fsdp
1202
+ and (
1203
+ self.data_parallel_process_group is not None
1204
+ or torch.distributed.is_initialized()
1205
+ )
1206
+ )
1207
+ return self.optimizer.clip_grad_norm(
1208
+ clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None
1209
+ )
1210
+
1211
+ def cumulative_training_time(self):
1212
+ if self._cumulative_training_time is None:
1213
+ # single GPU
1214
+ return self._local_cumulative_training_time()
1215
+ else:
1216
+ return self._cumulative_training_time
1217
+
1218
+ def _local_cumulative_training_time(self):
1219
+ """Aggregate training time in seconds."""
1220
+ return time.time() - self._start_time + self._previous_training_time
1221
+
1222
+ def _fp_convert_sample(self, sample):
1223
+ def apply_half(t):
1224
+ if t.dtype is torch.float32:
1225
+ return t.to(dtype=torch.half)
1226
+ return t
1227
+
1228
+ def apply_bfloat16(t):
1229
+ if t.dtype is torch.float32:
1230
+ return t.to(dtype=torch.bfloat16)
1231
+ return t
1232
+
1233
+ if self.cfg.common.fp16:
1234
+ sample = utils.apply_to_sample(apply_half, sample)
1235
+
1236
+ if self.cfg.common.bf16:
1237
+ sample = utils.apply_to_sample(apply_bfloat16, sample)
1238
+
1239
+ return sample
1240
+
1241
+ def _prepare_sample(self, sample, is_dummy=False):
1242
+ if sample == "DUMMY":
1243
+ raise Exception(
1244
+ "Trying to use an uninitialized 'dummy' batch. This usually indicates "
1245
+ "that the total number of batches is smaller than the number of "
1246
+ "participating GPUs. Try reducing the batch size or using fewer GPUs."
1247
+ )
1248
+
1249
+ if sample is None or len(sample) == 0:
1250
+ assert (
1251
+ self._dummy_batch is not None and len(self._dummy_batch) > 0
1252
+ ), "Invalid dummy batch: {}".format(self._dummy_batch)
1253
+ sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
1254
+ return sample, True
1255
+
1256
+ # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth
1257
+ # it makes sense to do the format conversion on the CPU and then transfer
1258
+ # a smaller buffer to the device. This also saves GPU memory capacity.
1259
+
1260
+ if self.cfg.common.on_cpu_convert_precision:
1261
+ sample = self._fp_convert_sample(sample)
1262
+
1263
+ if self.cuda:
1264
+ if self.pipeline_model_parallel:
1265
+ if 'target' in sample:
1266
+ sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device)
1267
+ else:
1268
+ sample = utils.move_to_cuda(sample)
1269
+ elif self.tpu and is_dummy:
1270
+ # the dummy batch may not be on the appropriate device
1271
+ sample = utils.move_to_cuda(sample, device=self.device)
1272
+
1273
+ if not self.cfg.common.on_cpu_convert_precision:
1274
+ sample = self._fp_convert_sample(sample)
1275
+
1276
+ if self._dummy_batch == "DUMMY":
1277
+ self._dummy_batch = sample
1278
+
1279
+ return sample, False
1280
+
1281
+ def _set_seed(self):
1282
+ # Set seed based on args.seed and the update number so that we get
1283
+ # reproducible results when resuming from checkpoints
1284
+ seed = self.cfg.common.seed + self.get_num_updates()
1285
+ utils.set_torch_seed(seed)
1286
+
1287
+ def _sync_stats(self):
1288
+ # Return True if it's using multiple GPUs and DDP or multiple GPUs with
1289
+ # BMUF and it's a bmuf sync with warmup iterations completed before.
1290
+ if self.data_parallel_world_size == 1:
1291
+ return False
1292
+ elif self.cfg.optimization.use_bmuf:
1293
+ return (
1294
+ self.get_num_updates() + 1
1295
+ ) % self.cfg.bmuf.global_sync_iter == 0 and (
1296
+ self.get_num_updates() + 1
1297
+ ) > self.cfg.bmuf.warmup_iterations
1298
+ else:
1299
+ return True
1300
+
1301
+ def _log_oom(self, exc):
1302
+ msg = "OOM: Ran out of memory with exception: {}".format(exc)
1303
+ logger.warning(msg)
1304
+ if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
1305
+ for device_idx in range(torch.cuda.device_count()):
1306
+ logger.warning(torch.cuda.memory_summary(device=device_idx))
1307
+ sys.stderr.flush()
1308
+
1309
+ def _aggregate_logging_outputs(
1310
+ self,
1311
+ logging_outputs: List[Dict[str, Any]],
1312
+ *extra_stats_to_sum,
1313
+ ignore=False,
1314
+ ):
1315
+ if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()):
1316
+ return self._fast_stat_sync_sum(
1317
+ logging_outputs, *extra_stats_to_sum, ignore=ignore
1318
+ )
1319
+ else:
1320
+ return self._all_gather_list_sync(
1321
+ logging_outputs, *extra_stats_to_sum, ignore=ignore
1322
+ )
1323
+
1324
+ def _all_gather_list_sync(
1325
+ self,
1326
+ logging_outputs: List[Dict[str, Any]],
1327
+ *extra_stats_to_sum,
1328
+ ignore=False,
1329
+ ):
1330
+ """
1331
+ Sync logging outputs across workers. all_gather_list_sync is
1332
+ suitable when logging outputs are complex types.
1333
+ """
1334
+ if self.tpu:
1335
+ raise NotImplementedError
1336
+ if ignore:
1337
+ logging_outputs = []
1338
+ results = list(
1339
+ zip(
1340
+ *distributed_utils.all_gather_list(
1341
+ [logging_outputs] + list(extra_stats_to_sum),
1342
+ max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
1343
+ group=self.data_parallel_process_group,
1344
+ )
1345
+ )
1346
+ )
1347
+ logging_outputs, extra_stats_to_sum = results[0], results[1:]
1348
+ logging_outputs = list(chain.from_iterable(logging_outputs))
1349
+ extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
1350
+ return logging_outputs, extra_stats_to_sum
1351
+
1352
+ def _fast_stat_sync_sum(
1353
+ self,
1354
+ logging_outputs: List[Dict[str, Any]],
1355
+ *extra_stats_to_sum,
1356
+ ignore=False,
1357
+ ):
1358
+ """
1359
+ Sync logging outputs across workers. fast_stat_sync_sum is
1360
+ faster than all_gather_list_sync, but is only suitable when
1361
+ logging outputs are scalars and can be summed. Note that
1362
+ *logging_outputs* cannot contain any nested dicts/lists.
1363
+ """
1364
+ data = {}
1365
+ for i, stat in enumerate(extra_stats_to_sum):
1366
+ data["extra_stats_" + str(i)] = stat
1367
+ if len(logging_outputs) > 0:
1368
+ log_keys = list(logging_outputs[0].keys())
1369
+ for k in log_keys:
1370
+ if not ignore:
1371
+ v = sum(log[k] for log in logging_outputs if k in log)
1372
+ else:
1373
+ v = logging_outputs[0][k]
1374
+ v = torch.zeros_like(v) if torch.is_tensor(v) else 0
1375
+ data["logging_outputs_" + k] = v
1376
+ else:
1377
+ log_keys = None
1378
+
1379
+ data = distributed_utils.all_reduce_dict(
1380
+ data, device=self.device, group=self.data_parallel_process_group
1381
+ )
1382
+
1383
+ extra_stats_to_sum = [
1384
+ data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
1385
+ ]
1386
+ if log_keys is not None:
1387
+ logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
1388
+ else:
1389
+ logging_outputs = []
1390
+ return logging_outputs, extra_stats_to_sum
1391
+
1392
+ def _check_grad_norms(self, grad_norm):
1393
+ """Check that grad norms are consistent across workers."""
1394
+ if self._grad_norm_buf is not None:
1395
+ self._grad_norm_buf.zero_()
1396
+ self._grad_norm_buf[self.data_parallel_rank] = grad_norm
1397
+ distributed_utils.all_reduce(
1398
+ self._grad_norm_buf, group=self.data_parallel_process_group
1399
+ )
1400
+
1401
+ def is_consistent(tensor):
1402
+ max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
1403
+ return (
1404
+ (torch.isfinite(tensor).all()
1405
+ and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all())
1406
+ or
1407
+ (self.cfg.common.amp and not torch.isfinite(tensor).all())
1408
+ # in case of amp non-finite grads are fine
1409
+ )
1410
+
1411
+ if not is_consistent(self._grad_norm_buf):
1412
+ pretty_detail = "\n".join(
1413
+ "rank {:3d} = {:.8f}".format(r, n)
1414
+ for r, n in enumerate(self._grad_norm_buf.tolist())
1415
+ )
1416
+ error_detail = "grad_norm across the workers:\n{}\n".format(
1417
+ pretty_detail
1418
+ )
1419
+ # use FloatingPointError to trigger NanDetector
1420
+ raise FloatingPointError(
1421
+ "Fatal error: gradients are inconsistent between workers. "
1422
+ "Try --ddp-backend=legacy_ddp. "
1423
+ "Or are you mixing up different generation of GPUs in training?"
1424
+ + "\n"
1425
+ + "-" * 80
1426
+ + "\n{}\n".format(error_detail)
1427
+ + "-" * 80
1428
+ )
1429
+
1430
+ def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
1431
+ if grad_norm is not None and (
1432
+ not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
1433
+ ):
1434
+ metrics.log_speed("ups", 1.0, priority=100, round=2)
1435
+ metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
1436
+ if self.cfg.optimization.clip_norm > 0:
1437
+ metrics.log_scalar(
1438
+ "clip",
1439
+ torch.where(
1440
+ grad_norm > self.cfg.optimization.clip_norm,
1441
+ grad_norm.new_tensor(100),
1442
+ grad_norm.new_tensor(0),
1443
+ ),
1444
+ priority=500,
1445
+ round=1,
1446
+ )
1447
+
1448
+ with metrics.aggregate() as agg:
1449
+ if logging_outputs is not None:
1450
+ self.task.reduce_metrics(logging_outputs, self.get_criterion())
1451
+ del logging_outputs
1452
+
1453
+ # extra warning for criterions that don't properly log a loss value
1454
+ if "loss" not in agg:
1455
+ if "loss" not in self._warn_once:
1456
+ self._warn_once.add("loss")
1457
+ logger.warning(
1458
+ "Criterion.reduce_metrics did not log a 'loss' value, "
1459
+ "which may break some functionality"
1460
+ )
1461
+ metrics.log_scalar("loss", -1)
1462
+
1463
+ # support legacy interface
1464
+ if self.tpu:
1465
+ logging_output = {}
1466
+ else:
1467
+ logging_output = agg.get_smoothed_values()
1468
+ logging_output["sample_size"] = sample_size
1469
+ for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
1470
+ if key_to_delete in logging_output:
1471
+ del logging_output[key_to_delete]
1472
+ return logging_output
1473
+
1474
+ def _check_xla_compilation(self):
1475
+ import torch_xla.debug.metrics as met
1476
+
1477
+ compile_stats = met.metric_data("CompileTime")
1478
+ if compile_stats is None:
1479
+ return
1480
+ num_xla_compiles = compile_stats[0]
1481
+ if num_xla_compiles > self._num_xla_compiles:
1482
+ logger.warning(
1483
+ "XLA compilation detected on device #{}; too many of these can lead "
1484
+ "to slow training, but we expect a few in the beginning".format(
1485
+ self.cfg.distributed_training.distributed_rank
1486
+ )
1487
+ )
1488
+ self._num_xla_compiles = num_xla_compiles
1489
+
1490
+ def _xla_markstep_and_send_to_cpu(self, data=None):
1491
+ import torch_xla.core.xla_model as xm
1492
+
1493
+ xm.mark_step()
1494
+ if data is not None:
1495
+ from fairseq.utils import xla_device_to_cpu
1496
+
1497
+ return xla_device_to_cpu(data)
1498
+
1499
+
1500
+ def _catalog_shared_params(module, memo=None, prefix=""):
1501
+ if memo is None:
1502
+ first_call = True
1503
+ memo = {}
1504
+ else:
1505
+ first_call = False
1506
+ for name, param in module._parameters.items():
1507
+ param_prefix = prefix + ("." if prefix else "") + name
1508
+ if param not in memo:
1509
+ memo[param] = []
1510
+ memo[param].append(param_prefix)
1511
+ for name, m in module._modules.items():
1512
+ if m is None:
1513
+ continue
1514
+ submodule_prefix = prefix + ("." if prefix else "") + name
1515
+ _catalog_shared_params(m, memo, submodule_prefix)
1516
+ if first_call:
1517
+ return [x for x in memo.values() if len(x) > 1]
1518
+
1519
+
1520
+ def _get_module_by_path(module, path):
1521
+ path = path.split(".")
1522
+ for name in path:
1523
+ module = getattr(module, name)
1524
+ return module
1525
+
1526
+
1527
+ def _set_module_by_path(module, path, value):
1528
+ path = path.split(".")
1529
+ for name in path[:-1]:
1530
+ module = getattr(module, name)
1531
+ setattr(module, path[-1], value)
utils/BPE/__init__.py ADDED
File without changes
utils/BPE/dict.txt ADDED
The diff for this file is too large to render. See raw diff
utils/BPE/encoder.json ADDED
The diff for this file is too large to render. See raw diff
utils/BPE/vocab.bpe ADDED
The diff for this file is too large to render. See raw diff
utils/__init__.py ADDED
File without changes
utils/checkpoint_utils.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import ast
7
+ import collections
8
+ import contextlib
9
+ import logging
10
+ import numpy as np
11
+ import os
12
+ import re
13
+ import time
14
+ import traceback
15
+ import math
16
+ from collections import OrderedDict
17
+ from typing import Any, Dict, Optional, Union
18
+
19
+ import torch
20
+ from fairseq.dataclass.configs import CheckpointConfig
21
+ from fairseq.dataclass.utils import (
22
+ convert_namespace_to_omegaconf,
23
+ overwrite_args_by_name,
24
+ )
25
+ from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
26
+ from fairseq.file_io import PathManager
27
+ from fairseq.models import FairseqDecoder, FairseqEncoder
28
+ from omegaconf import DictConfig, open_dict, OmegaConf
29
+
30
+ from data import data_utils
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
36
+ from fairseq import meters
37
+
38
+ # only one worker should attempt to create the required dir
39
+ if trainer.data_parallel_rank == 0:
40
+ os.makedirs(cfg.save_dir, exist_ok=True)
41
+
42
+ prev_best = getattr(save_checkpoint, "best", val_loss)
43
+ if val_loss is not None:
44
+ best_function = max if cfg.maximize_best_checkpoint_metric else min
45
+ save_checkpoint.best = best_function(val_loss, prev_best)
46
+
47
+ if cfg.no_save:
48
+ return
49
+
50
+ trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
51
+
52
+ if not trainer.should_save_checkpoint_on_current_rank:
53
+ if trainer.always_call_state_dict_during_save_checkpoint:
54
+ trainer.state_dict()
55
+ return
56
+
57
+ write_timer = meters.StopwatchMeter()
58
+ write_timer.start()
59
+
60
+ epoch = epoch_itr.epoch
61
+ end_of_epoch = epoch_itr.end_of_epoch()
62
+ updates = trainer.get_num_updates()
63
+
64
+ logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
65
+
66
+ def is_better(a, b):
67
+ return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
68
+
69
+ suffix = trainer.checkpoint_suffix
70
+ checkpoint_conds = collections.OrderedDict()
71
+ checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
72
+ end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
73
+ )
74
+ checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
75
+ not end_of_epoch
76
+ and cfg.save_interval_updates > 0
77
+ and updates % cfg.save_interval_updates == 0
78
+ )
79
+ checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
80
+ not hasattr(save_checkpoint, "best")
81
+ or is_better(val_loss, save_checkpoint.best)
82
+ )
83
+ if val_loss is not None and cfg.keep_best_checkpoints > 0:
84
+ worst_best = getattr(save_checkpoint, "best", None)
85
+ chkpts = checkpoint_paths(
86
+ cfg.save_dir,
87
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
88
+ cfg.best_checkpoint_metric, suffix
89
+ ),
90
+ )
91
+ if len(chkpts) > 0:
92
+ p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
93
+ worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
94
+ # add random digits to resolve ties
95
+ with data_utils.numpy_seed(epoch, updates, val_loss):
96
+ rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
97
+
98
+ checkpoint_conds[
99
+ "checkpoint.best_{}_{:.3f}{}{}.pt".format(
100
+ cfg.best_checkpoint_metric,
101
+ val_loss,
102
+ rand_sfx,
103
+ suffix
104
+ )
105
+ ] = worst_best is None or is_better(val_loss, worst_best)
106
+ checkpoint_conds[
107
+ "checkpoint_last{}.pt".format(suffix)
108
+ ] = not cfg.no_last_checkpoints
109
+
110
+ extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
111
+ if hasattr(save_checkpoint, "best"):
112
+ extra_state.update({"best": save_checkpoint.best})
113
+
114
+ checkpoints = [
115
+ os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
116
+ ]
117
+ if len(checkpoints) > 0:
118
+ trainer.save_checkpoint(checkpoints[0], extra_state)
119
+ for cp in checkpoints[1:]:
120
+ if cfg.write_checkpoints_asynchronously:
121
+ # TODO[ioPath]: Need to implement a delayed asynchronous
122
+ # file copying/moving feature.
123
+ logger.warning(
124
+ f"ioPath is not copying {checkpoints[0]} to {cp} "
125
+ "since async write mode is on."
126
+ )
127
+ else:
128
+ assert PathManager.copy(
129
+ checkpoints[0], cp, overwrite=True
130
+ ), f"Failed to copy {checkpoints[0]} to {cp}"
131
+
132
+ write_timer.stop()
133
+ logger.info(
134
+ "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
135
+ checkpoints[0], epoch, updates, val_loss, write_timer.sum
136
+ )
137
+ )
138
+
139
+ if not end_of_epoch and cfg.keep_interval_updates > 0:
140
+ # remove old checkpoints; checkpoints are sorted in descending order
141
+ if cfg.keep_interval_updates_pattern == -1:
142
+ checkpoints = checkpoint_paths(
143
+ cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
144
+ )
145
+ else:
146
+ checkpoints = checkpoint_paths(
147
+ cfg.save_dir,
148
+ pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
149
+ keep_match=True,
150
+ )
151
+ checkpoints = [
152
+ x[0]
153
+ for x in checkpoints
154
+ if x[1] % cfg.keep_interval_updates_pattern != 0
155
+ ]
156
+
157
+ for old_chk in checkpoints[cfg.keep_interval_updates :]:
158
+ if os.path.lexists(old_chk):
159
+ os.remove(old_chk)
160
+ elif PathManager.exists(old_chk):
161
+ PathManager.rm(old_chk)
162
+
163
+ if cfg.keep_last_epochs > 0:
164
+ # remove old epoch checkpoints; checkpoints are sorted in descending order
165
+ checkpoints = checkpoint_paths(
166
+ cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
167
+ )
168
+ for old_chk in checkpoints[cfg.keep_last_epochs :]:
169
+ if os.path.lexists(old_chk):
170
+ os.remove(old_chk)
171
+ elif PathManager.exists(old_chk):
172
+ PathManager.rm(old_chk)
173
+
174
+ if cfg.keep_best_checkpoints > 0:
175
+ # only keep the best N checkpoints according to validation metric
176
+ checkpoints = checkpoint_paths(
177
+ cfg.save_dir,
178
+ pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
179
+ cfg.best_checkpoint_metric, suffix
180
+ ),
181
+ )
182
+ if not cfg.maximize_best_checkpoint_metric:
183
+ checkpoints = checkpoints[::-1]
184
+ for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
185
+ if os.path.lexists(old_chk):
186
+ os.remove(old_chk)
187
+ elif PathManager.exists(old_chk):
188
+ PathManager.rm(old_chk)
189
+
190
+
191
+ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
192
+ """
193
+ Load a checkpoint and restore the training iterator.
194
+
195
+ *passthrough_args* will be passed through to
196
+ ``trainer.get_train_iterator``.
197
+ """
198
+
199
+ reset_optimizer = cfg.reset_optimizer
200
+ reset_lr_scheduler = cfg.reset_lr_scheduler
201
+ optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
202
+ reset_meters = cfg.reset_meters
203
+ reset_dataloader = cfg.reset_dataloader
204
+
205
+ if cfg.finetune_from_model is not None and (
206
+ reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
207
+ ):
208
+ raise ValueError(
209
+ "--finetune-from-model can not be set together with either --reset-optimizer"
210
+ " or reset_lr_scheduler or reset_meters or reset_dataloader"
211
+ )
212
+
213
+ suffix = trainer.checkpoint_suffix
214
+ if (
215
+ cfg.restore_file == "checkpoint_last.pt"
216
+ ): # default value of restore_file is 'checkpoint_last.pt'
217
+ checkpoint_path = os.path.join(
218
+ cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
219
+ )
220
+ first_launch = not PathManager.exists(checkpoint_path)
221
+ if cfg.finetune_from_model is not None and first_launch:
222
+ # if there is no last checkpoint to restore, start the finetune from pretrained model
223
+ # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
224
+ if PathManager.exists(cfg.finetune_from_model):
225
+ checkpoint_path = cfg.finetune_from_model
226
+ reset_optimizer = True
227
+ reset_lr_scheduler = True
228
+ reset_meters = True
229
+ reset_dataloader = True
230
+ logger.info(
231
+ f"loading pretrained model from {checkpoint_path}: "
232
+ "optimizer, lr scheduler, meters, dataloader will be reset"
233
+ )
234
+ else:
235
+ raise ValueError(
236
+ f"--funetune-from-model {cfg.finetune_from_model} does not exist"
237
+ )
238
+ elif suffix is not None:
239
+ checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
240
+ else:
241
+ checkpoint_path = cfg.restore_file
242
+
243
+ if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
244
+ raise ValueError(
245
+ "--finetune-from-model and --restore-file (non-default value) "
246
+ "can not be specified together: " + str(cfg)
247
+ )
248
+
249
+ extra_state = trainer.load_checkpoint(
250
+ checkpoint_path,
251
+ reset_optimizer,
252
+ reset_lr_scheduler,
253
+ optimizer_overrides,
254
+ reset_meters=reset_meters,
255
+ )
256
+
257
+ if (
258
+ extra_state is not None
259
+ and "best" in extra_state
260
+ and not reset_optimizer
261
+ and not reset_meters
262
+ ):
263
+ save_checkpoint.best = extra_state["best"]
264
+
265
+ if extra_state is not None and not reset_dataloader:
266
+ # restore iterator from checkpoint
267
+ itr_state = extra_state["train_iterator"]
268
+ epoch_itr = trainer.get_train_iterator(
269
+ epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
270
+ )
271
+ epoch_itr.load_state_dict(itr_state)
272
+ _n = itr_state['iterations_in_epoch']
273
+ offset = sum(len(_) for _ in epoch_itr.batch_sampler[:_n])
274
+ epoch_itr.dataset.dataset._seek(offset=offset)
275
+ true_num = int(math.ceil(len(epoch_itr.dataset) / 8)) * 8
276
+ another_offset = ((epoch_itr.epoch - 1) * true_num + offset) // 8
277
+ if hasattr(epoch_itr.dataset, 'pure_text_dataset'):
278
+ text_offset = (2 * another_offset) % len(epoch_itr.dataset.pure_text_dataset)
279
+ epoch_itr.dataset.pure_text_dataset._seek(offset=text_offset)
280
+ if hasattr(epoch_itr.dataset, 'pure_image_dataset'):
281
+ image_offset = another_offset % len(epoch_itr.dataset.pure_image_dataset)
282
+ epoch_itr.dataset.pure_image_dataset._seek(offset=image_offset)
283
+ if hasattr(epoch_itr.dataset, 'detection_dataset'):
284
+ detection_offset = another_offset % len(epoch_itr.dataset.detection_dataset)
285
+ epoch_itr.dataset.detection_dataset._seek(offset=detection_offset)
286
+ else:
287
+ epoch_itr = trainer.get_train_iterator(
288
+ epoch=1, load_dataset=True, **passthrough_args
289
+ )
290
+
291
+ trainer.lr_step(epoch_itr.epoch)
292
+
293
+ return extra_state, epoch_itr
294
+
295
+
296
+ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
297
+ """Loads a checkpoint to CPU (with upgrading for backward compatibility).
298
+
299
+ If doing single-GPU training or if the checkpoint is only being loaded by at
300
+ most one process on each node (current default behavior is for only rank 0
301
+ to read the checkpoint from disk), load_on_all_ranks should be False to
302
+ avoid errors from torch.distributed not having been initialized or
303
+ torch.distributed.barrier() hanging.
304
+
305
+ If all processes on each node may be loading the checkpoint
306
+ simultaneously, load_on_all_ranks should be set to True to avoid I/O
307
+ conflicts.
308
+
309
+ There's currently no support for > 1 but < all processes loading the
310
+ checkpoint on each node.
311
+ """
312
+ local_path = PathManager.get_local_path(path)
313
+ # The locally cached file returned by get_local_path() may be stale for
314
+ # remote files that are periodically updated/overwritten (ex:
315
+ # checkpoint_last.pt) - so we remove the local copy, sync across processes
316
+ # (if needed), and then download a fresh copy.
317
+ if local_path != path and PathManager.path_requires_pathmanager(path):
318
+ try:
319
+ os.remove(local_path)
320
+ except FileNotFoundError:
321
+ # With potentially multiple processes removing the same file, the
322
+ # file being missing is benign (missing_ok isn't available until
323
+ # Python 3.8).
324
+ pass
325
+ if load_on_all_ranks:
326
+ torch.distributed.barrier()
327
+ local_path = PathManager.get_local_path(path)
328
+
329
+ with open(local_path, "rb") as f:
330
+ state = torch.load(f, map_location=torch.device("cpu"))
331
+
332
+ if "args" in state and state["args"] is not None and arg_overrides is not None:
333
+ args = state["args"]
334
+ for arg_name, arg_val in arg_overrides.items():
335
+ setattr(args, arg_name, arg_val)
336
+
337
+ if "cfg" in state and state["cfg"] is not None:
338
+
339
+ # hack to be able to set Namespace in dict config. this should be removed when we update to newer
340
+ # omegaconf version that supports object flags, or when we migrate all existing models
341
+ from omegaconf import _utils
342
+
343
+ old_primitive = _utils.is_primitive_type
344
+ _utils.is_primitive_type = lambda _: True
345
+
346
+ state["cfg"] = OmegaConf.create(state["cfg"])
347
+
348
+ _utils.is_primitive_type = old_primitive
349
+ OmegaConf.set_struct(state["cfg"], True)
350
+
351
+ if arg_overrides is not None:
352
+ overwrite_args_by_name(state["cfg"], arg_overrides)
353
+
354
+ state = _upgrade_state_dict(state)
355
+ return state
356
+
357
+
358
+ def load_model_ensemble(
359
+ filenames,
360
+ arg_overrides: Optional[Dict[str, Any]] = None,
361
+ task=None,
362
+ strict=True,
363
+ suffix="",
364
+ num_shards=1,
365
+ state=None,
366
+ ):
367
+ """Loads an ensemble of models.
368
+
369
+ Args:
370
+ filenames (List[str]): checkpoint files to load
371
+ arg_overrides (Dict[str,Any], optional): override model args that
372
+ were used during model training
373
+ task (fairseq.tasks.FairseqTask, optional): task to use for loading
374
+ """
375
+ assert not (
376
+ strict and num_shards > 1
377
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
378
+ ensemble, args, _task = load_model_ensemble_and_task(
379
+ filenames,
380
+ arg_overrides,
381
+ task,
382
+ strict,
383
+ suffix,
384
+ num_shards,
385
+ state,
386
+ )
387
+ return ensemble, args
388
+
389
+
390
+ def get_maybe_sharded_checkpoint_filename(
391
+ filename: str, suffix: str, shard_idx: int, num_shards: int
392
+ ) -> str:
393
+ orig_filename = filename
394
+ filename = filename.replace(".pt", suffix + ".pt")
395
+ fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
396
+ model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
397
+ if PathManager.exists(fsdp_filename):
398
+ return fsdp_filename
399
+ elif num_shards > 1:
400
+ return model_parallel_filename
401
+ else:
402
+ return filename
403
+
404
+
405
+ def load_model_ensemble_and_task(
406
+ filenames,
407
+ arg_overrides: Optional[Dict[str, Any]] = None,
408
+ task=None,
409
+ strict=True,
410
+ suffix="",
411
+ num_shards=1,
412
+ state=None,
413
+ ):
414
+ assert state is None or len(filenames) == 1
415
+
416
+ from fairseq import tasks
417
+
418
+ assert not (
419
+ strict and num_shards > 1
420
+ ), "Cannot load state dict with strict=True and checkpoint shards > 1"
421
+ ensemble = []
422
+ cfg = None
423
+ for filename in filenames:
424
+ orig_filename = filename
425
+ model_shard_state = {"shard_weights": [], "shard_metadata": []}
426
+ assert num_shards > 0
427
+ st = time.time()
428
+ for shard_idx in range(num_shards):
429
+ filename = get_maybe_sharded_checkpoint_filename(
430
+ orig_filename, suffix, shard_idx, num_shards
431
+ )
432
+
433
+ if not PathManager.exists(filename):
434
+ raise IOError("Model file not found: {}".format(filename))
435
+ if state is None:
436
+ state = load_checkpoint_to_cpu(filename, arg_overrides)
437
+ if "args" in state and state["args"] is not None:
438
+ cfg = convert_namespace_to_omegaconf(state["args"])
439
+ elif "cfg" in state and state["cfg"] is not None:
440
+ cfg = state["cfg"]
441
+ else:
442
+ raise RuntimeError(
443
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
444
+ )
445
+
446
+ if task is None:
447
+ task = tasks.setup_task(cfg.task)
448
+
449
+ if "task_state" in state:
450
+ task.load_state_dict(state["task_state"])
451
+
452
+ if "fsdp_metadata" in state and num_shards > 1:
453
+ model_shard_state["shard_weights"].append(state["model"])
454
+ model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
455
+ # check FSDP import before the code goes too far
456
+ if not has_FSDP:
457
+ raise ImportError(
458
+ "Cannot find FullyShardedDataParallel. "
459
+ "Please install fairscale with: pip install fairscale"
460
+ )
461
+ if shard_idx == num_shards - 1:
462
+ consolidated_model_state = FSDP.consolidate_shard_weights(
463
+ shard_weights=model_shard_state["shard_weights"],
464
+ shard_metadata=model_shard_state["shard_metadata"],
465
+ )
466
+ model = task.build_model(cfg.model)
467
+ model.load_state_dict(
468
+ consolidated_model_state, strict=strict, model_cfg=cfg.model
469
+ )
470
+ else:
471
+ # model parallel checkpoint or unsharded checkpoint
472
+ model = task.build_model(cfg.model)
473
+ model.load_state_dict(
474
+ state["model"], strict=strict, model_cfg=cfg.model
475
+ )
476
+
477
+ # reset state so it gets loaded for the next model in ensemble
478
+ state = None
479
+ if shard_idx % 10 == 0 and shard_idx > 0:
480
+ elapsed = time.time() - st
481
+ logger.info(
482
+ f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
483
+ )
484
+
485
+ # build model for ensemble
486
+ ensemble.append(model)
487
+ return ensemble, cfg, task
488
+
489
+
490
+ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
491
+ """Retrieves all checkpoints found in `path` directory.
492
+
493
+ Checkpoints are identified by matching filename to the specified pattern. If
494
+ the pattern contains groups, the result will be sorted by the first group in
495
+ descending order.
496
+ """
497
+ pt_regexp = re.compile(pattern)
498
+ files = PathManager.ls(path)
499
+
500
+ entries = []
501
+ for i, f in enumerate(files):
502
+ m = pt_regexp.fullmatch(f)
503
+ if m is not None:
504
+ idx = float(m.group(1)) if len(m.groups()) > 0 else i
505
+ entries.append((idx, m.group(0)))
506
+ if keep_match:
507
+ return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
508
+ else:
509
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
510
+
511
+
512
+ def torch_persistent_save(obj, filename, async_write: bool = False):
513
+ if async_write:
514
+ with PathManager.opena(filename, "wb") as f:
515
+ _torch_persistent_save(obj, f)
516
+ else:
517
+ with PathManager.open(filename, "wb") as f:
518
+ _torch_persistent_save(obj, f)
519
+ # if PathManager.supports_rename(filename):
520
+ # # do atomic save
521
+ # with PathManager.open(filename + ".tmp", "wb") as f:
522
+ # _torch_persistent_save(obj, f)
523
+ # PathManager.rename(filename + ".tmp", filename)
524
+ # else:
525
+ # # fallback to non-atomic save
526
+ # with PathManager.open(filename, "wb") as f:
527
+ # _torch_persistent_save(obj, f)
528
+
529
+
530
+ def _torch_persistent_save(obj, f):
531
+ if isinstance(f, str):
532
+ with PathManager.open(f, "wb") as h:
533
+ torch_persistent_save(obj, h)
534
+ return
535
+ for i in range(3):
536
+ try:
537
+ return torch.save(obj, f)
538
+ except Exception:
539
+ if i == 2:
540
+ logger.error(traceback.format_exc())
541
+ raise
542
+
543
+
544
+ def _upgrade_state_dict(state):
545
+ """Helper for upgrading old model checkpoints."""
546
+
547
+ # add optimizer_history
548
+ if "optimizer_history" not in state:
549
+ state["optimizer_history"] = [
550
+ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
551
+ ]
552
+ state["last_optimizer_state"] = state["optimizer"]
553
+ del state["optimizer"]
554
+ del state["best_loss"]
555
+ # move extra_state into sub-dictionary
556
+ if "epoch" in state and "extra_state" not in state:
557
+ state["extra_state"] = {
558
+ "epoch": state["epoch"],
559
+ "batch_offset": state["batch_offset"],
560
+ "val_loss": state["val_loss"],
561
+ }
562
+ del state["epoch"]
563
+ del state["batch_offset"]
564
+ del state["val_loss"]
565
+ # reduce optimizer history's memory usage (only keep the last state)
566
+ if "optimizer" in state["optimizer_history"][-1]:
567
+ state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
568
+ for optim_hist in state["optimizer_history"]:
569
+ del optim_hist["optimizer"]
570
+ # record the optimizer class name
571
+ if "optimizer_name" not in state["optimizer_history"][-1]:
572
+ state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
573
+ # move best_loss into lr_scheduler_state
574
+ if "lr_scheduler_state" not in state["optimizer_history"][-1]:
575
+ state["optimizer_history"][-1]["lr_scheduler_state"] = {
576
+ "best": state["optimizer_history"][-1]["best_loss"]
577
+ }
578
+ del state["optimizer_history"][-1]["best_loss"]
579
+ # keep track of number of updates
580
+ if "num_updates" not in state["optimizer_history"][-1]:
581
+ state["optimizer_history"][-1]["num_updates"] = 0
582
+ # old model checkpoints may not have separate source/target positions
583
+ if (
584
+ "args" in state
585
+ and hasattr(state["args"], "max_positions")
586
+ and not hasattr(state["args"], "max_source_positions")
587
+ ):
588
+ state["args"].max_source_positions = state["args"].max_positions
589
+ state["args"].max_target_positions = state["args"].max_positions
590
+ # use stateful training data iterator
591
+ if "train_iterator" not in state["extra_state"]:
592
+ state["extra_state"]["train_iterator"] = {
593
+ "epoch": state["extra_state"]["epoch"],
594
+ "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
595
+ }
596
+
597
+ # backward compatibility, cfg updates
598
+ if "args" in state and state["args"] is not None:
599
+ # default to translation task
600
+ if not hasattr(state["args"], "task"):
601
+ state["args"].task = "translation"
602
+ # --raw-text and --lazy-load are deprecated
603
+ if getattr(state["args"], "raw_text", False):
604
+ state["args"].dataset_impl = "raw"
605
+ elif getattr(state["args"], "lazy_load", False):
606
+ state["args"].dataset_impl = "lazy"
607
+ # epochs start at 1
608
+ if state["extra_state"]["train_iterator"] is not None:
609
+ state["extra_state"]["train_iterator"]["epoch"] = max(
610
+ state["extra_state"]["train_iterator"].get("epoch", 1), 1
611
+ )
612
+ # --remove-bpe ==> --postprocess
613
+ if hasattr(state["args"], "remove_bpe"):
614
+ state["args"].post_process = state["args"].remove_bpe
615
+ # --min-lr ==> --stop-min-lr
616
+ if hasattr(state["args"], "min_lr"):
617
+ state["args"].stop_min_lr = state["args"].min_lr
618
+ del state["args"].min_lr
619
+ # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
620
+ if (
621
+ hasattr(state["args"], "criterion")
622
+ and state["args"].criterion in [
623
+ "binary_cross_entropy",
624
+ "kd_binary_cross_entropy",
625
+ ]
626
+ ):
627
+ state["args"].criterion = "wav2vec"
628
+ # remove log_keys if it's None (criteria will supply a default value of [])
629
+ if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
630
+ delattr(state["args"], "log_keys")
631
+ # speech_pretraining => audio pretraining
632
+ if (
633
+ hasattr(state["args"], "task")
634
+ and state["args"].task == "speech_pretraining"
635
+ ):
636
+ state["args"].task = "audio_pretraining"
637
+ # audio_cpc => wav2vec
638
+ if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
639
+ state["args"].arch = "wav2vec"
640
+ # convert legacy float learning rate to List[float]
641
+ if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
642
+ state["args"].lr = [state["args"].lr]
643
+ # convert task data arg to a string instead of List[string]
644
+ if (
645
+ hasattr(state["args"], "data")
646
+ and isinstance(state["args"].data, list)
647
+ and len(state["args"].data) > 0
648
+ ):
649
+ state["args"].data = state["args"].data[0]
650
+ # remove keys in state["args"] related to teacher-student learning
651
+ for key in [
652
+ "static_teachers",
653
+ "static_teacher_weights",
654
+ "dynamic_teachers",
655
+ "dynamic_teacher_weights",
656
+ ]:
657
+ if key in state["args"]:
658
+ delattr(state["args"], key)
659
+
660
+ state["cfg"] = convert_namespace_to_omegaconf(state["args"])
661
+
662
+ if "cfg" in state and state["cfg"] is not None:
663
+ cfg = state["cfg"]
664
+ with open_dict(cfg):
665
+ # any upgrades for Hydra-based configs
666
+ if (
667
+ "task" in cfg
668
+ and "eval_wer_config" in cfg.task
669
+ and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
670
+ ):
671
+ cfg.task.eval_wer_config.print_alignment = "hard"
672
+ if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
673
+ cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None
674
+ if (
675
+ "model" in cfg
676
+ and "w2v_args" in cfg.model
677
+ and cfg.model.w2v_args is not None
678
+ and (
679
+ hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
680
+ )
681
+ and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
682
+ and cfg.model.w2v_args.task.eval_wer_config is not None
683
+ and isinstance(
684
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
685
+ )
686
+ ):
687
+ cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
688
+
689
+ return state
690
+
691
+
692
+ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
693
+ """Prune the given state_dict if desired for LayerDrop
694
+ (https://arxiv.org/abs/1909.11556).
695
+
696
+ Training with LayerDrop allows models to be robust to pruning at inference
697
+ time. This function prunes state_dict to allow smaller models to be loaded
698
+ from a larger model and re-maps the existing state_dict for this to occur.
699
+
700
+ It's called by functions that load models from checkpoints and does not
701
+ need to be called directly.
702
+ """
703
+ arch = None
704
+ if model_cfg is not None:
705
+ arch = (
706
+ model_cfg._name
707
+ if isinstance(model_cfg, DictConfig)
708
+ else getattr(model_cfg, "arch", None)
709
+ )
710
+
711
+ if not model_cfg or arch is None or arch == "ptt_transformer":
712
+ # args should not be none, but don't crash if it is.
713
+ return state_dict
714
+
715
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
716
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
717
+
718
+ if not encoder_layers_to_keep and not decoder_layers_to_keep:
719
+ return state_dict
720
+
721
+ # apply pruning
722
+ logger.info(
723
+ "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
724
+ )
725
+
726
+ def create_pruning_pass(layers_to_keep, layer_name):
727
+ keep_layers = sorted(
728
+ int(layer_string) for layer_string in layers_to_keep.split(",")
729
+ )
730
+ mapping_dict = {}
731
+ for i in range(len(keep_layers)):
732
+ mapping_dict[str(keep_layers[i])] = str(i)
733
+
734
+ regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
735
+ return {"substitution_regex": regex, "mapping_dict": mapping_dict}
736
+
737
+ pruning_passes = []
738
+ if encoder_layers_to_keep:
739
+ pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
740
+ if decoder_layers_to_keep:
741
+ pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
742
+
743
+ new_state_dict = {}
744
+ for layer_name in state_dict.keys():
745
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
746
+ # if layer has no number in it, it is a supporting layer, such as an
747
+ # embedding
748
+ if not match:
749
+ new_state_dict[layer_name] = state_dict[layer_name]
750
+ continue
751
+
752
+ # otherwise, layer should be pruned.
753
+ original_layer_number = match.group(1)
754
+ # figure out which mapping dict to replace from
755
+ for pruning_pass in pruning_passes:
756
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
757
+ "substitution_regex"
758
+ ].search(layer_name):
759
+ new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
760
+ substitution_match = pruning_pass["substitution_regex"].search(
761
+ layer_name
762
+ )
763
+ new_state_key = (
764
+ layer_name[: substitution_match.start(1)]
765
+ + new_layer_number
766
+ + layer_name[substitution_match.end(1) :]
767
+ )
768
+ new_state_dict[new_state_key] = state_dict[layer_name]
769
+
770
+ # Since layers are now pruned, *_layers_to_keep are no longer needed.
771
+ # This is more of "It would make it work fix" rather than a proper fix.
772
+ if isinstance(model_cfg, DictConfig):
773
+ context = open_dict(model_cfg)
774
+ else:
775
+ context = contextlib.ExitStack()
776
+ with context:
777
+ if hasattr(model_cfg, "encoder_layers_to_keep"):
778
+ model_cfg.encoder_layers_to_keep = None
779
+ if hasattr(model_cfg, "decoder_layers_to_keep"):
780
+ model_cfg.decoder_layers_to_keep = None
781
+
782
+ return new_state_dict
783
+
784
+
785
+ def load_pretrained_component_from_model(
786
+ component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
787
+ ):
788
+ """
789
+ Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
790
+ provided `component` object. If state_dict fails to load, there may be a
791
+ mismatch in the architecture of the corresponding `component` found in the
792
+ `checkpoint` file.
793
+ """
794
+ if not PathManager.exists(checkpoint):
795
+ raise IOError("Model file not found: {}".format(checkpoint))
796
+ state = load_checkpoint_to_cpu(checkpoint)
797
+ if isinstance(component, FairseqEncoder):
798
+ component_type = "encoder"
799
+ elif isinstance(component, FairseqDecoder):
800
+ component_type = "decoder"
801
+ else:
802
+ raise ValueError(
803
+ "component to load must be either a FairseqEncoder or "
804
+ "FairseqDecoder. Loading other component types are not supported."
805
+ )
806
+ component_state_dict = OrderedDict()
807
+ for key in state["model"].keys():
808
+ if key.startswith(component_type):
809
+ # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
810
+ component_subkey = key[len(component_type) + 1 :]
811
+ component_state_dict[component_subkey] = state["model"][key]
812
+ component.load_state_dict(component_state_dict, strict=True)
813
+ return component
814
+
815
+
816
+ def verify_checkpoint_directory(save_dir: str) -> None:
817
+ if not os.path.exists(save_dir):
818
+ os.makedirs(save_dir, exist_ok=True)
819
+ temp_file_path = os.path.join(save_dir, "dummy")
820
+ try:
821
+ with open(temp_file_path, "w"):
822
+ pass
823
+ except OSError as e:
824
+ logger.warning(
825
+ "Unable to access checkpoint save directory: {}".format(save_dir)
826
+ )
827
+ raise e
828
+ else:
829
+ os.remove(temp_file_path)
830
+
831
+
832
+ def load_ema_from_checkpoint(fpath):
833
+ """Loads exponential moving averaged (EMA) checkpoint from input and
834
+ returns a model with ema weights.
835
+
836
+ Args:
837
+ fpath: A string path of checkpoint to load from.
838
+
839
+ Returns:
840
+ A dict of string keys mapping to various values. The 'model' key
841
+ from the returned dict should correspond to an OrderedDict mapping
842
+ string parameter names to torch Tensors.
843
+ """
844
+ params_dict = collections.OrderedDict()
845
+ new_state = None
846
+
847
+ with PathManager.open(fpath, 'rb') as f:
848
+ new_state = torch.load(
849
+ f,
850
+ map_location=(
851
+ lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
852
+ ),
853
+ )
854
+
855
+ # EMA model is stored in a separate "extra state"
856
+ model_params = new_state['extra_state']['ema']
857
+
858
+ for key in list(model_params.keys()):
859
+ p = model_params[key]
860
+ if isinstance(p, torch.HalfTensor):
861
+ p = p.float()
862
+ if key not in params_dict:
863
+ params_dict[key] = p.clone()
864
+ # NOTE: clone() is needed in case of p is a shared parameter
865
+ else:
866
+ raise ValueError("Key {} is repeated in EMA model params.".format(key))
867
+
868
+ if len(params_dict) == 0:
869
+ raise ValueError(
870
+ f"Input checkpoint path '{fpath}' does not contain "
871
+ "ema model weights, is this model trained with EMA?"
872
+ )
873
+
874
+ new_state['model'] = params_dict
875
+ return new_state
utils/cider/pyciderevalcap/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ __author__ = 'tylin'
utils/cider/pyciderevalcap/cider/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ __author__ = 'tylin'