yangapku commited on
Commit
0d735a2
1 Parent(s): 4068994

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +4 -4
  3. airship.jpg +0 -0
  4. app.py +153 -0
  5. checkpoints.md +13 -0
  6. colab.md +8 -0
  7. criterions/__init__.py +2 -0
  8. criterions/label_smoothed_cross_entropy.py +343 -0
  9. criterions/scst_loss.py +280 -0
  10. data/__init__.py +0 -0
  11. data/data_utils.py +601 -0
  12. data/file_dataset.py +102 -0
  13. data/mm_data/__init__.py +0 -0
  14. data/mm_data/caption_dataset.py +154 -0
  15. data/mm_data/refcoco_dataset.py +168 -0
  16. data/mm_data/vqa_gen_dataset.py +211 -0
  17. data/ofa_dataset.py +74 -0
  18. datasets.md +10 -0
  19. evaluate.py +156 -0
  20. fairseq/.github/ISSUE_TEMPLATE.md +3 -0
  21. fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  22. fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  23. fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  24. fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  25. fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
  26. fairseq/.github/stale.yml +30 -0
  27. fairseq/.github/workflows/build.yml +55 -0
  28. fairseq/.github/workflows/build_wheels.yml +41 -0
  29. fairseq/.gitignore +136 -0
  30. fairseq/.gitmodules +4 -0
  31. fairseq/CODE_OF_CONDUCT.md +77 -0
  32. fairseq/CONTRIBUTING.md +28 -0
  33. fairseq/LICENSE +21 -0
  34. fairseq/README.md +229 -0
  35. fairseq/docs/Makefile +20 -0
  36. fairseq/docs/_static/theme_overrides.css +9 -0
  37. fairseq/docs/command_line_tools.rst +85 -0
  38. fairseq/docs/conf.py +134 -0
  39. fairseq/docs/criterions.rst +31 -0
  40. fairseq/docs/data.rst +58 -0
  41. fairseq/docs/docutils.conf +2 -0
  42. fairseq/docs/fairseq_logo.png +0 -0
  43. fairseq/docs/getting_started.rst +216 -0
  44. fairseq/docs/hydra_integration.md +284 -0
  45. fairseq/docs/index.rst +49 -0
  46. fairseq/docs/lr_scheduler.rst +34 -0
  47. fairseq/docs/make.bat +36 -0
  48. fairseq/docs/models.rst +104 -0
  49. fairseq/docs/modules.rst +9 -0
  50. fairseq/docs/optim.rst +38 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 1999-2022 Alibaba Group Holding Ltd.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: OFA Open_Domain_VQA
3
- emoji: 🔥
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
 
1
  ---
2
+ title: OFA-Open_Domain_VQA
3
+ emoji: 💩
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
airship.jpg ADDED
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;'
4
+ 'pip install --use-feature=in-tree-build ./; cd ..')
5
+ os.system('ls -l')
6
+
7
+ import torch
8
+ import numpy as np
9
+ import re
10
+ from fairseq import utils,tasks
11
+ from fairseq import checkpoint_utils
12
+ from fairseq import distributed_utils, options, tasks, utils
13
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
14
+ from utils.zero_shot_utils import zero_shot_step
15
+ from tasks.mm_tasks.vqa_gen import VqaGenTask
16
+ from models.ofa import OFAModel
17
+ from PIL import Image
18
+ from torchvision import transforms
19
+ import gradio as gr
20
+
21
+ # Register VQA task
22
+ tasks.register_task('vqa_gen',VqaGenTask)
23
+ # turn on cuda if GPU is available
24
+ use_cuda = torch.cuda.is_available()
25
+ # use fp16 only when GPU is available
26
+ use_fp16 = False
27
+
28
+ os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/ofa_large_384.pt; '
29
+ 'mkdir -p checkpoints; mv ofa_large_384.pt checkpoints/ofa_large_384.pt')
30
+
31
+ # specify some options for evaluation
32
+ parser = options.get_generation_parser()
33
+ input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path=checkpoints/ofa_large_384.pt", "--bpe-dir=utils/BPE"]
34
+ args = options.parse_args_and_arch(parser, input_args)
35
+ cfg = convert_namespace_to_omegaconf(args)
36
+
37
+ # Load pretrained ckpt & config
38
+ task = tasks.setup_task(cfg.task)
39
+ models, cfg = checkpoint_utils.load_model_ensemble(
40
+ utils.split_paths(cfg.common_eval.path),
41
+ task=task
42
+ )
43
+
44
+ # Move models to GPU
45
+ for model in models:
46
+ model.eval()
47
+ if use_fp16:
48
+ model.half()
49
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
50
+ model.cuda()
51
+ model.prepare_for_inference_(cfg)
52
+
53
+ # Initialize generator
54
+ generator = task.build_generator(models, cfg.generation)
55
+
56
+ # Image transform
57
+ from torchvision import transforms
58
+ mean = [0.5, 0.5, 0.5]
59
+ std = [0.5, 0.5, 0.5]
60
+
61
+ patch_resize_transform = transforms.Compose([
62
+ lambda image: image.convert("RGB"),
63
+ transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=mean, std=std),
66
+ ])
67
+
68
+ # Text preprocess
69
+ bos_item = torch.LongTensor([task.src_dict.bos()])
70
+ eos_item = torch.LongTensor([task.src_dict.eos()])
71
+ pad_idx = task.src_dict.pad()
72
+
73
+ # Normalize the question
74
+ def pre_question(question, max_ques_words):
75
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
76
+ question = re.sub(
77
+ r"\s{2,}",
78
+ ' ',
79
+ question,
80
+ )
81
+ question = question.rstrip('\n')
82
+ question = question.strip(' ')
83
+ # truncate question
84
+ question_words = question.split(' ')
85
+ if len(question_words) > max_ques_words:
86
+ question = ' '.join(question_words[:max_ques_words])
87
+ return question
88
+
89
+ def encode_text(text, length=None, append_bos=False, append_eos=False):
90
+ s = task.tgt_dict.encode_line(
91
+ line=task.bpe.encode(text),
92
+ add_if_not_exist=False,
93
+ append_eos=False
94
+ ).long()
95
+ if length is not None:
96
+ s = s[:length]
97
+ if append_bos:
98
+ s = torch.cat([bos_item, s])
99
+ if append_eos:
100
+ s = torch.cat([s, eos_item])
101
+ return s
102
+
103
+ # Construct input for open-domain VQA task
104
+ def construct_sample(image: Image, question: str):
105
+ patch_image = patch_resize_transform(image).unsqueeze(0)
106
+ patch_mask = torch.tensor([True])
107
+
108
+ question = pre_question(question, task.cfg.max_src_length)
109
+ question = question + '?' if not question.endswith('?') else question
110
+ src_text = encode_text(' {}'.format(question), append_bos=True, append_eos=True).unsqueeze(0)
111
+
112
+ src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
113
+ ref_dict = np.array([{'yes': 1.0}]) # just placeholder
114
+ sample = {
115
+ "id":np.array(['42']),
116
+ "net_input": {
117
+ "src_tokens": src_text,
118
+ "src_lengths": src_length,
119
+ "patch_images": patch_image,
120
+ "patch_masks": patch_mask,
121
+ },
122
+ "ref_dict": ref_dict,
123
+ }
124
+ return sample
125
+
126
+ # Function to turn FP32 to FP16
127
+ def apply_half(t):
128
+ if t.dtype is torch.float32:
129
+ return t.to(dtype=torch.half)
130
+ return t
131
+
132
+
133
+ # Function for image captioning
134
+ def open_domain_vqa(Image, Question):
135
+ sample = construct_sample(Image, Question)
136
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
137
+ sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
138
+ # Run eval step for open-domain VQA
139
+ with torch.no_grad():
140
+ result, scores = zero_shot_step(task, generator, models, sample)
141
+ return result[0]['answer']
142
+
143
+
144
+ title = "OFA-Open_Domain_VQA"
145
+ description = "Gradio Demo for OFA-Open_Domain_VQA. Upload your own image or click any one of the examples, and click " \
146
+ "\"Submit\" and then wait for OFA's answer. "
147
+ article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
148
+ "Repo</a></p> "
149
+ examples = [['money_tree.png', 'what is grown on the plant?'], ['airship.jpg', 'what does the red-roofed building right to the big airship look like?'], ['sitting_man.png', 'what is the man sitting on?']]
150
+ io = gr.Interface(fn=open_domain_vqa, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"),
151
+ title=title, description=description, article=article, examples=examples,
152
+ allow_flagging=False, allow_screenshot=False)
153
+ io.launch(cache_examples=True)
checkpoints.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Checkpoints
2
+
3
+ We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks.
4
+
5
+ ## Pretraining
6
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a>
7
+
8
+ ## Finetuning
9
+
10
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
11
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
12
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
13
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
colab.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Colab Notebooks
2
+
3
+ We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below.
4
+
5
+ * Image Captioning: [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
6
+ * Referring Expression Comprehension: [![][colab]](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
7
+
8
+ [colab]: <https://colab.research.google.com/assets/colab-badge.svg>
criterions/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scst_loss import ScstRewardCriterion
2
+ from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
criterions/label_smoothed_cross_entropy.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+
19
+ @dataclass
20
+ class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
21
+ label_smoothing: float = field(
22
+ default=0.0,
23
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
24
+ )
25
+ report_accuracy: bool = field(
26
+ default=False,
27
+ metadata={"help": "report accuracy metric"},
28
+ )
29
+ ignore_prefix_size: int = field(
30
+ default=0,
31
+ metadata={"help": "Ignore first N tokens"},
32
+ )
33
+ ignore_eos: bool = field(
34
+ default=False,
35
+ metadata={"help": "Ignore eos token"},
36
+ )
37
+ sentence_avg: bool = II("optimization.sentence_avg")
38
+ drop_worst_ratio: float = field(
39
+ default=0.0,
40
+ metadata={"help": "ratio for discarding bad samples"},
41
+ )
42
+ drop_worst_after: int = field(
43
+ default=0,
44
+ metadata={"help": "steps for discarding bad samples"},
45
+ )
46
+ use_rdrop: bool = field(
47
+ default=False, metadata={"help": "use R-Drop"}
48
+ )
49
+ reg_alpha: float = field(
50
+ default=1.0, metadata={"help": "weight for R-Drop"}
51
+ )
52
+ sample_patch_num: int = field(
53
+ default=196, metadata={"help": "sample patchs for v1"}
54
+ )
55
+ constraint_range: Optional[str] = field(
56
+ default=None,
57
+ metadata={"help": "constraint range"}
58
+ )
59
+
60
+
61
+ def construct_rdrop_sample(x):
62
+ if isinstance(x, dict):
63
+ for key in x:
64
+ x[key] = construct_rdrop_sample(x[key])
65
+ return x
66
+ elif isinstance(x, torch.Tensor):
67
+ return x.repeat(2, *([1] * (x.dim()-1)))
68
+ elif isinstance(x, int):
69
+ return x * 2
70
+ elif isinstance(x, np.ndarray):
71
+ return x.repeat(2)
72
+ else:
73
+ raise NotImplementedError
74
+
75
+
76
+ def kl_loss(p, q):
77
+ p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
78
+ q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
79
+ loss = (p_loss + q_loss) / 2
80
+ return loss
81
+
82
+
83
+ def label_smoothed_nll_loss(
84
+ lprobs, target, epsilon, update_num, reduce=True,
85
+ drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
86
+ constraint_masks=None, constraint_start=None, constraint_end=None
87
+ ):
88
+ if target.dim() == lprobs.dim() - 1:
89
+ target = target.unsqueeze(-1)
90
+ nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
91
+ if constraint_masks is not None:
92
+ smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
93
+ eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
94
+ elif constraint_start is not None and constraint_end is not None:
95
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
96
+ smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
97
+ eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
98
+ else:
99
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
100
+ eps_i = epsilon / (lprobs.size(-1) - 1)
101
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
102
+ if drop_worst_ratio > 0 and update_num > drop_worst_after:
103
+ if use_rdrop:
104
+ true_batch_size = loss.size(0) // 2
105
+ _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
106
+ loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
107
+ nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
108
+ lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
109
+ else:
110
+ loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
111
+ nll_loss = nll_loss[indices]
112
+ lprobs = lprobs[indices]
113
+
114
+ ntokens = loss.numel()
115
+ nll_loss = nll_loss.sum()
116
+ loss = loss.sum()
117
+ if use_rdrop:
118
+ true_batch_size = lprobs.size(0) // 2
119
+ p = lprobs[:true_batch_size]
120
+ q = lprobs[true_batch_size:]
121
+ if constraint_start is not None and constraint_end is not None:
122
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
123
+ p = p[:, constraint_range]
124
+ q = q[:, constraint_range]
125
+ loss += kl_loss(p, q) * reg_alpha
126
+
127
+ return loss, nll_loss, ntokens
128
+
129
+
130
+ @register_criterion(
131
+ "ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
132
+ )
133
+ class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
134
+ def __init__(
135
+ self,
136
+ task,
137
+ sentence_avg,
138
+ label_smoothing,
139
+ ignore_prefix_size=0,
140
+ ignore_eos=False,
141
+ report_accuracy=False,
142
+ drop_worst_ratio=0,
143
+ drop_worst_after=0,
144
+ use_rdrop=False,
145
+ reg_alpha=1.0,
146
+ sample_patch_num=196,
147
+ constraint_range=None
148
+ ):
149
+ super().__init__(task)
150
+ self.sentence_avg = sentence_avg
151
+ self.eps = label_smoothing
152
+ self.ignore_prefix_size = ignore_prefix_size
153
+ self.ignore_eos = ignore_eos
154
+ self.report_accuracy = report_accuracy
155
+ self.drop_worst_ratio = drop_worst_ratio
156
+ self.drop_worst_after = drop_worst_after
157
+ self.use_rdrop = use_rdrop
158
+ self.reg_alpha = reg_alpha
159
+ self.sample_patch_num = sample_patch_num
160
+
161
+ self.constraint_start = None
162
+ self.constraint_end = None
163
+ if constraint_range is not None:
164
+ constraint_start, constraint_end = constraint_range.split(',')
165
+ self.constraint_start = int(constraint_start)
166
+ self.constraint_end = int(constraint_end)
167
+
168
+ def forward(self, model, sample, update_num=0, reduce=True):
169
+ """Compute the loss for the given sample.
170
+
171
+ Returns a tuple with three elements:
172
+ 1) the loss
173
+ 2) the sample size, which is used as the denominator for the gradient
174
+ 3) logging outputs to display while training
175
+ """
176
+ if isinstance(sample, list):
177
+ if self.sample_patch_num > 0:
178
+ sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
179
+ loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
180
+ loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
181
+ loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
182
+ sample_size = 1
183
+ logging_output = {
184
+ "loss": loss.data,
185
+ "loss_v1": loss_v1.data,
186
+ "loss_v2": loss_v2.data,
187
+ "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
188
+ "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
189
+ "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
190
+ "sample_size": 1,
191
+ "sample_size_v1": sample_size_v1,
192
+ "sample_size_v2": sample_size_v2,
193
+ }
194
+ return loss, sample_size, logging_output
195
+
196
+ if self.use_rdrop:
197
+ construct_rdrop_sample(sample)
198
+
199
+ net_output = model(**sample["net_input"])
200
+ loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
201
+ sample_size = (
202
+ sample["target"].size(0) if self.sentence_avg else ntokens
203
+ )
204
+ logging_output = {
205
+ "loss": loss.data,
206
+ "nll_loss": nll_loss.data,
207
+ "ntokens": sample["ntokens"],
208
+ "nsentences": sample["nsentences"],
209
+ "sample_size": sample_size,
210
+ }
211
+ if self.report_accuracy:
212
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
213
+ logging_output["n_correct"] = utils.item(n_correct.data)
214
+ logging_output["total"] = utils.item(total.data)
215
+ return loss, sample_size, logging_output
216
+
217
+ def get_lprobs_and_target(self, model, net_output, sample):
218
+ conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
219
+ constraint_masks = None
220
+ if "constraint_masks" in sample and sample["constraint_masks"] is not None:
221
+ constraint_masks = sample["constraint_masks"]
222
+ net_output[0].masked_fill_(~constraint_masks, -math.inf)
223
+ if self.constraint_start is not None and self.constraint_end is not None:
224
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
225
+ net_output[0][:, :, self.constraint_end:] = -math.inf
226
+ lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
227
+ target = model.get_targets(sample, net_output)
228
+ if self.ignore_prefix_size > 0:
229
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
230
+ target = target[:, self.ignore_prefix_size :].contiguous()
231
+ if constraint_masks is not None:
232
+ constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
233
+ if self.ignore_eos:
234
+ bsz, seq_len, embed_dim = lprobs.size()
235
+ eos_indices = target.eq(self.task.tgt_dict.eos())
236
+ lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
237
+ target = target[~eos_indices].reshape(bsz, seq_len-1)
238
+ if constraint_masks is not None:
239
+ constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
240
+ if constraint_masks is not None:
241
+ constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
242
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
243
+
244
+ def compute_loss(self, model, net_output, sample, update_num, reduce=True):
245
+ lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
246
+ if constraint_masks is not None:
247
+ constraint_masks = constraint_masks[target != self.padding_idx]
248
+ lprobs = lprobs[target != self.padding_idx]
249
+ target = target[target != self.padding_idx]
250
+ loss, nll_loss, ntokens = label_smoothed_nll_loss(
251
+ lprobs,
252
+ target,
253
+ self.eps,
254
+ update_num,
255
+ reduce=reduce,
256
+ drop_worst_ratio=self.drop_worst_ratio,
257
+ drop_worst_after=self.drop_worst_after,
258
+ use_rdrop=self.use_rdrop,
259
+ reg_alpha=self.reg_alpha,
260
+ constraint_masks=constraint_masks,
261
+ constraint_start=self.constraint_start,
262
+ constraint_end=self.constraint_end
263
+ )
264
+ return loss, nll_loss, ntokens
265
+
266
+ def compute_accuracy(self, model, net_output, sample):
267
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
268
+ mask = target.ne(self.padding_idx)
269
+ n_correct = torch.sum(
270
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
271
+ )
272
+ total = torch.sum(mask)
273
+ return n_correct, total
274
+
275
+ @classmethod
276
+ def reduce_metrics(cls, logging_outputs) -> None:
277
+ """Aggregate logging outputs from data parallel training."""
278
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
279
+ loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
280
+ loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
281
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
282
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
283
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
284
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
285
+ sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
286
+ sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
287
+
288
+ metrics.log_scalar(
289
+ "loss", loss_sum / sample_size, sample_size, round=3
290
+ )
291
+ metrics.log_scalar(
292
+ "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
293
+ )
294
+ metrics.log_scalar(
295
+ "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
296
+ )
297
+ metrics.log_scalar(
298
+ "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
299
+ )
300
+ metrics.log_derived(
301
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
302
+ )
303
+
304
+ metrics.log_scalar(
305
+ "ntokens", ntokens, 1, round=3
306
+ )
307
+ metrics.log_scalar(
308
+ "nsentences", nsentences, 1, round=3
309
+ )
310
+ metrics.log_scalar(
311
+ "sample_size", sample_size, 1, round=3
312
+ )
313
+ metrics.log_scalar(
314
+ "sample_size_v1", sample_size_v1, 1, round=3
315
+ )
316
+ metrics.log_scalar(
317
+ "sample_size_v2", sample_size_v2, 1, round=3
318
+ )
319
+
320
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
321
+ if total > 0:
322
+ metrics.log_scalar("total", total)
323
+ n_correct = utils.item(
324
+ sum(log.get("n_correct", 0) for log in logging_outputs)
325
+ )
326
+ metrics.log_scalar("n_correct", n_correct)
327
+ metrics.log_derived(
328
+ "accuracy",
329
+ lambda meters: round(
330
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
331
+ )
332
+ if meters["total"].sum > 0
333
+ else float("nan"),
334
+ )
335
+
336
+ @staticmethod
337
+ def logging_outputs_can_be_summed() -> bool:
338
+ """
339
+ Whether the logging outputs returned by `forward` can be summed
340
+ across workers prior to calling `reduce_metrics`. Setting this
341
+ to True will improves distributed training speed.
342
+ """
343
+ return True
criterions/scst_loss.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import string
8
+ from dataclasses import dataclass, field
9
+ from collections import OrderedDict
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+ from data import data_utils
19
+ from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
20
+
21
+
22
+ def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
23
+ loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
24
+ if ignore_index is not None:
25
+ pad_mask = target.eq(ignore_index)
26
+ loss.masked_fill_(pad_mask, 0.0)
27
+ ntokens = (~pad_mask).sum()
28
+ else:
29
+ loss = loss.squeeze(-1)
30
+ ntokens = target.numel()
31
+ if reduce:
32
+ loss = loss.sum()
33
+ return loss, ntokens
34
+
35
+ @dataclass
36
+ class ScstRewardCriterionConfig(FairseqDataclass):
37
+ scst_cider_cached_tokens: str = field(
38
+ default="coco-train-words.p",
39
+ metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
40
+ )
41
+ ignore_prefix_size: int = field(
42
+ default=0,
43
+ metadata={"help": "Ignore first N tokens"},
44
+ )
45
+ sentence_avg: bool = II("optimization.sentence_avg")
46
+ constraint_range: Optional[str] = field(
47
+ default=None,
48
+ metadata={"help": "constraint range"}
49
+ )
50
+
51
+
52
+ @register_criterion(
53
+ "scst_reward_criterion", dataclass=ScstRewardCriterionConfig
54
+ )
55
+ class ScstRewardCriterion(FairseqCriterion):
56
+ CIDER_REWARD_WEIGHT = 1
57
+
58
+ def __init__(
59
+ self,
60
+ task,
61
+ scst_cider_cached_tokens,
62
+ sentence_avg,
63
+ ignore_prefix_size=0,
64
+ constraint_range=None
65
+ ):
66
+ super().__init__(task)
67
+ self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
68
+ self.sentence_avg = sentence_avg
69
+ self.ignore_prefix_size = ignore_prefix_size
70
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
71
+
72
+ self.constraint_start = None
73
+ self.constraint_end = None
74
+ if constraint_range is not None:
75
+ constraint_start, constraint_end = constraint_range.split(',')
76
+ self.constraint_start = int(constraint_start)
77
+ self.constraint_end = int(constraint_end)
78
+
79
+ def forward(self, model, sample, update_num=0, reduce=True):
80
+ """Compute the loss for the given sample.
81
+
82
+ Returns a tuple with three elements:
83
+ 1) the loss
84
+ 2) the sample size, which is used as the denominator for the gradient
85
+ 3) logging outputs to display while training
86
+ """
87
+ loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
88
+
89
+ sample_size = (
90
+ nsentences if self.sentence_avg else ntokens
91
+ )
92
+ logging_output = {
93
+ "loss": loss.data,
94
+ "score": score,
95
+ "ntokens": ntokens,
96
+ "nsentences": nsentences,
97
+ "sample_size": sample_size,
98
+ }
99
+ return loss, sample_size, logging_output
100
+
101
+ def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
102
+ '''
103
+ gen_res: generated captions, list of str
104
+ gt_idx: list of int, of the same length as gen_res
105
+ gt_res: ground truth captions, list of list of str.
106
+ gen_res[i] corresponds to gt_res[gt_idx[i]]
107
+ Each image can have multiple ground truth captions
108
+ '''
109
+ gen_res_size = len(gen_res)
110
+
111
+ res = OrderedDict()
112
+ for i in range(gen_res_size):
113
+ res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
114
+
115
+ gts = OrderedDict()
116
+ gt_res_ = [
117
+ [self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
118
+ for i in range(len(gt_res))
119
+ ]
120
+ for i in range(gen_res_size):
121
+ gts[i] = gt_res_[gt_idx[i]]
122
+
123
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
124
+ _, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
125
+ scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
126
+ return scores
127
+
128
+ @classmethod
129
+ def _wrap_sentence(self, s):
130
+ # ensure the sentence ends with <eos> token
131
+ # in order to keep consisitent with cider_cached_tokens
132
+ r = s.strip()
133
+ if r.endswith('.'):
134
+ r = r[:-1]
135
+ r += ' <eos>'
136
+ return r
137
+
138
+ def get_generator_out(self, model, sample):
139
+ def decode(toks):
140
+ hypo = toks.int().cpu()
141
+ hypo_str = self.task.tgt_dict.string(hypo)
142
+ hypo_str = self.task.bpe.decode(hypo_str).strip()
143
+ return hypo, hypo_str
144
+
145
+ model.eval()
146
+ with torch.no_grad():
147
+ self.task.scst_generator.model.eval()
148
+ gen_out = self.task.scst_generator.generate([model], sample)
149
+
150
+ gen_target = []
151
+ gen_res = []
152
+ gt_res = []
153
+ for i in range(len(gen_out)):
154
+ for j in range(len(gen_out[i])):
155
+ hypo, hypo_str = decode(gen_out[i][j]["tokens"])
156
+ gen_target.append(hypo)
157
+ gen_res.append(hypo_str)
158
+ gt_res.append(
159
+ decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
160
+ )
161
+
162
+ return gen_target, gen_res, gt_res
163
+
164
+ def get_reward_and_scores(self, gen_res, gt_res, device):
165
+ batch_size = len(gt_res)
166
+ gen_res_size = len(gen_res)
167
+ seq_per_img = gen_res_size // batch_size
168
+
169
+ gt_idx = [i // seq_per_img for i in range(gen_res_size)]
170
+ scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
171
+ sc_ = scores.reshape(batch_size, seq_per_img)
172
+ baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
173
+ # sample - baseline
174
+ reward = scores.reshape(batch_size, seq_per_img)
175
+ reward = reward - baseline
176
+ reward = reward.reshape(gen_res_size)
177
+ reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
178
+
179
+ return reward, scores
180
+
181
+ def get_net_output(self, model, sample, gen_target):
182
+ def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
183
+ return data_utils.collate_tokens(
184
+ sample_list,
185
+ pad_idx=self.padding_idx,
186
+ eos_idx=eos,
187
+ left_pad=False,
188
+ move_eos_to_beginning=move_eos_to_beginning,
189
+ )
190
+
191
+ batch_size = len(sample["target"])
192
+ gen_target_size = len(gen_target)
193
+ seq_per_img = gen_target_size // batch_size
194
+
195
+ model.train()
196
+ sample_src_tokens = torch.repeat_interleave(
197
+ sample['net_input']['src_tokens'], seq_per_img, dim=0
198
+ )
199
+ sample_src_lengths = torch.repeat_interleave(
200
+ sample['net_input']['src_lengths'], seq_per_img, dim=0
201
+ )
202
+ sample_patch_images = torch.repeat_interleave(
203
+ sample['net_input']['patch_images'], seq_per_img, dim=0
204
+ )
205
+ sample_patch_masks = torch.repeat_interleave(
206
+ sample['net_input']['patch_masks'], seq_per_img, dim=0
207
+ )
208
+ gen_prev_output_tokens = torch.as_tensor(
209
+ merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
210
+ device=sample["target"].device, dtype=torch.int64
211
+ )
212
+ gen_target_tokens = torch.as_tensor(
213
+ merge(gen_target), device=sample["target"].device, dtype=torch.int64
214
+ )
215
+ net_output = model(
216
+ src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
217
+ patch_images=sample_patch_images, patch_masks=sample_patch_masks,
218
+ prev_output_tokens=gen_prev_output_tokens
219
+ )
220
+
221
+ return net_output, gen_target_tokens
222
+
223
+ def get_lprobs_and_target(self, model, net_output, gen_target):
224
+ if self.constraint_start is not None and self.constraint_end is not None:
225
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
226
+ net_output[0][:, :, self.constraint_end:] = -math.inf
227
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
228
+ if self.ignore_prefix_size > 0:
229
+ if getattr(lprobs, "batch_first", False):
230
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
231
+ gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
232
+ else:
233
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
234
+ gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
235
+ return lprobs, gen_target
236
+
237
+ def compute_loss(self, model, sample, reduce=True):
238
+ gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
239
+ reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
240
+ net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
241
+ gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
242
+ loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
243
+ nsentences = gen_target_tokens.size(0)
244
+
245
+ return loss, scores.sum(), ntokens, nsentences
246
+
247
+ @classmethod
248
+ def reduce_metrics(cls, logging_outputs) -> None:
249
+ """Aggregate logging outputs from data parallel training."""
250
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
251
+ score_sum = sum(log.get("score", 0) for log in logging_outputs)
252
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
253
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
254
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
255
+
256
+ metrics.log_scalar(
257
+ "loss", loss_sum / sample_size, sample_size, round=3
258
+ )
259
+ metrics.log_scalar(
260
+ "score", score_sum / nsentences, nsentences, round=3
261
+ )
262
+
263
+ metrics.log_scalar(
264
+ "ntokens", ntokens, 1, round=3
265
+ )
266
+ metrics.log_scalar(
267
+ "nsentences", nsentences, 1, round=3
268
+ )
269
+ metrics.log_scalar(
270
+ "sample_size", sample_size, 1, round=3
271
+ )
272
+
273
+ @staticmethod
274
+ def logging_outputs_can_be_summed() -> bool:
275
+ """
276
+ Whether the logging outputs returned by `forward` can be summed
277
+ across workers prior to calling `reduce_metrics`. Setting this
278
+ to True will improves distributed training speed.
279
+ """
280
+ return True
data/__init__.py ADDED
File without changes
data/data_utils.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import re
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from fairseq.file_io import PathManager
21
+ from fairseq import utils
22
+ import os
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def infer_language_pair(path):
28
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
29
+ src, dst = None, None
30
+ for filename in PathManager.ls(path):
31
+ parts = filename.split(".")
32
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
33
+ return parts[1].split("-")
34
+ return src, dst
35
+
36
+
37
+ def collate_tokens(
38
+ values,
39
+ pad_idx,
40
+ eos_idx=None,
41
+ left_pad=False,
42
+ move_eos_to_beginning=False,
43
+ pad_to_length=None,
44
+ pad_to_multiple=1,
45
+ pad_to_bsz=None,
46
+ ):
47
+ """Convert a list of 1d tensors into a padded 2d tensor."""
48
+ size = max(v.size(0) for v in values)
49
+ size = size if pad_to_length is None else max(size, pad_to_length)
50
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
51
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
52
+
53
+ def copy_tensor(src, dst):
54
+ assert dst.numel() == src.numel()
55
+ if move_eos_to_beginning:
56
+ if eos_idx is None:
57
+ # if no eos_idx is specified, then use the last token in src
58
+ dst[0] = src[-1]
59
+ else:
60
+ dst[0] = eos_idx
61
+ dst[1:] = src[:-1]
62
+ else:
63
+ dst.copy_(src)
64
+
65
+ if values[0].dim() == 1:
66
+ res = values[0].new(len(values), size).fill_(pad_idx)
67
+ elif values[0].dim() == 2:
68
+ assert move_eos_to_beginning is False
69
+ res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ for i, v in enumerate(values):
74
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
75
+ return res
76
+
77
+
78
+ def load_indexed_dataset(
79
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
80
+ ):
81
+ """A helper function for loading indexed datasets.
82
+
83
+ Args:
84
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
85
+ dictionary (~fairseq.data.Dictionary): data dictionary
86
+ dataset_impl (str, optional): which dataset implementation to use. If
87
+ not provided, it will be inferred automatically. For legacy indexed
88
+ data we use the 'cached' implementation by default.
89
+ combine (bool, optional): automatically load and combine multiple
90
+ datasets. For example, if *path* is 'data-bin/train', then we will
91
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
92
+ single ConcatDataset instance.
93
+ """
94
+ import fairseq.data.indexed_dataset as indexed_dataset
95
+ from fairseq.data.concat_dataset import ConcatDataset
96
+
97
+ datasets = []
98
+ for k in itertools.count():
99
+ path_k = path + (str(k) if k > 0 else "")
100
+ try:
101
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
102
+ except Exception as e:
103
+ if "StorageException: [404] Path not found" in str(e):
104
+ logger.warning(f"path_k: {e} not found")
105
+ else:
106
+ raise e
107
+
108
+ dataset_impl_k = dataset_impl
109
+ if dataset_impl_k is None:
110
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
111
+ dataset = indexed_dataset.make_dataset(
112
+ path_k,
113
+ impl=dataset_impl_k or default,
114
+ fix_lua_indexing=True,
115
+ dictionary=dictionary,
116
+ )
117
+ if dataset is None:
118
+ break
119
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
120
+ datasets.append(dataset)
121
+ if not combine:
122
+ break
123
+ if len(datasets) == 0:
124
+ return None
125
+ elif len(datasets) == 1:
126
+ return datasets[0]
127
+ else:
128
+ return ConcatDataset(datasets)
129
+
130
+
131
+ @contextlib.contextmanager
132
+ def numpy_seed(seed, *addl_seeds):
133
+ """Context manager which seeds the NumPy PRNG with the specified seed and
134
+ restores the state afterward"""
135
+ if seed is None:
136
+ yield
137
+ return
138
+ if len(addl_seeds) > 0:
139
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
140
+ state = np.random.get_state()
141
+ np.random.seed(seed)
142
+ try:
143
+ yield
144
+ finally:
145
+ np.random.set_state(state)
146
+
147
+
148
+ def collect_filtered(function, iterable, filtered):
149
+ """
150
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
151
+
152
+ Args:
153
+ function (callable): function that returns ``False`` for elements that
154
+ should be filtered
155
+ iterable (iterable): iterable to filter
156
+ filtered (list): list to store filtered elements
157
+ """
158
+ for el in iterable:
159
+ if function(el):
160
+ yield el
161
+ else:
162
+ filtered.append(el)
163
+
164
+
165
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
166
+ def compare_leq(a, b):
167
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
168
+
169
+ def check_size(idx):
170
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
171
+ return size_fn(idx) <= max_positions
172
+ elif isinstance(max_positions, dict):
173
+ idx_size = size_fn(idx)
174
+ assert isinstance(idx_size, dict)
175
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
176
+ return all(
177
+ all(
178
+ a is None or b is None or a <= b
179
+ for a, b in zip(idx_size[key], max_positions[key])
180
+ )
181
+ for key in intersect_keys
182
+ )
183
+ else:
184
+ # For MultiCorpusSampledDataset, will generalize it later
185
+ if not isinstance(size_fn(idx), Iterable):
186
+ return all(size_fn(idx) <= b for b in max_positions)
187
+ return all(
188
+ a is None or b is None or a <= b
189
+ for a, b in zip(size_fn(idx), max_positions)
190
+ )
191
+
192
+ ignored = []
193
+ itr = collect_filtered(check_size, indices, ignored)
194
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
195
+ return indices, ignored
196
+
197
+
198
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
199
+ """
200
+ [deprecated] Filter indices based on their size.
201
+ Use `FairseqDataset::filter_indices_by_size` instead.
202
+
203
+ Args:
204
+ indices (List[int]): ordered list of dataset indices
205
+ dataset (FairseqDataset): fairseq dataset instance
206
+ max_positions (tuple): filter elements larger than this size.
207
+ Comparisons are done component-wise.
208
+ raise_exception (bool, optional): if ``True``, raise an exception if
209
+ any elements are filtered (default: False).
210
+ """
211
+ warnings.warn(
212
+ "data_utils.filter_by_size is deprecated. "
213
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
214
+ stacklevel=2,
215
+ )
216
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
217
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
218
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
219
+ indices = indices[dataset.sizes[indices] <= max_positions]
220
+ elif (
221
+ hasattr(dataset, "sizes")
222
+ and isinstance(dataset.sizes, list)
223
+ and len(dataset.sizes) == 1
224
+ ):
225
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
226
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
227
+ else:
228
+ indices, ignored = _filter_by_size_dynamic(
229
+ indices, dataset.size, max_positions
230
+ )
231
+ else:
232
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
233
+
234
+ if len(ignored) > 0 and raise_exception:
235
+ raise Exception(
236
+ (
237
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
238
+ "skip this example with --skip-invalid-size-inputs-valid-test"
239
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
240
+ )
241
+ if len(ignored) > 0:
242
+ logger.warning(
243
+ (
244
+ "{} samples have invalid sizes and will be skipped, "
245
+ "max_positions={}, first few sample ids={}"
246
+ ).format(len(ignored), max_positions, ignored[:10])
247
+ )
248
+ return indices
249
+
250
+
251
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
252
+ """Filter a list of sample indices. Remove those that are longer
253
+ than specified in max_sizes.
254
+
255
+ Args:
256
+ indices (np.array): original array of sample indices
257
+ max_sizes (int or list[int] or tuple[int]): max sample size,
258
+ can be defined separately for src and tgt (then list or tuple)
259
+
260
+ Returns:
261
+ np.array: filtered sample array
262
+ list: list of removed indices
263
+ """
264
+ if max_sizes is None:
265
+ return indices, []
266
+ if type(max_sizes) in (int, float):
267
+ max_src_size, max_tgt_size = max_sizes, max_sizes
268
+ else:
269
+ max_src_size, max_tgt_size = max_sizes
270
+ if tgt_sizes is None:
271
+ ignored = indices[src_sizes[indices] > max_src_size]
272
+ else:
273
+ ignored = indices[
274
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
275
+ ]
276
+ if len(ignored) > 0:
277
+ if tgt_sizes is None:
278
+ indices = indices[src_sizes[indices] <= max_src_size]
279
+ else:
280
+ indices = indices[
281
+ (src_sizes[indices] <= max_src_size)
282
+ & (tgt_sizes[indices] <= max_tgt_size)
283
+ ]
284
+ return indices, ignored.tolist()
285
+
286
+
287
+ def batch_by_size(
288
+ indices,
289
+ num_tokens_fn,
290
+ num_tokens_vec=None,
291
+ max_tokens=None,
292
+ max_sentences=None,
293
+ required_batch_size_multiple=1,
294
+ fixed_shapes=None,
295
+ ):
296
+ """
297
+ Yield mini-batches of indices bucketed by size. Batches may contain
298
+ sequences of different lengths.
299
+
300
+ Args:
301
+ indices (List[int]): ordered list of dataset indices
302
+ num_tokens_fn (callable): function that returns the number of tokens at
303
+ a given index
304
+ num_tokens_vec (List[int], optional): precomputed vector of the number
305
+ of tokens for each index in indices (to enable faster batch generation)
306
+ max_tokens (int, optional): max number of tokens in each batch
307
+ (default: None).
308
+ max_sentences (int, optional): max number of sentences in each
309
+ batch (default: None).
310
+ required_batch_size_multiple (int, optional): require batch size to
311
+ be less than N or a multiple of N (default: 1).
312
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
313
+ only be created with the given shapes. *max_sentences* and
314
+ *required_batch_size_multiple* will be ignored (default: None).
315
+ """
316
+ try:
317
+ from fairseq.data.data_utils_fast import (
318
+ batch_by_size_fn,
319
+ batch_by_size_vec,
320
+ batch_fixed_shapes_fast,
321
+ )
322
+ except ImportError:
323
+ raise ImportError(
324
+ "Please build Cython components with: "
325
+ "`python setup.py build_ext --inplace`"
326
+ )
327
+ except ValueError:
328
+ raise ValueError(
329
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
330
+ )
331
+
332
+ # added int() to avoid TypeError: an integer is required
333
+ max_tokens = (
334
+ int(max_tokens) if max_tokens is not None else -1
335
+ )
336
+ max_sentences = max_sentences if max_sentences is not None else -1
337
+ bsz_mult = required_batch_size_multiple
338
+
339
+ if not isinstance(indices, np.ndarray):
340
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
341
+
342
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
343
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
344
+
345
+ if fixed_shapes is None:
346
+ if num_tokens_vec is None:
347
+ return batch_by_size_fn(
348
+ indices,
349
+ num_tokens_fn,
350
+ max_tokens,
351
+ max_sentences,
352
+ bsz_mult,
353
+ )
354
+ else:
355
+ return batch_by_size_vec(
356
+ indices,
357
+ num_tokens_vec,
358
+ max_tokens,
359
+ max_sentences,
360
+ bsz_mult,
361
+ )
362
+
363
+ else:
364
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
365
+ sort_order = np.lexsort(
366
+ [
367
+ fixed_shapes[:, 1].argsort(), # length
368
+ fixed_shapes[:, 0].argsort(), # bsz
369
+ ]
370
+ )
371
+ fixed_shapes_sorted = fixed_shapes[sort_order]
372
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
373
+
374
+
375
+ def post_process(sentence: str, symbol: str):
376
+ if symbol == "sentencepiece":
377
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
378
+ elif symbol == "wordpiece":
379
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
380
+ elif symbol == "letter":
381
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
382
+ elif symbol == "silence":
383
+ import re
384
+ sentence = sentence.replace("<SIL>", "")
385
+ sentence = re.sub(' +', ' ', sentence).strip()
386
+ elif symbol == "_EOW":
387
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
+ if symbol == "subword_nmt":
390
+ symbol = "@@ "
391
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
392
+ elif symbol == "none":
393
+ pass
394
+ elif symbol is not None:
395
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
+ return sentence
397
+
398
+
399
+ def compute_mask_indices(
400
+ shape: Tuple[int, int],
401
+ padding_mask: Optional[torch.Tensor],
402
+ mask_prob: float,
403
+ mask_length: int,
404
+ mask_type: str = "static",
405
+ mask_other: float = 0.0,
406
+ min_masks: int = 0,
407
+ no_overlap: bool = False,
408
+ min_space: int = 0,
409
+ ) -> np.ndarray:
410
+ """
411
+ Computes random mask spans for a given shape
412
+
413
+ Args:
414
+ shape: the the shape for which to compute masks.
415
+ should be of size 2 where first element is batch size and 2nd is timesteps
416
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
417
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
418
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
419
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
420
+ mask_type: how to compute mask lengths
421
+ static = fixed size
422
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
423
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
424
+ poisson = sample from possion distribution with lambda = mask length
425
+ min_masks: minimum number of masked spans
426
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
427
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
428
+ """
429
+
430
+ bsz, all_sz = shape
431
+ mask = np.full((bsz, all_sz), False)
432
+
433
+ all_num_mask = int(
434
+ # add a random number for probabilistic rounding
435
+ mask_prob * all_sz / float(mask_length)
436
+ + np.random.rand()
437
+ )
438
+
439
+ all_num_mask = max(min_masks, all_num_mask)
440
+
441
+ mask_idcs = []
442
+ for i in range(bsz):
443
+ if padding_mask is not None:
444
+ sz = all_sz - padding_mask[i].long().sum().item()
445
+ num_mask = int(
446
+ # add a random number for probabilistic rounding
447
+ mask_prob * sz / float(mask_length)
448
+ + np.random.rand()
449
+ )
450
+ num_mask = max(min_masks, num_mask)
451
+ else:
452
+ sz = all_sz
453
+ num_mask = all_num_mask
454
+
455
+ if mask_type == "static":
456
+ lengths = np.full(num_mask, mask_length)
457
+ elif mask_type == "uniform":
458
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
459
+ elif mask_type == "normal":
460
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
461
+ lengths = [max(1, int(round(x))) for x in lengths]
462
+ elif mask_type == "poisson":
463
+ lengths = np.random.poisson(mask_length, size=num_mask)
464
+ lengths = [int(round(x)) for x in lengths]
465
+ else:
466
+ raise Exception("unknown mask selection " + mask_type)
467
+
468
+ if sum(lengths) == 0:
469
+ lengths[0] = min(mask_length, sz - 1)
470
+
471
+ if no_overlap:
472
+ mask_idc = []
473
+
474
+ def arrange(s, e, length, keep_length):
475
+ span_start = np.random.randint(s, e - length)
476
+ mask_idc.extend(span_start + i for i in range(length))
477
+
478
+ new_parts = []
479
+ if span_start - s - min_space >= keep_length:
480
+ new_parts.append((s, span_start - min_space + 1))
481
+ if e - span_start - keep_length - min_space > keep_length:
482
+ new_parts.append((span_start + length + min_space, e))
483
+ return new_parts
484
+
485
+ parts = [(0, sz)]
486
+ min_length = min(lengths)
487
+ for length in sorted(lengths, reverse=True):
488
+ lens = np.fromiter(
489
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
490
+ np.int,
491
+ )
492
+ l_sum = np.sum(lens)
493
+ if l_sum == 0:
494
+ break
495
+ probs = lens / np.sum(lens)
496
+ c = np.random.choice(len(parts), p=probs)
497
+ s, e = parts.pop(c)
498
+ parts.extend(arrange(s, e, length, min_length))
499
+ mask_idc = np.asarray(mask_idc)
500
+ else:
501
+ min_len = min(lengths)
502
+ if sz - min_len <= num_mask:
503
+ min_len = sz - num_mask - 1
504
+
505
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
506
+
507
+ mask_idc = np.asarray(
508
+ [
509
+ mask_idc[j] + offset
510
+ for j in range(len(mask_idc))
511
+ for offset in range(lengths[j])
512
+ ]
513
+ )
514
+
515
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
516
+
517
+ min_len = min([len(m) for m in mask_idcs])
518
+ for i, mask_idc in enumerate(mask_idcs):
519
+ if len(mask_idc) > min_len:
520
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
521
+ mask[i, mask_idc] = True
522
+
523
+ return mask
524
+
525
+
526
+ def get_mem_usage():
527
+ try:
528
+ import psutil
529
+
530
+ mb = 1024 * 1024
531
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
532
+ except ImportError:
533
+ return "N/A"
534
+
535
+
536
+ # lens: torch.LongTensor
537
+ # returns: torch.BoolTensor
538
+ def lengths_to_padding_mask(lens):
539
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
540
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
541
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
542
+ return mask
543
+
544
+
545
+ # lens: torch.LongTensor
546
+ # returns: torch.BoolTensor
547
+ def lengths_to_mask(lens):
548
+ return ~lengths_to_padding_mask(lens)
549
+
550
+
551
+ def get_buckets(sizes, num_buckets):
552
+ buckets = np.unique(
553
+ np.percentile(
554
+ sizes,
555
+ np.linspace(0, 100, num_buckets + 1),
556
+ interpolation='lower',
557
+ )[1:]
558
+ )
559
+ return buckets
560
+
561
+
562
+ def get_bucketed_sizes(orig_sizes, buckets):
563
+ sizes = np.copy(orig_sizes)
564
+ assert np.min(sizes) >= 0
565
+ start_val = -1
566
+ for end_val in buckets:
567
+ mask = (sizes > start_val) & (sizes <= end_val)
568
+ sizes[mask] = end_val
569
+ start_val = end_val
570
+ return sizes
571
+
572
+
573
+
574
+ def _find_extra_valid_paths(dataset_path: str) -> set:
575
+ paths = utils.split_paths(dataset_path)
576
+ all_valid_paths = set()
577
+ for sub_dir in paths:
578
+ contents = PathManager.ls(sub_dir)
579
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
580
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
581
+ # Remove .bin, .idx etc
582
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
583
+ return roots
584
+
585
+
586
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
587
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
588
+ if (
589
+ train_cfg.dataset.ignore_unused_valid_subsets
590
+ or train_cfg.dataset.combine_valid_subsets
591
+ or train_cfg.dataset.disable_validation
592
+ or not hasattr(train_cfg.task, "data")
593
+ ):
594
+ return
595
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
596
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
597
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
598
+ if ignored_paths:
599
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
600
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
601
+ raise ValueError(msg)
data/file_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pickle
4
+
5
+
6
+ class FileDataset:
7
+ def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
8
+ self.file_path = file_path
9
+ assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
10
+
11
+ self.separator = separator
12
+ if selected_col_ids is None:
13
+ # default to all fields
14
+ self.selected_col_ids = list(
15
+ range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
16
+ else:
17
+ self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
18
+ if dtypes is None:
19
+ # default to str
20
+ self.dtypes = [str for col_id in self.selected_col_ids]
21
+ else:
22
+ self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
23
+ assert len(self.dtypes) == len(self.selected_col_ids)
24
+
25
+ self.data_cnt = 0
26
+ try:
27
+ self.slice_id = torch.distributed.get_rank()
28
+ self.slice_count = torch.distributed.get_world_size()
29
+ except Exception:
30
+ self.slice_id = 0
31
+ self.slice_count = 1
32
+ self.cached_index = cached_index
33
+ self._init_seek_index()
34
+ self._reader = self._get_reader()
35
+ print("file {} slice_id {} row count {} total row count {}".format(
36
+ self.file_path, self.slice_id, self.row_count, self.total_row_count)
37
+ )
38
+
39
+ def _init_seek_index(self):
40
+ if self.cached_index:
41
+ cache_path = "{}.index".format(self.file_path)
42
+ assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
43
+ self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
44
+ print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
45
+ self.file_path, self.slice_id))
46
+ else:
47
+ # make an iteration over the file to get row_count and line_idx-to-offset mapping
48
+ fp = open(self.file_path, "r")
49
+ print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
50
+ self.file_path, self.slice_id))
51
+ self.total_row_count = 0
52
+ offset = 0
53
+ self.lineid_to_offset = []
54
+ for line in fp:
55
+ self.lineid_to_offset.append(offset)
56
+ self.total_row_count += 1
57
+ offset += len(line.encode('utf-8'))
58
+ self._compute_start_pos_and_row_count()
59
+ print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
60
+ self.file_path, self.slice_id))
61
+
62
+ def _compute_start_pos_and_row_count(self):
63
+ self.row_count = self.total_row_count // self.slice_count
64
+ if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
65
+ self.row_count += 1
66
+ self.start_pos = self.row_count * self.slice_id
67
+ else:
68
+ self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
69
+
70
+ def _get_reader(self):
71
+ fp = open(self.file_path, "r")
72
+ fp.seek(self.lineid_to_offset[self.start_pos])
73
+ return fp
74
+
75
+ def _seek(self, offset=0):
76
+ try:
77
+ print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
78
+ self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
79
+ self.data_cnt = offset
80
+ except Exception:
81
+ print("slice_id {} seek offset {}".format(self.slice_id, offset))
82
+ self._reader.seek(self.lineid_to_offset[offset])
83
+ self.data_cnt = offset
84
+
85
+ def __del__(self):
86
+ self._reader.close()
87
+
88
+ def __len__(self):
89
+ return self.row_count
90
+
91
+ def get_total_row_count(self):
92
+ return self.total_row_count
93
+
94
+ def __getitem__(self, index):
95
+ if self.data_cnt == self.row_count:
96
+ print("reach the end of datafile, start a new reader")
97
+ self.data_cnt = 0
98
+ self._reader = self._get_reader()
99
+ column_l = self._reader.readline().rstrip("\n").split(self.separator)
100
+ self.data_cnt += 1
101
+ column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
102
+ return column_l
data/mm_data/__init__.py ADDED
File without changes
data/mm_data/caption_dataset.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+ import string
10
+
11
+ import numpy as np
12
+ import torch
13
+ import base64
14
+ from torchvision import transforms
15
+
16
+ from PIL import Image, ImageFile
17
+
18
+ from data import data_utils
19
+ from data.ofa_dataset import OFADataset
20
+
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+ ImageFile.MAX_IMAGE_PIXELS = None
23
+ Image.MAX_IMAGE_PIXELS = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
+
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+
31
+
32
+ def collate(samples, pad_idx, eos_idx):
33
+ if len(samples) == 0:
34
+ return {}
35
+
36
+ def merge(key):
37
+ return data_utils.collate_tokens(
38
+ [s[key] for s in samples],
39
+ pad_idx,
40
+ eos_idx=eos_idx,
41
+ )
42
+
43
+ id = np.array([s["id"] for s in samples])
44
+ src_tokens = merge("source")
45
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
+
47
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
+
50
+ prev_output_tokens = None
51
+ target = None
52
+ if samples[0].get("target", None) is not None:
53
+ target = merge("target")
54
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
55
+ ntokens = tgt_lengths.sum().item()
56
+
57
+ if samples[0].get("prev_output_tokens", None) is not None:
58
+ prev_output_tokens = merge("prev_output_tokens")
59
+ else:
60
+ ntokens = src_lengths.sum().item()
61
+
62
+ batch = {
63
+ "id": id,
64
+ "nsentences": len(samples),
65
+ "ntokens": ntokens,
66
+ "net_input": {
67
+ "src_tokens": src_tokens,
68
+ "src_lengths": src_lengths,
69
+ "patch_images": patch_images,
70
+ "patch_masks": patch_masks,
71
+ "prev_output_tokens": prev_output_tokens
72
+ },
73
+ "target": target,
74
+ }
75
+
76
+ return batch
77
+
78
+
79
+ class CaptionDataset(OFADataset):
80
+ def __init__(
81
+ self,
82
+ split,
83
+ dataset,
84
+ bpe,
85
+ src_dict,
86
+ tgt_dict=None,
87
+ max_src_length=128,
88
+ max_tgt_length=30,
89
+ patch_image_size=224,
90
+ imagenet_default_mean_and_std=False,
91
+ scst=False
92
+ ):
93
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
94
+ self.max_src_length = max_src_length
95
+ self.max_tgt_length = max_tgt_length
96
+ self.patch_image_size = patch_image_size
97
+ self.scst = scst
98
+
99
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
100
+
101
+ if imagenet_default_mean_and_std:
102
+ mean = IMAGENET_DEFAULT_MEAN
103
+ std = IMAGENET_DEFAULT_STD
104
+ else:
105
+ mean = [0.5, 0.5, 0.5]
106
+ std = [0.5, 0.5, 0.5]
107
+
108
+ self.patch_resize_transform = transforms.Compose([
109
+ lambda image: image.convert("RGB"),
110
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(mean=mean, std=std),
113
+ ])
114
+
115
+ def __getitem__(self, index):
116
+ uniq_id, image, caption = self.dataset[index]
117
+
118
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
119
+ patch_image = self.patch_resize_transform(image)
120
+ patch_mask = torch.tensor([True])
121
+
122
+ if self.split == 'train' and not self.scst:
123
+ caption = caption.translate(self.transtab).strip()
124
+ caption_token_list = caption.strip().split()
125
+ tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
126
+ else:
127
+ caption = ' '.join(caption.strip().split())
128
+ caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
129
+ tgt_caption = '&&'.join(caption_list)
130
+ src_item = self.encode_text(" what does the image describe?")
131
+ tgt_item = self.encode_text(" {}".format(tgt_caption))
132
+
133
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
134
+ target_item = torch.cat([tgt_item, self.eos_item])
135
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
136
+
137
+ example = {
138
+ "id": uniq_id,
139
+ "source": src_item,
140
+ "patch_image": patch_image,
141
+ "patch_mask": patch_mask,
142
+ "target": target_item,
143
+ "prev_output_tokens": prev_output_item
144
+ }
145
+ return example
146
+
147
+ def collater(self, samples, pad_to_length=None):
148
+ """Merge a list of samples to form a mini-batch.
149
+ Args:
150
+ samples (List[dict]): samples to collate
151
+ Returns:
152
+ dict: a mini-batch with the following keys:
153
+ """
154
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/mm_data/refcoco_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ import base64
13
+ import utils.transforms as T
14
+
15
+ from PIL import Image, ImageFile
16
+
17
+ from data import data_utils
18
+ from data.ofa_dataset import OFADataset
19
+
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ ImageFile.MAX_IMAGE_PIXELS = None
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ logger = logging.getLogger(__name__)
25
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
26
+
27
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
28
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
29
+
30
+
31
+ def collate(samples, pad_idx, eos_idx):
32
+ if len(samples) == 0:
33
+ return {}
34
+
35
+ def merge(key):
36
+ return data_utils.collate_tokens(
37
+ [s[key] for s in samples],
38
+ pad_idx,
39
+ eos_idx=eos_idx,
40
+ )
41
+
42
+ id = np.array([s["id"] for s in samples])
43
+ src_tokens = merge("source")
44
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
45
+
46
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
47
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
48
+
49
+ w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
50
+ h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
51
+ region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
52
+
53
+ prev_output_tokens = None
54
+ target = None
55
+ if samples[0].get("target", None) is not None:
56
+ target = merge("target")
57
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
58
+ ntokens = tgt_lengths.sum().item()
59
+
60
+ if samples[0].get("prev_output_tokens", None) is not None:
61
+ prev_output_tokens = merge("prev_output_tokens")
62
+ else:
63
+ ntokens = src_lengths.sum().item()
64
+
65
+ batch = {
66
+ "id": id,
67
+ "nsentences": len(samples),
68
+ "ntokens": ntokens,
69
+ "net_input": {
70
+ "src_tokens": src_tokens,
71
+ "src_lengths": src_lengths,
72
+ "patch_images": patch_images,
73
+ "patch_masks": patch_masks,
74
+ "prev_output_tokens": prev_output_tokens
75
+ },
76
+ "target": target,
77
+ "w_resize_ratios": w_resize_ratios,
78
+ "h_resize_ratios": h_resize_ratios,
79
+ "region_coords": region_coords
80
+ }
81
+
82
+ return batch
83
+
84
+
85
+ class RefcocoDataset(OFADataset):
86
+ def __init__(
87
+ self,
88
+ split,
89
+ dataset,
90
+ bpe,
91
+ src_dict,
92
+ tgt_dict=None,
93
+ max_src_length=80,
94
+ max_tgt_length=30,
95
+ patch_image_size=512,
96
+ imagenet_default_mean_and_std=False,
97
+ num_bins=1000,
98
+ max_image_size=512
99
+ ):
100
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
101
+ self.max_src_length = max_src_length
102
+ self.max_tgt_length = max_tgt_length
103
+ self.patch_image_size = patch_image_size
104
+ self.num_bins = num_bins
105
+
106
+ if imagenet_default_mean_and_std:
107
+ mean = IMAGENET_DEFAULT_MEAN
108
+ std = IMAGENET_DEFAULT_STD
109
+ else:
110
+ mean = [0.5, 0.5, 0.5]
111
+ std = [0.5, 0.5, 0.5]
112
+
113
+ # for positioning
114
+ self.positioning_transform = T.Compose([
115
+ T.RandomResize([patch_image_size], max_size=patch_image_size),
116
+ T.ToTensor(),
117
+ T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
118
+ ])
119
+
120
+ def __getitem__(self, index):
121
+ uniq_id, base64_str, text, region_coord = self.dataset[index]
122
+
123
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
124
+ w, h = image.size
125
+ boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
126
+ x0, y0, x1, y1 = region_coord.strip().split(',')
127
+ region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
128
+ boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
129
+ boxes_target["labels"] = np.array([0])
130
+ boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
131
+
132
+ patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
133
+ resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
134
+ patch_mask = torch.tensor([True])
135
+ quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
136
+ quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
137
+ quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
138
+ quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
139
+ region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
140
+ src_caption = self.pre_caption(text, self.max_src_length)
141
+ src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
142
+ tgt_item = self.encode_text(region_coord, use_bpe=False)
143
+
144
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
145
+ target_item = torch.cat([tgt_item, self.eos_item])
146
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
147
+
148
+ example = {
149
+ "id": uniq_id,
150
+ "source": src_item,
151
+ "patch_image": patch_image,
152
+ "patch_mask": patch_mask,
153
+ "target": target_item,
154
+ "prev_output_tokens": prev_output_item,
155
+ "w_resize_ratio": resize_w / w,
156
+ "h_resize_ratio": resize_h / h,
157
+ "region_coord": region
158
+ }
159
+ return example
160
+
161
+ def collater(self, samples, pad_to_length=None):
162
+ """Merge a list of samples to form a mini-batch.
163
+ Args:
164
+ samples (List[dict]): samples to collate
165
+ Returns:
166
+ dict: a mini-batch with the following keys:
167
+ """
168
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/mm_data/vqa_gen_dataset.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ import base64
13
+ from torchvision import transforms
14
+
15
+ from PIL import Image, ImageFile
16
+
17
+ from data import data_utils
18
+ from data.ofa_dataset import OFADataset
19
+
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ ImageFile.MAX_IMAGE_PIXELS = None
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ logger = logging.getLogger(__name__)
25
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
26
+
27
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
28
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
29
+
30
+
31
+ def collate(samples, pad_idx, eos_idx):
32
+ if len(samples) == 0:
33
+ return {}
34
+
35
+ def merge(key):
36
+ return data_utils.collate_tokens(
37
+ [s[key] for s in samples],
38
+ pad_idx,
39
+ eos_idx=eos_idx,
40
+ )
41
+
42
+ id = np.array([s["id"] for s in samples])
43
+ src_tokens = merge("source")
44
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
45
+
46
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
47
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
48
+
49
+ conf = None
50
+ if samples[0].get("conf", None) is not None:
51
+ conf = torch.cat([s['conf'] for s in samples], dim=0)
52
+
53
+ ref_dict = None
54
+ if samples[0].get("ref_dict", None) is not None:
55
+ ref_dict = np.array([s['ref_dict'] for s in samples])
56
+
57
+ constraint_masks = None
58
+ if samples[0].get("constraint_mask", None) is not None:
59
+ constraint_masks = merge("constraint_mask")
60
+
61
+ decoder_prompts = None
62
+ if samples[0].get("decoder_prompt", None) is not None:
63
+ decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
64
+
65
+ prev_output_tokens = None
66
+ target = None
67
+ if samples[0].get("target", None) is not None:
68
+ target = merge("target")
69
+ tgt_lengths = torch.LongTensor(
70
+ [s["target"].ne(pad_idx).long().sum() for s in samples]
71
+ )
72
+ ntokens = tgt_lengths.sum().item()
73
+
74
+ if samples[0].get("prev_output_tokens", None) is not None:
75
+ prev_output_tokens = merge("prev_output_tokens")
76
+ else:
77
+ ntokens = src_lengths.sum().item()
78
+
79
+ batch = {
80
+ "id": id,
81
+ "nsentences": len(samples),
82
+ "ntokens": ntokens,
83
+ "net_input": {
84
+ "src_tokens": src_tokens,
85
+ "src_lengths": src_lengths,
86
+ "patch_images": patch_images,
87
+ "patch_masks": patch_masks,
88
+ "prev_output_tokens": prev_output_tokens
89
+ },
90
+ "conf": conf,
91
+ "ref_dict": ref_dict,
92
+ "constraint_masks": constraint_masks,
93
+ "decoder_prompts": decoder_prompts,
94
+ "target": target
95
+ }
96
+
97
+ return batch
98
+
99
+
100
+ class VqaGenDataset(OFADataset):
101
+ def __init__(
102
+ self,
103
+ split,
104
+ dataset,
105
+ bpe,
106
+ src_dict,
107
+ tgt_dict=None,
108
+ max_src_length=128,
109
+ max_object_length=30,
110
+ max_tgt_length=30,
111
+ patch_image_size=224,
112
+ add_object=False,
113
+ constraint_trie=None,
114
+ imagenet_default_mean_and_std=False,
115
+ prompt_type="none"
116
+ ):
117
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
118
+ self.max_src_length = max_src_length
119
+ self.max_object_length = max_object_length
120
+ self.max_tgt_length = max_tgt_length
121
+ self.patch_image_size = patch_image_size
122
+
123
+ self.add_object = add_object
124
+ self.constraint_trie = constraint_trie
125
+ self.prompt_type = prompt_type
126
+
127
+ if imagenet_default_mean_and_std:
128
+ mean = IMAGENET_DEFAULT_MEAN
129
+ std = IMAGENET_DEFAULT_STD
130
+ else:
131
+ mean = [0.5, 0.5, 0.5]
132
+ std = [0.5, 0.5, 0.5]
133
+
134
+ self.patch_resize_transform = transforms.Compose([
135
+ lambda image: image.convert("RGB"),
136
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
137
+ transforms.ToTensor(),
138
+ transforms.Normalize(mean=mean, std=std),
139
+ ])
140
+
141
+ def __getitem__(self, index):
142
+ item = self.dataset[index]
143
+ if len(item) == 5:
144
+ uniq_id, image, question, ref, predict_objects = item
145
+ else:
146
+ uniq_id, image, question, ref, predict_objects, caption = item
147
+
148
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
149
+ patch_image = self.patch_resize_transform(image)
150
+ patch_mask = torch.tensor([True])
151
+
152
+ question = self.pre_question(question, self.max_src_length)
153
+ question = question + '?' if not question.endswith('?') else question
154
+ src_item = self.encode_text(' {}'.format(question))
155
+
156
+ ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
157
+ answer = max(ref_dict, key=ref_dict.get)
158
+ conf = torch.tensor([ref_dict[answer]])
159
+ tgt_item = self.encode_text(" {}".format(answer))
160
+
161
+ if self.add_object and predict_objects is not None:
162
+ predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
163
+ predict_object_item = self.encode_text(" object: {}".format(predict_object_seq))
164
+ src_item = torch.cat([src_item, predict_object_item])
165
+
166
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
167
+ if self.prompt_type == 'none':
168
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
169
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
170
+ decoder_prompt = self.bos_item
171
+ elif self.prompt_type == 'src':
172
+ prev_output_item = torch.cat([src_item, tgt_item])
173
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
174
+ decoder_prompt = src_item
175
+ elif self.prompt_type == 'prev_output':
176
+ prev_output_item = torch.cat([src_item[:-1], tgt_item])
177
+ target_item = torch.cat([prev_output_item[1:], self.eos_item])
178
+ decoder_prompt = src_item[:-1]
179
+ else:
180
+ raise NotImplementedError
181
+ target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
182
+
183
+ example = {
184
+ "id": uniq_id,
185
+ "source": src_item,
186
+ "patch_image": patch_image,
187
+ "patch_mask": patch_mask,
188
+ "target": target_item,
189
+ "prev_output_tokens": prev_output_item,
190
+ "decoder_prompt": decoder_prompt,
191
+ "ref_dict": ref_dict,
192
+ "conf": conf,
193
+ }
194
+ if self.constraint_trie is not None:
195
+ constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
196
+ start_idx = len(target_item) - len(tgt_item) - 1
197
+ for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
198
+ constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
199
+ constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
200
+ constraint_mask[i][constraint_nodes] = True
201
+ example["constraint_mask"] = constraint_mask
202
+ return example
203
+
204
+ def collater(self, samples, pad_to_length=None):
205
+ """Merge a list of samples to form a mini-batch.
206
+ Args:
207
+ samples (List[dict]): samples to collate
208
+ Returns:
209
+ dict: a mini-batch with the following keys:
210
+ """
211
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/ofa_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import torch.utils.data
4
+ from fairseq.data import FairseqDataset
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class OFADataset(FairseqDataset):
10
+ def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
11
+ self.split = split
12
+ self.dataset = dataset
13
+ self.bpe = bpe
14
+ self.src_dict = src_dict
15
+ self.tgt_dict = tgt_dict
16
+
17
+ self.bos = src_dict.bos()
18
+ self.eos = src_dict.eos()
19
+ self.pad = src_dict.pad()
20
+ self.bos_item = torch.LongTensor([self.bos])
21
+ self.eos_item = torch.LongTensor([self.eos])
22
+
23
+ def __len__(self):
24
+ return len(self.dataset)
25
+
26
+ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
27
+ s = self.tgt_dict.encode_line(
28
+ line=self.bpe.encode(text) if use_bpe else text,
29
+ add_if_not_exist=False,
30
+ append_eos=False
31
+ ).long()
32
+ if length is not None:
33
+ s = s[:length]
34
+ if append_bos:
35
+ s = torch.cat([self.bos_item, s])
36
+ if append_eos:
37
+ s = torch.cat([s, self.eos_item])
38
+ return s
39
+
40
+ def pre_question(self, question, max_ques_words):
41
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
42
+
43
+ question = re.sub(
44
+ r"\s{2,}",
45
+ ' ',
46
+ question,
47
+ )
48
+ question = question.rstrip('\n')
49
+ question = question.strip(' ')
50
+
51
+ # truncate question
52
+ question_words = question.split(' ')
53
+ if len(question_words) > max_ques_words:
54
+ question = ' '.join(question_words[:max_ques_words])
55
+
56
+ return question
57
+
58
+ def pre_caption(self, caption, max_words):
59
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
60
+
61
+ caption = re.sub(
62
+ r"\s{2,}",
63
+ ' ',
64
+ caption,
65
+ )
66
+ caption = caption.rstrip('\n')
67
+ caption = caption.strip(' ')
68
+
69
+ # truncate caption
70
+ caption_words = caption.split(' ')
71
+ if len(caption_words) > max_words:
72
+ caption = ' '.join(caption_words[:max_words])
73
+
74
+ return caption
datasets.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Datasets
2
+
3
+ We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
4
+
5
+ ## Finetuning
6
+
7
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
8
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
9
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
10
+ * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
evaluate.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import json
11
+ from itertools import chain
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ from fairseq import distributed_utils, options, tasks, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.logging import progress_bar
19
+ from fairseq.utils import reset_logging
20
+ from omegaconf import DictConfig
21
+
22
+ from utils import checkpoint_utils
23
+ from utils.eval_utils import eval_step
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
29
+ stream=sys.stdout,
30
+ )
31
+ logger = logging.getLogger("ofa.evaluate")
32
+
33
+
34
+ def apply_half(t):
35
+ if t.dtype is torch.float32:
36
+ return t.to(dtype=torch.half)
37
+ return t
38
+
39
+
40
+ def main(cfg: DictConfig, **kwargs):
41
+ utils.import_user_module(cfg.common)
42
+
43
+ reset_logging()
44
+ logger.info(cfg)
45
+
46
+ assert (
47
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
48
+ ), "Must specify batch size either with --max-tokens or --batch-size"
49
+
50
+ # Fix seed for stochastic decoding
51
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
52
+ np.random.seed(cfg.common.seed)
53
+ utils.set_torch_seed(cfg.common.seed)
54
+
55
+ use_fp16 = cfg.common.fp16
56
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
57
+
58
+ if use_cuda:
59
+ torch.cuda.set_device(cfg.distributed_training.device_id)
60
+
61
+ # Load ensemble
62
+ overrides = eval(cfg.common_eval.model_overrides)
63
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
64
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
65
+ utils.split_paths(cfg.common_eval.path),
66
+ arg_overrides=overrides,
67
+ suffix=cfg.checkpoint.checkpoint_suffix,
68
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
69
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
70
+ )
71
+
72
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
73
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
74
+
75
+ # Move models to GPU
76
+ for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
77
+ if kwargs['ema_eval']:
78
+ logger.info("loading EMA weights from {}".format(ckpt_path))
79
+ model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
80
+ model.eval()
81
+ if use_fp16:
82
+ model.half()
83
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
84
+ model.cuda()
85
+ model.prepare_for_inference_(cfg)
86
+
87
+ # Load dataset (possibly sharded)
88
+ itr = task.get_batch_iterator(
89
+ dataset=task.dataset(cfg.dataset.gen_subset),
90
+ max_tokens=cfg.dataset.max_tokens,
91
+ max_sentences=cfg.dataset.batch_size,
92
+ max_positions=utils.resolve_max_positions(
93
+ task.max_positions(), *[m.max_positions() for m in models]
94
+ ),
95
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
96
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
97
+ seed=cfg.common.seed,
98
+ num_shards=cfg.distributed_training.distributed_world_size,
99
+ shard_id=cfg.distributed_training.distributed_rank,
100
+ num_workers=cfg.dataset.num_workers,
101
+ data_buffer_size=cfg.dataset.data_buffer_size,
102
+ ).next_epoch_itr(shuffle=False)
103
+ progress = progress_bar.progress_bar(
104
+ itr,
105
+ log_format=cfg.common.log_format,
106
+ log_interval=cfg.common.log_interval,
107
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
108
+ )
109
+
110
+ # Initialize generator
111
+ generator = task.build_generator(models, cfg.generation)
112
+
113
+ results = []
114
+ score_sum = torch.FloatTensor([0]).cuda()
115
+ score_cnt = torch.FloatTensor([0]).cuda()
116
+ for sample in progress:
117
+ if "net_input" not in sample:
118
+ continue
119
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
120
+ sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
121
+ with torch.no_grad():
122
+ result, scores = eval_step(task, generator, models, sample)
123
+ results += result
124
+ score_sum += sum(scores) if scores is not None else 0
125
+ score_cnt += len(scores) if scores is not None else 0
126
+ progress.log({"sentences": sample["nsentences"]})
127
+
128
+ gather_results = None
129
+ if cfg.distributed_training.distributed_world_size > 1:
130
+ gather_results = [None for _ in range(dist.get_world_size())]
131
+ dist.all_gather_object(gather_results, results)
132
+ dist.all_reduce(score_sum.data)
133
+ dist.all_reduce(score_cnt.data)
134
+ if score_cnt.item() > 0:
135
+ logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
136
+ score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
137
+ ))
138
+
139
+ if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
140
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
141
+ output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
142
+ gather_results = list(chain(*gather_results)) if gather_results is not None else results
143
+ with open(output_path, 'w') as fw:
144
+ json.dump(gather_results, fw)
145
+
146
+
147
+ def cli_main():
148
+ parser = options.get_generation_parser()
149
+ parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
150
+ args = options.parse_args_and_arch(parser)
151
+ cfg = convert_namespace_to_omegaconf(args)
152
+ distributed_utils.call_main(cfg, main, ema_eval=args.ema_eval)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ cli_main()
fairseq/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or main):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or main):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before submitting
2
+
3
+ - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
4
+ - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
5
+ - [ ] Did you make sure to update the docs?
6
+ - [ ] Did you write any new necessary tests?
7
+
8
+ ## What does this PR do?
9
+ Fixes # (issue).
10
+
11
+ ## PR review
12
+ Anyone in the community is free to review the PR once the tests have passed.
13
+ If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
14
+
15
+ ## Did you have fun?
16
+ Make sure you had fun coding 🙃
fairseq/.github/stale.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+ # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
3
+ # Number of days of inactivity before an issue becomes stale
4
+ daysUntilStale: 90
5
+ # Number of days of inactivity before a stale issue is closed
6
+ daysUntilClose: 7
7
+ # Issues with these labels will never be considered stale
8
+ exemptLabels:
9
+ - bug
10
+ # Label to use when marking an issue as stale
11
+ staleLabel: stale
12
+ issues:
13
+ # Comment to post when marking an issue as stale.
14
+ markComment: >
15
+ This issue has been automatically marked as stale.
16
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
17
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
18
+ # Comment to post when closing a stale issue.
19
+ closeComment: >
20
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
21
+ pulls:
22
+ # Comment to post when marking a pull request as stale.
23
+ markComment: >
24
+ This pull request has been automatically marked as stale.
25
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
26
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
27
+ # Comment to post when closing a stale pull request.
28
+ closeComment: >
29
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
30
+
fairseq/.github/workflows/build.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on:
4
+ # Trigger the workflow on push to main or any pull request
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+
10
+ jobs:
11
+ build:
12
+
13
+ strategy:
14
+ max-parallel: 4
15
+ matrix:
16
+ platform: [ubuntu-latest, macos-latest]
17
+ python-version: [3.6, 3.7]
18
+
19
+ runs-on: ${{ matrix.platform }}
20
+
21
+ steps:
22
+ - uses: actions/checkout@v2
23
+
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v2
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+
29
+ - name: Conditionally install pytorch
30
+ if: matrix.platform == 'windows-latest'
31
+ run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
32
+
33
+ - name: Install locally
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ git submodule update --init --recursive
37
+ python setup.py build_ext --inplace
38
+ python -m pip install --editable .
39
+
40
+ - name: Install optional test requirements
41
+ run: |
42
+ python -m pip install iopath transformers pyarrow
43
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
44
+
45
+ - name: Lint with flake8
46
+ run: |
47
+ pip install flake8
48
+ # stop the build if there are Python syntax errors or undefined names
49
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
50
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
52
+
53
+ - name: Run tests
54
+ run: |
55
+ python setup.py test
fairseq/.github/workflows/build_wheels.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build_wheels
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - v[0-9]+.[0-9]+.[x0-9]+
7
+ tags:
8
+ - v*
9
+
10
+ jobs:
11
+ build_wheels:
12
+ name: Build wheels on ${{ matrix.os }}
13
+ runs-on: ${{ matrix.os }}
14
+ strategy:
15
+ matrix:
16
+ os: [ubuntu-latest, macos-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+
21
+ - name: Install Python
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: '3.7'
25
+
26
+ - name: Install cibuildwheel
27
+ run: |
28
+ python -m pip install cibuildwheel
29
+
30
+ - name: Build wheels for CPython
31
+ run: |
32
+ python -m cibuildwheel --output-dir dist
33
+ env:
34
+ CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
35
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
36
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
37
+
38
+ - uses: actions/upload-artifact@v2
39
+ with:
40
+ name: wheels
41
+ path: ./dist/*.whl
fairseq/.gitignore ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JetBrains PyCharm IDE
2
+ .idea/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # macOS dir files
13
+ .DS_Store
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ env/
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+
34
+ # Checkpoints
35
+ checkpoints
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # pyenv
83
+ .python-version
84
+
85
+ # celery beat schedule file
86
+ celerybeat-schedule
87
+
88
+ # SageMath parsed files
89
+ *.sage.py
90
+
91
+ # dotenv
92
+ .env
93
+
94
+ # virtualenv
95
+ .venv
96
+ venv/
97
+ ENV/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+ .spyproject
102
+
103
+ # Rope project settings
104
+ .ropeproject
105
+
106
+ # mkdocs documentation
107
+ /site
108
+
109
+ # mypy
110
+ .mypy_cache/
111
+
112
+ # Generated files
113
+ /fairseq/temporal_convolution_tbc
114
+ /fairseq/modules/*_layer/*_forward.cu
115
+ /fairseq/modules/*_layer/*_backward.cu
116
+ /fairseq/version.py
117
+
118
+ # data
119
+ data-bin/
120
+
121
+ # reranking
122
+ /examples/reranking/rerank_data
123
+
124
+ # Cython-generated C++ source files
125
+ /fairseq/data/data_utils_fast.cpp
126
+ /fairseq/data/token_block_utils_fast.cpp
127
+
128
+ # VSCODE
129
+ .vscode/ftp-sync.json
130
+ .vscode/settings.json
131
+
132
+ # Experimental Folder
133
+ experimental/*
134
+
135
+ # Weights and Biases logs
136
+ wandb/
fairseq/.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "fairseq/model_parallel/megatron"]
2
+ path = fairseq/model_parallel/megatron
3
+ url = https://github.com/ngoyal2707/Megatron-LM
4
+ branch = fairseq
fairseq/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <conduct@pytorch.org>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
fairseq/CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fairseq/README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/fairseq_logo.png" width="150">
3
+ <br />
4
+ <br />
5
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
6
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
7
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
8
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
9
+ </p>
10
+
11
+ --------------------------------------------------------------------------------
12
+
13
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
14
+ developers to train custom models for translation, summarization, language
15
+ modeling and other text generation tasks.
16
+
17
+ We provide reference implementations of various sequence modeling papers:
18
+
19
+ <details><summary>List of implemented papers</summary><p>
20
+
21
+ * **Convolutional Neural Networks (CNN)**
22
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
23
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
24
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
25
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
26
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
27
+ * **LightConv and DynamicConv models**
28
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
29
+ * **Long Short-Term Memory (LSTM) networks**
30
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
31
+ * **Transformer (self-attention) networks**
32
+ + Attention Is All You Need (Vaswani et al., 2017)
33
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
34
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
35
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
36
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
37
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
38
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
39
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
40
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
41
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
42
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
43
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
44
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
45
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
46
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
47
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
48
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
49
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
50
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
51
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
52
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
53
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
54
+ * **Non-autoregressive Transformers**
55
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
56
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
57
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
58
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
59
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
60
+ * **Finetuning**
61
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
62
+
63
+ </p></details>
64
+
65
+ ### What's New:
66
+
67
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
68
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
69
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
70
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
71
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
72
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
73
+ * February 2021 [Added LASER training code](examples/laser/README.md)
74
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
75
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
76
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
77
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
78
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
79
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
80
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
81
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
82
+
83
+ <details><summary>Previous updates</summary><p>
84
+
85
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
86
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
87
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
88
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
89
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
90
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
91
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
92
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
93
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
94
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
95
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
96
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
97
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
98
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
99
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
100
+ * November 2019: [BART model and code released](examples/bart/README.md)
101
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
102
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
103
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
104
+ * July 2019: fairseq relicensed under MIT license
105
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
106
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
107
+
108
+ </p></details>
109
+
110
+ ### Features:
111
+
112
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
113
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
114
+ + beam search
115
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
116
+ + sampling (unconstrained, top-k and top-p/nucleus)
117
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
118
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
119
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
120
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
121
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
122
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
123
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
124
+
125
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
126
+ with a convenient `torch.hub` interface:
127
+
128
+ ``` python
129
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
130
+ en2de.translate('Hello world', beam=5)
131
+ # 'Hallo Welt'
132
+ ```
133
+
134
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
135
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
136
+
137
+ # Requirements and Installation
138
+
139
+ * [PyTorch](http://pytorch.org/) version >= 1.5.0
140
+ * Python version >= 3.6
141
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
142
+ * **To install fairseq** and develop locally:
143
+
144
+ ``` bash
145
+ git clone https://github.com/pytorch/fairseq
146
+ cd fairseq
147
+ pip install --editable ./
148
+
149
+ # on MacOS:
150
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
151
+
152
+ # to install the latest stable release (0.10.x)
153
+ # pip install fairseq
154
+ ```
155
+
156
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
157
+
158
+ ``` bash
159
+ git clone https://github.com/NVIDIA/apex
160
+ cd apex
161
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
162
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
163
+ --global-option="--fast_multihead_attn" ./
164
+ ```
165
+
166
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
167
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
168
+ as command line options to `nvidia-docker run` .
169
+
170
+ # Getting Started
171
+
172
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
173
+ for getting started, training new models and extending fairseq with new model
174
+ types and tasks.
175
+
176
+ # Pre-trained models and examples
177
+
178
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
179
+ as well as example training and evaluation commands.
180
+
181
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
182
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
183
+
184
+ We also have more detailed READMEs to reproduce results from specific papers:
185
+
186
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
187
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
188
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
189
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
190
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
191
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
192
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
193
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
194
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
195
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
196
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
197
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
198
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
199
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
200
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
201
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
202
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
203
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
204
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
205
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
206
+
207
+ # Join the fairseq community
208
+
209
+ * Twitter: https://twitter.com/fairseq
210
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
211
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
212
+
213
+ # License
214
+
215
+ fairseq(-py) is MIT-licensed.
216
+ The license applies to the pre-trained models as well.
217
+
218
+ # Citation
219
+
220
+ Please cite as:
221
+
222
+ ``` bibtex
223
+ @inproceedings{ott2019fairseq,
224
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
225
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
226
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
227
+ year = {2019},
228
+ }
229
+ ```
fairseq/docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line.
5
+ SPHINXOPTS =
6
+ SPHINXBUILD = python -msphinx
7
+ SPHINXPROJ = fairseq
8
+ SOURCEDIR = .
9
+ BUILDDIR = _build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
fairseq/docs/_static/theme_overrides.css ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .wy-table-responsive table td kbd {
2
+ white-space: nowrap;
3
+ }
4
+ .wy-table-responsive table td {
5
+ white-space: normal !important;
6
+ }
7
+ .wy-table-responsive {
8
+ overflow: visible !important;
9
+ }
fairseq/docs/command_line_tools.rst ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. _Command-line Tools:
2
+
3
+ Command-line Tools
4
+ ==================
5
+
6
+ Fairseq provides several command-line tools for training and evaluating models:
7
+
8
+ - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
9
+ - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
10
+ - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
11
+ - :ref:`fairseq-interactive`: Translate raw text with a trained model
12
+ - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
13
+ - :ref:`fairseq-eval-lm`: Language model evaluation
14
+
15
+
16
+ .. _fairseq-preprocess:
17
+
18
+ fairseq-preprocess
19
+ ~~~~~~~~~~~~~~~~~~
20
+ .. automodule:: fairseq_cli.preprocess
21
+
22
+ .. argparse::
23
+ :module: fairseq.options
24
+ :func: get_preprocessing_parser
25
+ :prog: fairseq-preprocess
26
+
27
+
28
+ .. _fairseq-train:
29
+
30
+ fairseq-train
31
+ ~~~~~~~~~~~~~
32
+ .. automodule:: fairseq_cli.train
33
+
34
+ .. argparse::
35
+ :module: fairseq.options
36
+ :func: get_training_parser
37
+ :prog: fairseq-train
38
+
39
+
40
+ .. _fairseq-generate:
41
+
42
+ fairseq-generate
43
+ ~~~~~~~~~~~~~~~~
44
+ .. automodule:: fairseq_cli.generate
45
+
46
+ .. argparse::
47
+ :module: fairseq.options
48
+ :func: get_generation_parser
49
+ :prog: fairseq-generate
50
+
51
+
52
+ .. _fairseq-interactive:
53
+
54
+ fairseq-interactive
55
+ ~~~~~~~~~~~~~~~~~~~
56
+ .. automodule:: fairseq_cli.interactive
57
+
58
+ .. argparse::
59
+ :module: fairseq.options
60
+ :func: get_interactive_generation_parser
61
+ :prog: fairseq-interactive
62
+
63
+
64
+ .. _fairseq-score:
65
+
66
+ fairseq-score
67
+ ~~~~~~~~~~~~~
68
+ .. automodule:: fairseq_cli.score
69
+
70
+ .. argparse::
71
+ :module: fairseq_cli.score
72
+ :func: get_parser
73
+ :prog: fairseq-score
74
+
75
+
76
+ .. _fairseq-eval-lm:
77
+
78
+ fairseq-eval-lm
79
+ ~~~~~~~~~~~~~~~
80
+ .. automodule:: fairseq_cli.eval_lm
81
+
82
+ .. argparse::
83
+ :module: fairseq.options
84
+ :func: get_eval_lm_parser
85
+ :prog: fairseq-eval-lm
fairseq/docs/conf.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # fairseq documentation build configuration file, created by
5
+ # sphinx-quickstart on Fri Aug 17 21:45:30 2018.
6
+ #
7
+ # This file is execfile()d with the current directory set to its
8
+ # containing dir.
9
+ #
10
+ # Note that not all possible configuration values are present in this
11
+ # autogenerated file.
12
+ #
13
+ # All configuration values have a default; values that are commented out
14
+ # serve to show the default.
15
+
16
+ # If extensions (or modules to document with autodoc) are in another directory,
17
+ # add these directories to sys.path here. If the directory is relative to the
18
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
19
+
20
+ import os
21
+ import sys
22
+ from fairseq import __version__
23
+
24
+
25
+ # source code directory, relative to this file, for sphinx-autobuild
26
+ sys.path.insert(0, os.path.abspath(".."))
27
+
28
+ source_suffix = [".rst"]
29
+
30
+ # -- General configuration ------------------------------------------------
31
+
32
+ # If your documentation needs a minimal Sphinx version, state it here.
33
+ #
34
+ # needs_sphinx = '1.0'
35
+
36
+ # Add any Sphinx extension module names here, as strings. They can be
37
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
38
+ # ones.
39
+ extensions = [
40
+ "sphinx.ext.autodoc",
41
+ "sphinx.ext.intersphinx",
42
+ "sphinx.ext.viewcode",
43
+ "sphinx.ext.napoleon",
44
+ "sphinxarg.ext",
45
+ ]
46
+
47
+ # Add any paths that contain templates here, relative to this directory.
48
+ templates_path = ["_templates"]
49
+
50
+ # The master toctree document.
51
+ master_doc = "index"
52
+
53
+ # General information about the project.
54
+ project = "fairseq"
55
+ copyright = "Facebook AI Research (FAIR)"
56
+ author = "Facebook AI Research (FAIR)"
57
+
58
+ github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
59
+
60
+ # The version info for the project you're documenting, acts as replacement for
61
+ # |version| and |release|, also used in various other places throughout the
62
+ # built documents.
63
+ #
64
+ # The short X.Y version.
65
+ version = __version__
66
+ # The full version, including alpha/beta/rc tags.
67
+ release = __version__
68
+
69
+ # The language for content autogenerated by Sphinx. Refer to documentation
70
+ # for a list of supported languages.
71
+ #
72
+ # This is also used if you do content translation via gettext catalogs.
73
+ # Usually you set "language" from the command line for these cases.
74
+ language = None
75
+
76
+ # List of patterns, relative to source directory, that match files and
77
+ # directories to ignore when looking for source files.
78
+ # This patterns also effect to html_static_path and html_extra_path
79
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
80
+
81
+ # The name of the Pygments (syntax highlighting) style to use.
82
+ pygments_style = "sphinx"
83
+ highlight_language = "python"
84
+
85
+ # If true, `todo` and `todoList` produce output, else they produce nothing.
86
+ todo_include_todos = False
87
+
88
+
89
+ # -- Options for HTML output ----------------------------------------------
90
+
91
+ # The theme to use for HTML and HTML Help pages. See the documentation for
92
+ # a list of builtin themes.
93
+ #
94
+ html_theme = "sphinx_rtd_theme"
95
+
96
+ # Theme options are theme-specific and customize the look and feel of a theme
97
+ # further. For a list of options available for each theme, see the
98
+ # documentation.
99
+ #
100
+ # html_theme_options = {}
101
+
102
+ # Add any paths that contain custom static files (such as style sheets) here,
103
+ # relative to this directory. They are copied after the builtin static files,
104
+ # so a file named "default.css" will overwrite the builtin "default.css".
105
+ html_static_path = ["_static"]
106
+
107
+ html_context = {
108
+ "css_files": [
109
+ "_static/theme_overrides.css", # override wide tables in RTD theme
110
+ ],
111
+ }
112
+
113
+ # Custom sidebar templates, must be a dictionary that maps document names
114
+ # to template names.
115
+ #
116
+ # This is required for the alabaster theme
117
+ # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
118
+ # html_sidebars = {
119
+ # '**': [
120
+ # 'about.html',
121
+ # 'navigation.html',
122
+ # 'relations.html', # needs 'show_related': True theme option to display
123
+ # 'searchbox.html',
124
+ # 'donate.html',
125
+ # ]
126
+ # }
127
+
128
+
129
+ # Example configuration for intersphinx: refer to the Python standard library.
130
+ intersphinx_mapping = {
131
+ "numpy": ("http://docs.scipy.org/doc/numpy/", None),
132
+ "python": ("https://docs.python.org/", None),
133
+ "torch": ("https://pytorch.org/docs/master/", None),
134
+ }
fairseq/docs/criterions.rst ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _Criterions:
5
+
6
+ Criterions
7
+ ==========
8
+
9
+ Criterions compute the loss function given the model and batch, roughly::
10
+
11
+ loss = criterion(model, batch)
12
+
13
+ .. automodule:: fairseq.criterions
14
+ :members:
15
+
16
+ .. autoclass:: fairseq.criterions.FairseqCriterion
17
+ :members:
18
+ :undoc-members:
19
+
20
+ .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
21
+ :members:
22
+ :undoc-members:
23
+ .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
24
+ :members:
25
+ :undoc-members:
26
+ .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
27
+ :members:
28
+ :undoc-members:
29
+ .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
30
+ :members:
31
+ :undoc-members:
fairseq/docs/data.rst ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.data
5
+
6
+ Data Loading and Utilities
7
+ ==========================
8
+
9
+ .. _datasets:
10
+
11
+ Datasets
12
+ --------
13
+
14
+ **Datasets** define the data format and provide helpers for creating
15
+ mini-batches.
16
+
17
+ .. autoclass:: fairseq.data.FairseqDataset
18
+ :members:
19
+ .. autoclass:: fairseq.data.LanguagePairDataset
20
+ :members:
21
+ .. autoclass:: fairseq.data.MonolingualDataset
22
+ :members:
23
+
24
+ **Helper Datasets**
25
+
26
+ These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
27
+ provide additional functionality:
28
+
29
+ .. autoclass:: fairseq.data.BacktranslationDataset
30
+ :members:
31
+ .. autoclass:: fairseq.data.ConcatDataset
32
+ :members:
33
+ .. autoclass:: fairseq.data.ResamplingDataset
34
+ :members:
35
+ .. autoclass:: fairseq.data.RoundRobinZipDatasets
36
+ :members:
37
+ .. autoclass:: fairseq.data.TransformEosDataset
38
+ :members:
39
+
40
+
41
+ Dictionary
42
+ ----------
43
+
44
+ .. autoclass:: fairseq.data.Dictionary
45
+ :members:
46
+
47
+
48
+ Iterators
49
+ ---------
50
+
51
+ .. autoclass:: fairseq.data.CountingIterator
52
+ :members:
53
+ .. autoclass:: fairseq.data.EpochBatchIterator
54
+ :members:
55
+ .. autoclass:: fairseq.data.GroupedIterator
56
+ :members:
57
+ .. autoclass:: fairseq.data.ShardedIterator
58
+ :members:
fairseq/docs/docutils.conf ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [writers]
2
+ option-limit=0
fairseq/docs/fairseq_logo.png ADDED
fairseq/docs/getting_started.rst ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Evaluating Pre-trained Models
2
+ =============================
3
+
4
+ First, download a pre-trained model along with its vocabularies:
5
+
6
+ .. code-block:: console
7
+
8
+ > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
9
+
10
+ This model uses a `Byte Pair Encoding (BPE)
11
+ vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
12
+ the encoding to the source text before it can be translated. This can be
13
+ done with the
14
+ `apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
15
+ script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
16
+ used as a continuation marker and the original text can be easily
17
+ recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
18
+ flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
19
+ using ``tokenizer.perl`` from
20
+ `mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
21
+
22
+ Let's use :ref:`fairseq-interactive` to generate translations interactively.
23
+ Here, we use a beam size of 5 and preprocess the input with the Moses
24
+ tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
25
+ remove the BPE continuation markers and detokenize the output.
26
+
27
+ .. code-block:: console
28
+
29
+ > MODEL_DIR=wmt14.en-fr.fconv-py
30
+ > fairseq-interactive \
31
+ --path $MODEL_DIR/model.pt $MODEL_DIR \
32
+ --beam 5 --source-lang en --target-lang fr \
33
+ --tokenizer moses \
34
+ --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
35
+ | loading model(s) from wmt14.en-fr.fconv-py/model.pt
36
+ | [en] dictionary: 44206 types
37
+ | [fr] dictionary: 44463 types
38
+ | Type the input sentence and press return:
39
+ Why is it rare to discover new marine mammal species?
40
+ S-0 Why is it rare to discover new marine mam@@ mal species ?
41
+ H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
42
+ P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
43
+
44
+ This generation script produces three types of outputs: a line prefixed
45
+ with *O* is a copy of the original source sentence; *H* is the
46
+ hypothesis along with an average log-likelihood; and *P* is the
47
+ positional score per token position, including the
48
+ end-of-sentence marker which is omitted from the text.
49
+
50
+ Other types of output lines you might see are *D*, the detokenized hypothesis,
51
+ *T*, the reference target, *A*, alignment info, *E* the history of generation steps.
52
+
53
+ See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
54
+ full list of pre-trained models available.
55
+
56
+ Training a New Model
57
+ ====================
58
+
59
+ The following tutorial is for machine translation. For an example of how
60
+ to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
61
+ ``examples/`` directory.
62
+
63
+ Data Pre-processing
64
+ -------------------
65
+
66
+ Fairseq contains example pre-processing scripts for several translation
67
+ datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
68
+ 2014 (English-German). To pre-process and binarize the IWSLT dataset:
69
+
70
+ .. code-block:: console
71
+
72
+ > cd examples/translation/
73
+ > bash prepare-iwslt14.sh
74
+ > cd ../..
75
+ > TEXT=examples/translation/iwslt14.tokenized.de-en
76
+ > fairseq-preprocess --source-lang de --target-lang en \
77
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
78
+ --destdir data-bin/iwslt14.tokenized.de-en
79
+
80
+ This will write binarized data that can be used for model training to
81
+ ``data-bin/iwslt14.tokenized.de-en``.
82
+
83
+ Training
84
+ --------
85
+
86
+ Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
87
+ well for the IWSLT 2014 dataset:
88
+
89
+ .. code-block:: console
90
+
91
+ > mkdir -p checkpoints/fconv
92
+ > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
93
+ --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
94
+ --arch fconv_iwslt_de_en --save-dir checkpoints/fconv
95
+
96
+ By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
97
+ ``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
98
+ change the number of GPU devices that will be used.
99
+
100
+ Also note that the batch size is specified in terms of the maximum
101
+ number of tokens per batch (``--max-tokens``). You may need to use a
102
+ smaller value depending on the available GPU memory on your system.
103
+
104
+ Generation
105
+ ----------
106
+
107
+ Once your model is trained, you can generate translations using
108
+ :ref:`fairseq-generate` **(for binarized data)** or
109
+ :ref:`fairseq-interactive` **(for raw text)**:
110
+
111
+ .. code-block:: console
112
+
113
+ > fairseq-generate data-bin/iwslt14.tokenized.de-en \
114
+ --path checkpoints/fconv/checkpoint_best.pt \
115
+ --batch-size 128 --beam 5
116
+ | [de] dictionary: 35475 types
117
+ | [en] dictionary: 24739 types
118
+ | data-bin/iwslt14.tokenized.de-en test 6750 examples
119
+ | model fconv
120
+ | loaded checkpoint trainings/fconv/checkpoint_best.pt
121
+ S-721 danke .
122
+ T-721 thank you .
123
+ ...
124
+
125
+ To generate translations with only a CPU, use the ``--cpu`` flag. BPE
126
+ continuation markers can be removed with the ``--remove-bpe`` flag.
127
+
128
+ Advanced Training Options
129
+ =========================
130
+
131
+ Large mini-batch training with delayed updates
132
+ ----------------------------------------------
133
+
134
+ The ``--update-freq`` option can be used to accumulate gradients from
135
+ multiple mini-batches and delay updating, creating a larger effective
136
+ batch size. Delayed updates can also improve training speed by reducing
137
+ inter-GPU communication costs and by saving idle time caused by variance
138
+ in workload across GPUs. See `Ott et al.
139
+ (2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
140
+
141
+ To train on a single GPU with an effective batch size that is equivalent
142
+ to training on 8 GPUs:
143
+
144
+ .. code-block:: console
145
+
146
+ > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
147
+
148
+ Training with half precision floating point (FP16)
149
+ --------------------------------------------------
150
+
151
+ .. note::
152
+
153
+ FP16 training requires a Volta GPU and CUDA 9.1 or greater
154
+
155
+ Recent GPUs enable efficient half precision floating point computation,
156
+ e.g., using `Nvidia Tensor Cores
157
+ <https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
158
+ Fairseq supports FP16 training with the ``--fp16`` flag:
159
+
160
+ .. code-block:: console
161
+
162
+ > fairseq-train --fp16 (...)
163
+
164
+ Distributed training
165
+ --------------------
166
+
167
+ Distributed training in fairseq is implemented on top of ``torch.distributed``.
168
+ The easiest way to launch jobs is with the `torch.distributed.launch
169
+ <https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
170
+
171
+ For example, to train a large English-German Transformer model on 2 nodes each
172
+ with 8 GPUs (in total 16 GPUs), run the following command on each node,
173
+ replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
174
+ sure to update ``--master_addr`` to the IP address of the first node:
175
+
176
+ .. code-block:: console
177
+
178
+ > python -m torch.distributed.launch --nproc_per_node=8 \
179
+ --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
180
+ --master_port=12345 \
181
+ $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
182
+ --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
183
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
184
+ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
185
+ --lr 0.0005 \
186
+ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
187
+ --max-tokens 3584 \
188
+ --max-epoch 70 \
189
+ --fp16
190
+
191
+ On SLURM clusters, fairseq will automatically detect the number of nodes and
192
+ GPUs, but a port number must be provided:
193
+
194
+ .. code-block:: console
195
+
196
+ > salloc --gpus=16 --nodes 2 (...)
197
+ > srun fairseq-train --distributed-port 12345 (...).
198
+
199
+ Sharding very large datasets
200
+ ----------------------------
201
+
202
+ It can be challenging to train over very large datasets, particularly if your
203
+ machine does not have much system RAM. Most tasks in fairseq support training
204
+ over "sharded" datasets, in which the original dataset has been preprocessed
205
+ into non-overlapping chunks (or "shards").
206
+
207
+ For example, instead of preprocessing all your data into a single "data-bin"
208
+ directory, you can split the data and create "data-bin1", "data-bin2", etc.
209
+ Then you can adapt your training command like so:
210
+
211
+ .. code-block:: console
212
+
213
+ > fairseq-train data-bin1:data-bin2:data-bin3 (...)
214
+
215
+ Training will now iterate over each shard, one by one, with each shard
216
+ corresponding to an "epoch", thus reducing system memory usage.
fairseq/docs/hydra_integration.md ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Hydra
2
+
3
+ [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
4
+ framework that simplifies the development of research and other complex
5
+ applications. The key feature is the ability to dynamically create a
6
+ hierarchical configuration by composition and override it through config files
7
+ and the command line. The name Hydra comes from its ability to run multiple
8
+ similar jobs - much like a Hydra with multiple heads.
9
+
10
+ ## Motivation
11
+
12
+ Until recently, all components in fairseq were configured through a shared
13
+ `args` namespace that was created at application startup. Components declared
14
+ their own `add_args` method to update the argparse parser, hoping that the names
15
+ would not clash with arguments from other components. While this model works for
16
+ smaller applications, as fairseq grew and became integrated into other
17
+ applications, this became problematic. In order to determine how to configure
18
+ each component, one needed to a) examine what args were added by this component,
19
+ and b) read the code to figure out what shared arguments it is using that were
20
+ added in other places. Reproducing models involved sharing commands that often
21
+ contained dozens of command line switches.
22
+
23
+ The model described above is still supported by fairseq for backward
24
+ compatibility, but will be deprecated some time in the future.
25
+
26
+ New components in fairseq should now create a dataclass that encapsulates all
27
+ parameters required to configure this component. The dataclass is registered
28
+ along with the component, and fairseq takes care of constructing and providing
29
+ this configuration object to the component's constructor. Note that sharing
30
+ parameters can optionally still work, but one has to explicitly point to the
31
+ "source of truth" (see inheritance example below). These changes make components
32
+ in fairseq more independent and re-usable by other applications: all that is
33
+ needed to create a component is to initialize its dataclass and overwrite some
34
+ of the defaults.
35
+
36
+ While configuring fairseq through command line (using either the legacy argparse
37
+ based or the new Hydra based entry points) is still fully supported, you can now
38
+ take advantage of configuring fairseq completely or piece-by-piece through
39
+ hierarchical YAML configuration files. These files can also be shipped as
40
+ examples that others can use to run an identically configured job.
41
+
42
+ Additionally, Hydra has a rich and growing [library of
43
+ plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
44
+ provide functionality such as hyperparameter sweeping (including using bayesian
45
+ optimization through the [Ax](https://github.com/facebook/Ax) library), job
46
+ launching across various platforms, and more.
47
+
48
+ ## Creating or migrating components
49
+
50
+ In general, each new (or updated) component should provide a companion
51
+ [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
52
+ typically located in the same file as the component and are passed as arguments
53
+ to the `register_*()` functions. Top-level configs that should be present in
54
+ every fairseq application are placed in the
55
+ [global](fairseq/dataclass/configs.py) config file and added to the
56
+ `FairseqConfig` object.
57
+
58
+ Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
59
+ classes are decorated with a `@dataclass` decorator, and typically inherit from
60
+ `FairseqDataclass` (which adds some functionality for backward compatibility).
61
+ Each field must have a type, and generally has metadata (such as a help string)
62
+ and a default value. Only primitive types or other config objects are allowed as
63
+ data types for each field.
64
+
65
+ #### Example:
66
+
67
+ ```python
68
+ from dataclasses import dataclass, field
69
+ from fairseq.dataclass import FairseqDataclass
70
+
71
+ @dataclass
72
+ class InteractiveConfig(FairseqDataclass):
73
+ buffer_size: int = field(
74
+ default=0,
75
+ metadata={
76
+ "help": "read this many sentences into a buffer before processing them"
77
+ },
78
+ )
79
+ input: str = field(
80
+ default="-",
81
+ metadata={"help": "file to read from; use - for stdin"},
82
+ )
83
+ ```
84
+
85
+ ### Inherting values
86
+
87
+ Some components require sharing a value. For example, a learning rate scheduler
88
+ and an optimizer may both need to know the initial learning rate value. One can
89
+ declare a field that, by default, will inherit its value from another config
90
+ node in the same hierarchy:
91
+
92
+ ```python
93
+ @dataclass
94
+ FairseqAdamConfig(FairseqDataclass):
95
+ ...
96
+ lr: List[float] = II("optimization.lr")
97
+ ...
98
+ ```
99
+
100
+ `II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
101
+ the value one can use in a YAML config file or through command line to achieve
102
+ the same effect. Note that this assumes that there is an "optimization" config
103
+ object in the root config and it has a field called "lr".
104
+
105
+ ### Tasks and Models
106
+
107
+ Creating Tasks and Models works same as before, except that legacy
108
+ implementations now inherit from `LegacyFairseq*` base classes, while new
109
+ components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
110
+ to the `register_*()` functions.
111
+
112
+ #### Task example:
113
+
114
+ ```python
115
+ @dataclass
116
+ class LanguageModelingConfig(FairseqDataclass):
117
+ data: Optional[str] = field(
118
+ default=None, metadata={"help": "path to data directory"}
119
+ )
120
+ ...
121
+
122
+ @register_task("language_modeling", dataclass=LanguageModelingConfig)
123
+ class LanguageModelingTask(FairseqTask):
124
+ ...
125
+ @classmethod
126
+ def setup_task(cls, cfg: LanguageModelingConfig):
127
+ ...
128
+ ```
129
+
130
+ #### Model example:
131
+
132
+ ```python
133
+ @dataclass
134
+ class TransformerLanguageModelConfig(FairseqDataclass):
135
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
136
+ default="relu", metadata={"help": "activation function to use"}
137
+ )
138
+ dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
139
+ ...
140
+
141
+ @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
142
+ class TransformerLanguageModel(FairseqLanguageModel):
143
+ ...
144
+ @classmethod
145
+ def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
146
+ ...
147
+ ```
148
+
149
+ ### Other components
150
+
151
+ Other components work as before, but they now take their configuration dataclass
152
+ as the only constructor argument:
153
+
154
+ ```python
155
+ @dataclass
156
+ class MosesTokenizerConfig(FairseqDataclass):
157
+ source_lang: str = field(default="en", metadata={"help": "source language"})
158
+ ...
159
+
160
+ @register_tokenizer("moses", dataclass=MosesTokenizerConfig)
161
+ class MosesTokenizer(object):
162
+ def __init__(self, cfg: MosesTokenizerConfig):
163
+ ...
164
+ ```
165
+
166
+ Note that if you are adding a new registry for a new set of components, you need
167
+ to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
168
+
169
+ ```python
170
+ @dataclass
171
+ class FairseqConfig(object):
172
+ ...
173
+ my_new_registry: Any = None
174
+ ```
175
+
176
+ ## Training with `fairseq-hydra-train`
177
+
178
+ To fully take advantage of configuration flexibility offered by Hydra, you may
179
+ want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
180
+ tools such as `fairseq-train` will remain supported for the foreseeable future
181
+ but will be deprecated eventually.
182
+
183
+ On startup, Hydra will create a configuration object that contains a hierarchy
184
+ of all the necessary dataclasses populated with their default values in the
185
+ code. The default values are overwritten by values found in YAML files in
186
+ `fairseq/config` directory (which currently sets minimal defaults) and then
187
+ further overwritten by values provided through command line arguments.
188
+
189
+ Some of the most common use cases are shown below:
190
+
191
+ ### 1. Override default values through command line:
192
+
193
+ ```shell script
194
+ $ fairseq-hydra-train \
195
+ distributed_training.distributed_world_size=1 \
196
+ dataset.batch_size=2 \
197
+ task.data=data-bin \
198
+ model=transformer_lm/transformer_lm_gpt \
199
+ task=language_modeling \
200
+ optimization.max_update=5000
201
+ ```
202
+
203
+ Note that along with explicitly providing values for parameters such as
204
+ `dataset.batch_size`, this also tells Hydra to overlay configuration found in
205
+ `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
206
+ values in the dataclass. If you want to train a model without specifying a
207
+ particular architecture you can simply specify `model=transformer_lm`. This only
208
+ works for migrated tasks and models.
209
+
210
+ ### 2. Replace bundled configs with an external config:
211
+
212
+ ```shell script
213
+ $ fairseq-hydra-train \
214
+ --config-dir /path/to/external/configs \
215
+ --config-name wiki103
216
+ ```
217
+
218
+ where `/path/to/external/configs/wiki103.yaml` contains:
219
+
220
+ ```yaml
221
+ # @package _group_
222
+
223
+ model:
224
+ _name: transformer_lm
225
+ distributed_training:
226
+ distributed_world_size: 1
227
+ dataset:
228
+ batch_size: 2
229
+ task:
230
+ _name: language_modeling
231
+ data: /path/to/data
232
+ add_bos_token: false
233
+ max_target_positions: 1024
234
+ optimization:
235
+ max_update: 50000
236
+ lr: [ 0.25 ]
237
+ criterion: cross_entropy
238
+ optimizer: adam
239
+ lr_scheduler:
240
+ _name: cosine
241
+ ```
242
+
243
+ Note that here bundled configs from `fairseq/config` directory are not used,
244
+ however the defaults from each dataclass will still be used (unless overwritten
245
+ by your external config).
246
+
247
+ Additionally you can choose to break up your configs by creating a directory
248
+ structure in the same location as your main config file, with the names of the
249
+ top-level fields (such as "model", "dataset", etc), and placing config files
250
+ with meaningful names that would populate that specific section of your
251
+ top-level config file (for example, you might have
252
+ `model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
253
+ can then specify the correct configuration via command line, defaults in the
254
+ main config, or even launch all of them as a sweep (see Hydra documentation on
255
+ how to do this).
256
+
257
+ ### 3. Add an external config directory to Hydra search path:
258
+
259
+ This allows combining default configuration (including using any bundled config
260
+ files), while specifying your own config files for some parts of the
261
+ configuration.
262
+
263
+ ```shell script
264
+ $ fairseq-hydra-train \
265
+ distributed_training.distributed_world_size=1 \
266
+ dataset.batch_size=2 \
267
+ task.data=/path/to/data/ \
268
+ model=transformer_lm/2_layers \
269
+ task=language_modeling \
270
+ optimization.max_update=5000 \
271
+ --config-dir /path/to/external/configs
272
+ ```
273
+
274
+ where `/path/to/external/configs` has the following structure:
275
+ ```
276
+ .
277
+ +-- model
278
+ | +-- transformer_lm
279
+ | | +-- 2_layers.yaml
280
+ ```
281
+
282
+ and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
283
+ `decoder_layers` set to 2. You can add other configs to configure other
284
+ components as well.
fairseq/docs/index.rst ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. fairseq documentation master file, created by
2
+ sphinx-quickstart on Fri Aug 17 21:45:30 2018.
3
+ You can adapt this file completely to your liking, but it should at least
4
+ contain the root `toctree` directive.
5
+
6
+ :github_url: https://github.com/pytorch/fairseq
7
+
8
+
9
+ fairseq documentation
10
+ =====================
11
+
12
+ Fairseq is a sequence modeling toolkit written in `PyTorch
13
+ <http://pytorch.org/>`_ that allows researchers and developers to
14
+ train custom models for translation, summarization, language modeling and other
15
+ text generation tasks.
16
+
17
+ .. toctree::
18
+ :maxdepth: 1
19
+ :caption: Getting Started
20
+
21
+ getting_started
22
+ command_line_tools
23
+
24
+ .. toctree::
25
+ :maxdepth: 1
26
+ :caption: Extending Fairseq
27
+
28
+ overview
29
+ tutorial_simple_lstm
30
+ tutorial_classifying_names
31
+
32
+ .. toctree::
33
+ :maxdepth: 2
34
+ :caption: Library Reference
35
+
36
+ tasks
37
+ models
38
+ criterions
39
+ optim
40
+ lr_scheduler
41
+ data
42
+ modules
43
+
44
+
45
+ Indices and tables
46
+ ==================
47
+
48
+ * :ref:`genindex`
49
+ * :ref:`search`
fairseq/docs/lr_scheduler.rst ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _Learning Rate Schedulers:
5
+
6
+ Learning Rate Schedulers
7
+ ========================
8
+
9
+ Learning Rate Schedulers update the learning rate over the course of training.
10
+ Learning rates can be updated after each update via :func:`step_update` or at
11
+ epoch boundaries via :func:`step`.
12
+
13
+ .. automodule:: fairseq.optim.lr_scheduler
14
+ :members:
15
+
16
+ .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
17
+ :members:
18
+ :undoc-members:
19
+
20
+ .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
21
+ :members:
22
+ :undoc-members:
23
+ .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
24
+ :members:
25
+ :undoc-members:
26
+ .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
27
+ :members:
28
+ :undoc-members:
29
+ .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
30
+ :members:
31
+ :undoc-members:
32
+ .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
33
+ :members:
34
+ :undoc-members:
fairseq/docs/make.bat ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @ECHO OFF
2
+
3
+ pushd %~dp0
4
+
5
+ REM Command file for Sphinx documentation
6
+
7
+ if "%SPHINXBUILD%" == "" (
8
+ set SPHINXBUILD=python -msphinx
9
+ )
10
+ set SOURCEDIR=.
11
+ set BUILDDIR=_build
12
+ set SPHINXPROJ=fairseq
13
+
14
+ if "%1" == "" goto help
15
+
16
+ %SPHINXBUILD% >NUL 2>NUL
17
+ if errorlevel 9009 (
18
+ echo.
19
+ echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20
+ echo.then set the SPHINXBUILD environment variable to point to the full
21
+ echo.path of the 'sphinx-build' executable. Alternatively you may add the
22
+ echo.Sphinx directory to PATH.
23
+ echo.
24
+ echo.If you don't have Sphinx installed, grab it from
25
+ echo.http://sphinx-doc.org/
26
+ exit /b 1
27
+ )
28
+
29
+ %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30
+ goto end
31
+
32
+ :help
33
+ %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34
+
35
+ :end
36
+ popd
fairseq/docs/models.rst ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.models
5
+
6
+ .. _Models:
7
+
8
+ Models
9
+ ======
10
+
11
+ A Model defines the neural network's ``forward()`` method and encapsulates all
12
+ of the learnable parameters in the network. Each model also provides a set of
13
+ named *architectures* that define the precise network configuration (e.g.,
14
+ embedding dimension, number of layers, etc.).
15
+
16
+ Both the model type and architecture are selected via the ``--arch``
17
+ command-line argument. Once selected, a model may expose additional command-line
18
+ arguments for further configuration.
19
+
20
+ .. note::
21
+
22
+ All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23
+ :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24
+ stand-alone Module in other PyTorch code.
25
+
26
+
27
+ Convolutional Neural Networks (CNN)
28
+ -----------------------------------
29
+
30
+ .. module:: fairseq.models.fconv
31
+ .. autoclass:: fairseq.models.fconv.FConvModel
32
+ :members:
33
+ .. autoclass:: fairseq.models.fconv.FConvEncoder
34
+ :members:
35
+ :undoc-members:
36
+ .. autoclass:: fairseq.models.fconv.FConvDecoder
37
+ :members:
38
+
39
+
40
+ Long Short-Term Memory (LSTM) networks
41
+ --------------------------------------
42
+
43
+ .. module:: fairseq.models.lstm
44
+ .. autoclass:: fairseq.models.lstm.LSTMModel
45
+ :members:
46
+ .. autoclass:: fairseq.models.lstm.LSTMEncoder
47
+ :members:
48
+ .. autoclass:: fairseq.models.lstm.LSTMDecoder
49
+ :members:
50
+
51
+
52
+ Transformer (self-attention) networks
53
+ -------------------------------------
54
+
55
+ .. module:: fairseq.models.transformer
56
+ .. autoclass:: fairseq.models.transformer.TransformerModel
57
+ :members:
58
+ .. autoclass:: fairseq.models.transformer.TransformerEncoder
59
+ :members:
60
+ .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61
+ :members:
62
+ .. autoclass:: fairseq.models.transformer.TransformerDecoder
63
+ :members:
64
+ .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65
+ :members:
66
+
67
+
68
+ Adding new models
69
+ -----------------
70
+
71
+ .. currentmodule:: fairseq.models
72
+ .. autofunction:: fairseq.models.register_model
73
+ .. autofunction:: fairseq.models.register_model_architecture
74
+ .. autoclass:: fairseq.models.BaseFairseqModel
75
+ :members:
76
+ :undoc-members:
77
+ .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78
+ :members:
79
+ :undoc-members:
80
+ .. autoclass:: fairseq.models.FairseqEncoderModel
81
+ :members:
82
+ :undoc-members:
83
+ .. autoclass:: fairseq.models.FairseqLanguageModel
84
+ :members:
85
+ :undoc-members:
86
+ .. autoclass:: fairseq.models.FairseqMultiModel
87
+ :members:
88
+ :undoc-members:
89
+ .. autoclass:: fairseq.models.FairseqEncoder
90
+ :members:
91
+ .. autoclass:: fairseq.models.CompositeEncoder
92
+ :members:
93
+ .. autoclass:: fairseq.models.FairseqDecoder
94
+ :members:
95
+
96
+
97
+ .. _Incremental decoding:
98
+
99
+ Incremental decoding
100
+ --------------------
101
+
102
+ .. autoclass:: fairseq.models.FairseqIncrementalDecoder
103
+ :members:
104
+ :undoc-members:
fairseq/docs/modules.rst ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Modules
2
+ =======
3
+
4
+ Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
5
+ be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
6
+
7
+ .. automodule:: fairseq.modules
8
+ :members:
9
+ :undoc-members:
fairseq/docs/optim.rst ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. _optimizers:
5
+
6
+ Optimizers
7
+ ==========
8
+
9
+ Optimizers update the Model parameters based on the gradients.
10
+
11
+ .. automodule:: fairseq.optim
12
+ :members:
13
+
14
+ .. autoclass:: fairseq.optim.FairseqOptimizer
15
+ :members:
16
+ :undoc-members:
17
+
18
+ .. autoclass:: fairseq.optim.adadelta.Adadelta
19
+ :members:
20
+ :undoc-members:
21
+ .. autoclass:: fairseq.optim.adagrad.Adagrad
22
+ :members:
23
+ :undoc-members:
24
+ .. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
25
+ :members:
26
+ :undoc-members:
27
+ .. autoclass:: fairseq.optim.adam.FairseqAdam
28
+ :members:
29
+ :undoc-members:
30
+ .. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
31
+ :members:
32
+ :undoc-members:
33
+ .. autoclass:: fairseq.optim.nag.FairseqNAG
34
+ :members:
35
+ :undoc-members:
36
+ .. autoclass:: fairseq.optim.sgd.SGD
37
+ :members:
38
+ :undoc-members: