root commited on
Commit
93b9482
1 Parent(s): a10ead2
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -10
  2. .gitignore +1 -0
  3. LICENSE +201 -0
  4. README.md +21 -6
  5. app.py +155 -0
  6. criterions/__init__.py +2 -0
  7. criterions/label_smoothed_cross_entropy.py +343 -0
  8. criterions/scst_loss.py +280 -0
  9. data/__init__.py +0 -0
  10. data/data_utils.py +601 -0
  11. data/file_dataset.py +102 -0
  12. data/mm_data/__init__.py +0 -0
  13. data/mm_data/caption_dataset.py +154 -0
  14. data/mm_data/refcoco_dataset.py +168 -0
  15. data/ofa_dataset.py +74 -0
  16. evaluate.py +152 -0
  17. fairseq/.github/ISSUE_TEMPLATE.md +3 -0
  18. fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  19. fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  20. fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  21. fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  22. fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
  23. fairseq/.github/stale.yml +30 -0
  24. fairseq/.github/workflows/build.yml +55 -0
  25. fairseq/.github/workflows/build_wheels.yml +41 -0
  26. fairseq/.gitmodules +4 -0
  27. fairseq/CODE_OF_CONDUCT.md +77 -0
  28. fairseq/CONTRIBUTING.md +28 -0
  29. fairseq/LICENSE +21 -0
  30. fairseq/README.md +229 -0
  31. fairseq/examples/__init__.py +9 -0
  32. fairseq/examples/adaptive_span/README.md +90 -0
  33. fairseq/examples/adaptive_span/__init__.py +19 -0
  34. fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
  35. fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
  36. fairseq/examples/adaptive_span/adaptive_span_loss.py +106 -0
  37. fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
  38. fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
  39. fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +281 -0
  40. fairseq/examples/backtranslation/README.md +297 -0
  41. fairseq/examples/backtranslation/deduplicate_lines.py +41 -0
  42. fairseq/examples/backtranslation/extract_bt_data.py +72 -0
  43. fairseq/examples/backtranslation/prepare-de-monolingual.sh +98 -0
  44. fairseq/examples/backtranslation/prepare-wmt18en2de.sh +135 -0
  45. fairseq/examples/backtranslation/sacrebleu.sh +37 -0
  46. fairseq/examples/backtranslation/tokenized_bleu.sh +46 -0
  47. fairseq/examples/bart/README.glue.md +99 -0
  48. fairseq/examples/bart/README.md +228 -0
  49. fairseq/examples/bart/README.summarization.md +102 -0
  50. fairseq/examples/bart/summarize.py +100 -0
.gitattributes CHANGED
@@ -1,35 +1,27 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 1999-2022 Alibaba Group Holding Ltd.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,27 @@
1
  ---
2
- title: OFA
3
- emoji:
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: OFA-Visual_Grounding
3
+ emoji: 👀
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
+ duplicated_from: OFA-Sys/OFA-Visual_Grounding
10
  ---
11
+ # Configuration
12
+ `title`: _string_
13
+ OFA Image Caption
14
+ `emoji`: _string_
15
+ 🖼
16
+ `colorFrom`: _string_
17
+ red
18
+ `colorTo`: _string_
19
+ indigo
20
+ `sdk`: _string_
21
+ gradio
22
+ `app_file`: _string_
23
+ app.py
24
 
25
+
26
+ `pinned`: _boolean_
27
+ false
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system('cd fairseq;'
4
+ 'pip install ./; cd ..')
5
+ os.system('ls -l')
6
+
7
+ import torch
8
+ import numpy as np
9
+ from fairseq import utils, tasks
10
+ from fairseq import checkpoint_utils
11
+ from utils.eval_utils import eval_step
12
+ from tasks.mm_tasks.refcoco import RefcocoTask
13
+ from models.ofa import OFAModel
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ import cv2
17
+ import gradio as gr
18
+
19
+ # Register refcoco task
20
+ tasks.register_task('refcoco', RefcocoTask)
21
+
22
+ # turn on cuda if GPU is available
23
+ use_cuda = torch.cuda.is_available()
24
+ # use fp16 only when GPU is available
25
+ use_fp16 = False
26
+
27
+ os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/refcocog_large_best.pt; '
28
+ 'mkdir -p checkpoints; mv refcocog_large_best.pt checkpoints/refcocog.pt')
29
+
30
+ # Load pretrained ckpt & config
31
+ overrides = {"bpe_dir": "utils/BPE", "eval_cider": False, "beam": 5,
32
+ "max_len_b": 16, "no_repeat_ngram_size": 3, "seed": 7}
33
+ models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
34
+ utils.split_paths('checkpoints/refcocog.pt'),
35
+ arg_overrides=overrides
36
+ )
37
+
38
+ cfg.common.seed = 7
39
+ cfg.generation.beam = 5
40
+ cfg.generation.min_len = 4
41
+ cfg.generation.max_len_a = 0
42
+ cfg.generation.max_len_b = 4
43
+ cfg.generation.no_repeat_ngram_size = 3
44
+
45
+ # Fix seed for stochastic decoding
46
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
47
+ np.random.seed(cfg.common.seed)
48
+ utils.set_torch_seed(cfg.common.seed)
49
+
50
+ # Move models to GPU
51
+ for model in models:
52
+ model.eval()
53
+ if use_fp16:
54
+ model.half()
55
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
56
+ model.cuda()
57
+ model.prepare_for_inference_(cfg)
58
+
59
+ # Initialize generator
60
+ generator = task.build_generator(models, cfg.generation)
61
+
62
+ mean = [0.5, 0.5, 0.5]
63
+ std = [0.5, 0.5, 0.5]
64
+
65
+ patch_resize_transform = transforms.Compose([
66
+ lambda image: image.convert("RGB"),
67
+ transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=mean, std=std),
70
+ ])
71
+
72
+ # Text preprocess
73
+ bos_item = torch.LongTensor([task.src_dict.bos()])
74
+ eos_item = torch.LongTensor([task.src_dict.eos()])
75
+ pad_idx = task.src_dict.pad()
76
+
77
+
78
+ def encode_text(text, length=None, append_bos=False, append_eos=False):
79
+ s = task.tgt_dict.encode_line(
80
+ line=task.bpe.encode(text),
81
+ add_if_not_exist=False,
82
+ append_eos=False
83
+ ).long()
84
+ if length is not None:
85
+ s = s[:length]
86
+ if append_bos:
87
+ s = torch.cat([bos_item, s])
88
+ if append_eos:
89
+ s = torch.cat([s, eos_item])
90
+ return s
91
+
92
+
93
+ patch_image_size = cfg.task.patch_image_size
94
+
95
+
96
+ def construct_sample(image: Image, text: str):
97
+ w, h = image.size
98
+ w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0)
99
+ h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
100
+ patch_image = patch_resize_transform(image).unsqueeze(0)
101
+ patch_mask = torch.tensor([True])
102
+ src_text = encode_text(' which region does the text " {} " describe?'.format(text), append_bos=True,
103
+ append_eos=True).unsqueeze(0)
104
+ src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
105
+ sample = {
106
+ "id": np.array(['42']),
107
+ "net_input": {
108
+ "src_tokens": src_text,
109
+ "src_lengths": src_length,
110
+ "patch_images": patch_image,
111
+ "patch_masks": patch_mask,
112
+ },
113
+ "w_resize_ratios": w_resize_ratio,
114
+ "h_resize_ratios": h_resize_ratio,
115
+ "region_coords": torch.randn(1, 4)
116
+ }
117
+ return sample
118
+
119
+
120
+ # Function to turn FP32 to FP16
121
+ def apply_half(t):
122
+ if t.dtype is torch.float32:
123
+ return t.to(dtype=torch.half)
124
+ return t
125
+
126
+
127
+ # Function for visual grounding
128
+ def visual_grounding(Image, Text):
129
+ sample = construct_sample(Image, Text.lower())
130
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
131
+ sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
132
+ with torch.no_grad():
133
+ result, scores = eval_step(task, generator, models, sample)
134
+ img = np.asarray(Image)
135
+ cv2.rectangle(
136
+ img,
137
+ (int(result[0]["box"][0]), int(result[0]["box"][1])),
138
+ (int(result[0]["box"][2]), int(result[0]["box"][3])),
139
+ (0, 255, 0),
140
+ 3
141
+ )
142
+ return img
143
+
144
+
145
+ title = "OFA Visual Grounding"
146
+ description = "Démonstration pour OFA Visual Grounding. Téléchargez votre image ou cliquez sur l'un des exemples, et rédigez une description concernant un objet spécifique."
147
+
148
+ examples = [['test-1.jpeg', 'black chair'],
149
+ ['test-2.jpeg', 'orange door'],
150
+ ['test-3.jpeg', 'fire extinguisher']]
151
+ io = gr.Interface(fn=visual_grounding, inputs=[gr.inputs.Image(type='pil'), "textbox"],
152
+ outputs=gr.outputs.Image(type='numpy'),
153
+ title=title, description=description, examples=examples,
154
+ allow_flagging=False, allow_screenshot=False)
155
+ io.launch()
criterions/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scst_loss import ScstRewardCriterion
2
+ from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
criterions/label_smoothed_cross_entropy.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+
19
+ @dataclass
20
+ class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
21
+ label_smoothing: float = field(
22
+ default=0.0,
23
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
24
+ )
25
+ report_accuracy: bool = field(
26
+ default=False,
27
+ metadata={"help": "report accuracy metric"},
28
+ )
29
+ ignore_prefix_size: int = field(
30
+ default=0,
31
+ metadata={"help": "Ignore first N tokens"},
32
+ )
33
+ ignore_eos: bool = field(
34
+ default=False,
35
+ metadata={"help": "Ignore eos token"},
36
+ )
37
+ sentence_avg: bool = II("optimization.sentence_avg")
38
+ drop_worst_ratio: float = field(
39
+ default=0.0,
40
+ metadata={"help": "ratio for discarding bad samples"},
41
+ )
42
+ drop_worst_after: int = field(
43
+ default=0,
44
+ metadata={"help": "steps for discarding bad samples"},
45
+ )
46
+ use_rdrop: bool = field(
47
+ default=False, metadata={"help": "use R-Drop"}
48
+ )
49
+ reg_alpha: float = field(
50
+ default=1.0, metadata={"help": "weight for R-Drop"}
51
+ )
52
+ sample_patch_num: int = field(
53
+ default=196, metadata={"help": "sample patchs for v1"}
54
+ )
55
+ constraint_range: Optional[str] = field(
56
+ default=None,
57
+ metadata={"help": "constraint range"}
58
+ )
59
+
60
+
61
+ def construct_rdrop_sample(x):
62
+ if isinstance(x, dict):
63
+ for key in x:
64
+ x[key] = construct_rdrop_sample(x[key])
65
+ return x
66
+ elif isinstance(x, torch.Tensor):
67
+ return x.repeat(2, *([1] * (x.dim()-1)))
68
+ elif isinstance(x, int):
69
+ return x * 2
70
+ elif isinstance(x, np.ndarray):
71
+ return x.repeat(2)
72
+ else:
73
+ raise NotImplementedError
74
+
75
+
76
+ def kl_loss(p, q):
77
+ p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
78
+ q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
79
+ loss = (p_loss + q_loss) / 2
80
+ return loss
81
+
82
+
83
+ def label_smoothed_nll_loss(
84
+ lprobs, target, epsilon, update_num, reduce=True,
85
+ drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
86
+ constraint_masks=None, constraint_start=None, constraint_end=None
87
+ ):
88
+ if target.dim() == lprobs.dim() - 1:
89
+ target = target.unsqueeze(-1)
90
+ nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
91
+ if constraint_masks is not None:
92
+ smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
93
+ eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
94
+ elif constraint_start is not None and constraint_end is not None:
95
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
96
+ smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
97
+ eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
98
+ else:
99
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
100
+ eps_i = epsilon / (lprobs.size(-1) - 1)
101
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
102
+ if drop_worst_ratio > 0 and update_num > drop_worst_after:
103
+ if use_rdrop:
104
+ true_batch_size = loss.size(0) // 2
105
+ _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
106
+ loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
107
+ nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
108
+ lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
109
+ else:
110
+ loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
111
+ nll_loss = nll_loss[indices]
112
+ lprobs = lprobs[indices]
113
+
114
+ ntokens = loss.numel()
115
+ nll_loss = nll_loss.sum()
116
+ loss = loss.sum()
117
+ if use_rdrop:
118
+ true_batch_size = lprobs.size(0) // 2
119
+ p = lprobs[:true_batch_size]
120
+ q = lprobs[true_batch_size:]
121
+ if constraint_start is not None and constraint_end is not None:
122
+ constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
123
+ p = p[:, constraint_range]
124
+ q = q[:, constraint_range]
125
+ loss += kl_loss(p, q) * reg_alpha
126
+
127
+ return loss, nll_loss, ntokens
128
+
129
+
130
+ @register_criterion(
131
+ "ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
132
+ )
133
+ class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
134
+ def __init__(
135
+ self,
136
+ task,
137
+ sentence_avg,
138
+ label_smoothing,
139
+ ignore_prefix_size=0,
140
+ ignore_eos=False,
141
+ report_accuracy=False,
142
+ drop_worst_ratio=0,
143
+ drop_worst_after=0,
144
+ use_rdrop=False,
145
+ reg_alpha=1.0,
146
+ sample_patch_num=196,
147
+ constraint_range=None
148
+ ):
149
+ super().__init__(task)
150
+ self.sentence_avg = sentence_avg
151
+ self.eps = label_smoothing
152
+ self.ignore_prefix_size = ignore_prefix_size
153
+ self.ignore_eos = ignore_eos
154
+ self.report_accuracy = report_accuracy
155
+ self.drop_worst_ratio = drop_worst_ratio
156
+ self.drop_worst_after = drop_worst_after
157
+ self.use_rdrop = use_rdrop
158
+ self.reg_alpha = reg_alpha
159
+ self.sample_patch_num = sample_patch_num
160
+
161
+ self.constraint_start = None
162
+ self.constraint_end = None
163
+ if constraint_range is not None:
164
+ constraint_start, constraint_end = constraint_range.split(',')
165
+ self.constraint_start = int(constraint_start)
166
+ self.constraint_end = int(constraint_end)
167
+
168
+ def forward(self, model, sample, update_num=0, reduce=True):
169
+ """Compute the loss for the given sample.
170
+
171
+ Returns a tuple with three elements:
172
+ 1) the loss
173
+ 2) the sample size, which is used as the denominator for the gradient
174
+ 3) logging outputs to display while training
175
+ """
176
+ if isinstance(sample, list):
177
+ if self.sample_patch_num > 0:
178
+ sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
179
+ loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
180
+ loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
181
+ loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
182
+ sample_size = 1
183
+ logging_output = {
184
+ "loss": loss.data,
185
+ "loss_v1": loss_v1.data,
186
+ "loss_v2": loss_v2.data,
187
+ "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
188
+ "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
189
+ "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
190
+ "sample_size": 1,
191
+ "sample_size_v1": sample_size_v1,
192
+ "sample_size_v2": sample_size_v2,
193
+ }
194
+ return loss, sample_size, logging_output
195
+
196
+ if self.use_rdrop:
197
+ construct_rdrop_sample(sample)
198
+
199
+ net_output = model(**sample["net_input"])
200
+ loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
201
+ sample_size = (
202
+ sample["target"].size(0) if self.sentence_avg else ntokens
203
+ )
204
+ logging_output = {
205
+ "loss": loss.data,
206
+ "nll_loss": nll_loss.data,
207
+ "ntokens": sample["ntokens"],
208
+ "nsentences": sample["nsentences"],
209
+ "sample_size": sample_size,
210
+ }
211
+ if self.report_accuracy:
212
+ n_correct, total = self.compute_accuracy(model, net_output, sample)
213
+ logging_output["n_correct"] = utils.item(n_correct.data)
214
+ logging_output["total"] = utils.item(total.data)
215
+ return loss, sample_size, logging_output
216
+
217
+ def get_lprobs_and_target(self, model, net_output, sample):
218
+ conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
219
+ constraint_masks = None
220
+ if "constraint_masks" in sample and sample["constraint_masks"] is not None:
221
+ constraint_masks = sample["constraint_masks"]
222
+ net_output[0].masked_fill_(~constraint_masks, -math.inf)
223
+ if self.constraint_start is not None and self.constraint_end is not None:
224
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
225
+ net_output[0][:, :, self.constraint_end:] = -math.inf
226
+ lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
227
+ target = model.get_targets(sample, net_output)
228
+ if self.ignore_prefix_size > 0:
229
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
230
+ target = target[:, self.ignore_prefix_size :].contiguous()
231
+ if constraint_masks is not None:
232
+ constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
233
+ if self.ignore_eos:
234
+ bsz, seq_len, embed_dim = lprobs.size()
235
+ eos_indices = target.eq(self.task.tgt_dict.eos())
236
+ lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
237
+ target = target[~eos_indices].reshape(bsz, seq_len-1)
238
+ if constraint_masks is not None:
239
+ constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
240
+ if constraint_masks is not None:
241
+ constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
242
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
243
+
244
+ def compute_loss(self, model, net_output, sample, update_num, reduce=True):
245
+ lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
246
+ if constraint_masks is not None:
247
+ constraint_masks = constraint_masks[target != self.padding_idx]
248
+ lprobs = lprobs[target != self.padding_idx]
249
+ target = target[target != self.padding_idx]
250
+ loss, nll_loss, ntokens = label_smoothed_nll_loss(
251
+ lprobs,
252
+ target,
253
+ self.eps,
254
+ update_num,
255
+ reduce=reduce,
256
+ drop_worst_ratio=self.drop_worst_ratio,
257
+ drop_worst_after=self.drop_worst_after,
258
+ use_rdrop=self.use_rdrop,
259
+ reg_alpha=self.reg_alpha,
260
+ constraint_masks=constraint_masks,
261
+ constraint_start=self.constraint_start,
262
+ constraint_end=self.constraint_end
263
+ )
264
+ return loss, nll_loss, ntokens
265
+
266
+ def compute_accuracy(self, model, net_output, sample):
267
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
268
+ mask = target.ne(self.padding_idx)
269
+ n_correct = torch.sum(
270
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
271
+ )
272
+ total = torch.sum(mask)
273
+ return n_correct, total
274
+
275
+ @classmethod
276
+ def reduce_metrics(cls, logging_outputs) -> None:
277
+ """Aggregate logging outputs from data parallel training."""
278
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
279
+ loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
280
+ loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
281
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
282
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
283
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
284
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
285
+ sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
286
+ sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
287
+
288
+ metrics.log_scalar(
289
+ "loss", loss_sum / sample_size, sample_size, round=3
290
+ )
291
+ metrics.log_scalar(
292
+ "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
293
+ )
294
+ metrics.log_scalar(
295
+ "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
296
+ )
297
+ metrics.log_scalar(
298
+ "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
299
+ )
300
+ metrics.log_derived(
301
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
302
+ )
303
+
304
+ metrics.log_scalar(
305
+ "ntokens", ntokens, 1, round=3
306
+ )
307
+ metrics.log_scalar(
308
+ "nsentences", nsentences, 1, round=3
309
+ )
310
+ metrics.log_scalar(
311
+ "sample_size", sample_size, 1, round=3
312
+ )
313
+ metrics.log_scalar(
314
+ "sample_size_v1", sample_size_v1, 1, round=3
315
+ )
316
+ metrics.log_scalar(
317
+ "sample_size_v2", sample_size_v2, 1, round=3
318
+ )
319
+
320
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
321
+ if total > 0:
322
+ metrics.log_scalar("total", total)
323
+ n_correct = utils.item(
324
+ sum(log.get("n_correct", 0) for log in logging_outputs)
325
+ )
326
+ metrics.log_scalar("n_correct", n_correct)
327
+ metrics.log_derived(
328
+ "accuracy",
329
+ lambda meters: round(
330
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
331
+ )
332
+ if meters["total"].sum > 0
333
+ else float("nan"),
334
+ )
335
+
336
+ @staticmethod
337
+ def logging_outputs_can_be_summed() -> bool:
338
+ """
339
+ Whether the logging outputs returned by `forward` can be summed
340
+ across workers prior to calling `reduce_metrics`. Setting this
341
+ to True will improves distributed training speed.
342
+ """
343
+ return True
criterions/scst_loss.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import string
8
+ from dataclasses import dataclass, field
9
+ from collections import OrderedDict
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from fairseq import metrics, utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+ from fairseq.dataclass import FairseqDataclass
16
+ from omegaconf import II
17
+
18
+ from data import data_utils
19
+ from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
20
+
21
+
22
+ def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
23
+ loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
24
+ if ignore_index is not None:
25
+ pad_mask = target.eq(ignore_index)
26
+ loss.masked_fill_(pad_mask, 0.0)
27
+ ntokens = (~pad_mask).sum()
28
+ else:
29
+ loss = loss.squeeze(-1)
30
+ ntokens = target.numel()
31
+ if reduce:
32
+ loss = loss.sum()
33
+ return loss, ntokens
34
+
35
+ @dataclass
36
+ class ScstRewardCriterionConfig(FairseqDataclass):
37
+ scst_cider_cached_tokens: str = field(
38
+ default="coco-train-words.p",
39
+ metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
40
+ )
41
+ ignore_prefix_size: int = field(
42
+ default=0,
43
+ metadata={"help": "Ignore first N tokens"},
44
+ )
45
+ sentence_avg: bool = II("optimization.sentence_avg")
46
+ constraint_range: Optional[str] = field(
47
+ default=None,
48
+ metadata={"help": "constraint range"}
49
+ )
50
+
51
+
52
+ @register_criterion(
53
+ "scst_reward_criterion", dataclass=ScstRewardCriterionConfig
54
+ )
55
+ class ScstRewardCriterion(FairseqCriterion):
56
+ CIDER_REWARD_WEIGHT = 1
57
+
58
+ def __init__(
59
+ self,
60
+ task,
61
+ scst_cider_cached_tokens,
62
+ sentence_avg,
63
+ ignore_prefix_size=0,
64
+ constraint_range=None
65
+ ):
66
+ super().__init__(task)
67
+ self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
68
+ self.sentence_avg = sentence_avg
69
+ self.ignore_prefix_size = ignore_prefix_size
70
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
71
+
72
+ self.constraint_start = None
73
+ self.constraint_end = None
74
+ if constraint_range is not None:
75
+ constraint_start, constraint_end = constraint_range.split(',')
76
+ self.constraint_start = int(constraint_start)
77
+ self.constraint_end = int(constraint_end)
78
+
79
+ def forward(self, model, sample, update_num=0, reduce=True):
80
+ """Compute the loss for the given sample.
81
+
82
+ Returns a tuple with three elements:
83
+ 1) the loss
84
+ 2) the sample size, which is used as the denominator for the gradient
85
+ 3) logging outputs to display while training
86
+ """
87
+ loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
88
+
89
+ sample_size = (
90
+ nsentences if self.sentence_avg else ntokens
91
+ )
92
+ logging_output = {
93
+ "loss": loss.data,
94
+ "score": score,
95
+ "ntokens": ntokens,
96
+ "nsentences": nsentences,
97
+ "sample_size": sample_size,
98
+ }
99
+ return loss, sample_size, logging_output
100
+
101
+ def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
102
+ '''
103
+ gen_res: generated captions, list of str
104
+ gt_idx: list of int, of the same length as gen_res
105
+ gt_res: ground truth captions, list of list of str.
106
+ gen_res[i] corresponds to gt_res[gt_idx[i]]
107
+ Each image can have multiple ground truth captions
108
+ '''
109
+ gen_res_size = len(gen_res)
110
+
111
+ res = OrderedDict()
112
+ for i in range(gen_res_size):
113
+ res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
114
+
115
+ gts = OrderedDict()
116
+ gt_res_ = [
117
+ [self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
118
+ for i in range(len(gt_res))
119
+ ]
120
+ for i in range(gen_res_size):
121
+ gts[i] = gt_res_[gt_idx[i]]
122
+
123
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
124
+ _, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
125
+ scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
126
+ return scores
127
+
128
+ @classmethod
129
+ def _wrap_sentence(self, s):
130
+ # ensure the sentence ends with <eos> token
131
+ # in order to keep consisitent with cider_cached_tokens
132
+ r = s.strip()
133
+ if r.endswith('.'):
134
+ r = r[:-1]
135
+ r += ' <eos>'
136
+ return r
137
+
138
+ def get_generator_out(self, model, sample):
139
+ def decode(toks):
140
+ hypo = toks.int().cpu()
141
+ hypo_str = self.task.tgt_dict.string(hypo)
142
+ hypo_str = self.task.bpe.decode(hypo_str).strip()
143
+ return hypo, hypo_str
144
+
145
+ model.eval()
146
+ with torch.no_grad():
147
+ self.task.scst_generator.model.eval()
148
+ gen_out = self.task.scst_generator.generate([model], sample)
149
+
150
+ gen_target = []
151
+ gen_res = []
152
+ gt_res = []
153
+ for i in range(len(gen_out)):
154
+ for j in range(len(gen_out[i])):
155
+ hypo, hypo_str = decode(gen_out[i][j]["tokens"])
156
+ gen_target.append(hypo)
157
+ gen_res.append(hypo_str)
158
+ gt_res.append(
159
+ decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
160
+ )
161
+
162
+ return gen_target, gen_res, gt_res
163
+
164
+ def get_reward_and_scores(self, gen_res, gt_res, device):
165
+ batch_size = len(gt_res)
166
+ gen_res_size = len(gen_res)
167
+ seq_per_img = gen_res_size // batch_size
168
+
169
+ gt_idx = [i // seq_per_img for i in range(gen_res_size)]
170
+ scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
171
+ sc_ = scores.reshape(batch_size, seq_per_img)
172
+ baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
173
+ # sample - baseline
174
+ reward = scores.reshape(batch_size, seq_per_img)
175
+ reward = reward - baseline
176
+ reward = reward.reshape(gen_res_size)
177
+ reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
178
+
179
+ return reward, scores
180
+
181
+ def get_net_output(self, model, sample, gen_target):
182
+ def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
183
+ return data_utils.collate_tokens(
184
+ sample_list,
185
+ pad_idx=self.padding_idx,
186
+ eos_idx=eos,
187
+ left_pad=False,
188
+ move_eos_to_beginning=move_eos_to_beginning,
189
+ )
190
+
191
+ batch_size = len(sample["target"])
192
+ gen_target_size = len(gen_target)
193
+ seq_per_img = gen_target_size // batch_size
194
+
195
+ model.train()
196
+ sample_src_tokens = torch.repeat_interleave(
197
+ sample['net_input']['src_tokens'], seq_per_img, dim=0
198
+ )
199
+ sample_src_lengths = torch.repeat_interleave(
200
+ sample['net_input']['src_lengths'], seq_per_img, dim=0
201
+ )
202
+ sample_patch_images = torch.repeat_interleave(
203
+ sample['net_input']['patch_images'], seq_per_img, dim=0
204
+ )
205
+ sample_patch_masks = torch.repeat_interleave(
206
+ sample['net_input']['patch_masks'], seq_per_img, dim=0
207
+ )
208
+ gen_prev_output_tokens = torch.as_tensor(
209
+ merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
210
+ device=sample["target"].device, dtype=torch.int64
211
+ )
212
+ gen_target_tokens = torch.as_tensor(
213
+ merge(gen_target), device=sample["target"].device, dtype=torch.int64
214
+ )
215
+ net_output = model(
216
+ src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
217
+ patch_images=sample_patch_images, patch_masks=sample_patch_masks,
218
+ prev_output_tokens=gen_prev_output_tokens
219
+ )
220
+
221
+ return net_output, gen_target_tokens
222
+
223
+ def get_lprobs_and_target(self, model, net_output, gen_target):
224
+ if self.constraint_start is not None and self.constraint_end is not None:
225
+ net_output[0][:, :, 4:self.constraint_start] = -math.inf
226
+ net_output[0][:, :, self.constraint_end:] = -math.inf
227
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
228
+ if self.ignore_prefix_size > 0:
229
+ if getattr(lprobs, "batch_first", False):
230
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
231
+ gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
232
+ else:
233
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
234
+ gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
235
+ return lprobs, gen_target
236
+
237
+ def compute_loss(self, model, sample, reduce=True):
238
+ gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
239
+ reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
240
+ net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
241
+ gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
242
+ loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
243
+ nsentences = gen_target_tokens.size(0)
244
+
245
+ return loss, scores.sum(), ntokens, nsentences
246
+
247
+ @classmethod
248
+ def reduce_metrics(cls, logging_outputs) -> None:
249
+ """Aggregate logging outputs from data parallel training."""
250
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
251
+ score_sum = sum(log.get("score", 0) for log in logging_outputs)
252
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
253
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
254
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
255
+
256
+ metrics.log_scalar(
257
+ "loss", loss_sum / sample_size, sample_size, round=3
258
+ )
259
+ metrics.log_scalar(
260
+ "score", score_sum / nsentences, nsentences, round=3
261
+ )
262
+
263
+ metrics.log_scalar(
264
+ "ntokens", ntokens, 1, round=3
265
+ )
266
+ metrics.log_scalar(
267
+ "nsentences", nsentences, 1, round=3
268
+ )
269
+ metrics.log_scalar(
270
+ "sample_size", sample_size, 1, round=3
271
+ )
272
+
273
+ @staticmethod
274
+ def logging_outputs_can_be_summed() -> bool:
275
+ """
276
+ Whether the logging outputs returned by `forward` can be summed
277
+ across workers prior to calling `reduce_metrics`. Setting this
278
+ to True will improves distributed training speed.
279
+ """
280
+ return True
data/__init__.py ADDED
File without changes
data/data_utils.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from collections.abc import Iterable
8
+ except ImportError:
9
+ from collections import Iterable
10
+ import contextlib
11
+ import itertools
12
+ import logging
13
+ import re
14
+ import warnings
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from fairseq.file_io import PathManager
21
+ from fairseq import utils
22
+ import os
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def infer_language_pair(path):
28
+ """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
29
+ src, dst = None, None
30
+ for filename in PathManager.ls(path):
31
+ parts = filename.split(".")
32
+ if len(parts) >= 3 and len(parts[1].split("-")) == 2:
33
+ return parts[1].split("-")
34
+ return src, dst
35
+
36
+
37
+ def collate_tokens(
38
+ values,
39
+ pad_idx,
40
+ eos_idx=None,
41
+ left_pad=False,
42
+ move_eos_to_beginning=False,
43
+ pad_to_length=None,
44
+ pad_to_multiple=1,
45
+ pad_to_bsz=None,
46
+ ):
47
+ """Convert a list of 1d tensors into a padded 2d tensor."""
48
+ size = max(v.size(0) for v in values)
49
+ size = size if pad_to_length is None else max(size, pad_to_length)
50
+ if pad_to_multiple != 1 and size % pad_to_multiple != 0:
51
+ size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
52
+
53
+ def copy_tensor(src, dst):
54
+ assert dst.numel() == src.numel()
55
+ if move_eos_to_beginning:
56
+ if eos_idx is None:
57
+ # if no eos_idx is specified, then use the last token in src
58
+ dst[0] = src[-1]
59
+ else:
60
+ dst[0] = eos_idx
61
+ dst[1:] = src[:-1]
62
+ else:
63
+ dst.copy_(src)
64
+
65
+ if values[0].dim() == 1:
66
+ res = values[0].new(len(values), size).fill_(pad_idx)
67
+ elif values[0].dim() == 2:
68
+ assert move_eos_to_beginning is False
69
+ res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ for i, v in enumerate(values):
74
+ copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
75
+ return res
76
+
77
+
78
+ def load_indexed_dataset(
79
+ path, dictionary=None, dataset_impl=None, combine=False, default="cached"
80
+ ):
81
+ """A helper function for loading indexed datasets.
82
+
83
+ Args:
84
+ path (str): path to indexed dataset (e.g., 'data-bin/train')
85
+ dictionary (~fairseq.data.Dictionary): data dictionary
86
+ dataset_impl (str, optional): which dataset implementation to use. If
87
+ not provided, it will be inferred automatically. For legacy indexed
88
+ data we use the 'cached' implementation by default.
89
+ combine (bool, optional): automatically load and combine multiple
90
+ datasets. For example, if *path* is 'data-bin/train', then we will
91
+ combine 'data-bin/train', 'data-bin/train1', ... and return a
92
+ single ConcatDataset instance.
93
+ """
94
+ import fairseq.data.indexed_dataset as indexed_dataset
95
+ from fairseq.data.concat_dataset import ConcatDataset
96
+
97
+ datasets = []
98
+ for k in itertools.count():
99
+ path_k = path + (str(k) if k > 0 else "")
100
+ try:
101
+ path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
102
+ except Exception as e:
103
+ if "StorageException: [404] Path not found" in str(e):
104
+ logger.warning(f"path_k: {e} not found")
105
+ else:
106
+ raise e
107
+
108
+ dataset_impl_k = dataset_impl
109
+ if dataset_impl_k is None:
110
+ dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
111
+ dataset = indexed_dataset.make_dataset(
112
+ path_k,
113
+ impl=dataset_impl_k or default,
114
+ fix_lua_indexing=True,
115
+ dictionary=dictionary,
116
+ )
117
+ if dataset is None:
118
+ break
119
+ logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
120
+ datasets.append(dataset)
121
+ if not combine:
122
+ break
123
+ if len(datasets) == 0:
124
+ return None
125
+ elif len(datasets) == 1:
126
+ return datasets[0]
127
+ else:
128
+ return ConcatDataset(datasets)
129
+
130
+
131
+ @contextlib.contextmanager
132
+ def numpy_seed(seed, *addl_seeds):
133
+ """Context manager which seeds the NumPy PRNG with the specified seed and
134
+ restores the state afterward"""
135
+ if seed is None:
136
+ yield
137
+ return
138
+ if len(addl_seeds) > 0:
139
+ seed = int(hash((seed, *addl_seeds)) % 1e6)
140
+ state = np.random.get_state()
141
+ np.random.seed(seed)
142
+ try:
143
+ yield
144
+ finally:
145
+ np.random.set_state(state)
146
+
147
+
148
+ def collect_filtered(function, iterable, filtered):
149
+ """
150
+ Similar to :func:`filter` but collects filtered elements in ``filtered``.
151
+
152
+ Args:
153
+ function (callable): function that returns ``False`` for elements that
154
+ should be filtered
155
+ iterable (iterable): iterable to filter
156
+ filtered (list): list to store filtered elements
157
+ """
158
+ for el in iterable:
159
+ if function(el):
160
+ yield el
161
+ else:
162
+ filtered.append(el)
163
+
164
+
165
+ def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
166
+ def compare_leq(a, b):
167
+ return a <= b if not isinstance(a, tuple) else max(a) <= b
168
+
169
+ def check_size(idx):
170
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
171
+ return size_fn(idx) <= max_positions
172
+ elif isinstance(max_positions, dict):
173
+ idx_size = size_fn(idx)
174
+ assert isinstance(idx_size, dict)
175
+ intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
176
+ return all(
177
+ all(
178
+ a is None or b is None or a <= b
179
+ for a, b in zip(idx_size[key], max_positions[key])
180
+ )
181
+ for key in intersect_keys
182
+ )
183
+ else:
184
+ # For MultiCorpusSampledDataset, will generalize it later
185
+ if not isinstance(size_fn(idx), Iterable):
186
+ return all(size_fn(idx) <= b for b in max_positions)
187
+ return all(
188
+ a is None or b is None or a <= b
189
+ for a, b in zip(size_fn(idx), max_positions)
190
+ )
191
+
192
+ ignored = []
193
+ itr = collect_filtered(check_size, indices, ignored)
194
+ indices = np.fromiter(itr, dtype=np.int64, count=-1)
195
+ return indices, ignored
196
+
197
+
198
+ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
199
+ """
200
+ [deprecated] Filter indices based on their size.
201
+ Use `FairseqDataset::filter_indices_by_size` instead.
202
+
203
+ Args:
204
+ indices (List[int]): ordered list of dataset indices
205
+ dataset (FairseqDataset): fairseq dataset instance
206
+ max_positions (tuple): filter elements larger than this size.
207
+ Comparisons are done component-wise.
208
+ raise_exception (bool, optional): if ``True``, raise an exception if
209
+ any elements are filtered (default: False).
210
+ """
211
+ warnings.warn(
212
+ "data_utils.filter_by_size is deprecated. "
213
+ "Use `FairseqDataset::filter_indices_by_size` instead.",
214
+ stacklevel=2,
215
+ )
216
+ if isinstance(max_positions, float) or isinstance(max_positions, int):
217
+ if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
218
+ ignored = indices[dataset.sizes[indices] > max_positions].tolist()
219
+ indices = indices[dataset.sizes[indices] <= max_positions]
220
+ elif (
221
+ hasattr(dataset, "sizes")
222
+ and isinstance(dataset.sizes, list)
223
+ and len(dataset.sizes) == 1
224
+ ):
225
+ ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
226
+ indices = indices[dataset.sizes[0][indices] <= max_positions]
227
+ else:
228
+ indices, ignored = _filter_by_size_dynamic(
229
+ indices, dataset.size, max_positions
230
+ )
231
+ else:
232
+ indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
233
+
234
+ if len(ignored) > 0 and raise_exception:
235
+ raise Exception(
236
+ (
237
+ "Size of sample #{} is invalid (={}) since max_positions={}, "
238
+ "skip this example with --skip-invalid-size-inputs-valid-test"
239
+ ).format(ignored[0], dataset.size(ignored[0]), max_positions)
240
+ )
241
+ if len(ignored) > 0:
242
+ logger.warning(
243
+ (
244
+ "{} samples have invalid sizes and will be skipped, "
245
+ "max_positions={}, first few sample ids={}"
246
+ ).format(len(ignored), max_positions, ignored[:10])
247
+ )
248
+ return indices
249
+
250
+
251
+ def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
252
+ """Filter a list of sample indices. Remove those that are longer
253
+ than specified in max_sizes.
254
+
255
+ Args:
256
+ indices (np.array): original array of sample indices
257
+ max_sizes (int or list[int] or tuple[int]): max sample size,
258
+ can be defined separately for src and tgt (then list or tuple)
259
+
260
+ Returns:
261
+ np.array: filtered sample array
262
+ list: list of removed indices
263
+ """
264
+ if max_sizes is None:
265
+ return indices, []
266
+ if type(max_sizes) in (int, float):
267
+ max_src_size, max_tgt_size = max_sizes, max_sizes
268
+ else:
269
+ max_src_size, max_tgt_size = max_sizes
270
+ if tgt_sizes is None:
271
+ ignored = indices[src_sizes[indices] > max_src_size]
272
+ else:
273
+ ignored = indices[
274
+ (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
275
+ ]
276
+ if len(ignored) > 0:
277
+ if tgt_sizes is None:
278
+ indices = indices[src_sizes[indices] <= max_src_size]
279
+ else:
280
+ indices = indices[
281
+ (src_sizes[indices] <= max_src_size)
282
+ & (tgt_sizes[indices] <= max_tgt_size)
283
+ ]
284
+ return indices, ignored.tolist()
285
+
286
+
287
+ def batch_by_size(
288
+ indices,
289
+ num_tokens_fn,
290
+ num_tokens_vec=None,
291
+ max_tokens=None,
292
+ max_sentences=None,
293
+ required_batch_size_multiple=1,
294
+ fixed_shapes=None,
295
+ ):
296
+ """
297
+ Yield mini-batches of indices bucketed by size. Batches may contain
298
+ sequences of different lengths.
299
+
300
+ Args:
301
+ indices (List[int]): ordered list of dataset indices
302
+ num_tokens_fn (callable): function that returns the number of tokens at
303
+ a given index
304
+ num_tokens_vec (List[int], optional): precomputed vector of the number
305
+ of tokens for each index in indices (to enable faster batch generation)
306
+ max_tokens (int, optional): max number of tokens in each batch
307
+ (default: None).
308
+ max_sentences (int, optional): max number of sentences in each
309
+ batch (default: None).
310
+ required_batch_size_multiple (int, optional): require batch size to
311
+ be less than N or a multiple of N (default: 1).
312
+ fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
313
+ only be created with the given shapes. *max_sentences* and
314
+ *required_batch_size_multiple* will be ignored (default: None).
315
+ """
316
+ try:
317
+ from fairseq.data.data_utils_fast import (
318
+ batch_by_size_fn,
319
+ batch_by_size_vec,
320
+ batch_fixed_shapes_fast,
321
+ )
322
+ except ImportError:
323
+ raise ImportError(
324
+ "Please build Cython components with: "
325
+ "`python setup.py build_ext --inplace`"
326
+ )
327
+ except ValueError:
328
+ raise ValueError(
329
+ "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
330
+ )
331
+
332
+ # added int() to avoid TypeError: an integer is required
333
+ max_tokens = (
334
+ int(max_tokens) if max_tokens is not None else -1
335
+ )
336
+ max_sentences = max_sentences if max_sentences is not None else -1
337
+ bsz_mult = required_batch_size_multiple
338
+
339
+ if not isinstance(indices, np.ndarray):
340
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
341
+
342
+ if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
343
+ num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
344
+
345
+ if fixed_shapes is None:
346
+ if num_tokens_vec is None:
347
+ return batch_by_size_fn(
348
+ indices,
349
+ num_tokens_fn,
350
+ max_tokens,
351
+ max_sentences,
352
+ bsz_mult,
353
+ )
354
+ else:
355
+ return batch_by_size_vec(
356
+ indices,
357
+ num_tokens_vec,
358
+ max_tokens,
359
+ max_sentences,
360
+ bsz_mult,
361
+ )
362
+
363
+ else:
364
+ fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
365
+ sort_order = np.lexsort(
366
+ [
367
+ fixed_shapes[:, 1].argsort(), # length
368
+ fixed_shapes[:, 0].argsort(), # bsz
369
+ ]
370
+ )
371
+ fixed_shapes_sorted = fixed_shapes[sort_order]
372
+ return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
373
+
374
+
375
+ def post_process(sentence: str, symbol: str):
376
+ if symbol == "sentencepiece":
377
+ sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
378
+ elif symbol == "wordpiece":
379
+ sentence = sentence.replace(" ", "").replace("_", " ").strip()
380
+ elif symbol == "letter":
381
+ sentence = sentence.replace(" ", "").replace("|", " ").strip()
382
+ elif symbol == "silence":
383
+ import re
384
+ sentence = sentence.replace("<SIL>", "")
385
+ sentence = re.sub(' +', ' ', sentence).strip()
386
+ elif symbol == "_EOW":
387
+ sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
+ elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
+ if symbol == "subword_nmt":
390
+ symbol = "@@ "
391
+ sentence = (sentence + " ").replace(symbol, "").rstrip()
392
+ elif symbol == "none":
393
+ pass
394
+ elif symbol is not None:
395
+ raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
+ return sentence
397
+
398
+
399
+ def compute_mask_indices(
400
+ shape: Tuple[int, int],
401
+ padding_mask: Optional[torch.Tensor],
402
+ mask_prob: float,
403
+ mask_length: int,
404
+ mask_type: str = "static",
405
+ mask_other: float = 0.0,
406
+ min_masks: int = 0,
407
+ no_overlap: bool = False,
408
+ min_space: int = 0,
409
+ ) -> np.ndarray:
410
+ """
411
+ Computes random mask spans for a given shape
412
+
413
+ Args:
414
+ shape: the the shape for which to compute masks.
415
+ should be of size 2 where first element is batch size and 2nd is timesteps
416
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
417
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
418
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
419
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
420
+ mask_type: how to compute mask lengths
421
+ static = fixed size
422
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
423
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
424
+ poisson = sample from possion distribution with lambda = mask length
425
+ min_masks: minimum number of masked spans
426
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
427
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
428
+ """
429
+
430
+ bsz, all_sz = shape
431
+ mask = np.full((bsz, all_sz), False)
432
+
433
+ all_num_mask = int(
434
+ # add a random number for probabilistic rounding
435
+ mask_prob * all_sz / float(mask_length)
436
+ + np.random.rand()
437
+ )
438
+
439
+ all_num_mask = max(min_masks, all_num_mask)
440
+
441
+ mask_idcs = []
442
+ for i in range(bsz):
443
+ if padding_mask is not None:
444
+ sz = all_sz - padding_mask[i].long().sum().item()
445
+ num_mask = int(
446
+ # add a random number for probabilistic rounding
447
+ mask_prob * sz / float(mask_length)
448
+ + np.random.rand()
449
+ )
450
+ num_mask = max(min_masks, num_mask)
451
+ else:
452
+ sz = all_sz
453
+ num_mask = all_num_mask
454
+
455
+ if mask_type == "static":
456
+ lengths = np.full(num_mask, mask_length)
457
+ elif mask_type == "uniform":
458
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
459
+ elif mask_type == "normal":
460
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
461
+ lengths = [max(1, int(round(x))) for x in lengths]
462
+ elif mask_type == "poisson":
463
+ lengths = np.random.poisson(mask_length, size=num_mask)
464
+ lengths = [int(round(x)) for x in lengths]
465
+ else:
466
+ raise Exception("unknown mask selection " + mask_type)
467
+
468
+ if sum(lengths) == 0:
469
+ lengths[0] = min(mask_length, sz - 1)
470
+
471
+ if no_overlap:
472
+ mask_idc = []
473
+
474
+ def arrange(s, e, length, keep_length):
475
+ span_start = np.random.randint(s, e - length)
476
+ mask_idc.extend(span_start + i for i in range(length))
477
+
478
+ new_parts = []
479
+ if span_start - s - min_space >= keep_length:
480
+ new_parts.append((s, span_start - min_space + 1))
481
+ if e - span_start - keep_length - min_space > keep_length:
482
+ new_parts.append((span_start + length + min_space, e))
483
+ return new_parts
484
+
485
+ parts = [(0, sz)]
486
+ min_length = min(lengths)
487
+ for length in sorted(lengths, reverse=True):
488
+ lens = np.fromiter(
489
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
490
+ np.int,
491
+ )
492
+ l_sum = np.sum(lens)
493
+ if l_sum == 0:
494
+ break
495
+ probs = lens / np.sum(lens)
496
+ c = np.random.choice(len(parts), p=probs)
497
+ s, e = parts.pop(c)
498
+ parts.extend(arrange(s, e, length, min_length))
499
+ mask_idc = np.asarray(mask_idc)
500
+ else:
501
+ min_len = min(lengths)
502
+ if sz - min_len <= num_mask:
503
+ min_len = sz - num_mask - 1
504
+
505
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
506
+
507
+ mask_idc = np.asarray(
508
+ [
509
+ mask_idc[j] + offset
510
+ for j in range(len(mask_idc))
511
+ for offset in range(lengths[j])
512
+ ]
513
+ )
514
+
515
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
516
+
517
+ min_len = min([len(m) for m in mask_idcs])
518
+ for i, mask_idc in enumerate(mask_idcs):
519
+ if len(mask_idc) > min_len:
520
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
521
+ mask[i, mask_idc] = True
522
+
523
+ return mask
524
+
525
+
526
+ def get_mem_usage():
527
+ try:
528
+ import psutil
529
+
530
+ mb = 1024 * 1024
531
+ return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
532
+ except ImportError:
533
+ return "N/A"
534
+
535
+
536
+ # lens: torch.LongTensor
537
+ # returns: torch.BoolTensor
538
+ def lengths_to_padding_mask(lens):
539
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
540
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
541
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
542
+ return mask
543
+
544
+
545
+ # lens: torch.LongTensor
546
+ # returns: torch.BoolTensor
547
+ def lengths_to_mask(lens):
548
+ return ~lengths_to_padding_mask(lens)
549
+
550
+
551
+ def get_buckets(sizes, num_buckets):
552
+ buckets = np.unique(
553
+ np.percentile(
554
+ sizes,
555
+ np.linspace(0, 100, num_buckets + 1),
556
+ interpolation='lower',
557
+ )[1:]
558
+ )
559
+ return buckets
560
+
561
+
562
+ def get_bucketed_sizes(orig_sizes, buckets):
563
+ sizes = np.copy(orig_sizes)
564
+ assert np.min(sizes) >= 0
565
+ start_val = -1
566
+ for end_val in buckets:
567
+ mask = (sizes > start_val) & (sizes <= end_val)
568
+ sizes[mask] = end_val
569
+ start_val = end_val
570
+ return sizes
571
+
572
+
573
+
574
+ def _find_extra_valid_paths(dataset_path: str) -> set:
575
+ paths = utils.split_paths(dataset_path)
576
+ all_valid_paths = set()
577
+ for sub_dir in paths:
578
+ contents = PathManager.ls(sub_dir)
579
+ valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
580
+ all_valid_paths |= {os.path.basename(p) for p in valid_paths}
581
+ # Remove .bin, .idx etc
582
+ roots = {os.path.splitext(p)[0] for p in all_valid_paths}
583
+ return roots
584
+
585
+
586
+ def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
587
+ """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
588
+ if (
589
+ train_cfg.dataset.ignore_unused_valid_subsets
590
+ or train_cfg.dataset.combine_valid_subsets
591
+ or train_cfg.dataset.disable_validation
592
+ or not hasattr(train_cfg.task, "data")
593
+ ):
594
+ return
595
+ other_paths = _find_extra_valid_paths(train_cfg.task.data)
596
+ specified_subsets = train_cfg.dataset.valid_subset.split(",")
597
+ ignored_paths = [p for p in other_paths if p not in specified_subsets]
598
+ if ignored_paths:
599
+ advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
600
+ msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
601
+ raise ValueError(msg)
data/file_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pickle
4
+
5
+
6
+ class FileDataset:
7
+ def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
8
+ self.file_path = file_path
9
+ assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
10
+
11
+ self.separator = separator
12
+ if selected_col_ids is None:
13
+ # default to all fields
14
+ self.selected_col_ids = list(
15
+ range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
16
+ else:
17
+ self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
18
+ if dtypes is None:
19
+ # default to str
20
+ self.dtypes = [str for col_id in self.selected_col_ids]
21
+ else:
22
+ self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
23
+ assert len(self.dtypes) == len(self.selected_col_ids)
24
+
25
+ self.data_cnt = 0
26
+ try:
27
+ self.slice_id = torch.distributed.get_rank()
28
+ self.slice_count = torch.distributed.get_world_size()
29
+ except Exception:
30
+ self.slice_id = 0
31
+ self.slice_count = 1
32
+ self.cached_index = cached_index
33
+ self._init_seek_index()
34
+ self._reader = self._get_reader()
35
+ print("file {} slice_id {} row count {} total row count {}".format(
36
+ self.file_path, self.slice_id, self.row_count, self.total_row_count)
37
+ )
38
+
39
+ def _init_seek_index(self):
40
+ if self.cached_index:
41
+ cache_path = "{}.index".format(self.file_path)
42
+ assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
43
+ self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
44
+ print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
45
+ self.file_path, self.slice_id))
46
+ else:
47
+ # make an iteration over the file to get row_count and line_idx-to-offset mapping
48
+ fp = open(self.file_path, "r")
49
+ print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
50
+ self.file_path, self.slice_id))
51
+ self.total_row_count = 0
52
+ offset = 0
53
+ self.lineid_to_offset = []
54
+ for line in fp:
55
+ self.lineid_to_offset.append(offset)
56
+ self.total_row_count += 1
57
+ offset += len(line.encode('utf-8'))
58
+ self._compute_start_pos_and_row_count()
59
+ print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
60
+ self.file_path, self.slice_id))
61
+
62
+ def _compute_start_pos_and_row_count(self):
63
+ self.row_count = self.total_row_count // self.slice_count
64
+ if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
65
+ self.row_count += 1
66
+ self.start_pos = self.row_count * self.slice_id
67
+ else:
68
+ self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
69
+
70
+ def _get_reader(self):
71
+ fp = open(self.file_path, "r")
72
+ fp.seek(self.lineid_to_offset[self.start_pos])
73
+ return fp
74
+
75
+ def _seek(self, offset=0):
76
+ try:
77
+ print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
78
+ self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
79
+ self.data_cnt = offset
80
+ except Exception:
81
+ print("slice_id {} seek offset {}".format(self.slice_id, offset))
82
+ self._reader.seek(self.lineid_to_offset[offset])
83
+ self.data_cnt = offset
84
+
85
+ def __del__(self):
86
+ self._reader.close()
87
+
88
+ def __len__(self):
89
+ return self.row_count
90
+
91
+ def get_total_row_count(self):
92
+ return self.total_row_count
93
+
94
+ def __getitem__(self, index):
95
+ if self.data_cnt == self.row_count:
96
+ print("reach the end of datafile, start a new reader")
97
+ self.data_cnt = 0
98
+ self._reader = self._get_reader()
99
+ column_l = self._reader.readline().rstrip("\n").split(self.separator)
100
+ self.data_cnt += 1
101
+ column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
102
+ return column_l
data/mm_data/__init__.py ADDED
File without changes
data/mm_data/caption_dataset.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+ import string
10
+
11
+ import numpy as np
12
+ import torch
13
+ import base64
14
+ from torchvision import transforms
15
+
16
+ from PIL import Image, ImageFile
17
+
18
+ from data import data_utils
19
+ from data.ofa_dataset import OFADataset
20
+
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+ ImageFile.MAX_IMAGE_PIXELS = None
23
+ Image.MAX_IMAGE_PIXELS = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
+
28
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
+
31
+
32
+ def collate(samples, pad_idx, eos_idx):
33
+ if len(samples) == 0:
34
+ return {}
35
+
36
+ def merge(key):
37
+ return data_utils.collate_tokens(
38
+ [s[key] for s in samples],
39
+ pad_idx,
40
+ eos_idx=eos_idx,
41
+ )
42
+
43
+ id = np.array([s["id"] for s in samples])
44
+ src_tokens = merge("source")
45
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
+
47
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
+
50
+ prev_output_tokens = None
51
+ target = None
52
+ if samples[0].get("target", None) is not None:
53
+ target = merge("target")
54
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
55
+ ntokens = tgt_lengths.sum().item()
56
+
57
+ if samples[0].get("prev_output_tokens", None) is not None:
58
+ prev_output_tokens = merge("prev_output_tokens")
59
+ else:
60
+ ntokens = src_lengths.sum().item()
61
+
62
+ batch = {
63
+ "id": id,
64
+ "nsentences": len(samples),
65
+ "ntokens": ntokens,
66
+ "net_input": {
67
+ "src_tokens": src_tokens,
68
+ "src_lengths": src_lengths,
69
+ "patch_images": patch_images,
70
+ "patch_masks": patch_masks,
71
+ "prev_output_tokens": prev_output_tokens
72
+ },
73
+ "target": target,
74
+ }
75
+
76
+ return batch
77
+
78
+
79
+ class CaptionDataset(OFADataset):
80
+ def __init__(
81
+ self,
82
+ split,
83
+ dataset,
84
+ bpe,
85
+ src_dict,
86
+ tgt_dict=None,
87
+ max_src_length=128,
88
+ max_tgt_length=30,
89
+ patch_image_size=224,
90
+ imagenet_default_mean_and_std=False,
91
+ scst=False
92
+ ):
93
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
94
+ self.max_src_length = max_src_length
95
+ self.max_tgt_length = max_tgt_length
96
+ self.patch_image_size = patch_image_size
97
+ self.scst = scst
98
+
99
+ self.transtab = str.maketrans({key: None for key in string.punctuation})
100
+
101
+ if imagenet_default_mean_and_std:
102
+ mean = IMAGENET_DEFAULT_MEAN
103
+ std = IMAGENET_DEFAULT_STD
104
+ else:
105
+ mean = [0.5, 0.5, 0.5]
106
+ std = [0.5, 0.5, 0.5]
107
+
108
+ self.patch_resize_transform = transforms.Compose([
109
+ lambda image: image.convert("RGB"),
110
+ transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(mean=mean, std=std),
113
+ ])
114
+
115
+ def __getitem__(self, index):
116
+ uniq_id, image, caption = self.dataset[index]
117
+
118
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
119
+ patch_image = self.patch_resize_transform(image)
120
+ patch_mask = torch.tensor([True])
121
+
122
+ if self.split == 'train' and not self.scst:
123
+ caption = caption.translate(self.transtab).strip()
124
+ caption_token_list = caption.strip().split()
125
+ tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
126
+ else:
127
+ caption = ' '.join(caption.strip().split())
128
+ caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
129
+ tgt_caption = '&&'.join(caption_list)
130
+ src_item = self.encode_text(" what does the image describe?")
131
+ tgt_item = self.encode_text(" {}".format(tgt_caption))
132
+
133
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
134
+ target_item = torch.cat([tgt_item, self.eos_item])
135
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
136
+
137
+ example = {
138
+ "id": uniq_id,
139
+ "source": src_item,
140
+ "patch_image": patch_image,
141
+ "patch_mask": patch_mask,
142
+ "target": target_item,
143
+ "prev_output_tokens": prev_output_item
144
+ }
145
+ return example
146
+
147
+ def collater(self, samples, pad_to_length=None):
148
+ """Merge a list of samples to form a mini-batch.
149
+ Args:
150
+ samples (List[dict]): samples to collate
151
+ Returns:
152
+ dict: a mini-batch with the following keys:
153
+ """
154
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/mm_data/refcoco_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from io import BytesIO
6
+
7
+ import logging
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ import base64
13
+ import utils.transforms as T
14
+
15
+ from PIL import Image, ImageFile
16
+
17
+ from data import data_utils
18
+ from data.ofa_dataset import OFADataset
19
+
20
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
21
+ ImageFile.MAX_IMAGE_PIXELS = None
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ logger = logging.getLogger(__name__)
25
+ warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
26
+
27
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
28
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
29
+
30
+
31
+ def collate(samples, pad_idx, eos_idx):
32
+ if len(samples) == 0:
33
+ return {}
34
+
35
+ def merge(key):
36
+ return data_utils.collate_tokens(
37
+ [s[key] for s in samples],
38
+ pad_idx,
39
+ eos_idx=eos_idx,
40
+ )
41
+
42
+ id = np.array([s["id"] for s in samples])
43
+ src_tokens = merge("source")
44
+ src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
45
+
46
+ patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
47
+ patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
48
+
49
+ w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
50
+ h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
51
+ region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
52
+
53
+ prev_output_tokens = None
54
+ target = None
55
+ if samples[0].get("target", None) is not None:
56
+ target = merge("target")
57
+ tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
58
+ ntokens = tgt_lengths.sum().item()
59
+
60
+ if samples[0].get("prev_output_tokens", None) is not None:
61
+ prev_output_tokens = merge("prev_output_tokens")
62
+ else:
63
+ ntokens = src_lengths.sum().item()
64
+
65
+ batch = {
66
+ "id": id,
67
+ "nsentences": len(samples),
68
+ "ntokens": ntokens,
69
+ "net_input": {
70
+ "src_tokens": src_tokens,
71
+ "src_lengths": src_lengths,
72
+ "patch_images": patch_images,
73
+ "patch_masks": patch_masks,
74
+ "prev_output_tokens": prev_output_tokens
75
+ },
76
+ "target": target,
77
+ "w_resize_ratios": w_resize_ratios,
78
+ "h_resize_ratios": h_resize_ratios,
79
+ "region_coords": region_coords
80
+ }
81
+
82
+ return batch
83
+
84
+
85
+ class RefcocoDataset(OFADataset):
86
+ def __init__(
87
+ self,
88
+ split,
89
+ dataset,
90
+ bpe,
91
+ src_dict,
92
+ tgt_dict=None,
93
+ max_src_length=80,
94
+ max_tgt_length=30,
95
+ patch_image_size=512,
96
+ imagenet_default_mean_and_std=False,
97
+ num_bins=1000,
98
+ max_image_size=512
99
+ ):
100
+ super().__init__(split, dataset, bpe, src_dict, tgt_dict)
101
+ self.max_src_length = max_src_length
102
+ self.max_tgt_length = max_tgt_length
103
+ self.patch_image_size = patch_image_size
104
+ self.num_bins = num_bins
105
+
106
+ if imagenet_default_mean_and_std:
107
+ mean = IMAGENET_DEFAULT_MEAN
108
+ std = IMAGENET_DEFAULT_STD
109
+ else:
110
+ mean = [0.5, 0.5, 0.5]
111
+ std = [0.5, 0.5, 0.5]
112
+
113
+ # for positioning
114
+ self.positioning_transform = T.Compose([
115
+ T.RandomResize([patch_image_size], max_size=patch_image_size),
116
+ T.ToTensor(),
117
+ T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
118
+ ])
119
+
120
+ def __getitem__(self, index):
121
+ uniq_id, base64_str, text, region_coord = self.dataset[index]
122
+
123
+ image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
124
+ w, h = image.size
125
+ boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
126
+ x0, y0, x1, y1 = region_coord.strip().split(',')
127
+ region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
128
+ boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
129
+ boxes_target["labels"] = np.array([0])
130
+ boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
131
+
132
+ patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
133
+ resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
134
+ patch_mask = torch.tensor([True])
135
+ quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
136
+ quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
137
+ quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
138
+ quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
139
+ region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
140
+ src_caption = self.pre_caption(text, self.max_src_length)
141
+ src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
142
+ tgt_item = self.encode_text(region_coord, use_bpe=False)
143
+
144
+ src_item = torch.cat([self.bos_item, src_item, self.eos_item])
145
+ target_item = torch.cat([tgt_item, self.eos_item])
146
+ prev_output_item = torch.cat([self.bos_item, tgt_item])
147
+
148
+ example = {
149
+ "id": uniq_id,
150
+ "source": src_item,
151
+ "patch_image": patch_image,
152
+ "patch_mask": patch_mask,
153
+ "target": target_item,
154
+ "prev_output_tokens": prev_output_item,
155
+ "w_resize_ratio": resize_w / w,
156
+ "h_resize_ratio": resize_h / h,
157
+ "region_coord": region
158
+ }
159
+ return example
160
+
161
+ def collater(self, samples, pad_to_length=None):
162
+ """Merge a list of samples to form a mini-batch.
163
+ Args:
164
+ samples (List[dict]): samples to collate
165
+ Returns:
166
+ dict: a mini-batch with the following keys:
167
+ """
168
+ return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
data/ofa_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import torch.utils.data
4
+ from fairseq.data import FairseqDataset
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class OFADataset(FairseqDataset):
10
+ def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
11
+ self.split = split
12
+ self.dataset = dataset
13
+ self.bpe = bpe
14
+ self.src_dict = src_dict
15
+ self.tgt_dict = tgt_dict
16
+
17
+ self.bos = src_dict.bos()
18
+ self.eos = src_dict.eos()
19
+ self.pad = src_dict.pad()
20
+ self.bos_item = torch.LongTensor([self.bos])
21
+ self.eos_item = torch.LongTensor([self.eos])
22
+
23
+ def __len__(self):
24
+ return len(self.dataset)
25
+
26
+ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
27
+ s = self.tgt_dict.encode_line(
28
+ line=self.bpe.encode(text) if use_bpe else text,
29
+ add_if_not_exist=False,
30
+ append_eos=False
31
+ ).long()
32
+ if length is not None:
33
+ s = s[:length]
34
+ if append_bos:
35
+ s = torch.cat([self.bos_item, s])
36
+ if append_eos:
37
+ s = torch.cat([s, self.eos_item])
38
+ return s
39
+
40
+ def pre_question(self, question, max_ques_words):
41
+ question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
42
+
43
+ question = re.sub(
44
+ r"\s{2,}",
45
+ ' ',
46
+ question,
47
+ )
48
+ question = question.rstrip('\n')
49
+ question = question.strip(' ')
50
+
51
+ # truncate question
52
+ question_words = question.split(' ')
53
+ if len(question_words) > max_ques_words:
54
+ question = ' '.join(question_words[:max_ques_words])
55
+
56
+ return question
57
+
58
+ def pre_caption(self, caption, max_words):
59
+ caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
60
+
61
+ caption = re.sub(
62
+ r"\s{2,}",
63
+ ' ',
64
+ caption,
65
+ )
66
+ caption = caption.rstrip('\n')
67
+ caption = caption.strip(' ')
68
+
69
+ # truncate caption
70
+ caption_words = caption.split(' ')
71
+ if len(caption_words) > max_words:
72
+ caption = ' '.join(caption_words[:max_words])
73
+
74
+ return caption
evaluate.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import json
11
+ from itertools import chain
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ from fairseq import distributed_utils, options, tasks, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.logging import progress_bar
19
+ from fairseq.utils import reset_logging
20
+ from omegaconf import DictConfig
21
+
22
+ from utils import checkpoint_utils
23
+ from utils.eval_utils import eval_step
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
29
+ stream=sys.stdout,
30
+ )
31
+ logger = logging.getLogger("ofa.evaluate")
32
+
33
+
34
+ def apply_half(t):
35
+ if t.dtype is torch.float32:
36
+ return t.to(dtype=torch.half)
37
+ return t
38
+
39
+
40
+ def main(cfg: DictConfig):
41
+ utils.import_user_module(cfg.common)
42
+
43
+ reset_logging()
44
+ logger.info(cfg)
45
+
46
+ assert (
47
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
48
+ ), "Must specify batch size either with --max-tokens or --batch-size"
49
+
50
+ # Fix seed for stochastic decoding
51
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
52
+ np.random.seed(cfg.common.seed)
53
+ utils.set_torch_seed(cfg.common.seed)
54
+
55
+ use_fp16 = cfg.common.fp16
56
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
57
+
58
+ if use_cuda:
59
+ torch.cuda.set_device(cfg.distributed_training.device_id)
60
+
61
+ # Load ensemble
62
+ overrides = eval(cfg.common_eval.model_overrides)
63
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
64
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
65
+ utils.split_paths(cfg.common_eval.path),
66
+ arg_overrides=overrides,
67
+ suffix=cfg.checkpoint.checkpoint_suffix,
68
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
69
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
70
+ )
71
+
72
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
73
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
74
+
75
+ # Move models to GPU
76
+ for model in models:
77
+ model.eval()
78
+ if use_fp16:
79
+ model.half()
80
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
81
+ model.cuda()
82
+ model.prepare_for_inference_(cfg)
83
+
84
+ # Load dataset (possibly sharded)
85
+ itr = task.get_batch_iterator(
86
+ dataset=task.dataset(cfg.dataset.gen_subset),
87
+ max_tokens=cfg.dataset.max_tokens,
88
+ max_sentences=cfg.dataset.batch_size,
89
+ max_positions=utils.resolve_max_positions(
90
+ task.max_positions(), *[m.max_positions() for m in models]
91
+ ),
92
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
93
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
94
+ seed=cfg.common.seed,
95
+ num_shards=cfg.distributed_training.distributed_world_size,
96
+ shard_id=cfg.distributed_training.distributed_rank,
97
+ num_workers=cfg.dataset.num_workers,
98
+ data_buffer_size=cfg.dataset.data_buffer_size,
99
+ ).next_epoch_itr(shuffle=False)
100
+ progress = progress_bar.progress_bar(
101
+ itr,
102
+ log_format=cfg.common.log_format,
103
+ log_interval=cfg.common.log_interval,
104
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
105
+ )
106
+
107
+ # Initialize generator
108
+ generator = task.build_generator(models, cfg.generation)
109
+
110
+ results = []
111
+ score_sum = torch.FloatTensor([0]).cuda()
112
+ score_cnt = torch.FloatTensor([0]).cuda()
113
+ for sample in progress:
114
+ if "net_input" not in sample:
115
+ continue
116
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
117
+ sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
118
+ with torch.no_grad():
119
+ result, scores = eval_step(task, generator, models, sample)
120
+ results += result
121
+ score_sum += sum(scores) if scores is not None else 0
122
+ score_cnt += len(scores) if scores is not None else 0
123
+ progress.log({"sentences": sample["nsentences"]})
124
+
125
+ gather_results = None
126
+ if cfg.distributed_training.distributed_world_size > 1:
127
+ gather_results = [None for _ in range(dist.get_world_size())]
128
+ dist.all_gather_object(gather_results, results)
129
+ dist.all_reduce(score_sum.data)
130
+ dist.all_reduce(score_cnt.data)
131
+ if score_cnt.item() > 0:
132
+ logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
133
+ score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
134
+ ))
135
+
136
+ if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
137
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
138
+ output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
139
+ gather_results = list(chain(*gather_results)) if gather_results is not None else results
140
+ with open(output_path, 'w') as fw:
141
+ json.dump(gather_results, fw)
142
+
143
+
144
+ def cli_main():
145
+ parser = options.get_generation_parser()
146
+ args = options.parse_args_and_arch(parser)
147
+ cfg = convert_namespace_to_omegaconf(args)
148
+ distributed_utils.call_main(cfg, main)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ cli_main()
fairseq/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or main):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or main):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Before submitting
2
+
3
+ - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
4
+ - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
5
+ - [ ] Did you make sure to update the docs?
6
+ - [ ] Did you write any new necessary tests?
7
+
8
+ ## What does this PR do?
9
+ Fixes # (issue).
10
+
11
+ ## PR review
12
+ Anyone in the community is free to review the PR once the tests have passed.
13
+ If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
14
+
15
+ ## Did you have fun?
16
+ Make sure you had fun coding 🙃
fairseq/.github/stale.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+ # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
3
+ # Number of days of inactivity before an issue becomes stale
4
+ daysUntilStale: 90
5
+ # Number of days of inactivity before a stale issue is closed
6
+ daysUntilClose: 7
7
+ # Issues with these labels will never be considered stale
8
+ exemptLabels:
9
+ - bug
10
+ # Label to use when marking an issue as stale
11
+ staleLabel: stale
12
+ issues:
13
+ # Comment to post when marking an issue as stale.
14
+ markComment: >
15
+ This issue has been automatically marked as stale.
16
+ **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
17
+ We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
18
+ # Comment to post when closing a stale issue.
19
+ closeComment: >
20
+ Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
21
+ pulls:
22
+ # Comment to post when marking a pull request as stale.
23
+ markComment: >
24
+ This pull request has been automatically marked as stale.
25
+ **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
26
+ We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
27
+ # Comment to post when closing a stale pull request.
28
+ closeComment: >
29
+ Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
30
+
fairseq/.github/workflows/build.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build
2
+
3
+ on:
4
+ # Trigger the workflow on push to main or any pull request
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+
10
+ jobs:
11
+ build:
12
+
13
+ strategy:
14
+ max-parallel: 4
15
+ matrix:
16
+ platform: [ubuntu-latest, macos-latest]
17
+ python-version: [3.6, 3.7]
18
+
19
+ runs-on: ${{ matrix.platform }}
20
+
21
+ steps:
22
+ - uses: actions/checkout@v2
23
+
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v2
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+
29
+ - name: Conditionally install pytorch
30
+ if: matrix.platform == 'windows-latest'
31
+ run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
32
+
33
+ - name: Install locally
34
+ run: |
35
+ python -m pip install --upgrade pip
36
+ git submodule update --init --recursive
37
+ python setup.py build_ext --inplace
38
+ python -m pip install --editable .
39
+
40
+ - name: Install optional test requirements
41
+ run: |
42
+ python -m pip install iopath transformers pyarrow
43
+ python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
44
+
45
+ - name: Lint with flake8
46
+ run: |
47
+ pip install flake8
48
+ # stop the build if there are Python syntax errors or undefined names
49
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
50
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
51
+ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
52
+
53
+ - name: Run tests
54
+ run: |
55
+ python setup.py test
fairseq/.github/workflows/build_wheels.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: build_wheels
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - v[0-9]+.[0-9]+.[x0-9]+
7
+ tags:
8
+ - v*
9
+
10
+ jobs:
11
+ build_wheels:
12
+ name: Build wheels on ${{ matrix.os }}
13
+ runs-on: ${{ matrix.os }}
14
+ strategy:
15
+ matrix:
16
+ os: [ubuntu-latest, macos-latest]
17
+
18
+ steps:
19
+ - uses: actions/checkout@v2
20
+
21
+ - name: Install Python
22
+ uses: actions/setup-python@v2
23
+ with:
24
+ python-version: '3.7'
25
+
26
+ - name: Install cibuildwheel
27
+ run: |
28
+ python -m pip install cibuildwheel
29
+
30
+ - name: Build wheels for CPython
31
+ run: |
32
+ python -m cibuildwheel --output-dir dist
33
+ env:
34
+ CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
35
+ CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
36
+ CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
37
+
38
+ - uses: actions/upload-artifact@v2
39
+ with:
40
+ name: wheels
41
+ path: ./dist/*.whl
fairseq/.gitmodules ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [submodule "fairseq/model_parallel/megatron"]
2
+ path = fairseq/model_parallel/megatron
3
+ url = https://github.com/ngoyal2707/Megatron-LM
4
+ branch = fairseq
fairseq/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <conduct@pytorch.org>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
fairseq/CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
fairseq/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fairseq/README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/fairseq_logo.png" width="150">
3
+ <br />
4
+ <br />
5
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
6
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
7
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
8
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
9
+ </p>
10
+
11
+ --------------------------------------------------------------------------------
12
+
13
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
14
+ developers to train custom models for translation, summarization, language
15
+ modeling and other text generation tasks.
16
+
17
+ We provide reference implementations of various sequence modeling papers:
18
+
19
+ <details><summary>List of implemented papers</summary><p>
20
+
21
+ * **Convolutional Neural Networks (CNN)**
22
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
23
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
24
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
25
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
26
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
27
+ * **LightConv and DynamicConv models**
28
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
29
+ * **Long Short-Term Memory (LSTM) networks**
30
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
31
+ * **Transformer (self-attention) networks**
32
+ + Attention Is All You Need (Vaswani et al., 2017)
33
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
34
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
35
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
36
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
37
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
38
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
39
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
40
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
41
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
42
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
43
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
44
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
45
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
46
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
47
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
48
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
49
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
50
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
51
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
52
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
53
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
54
+ * **Non-autoregressive Transformers**
55
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
56
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
57
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
58
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
59
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
60
+ * **Finetuning**
61
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
62
+
63
+ </p></details>
64
+
65
+ ### What's New:
66
+
67
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
68
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
69
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
70
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
71
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
72
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
73
+ * February 2021 [Added LASER training code](examples/laser/README.md)
74
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
75
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
76
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
77
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
78
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
79
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
80
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
81
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
82
+
83
+ <details><summary>Previous updates</summary><p>
84
+
85
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
86
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
87
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
88
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
89
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
90
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
91
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
92
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
93
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
94
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
95
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
96
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
97
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
98
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
99
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
100
+ * November 2019: [BART model and code released](examples/bart/README.md)
101
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
102
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
103
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
104
+ * July 2019: fairseq relicensed under MIT license
105
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
106
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
107
+
108
+ </p></details>
109
+
110
+ ### Features:
111
+
112
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
113
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
114
+ + beam search
115
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
116
+ + sampling (unconstrained, top-k and top-p/nucleus)
117
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
118
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
119
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
120
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
121
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
122
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
123
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
124
+
125
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
126
+ with a convenient `torch.hub` interface:
127
+
128
+ ``` python
129
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
130
+ en2de.translate('Hello world', beam=5)
131
+ # 'Hallo Welt'
132
+ ```
133
+
134
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
135
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
136
+
137
+ # Requirements and Installation
138
+
139
+ * [PyTorch](http://pytorch.org/) version >= 1.5.0
140
+ * Python version >= 3.6
141
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
142
+ * **To install fairseq** and develop locally:
143
+
144
+ ``` bash
145
+ git clone https://github.com/pytorch/fairseq
146
+ cd fairseq
147
+ pip install --editable ./
148
+
149
+ # on MacOS:
150
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
151
+
152
+ # to install the latest stable release (0.10.x)
153
+ # pip install fairseq
154
+ ```
155
+
156
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
157
+
158
+ ``` bash
159
+ git clone https://github.com/NVIDIA/apex
160
+ cd apex
161
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
162
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
163
+ --global-option="--fast_multihead_attn" ./
164
+ ```
165
+
166
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
167
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
168
+ as command line options to `nvidia-docker run` .
169
+
170
+ # Getting Started
171
+
172
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
173
+ for getting started, training new models and extending fairseq with new model
174
+ types and tasks.
175
+
176
+ # Pre-trained models and examples
177
+
178
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
179
+ as well as example training and evaluation commands.
180
+
181
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
182
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
183
+
184
+ We also have more detailed READMEs to reproduce results from specific papers:
185
+
186
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
187
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
188
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
189
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
190
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
191
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
192
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
193
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
194
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
195
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
196
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
197
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
198
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
199
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
200
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
201
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
202
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
203
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
204
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
205
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
206
+
207
+ # Join the fairseq community
208
+
209
+ * Twitter: https://twitter.com/fairseq
210
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
211
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
212
+
213
+ # License
214
+
215
+ fairseq(-py) is MIT-licensed.
216
+ The license applies to the pre-trained models as well.
217
+
218
+ # Citation
219
+
220
+ Please cite as:
221
+
222
+ ``` bibtex
223
+ @inproceedings{ott2019fairseq,
224
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
225
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
226
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
227
+ year = {2019},
228
+ }
229
+ ```
fairseq/examples/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ try:
7
+ from fairseq.version import __version__ # noqa
8
+ except ImportError:
9
+ pass
fairseq/examples/adaptive_span/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive Span
2
+
3
+ Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
+ attention span. This allows us to extend significantly the maximum context size
5
+ used in Transformer, while maintaining control over their memory footprint
6
+ and computational time. It uses the Truncated BPTT technique for training,
7
+ as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
8
+
9
+ Adaptive Span was introduced by paper:
10
+ [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
+ which achieved state-of-the-art language modeling results at the time of publication.
12
+
13
+ We manage to reproduce their result in fairseq and keep most of the
14
+ [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
+ You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
+
17
+ ##### 0. Setup
18
+
19
+ First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
+ from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
+ You can download the dataset, and then run:
22
+ ```bash
23
+ fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
+ ```
27
+
28
+ ##### 1. Train a Adaptive Span model on Enwik8
29
+
30
+ We will train a 12-layer Adaptive Span model following the [hyperparameters
31
+ used in the original
32
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
+
34
+ The following command assumes 4 GPUs, so that the total batch size is 64
35
+ sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
+ ```bash
37
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
+ --user-dir examples/adaptive_span \
39
+ --data ~/data/enwik8/data-bin/ \
40
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
+ --validate-interval-updates 1000 \
45
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
+ ```
49
+ This should land around 1.05 on validation, 1.03 on test. You can lower the
50
+ --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
+ improvement to the transformerXL baseline here.
52
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
+ and simulate training on 4 GPUs.
54
+ You can also reproduce the transformerXL result on enwik8 using this code base.
55
+ It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
+ You can try by
57
+ ```bash
58
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
+ --user-dir examples/truncated_bptt \
60
+ ~/data/enwik8/data-bin/ \
61
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
62
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
+ --lr-scheduler cosine --warmup-updates 0 \
66
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
67
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
+ --fp16
69
+ ```
70
+
71
+ ##### 2. Evaluate
72
+ For Adaptive Span:
73
+ ```bash
74
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
+ --user-dir examples/adaptive_span \
76
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
+ ```
78
+ For Transformer-XL evaluation:
79
+ ```bash
80
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
+ --tokens-per-sample 80 \
83
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
+ --gen-subset valid
85
+ ```
86
+
87
+ *Note:* During training the model saw 512 tokens of context
88
+ (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
+ settings from [the original
90
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
fairseq/examples/adaptive_span/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ # automatically import any Python files in the current directory
10
+ cur_dir = os.path.dirname(__file__)
11
+ for file in os.listdir(cur_dir):
12
+ path = os.path.join(cur_dir, file)
13
+ if (
14
+ not file.startswith("_")
15
+ and not file.startswith(".")
16
+ and (file.endswith(".py") or os.path.isdir(path))
17
+ ):
18
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
+ module = importlib.import_module(__name__ + "." + mod_name)
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.optim import Adagrad
7
+
8
+ from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9
+
10
+
11
+ @register_optimizer("adagrad_with_grad_clip")
12
+ class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16
+
17
+ @staticmethod
18
+ def add_args(parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ # fmt: off
21
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22
+ help='weight decay')
23
+ parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24
+ help='internal grad clip')
25
+ # fmt: on
26
+
27
+ @property
28
+ def optimizer_config(self):
29
+ """
30
+ Return a kwarg dictionary that will be used to override optimizer
31
+ args stored in checkpoints. This allows us to load a checkpoint and
32
+ resume training using a different set of optimizer args, e.g., with a
33
+ different learning rate.
34
+ """
35
+ return {
36
+ "lr": self.args.lr[0],
37
+ "weight_decay": self.args.weight_decay,
38
+ "grad_clip": self.args.adagrad_clip,
39
+ }
40
+
41
+ @property
42
+ def supports_flat_params(self):
43
+ return False
44
+
45
+
46
+ def _clip_grad(clr, grad, group_grad_clip):
47
+ if group_grad_clip > 0:
48
+ norm = grad.norm(2).item()
49
+ if norm > group_grad_clip:
50
+ clr *= group_grad_clip / (norm + 1e-10)
51
+ return clr
52
+
53
+
54
+ class AdagradWithGradClip(Adagrad):
55
+ """Adagrad algorithm with custom gradient clipping"""
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr=1e-2,
61
+ lr_decay=0,
62
+ weight_decay=0,
63
+ initial_accumulator_value=0,
64
+ grad_clip=0,
65
+ ):
66
+ Adagrad.__init__(
67
+ self,
68
+ params,
69
+ lr=lr,
70
+ lr_decay=lr_decay,
71
+ weight_decay=weight_decay,
72
+ initial_accumulator_value=initial_accumulator_value,
73
+ )
74
+ self.defaults["grad_clip"] = grad_clip
75
+ self.param_groups[0].setdefault("grad_clip", grad_clip)
76
+
77
+ def step(self, closure=None):
78
+ loss = None
79
+ if closure is not None:
80
+ loss = closure()
81
+
82
+ for group in self.param_groups:
83
+ for p in group["params"]:
84
+ if p.grad is None:
85
+ continue
86
+
87
+ grad = p.grad.data
88
+ state = self.state[p]
89
+
90
+ state["step"] += 1
91
+
92
+ if group["weight_decay"] != 0:
93
+ if p.grad.data.is_sparse:
94
+ raise RuntimeError(
95
+ "weight_decay option is "
96
+ "not compatible with sparse "
97
+ "gradients"
98
+ )
99
+ grad = grad.add(group["weight_decay"], p.data)
100
+
101
+ clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102
+
103
+ # clip
104
+ clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105
+
106
+ if grad.is_sparse:
107
+ # the update is non-linear so indices must be unique
108
+ grad = grad.coalesce()
109
+ grad_indices = grad._indices()
110
+ grad_values = grad._values()
111
+ size = grad.size()
112
+
113
+ def make_sparse(values):
114
+ constructor = grad.new
115
+ if grad_indices.dim() == 0 or values.dim() == 0:
116
+ return constructor().resize_as_(grad)
117
+ return constructor(grad_indices, values, size)
118
+
119
+ state["sum"].add_(make_sparse(grad_values.pow(2)))
120
+ std = state["sum"]._sparse_mask(grad)
121
+ std_values = std._values().sqrt_().add_(1e-10)
122
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
123
+ else:
124
+ state["sum"].addcmul_(1, grad, grad)
125
+ std = state["sum"].sqrt().add_(1e-10)
126
+ p.data.addcdiv_(-clr, grad, std)
127
+
128
+ return loss
fairseq/examples/adaptive_span/adaptive_span_attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class AdaptiveMask(nn.Module):
13
+ """Soft masking function for adaptive size.
14
+ It masks out the last K values of an input. The masking value
15
+ goes from 1 to 0 gradually, so K can be learned with
16
+ back-propagation.
17
+ Args:
18
+ max_size: maximum size (i.e. input dimension)
19
+ ramp_size: size of the ramp going from 0 to 1
20
+ init_val: initial size proportion not to be masked out
21
+ shape: learn multiple sizes independent of each other
22
+ """
23
+
24
+ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25
+ nn.Module.__init__(self)
26
+ self._max_size = max_size
27
+ self._ramp_size = ramp_size
28
+ self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29
+ mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30
+ self.register_buffer("mask_template", mask_template)
31
+
32
+ def forward(self, x):
33
+ mask = self.mask_template.float() + self.current_val.float() * self._max_size
34
+ mask = mask / self._ramp_size + 1
35
+ mask = mask.clamp(0, 1)
36
+ if x.size(-1) < self._max_size:
37
+ # the input could have been trimmed beforehand to save computation
38
+ mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39
+ x = (x * mask).type_as(x)
40
+ return x
41
+
42
+ def get_current_max_size(self, include_ramp=True):
43
+ current_size = math.ceil(self.current_val.max().item() * self._max_size)
44
+ if include_ramp:
45
+ current_size += self._ramp_size
46
+ current_size = max(0, min(self._max_size, current_size))
47
+ return current_size
48
+
49
+ def get_current_avg_size(self, include_ramp=True):
50
+ current_size = math.ceil(
51
+ self.current_val.float().mean().item() * self._max_size
52
+ )
53
+ if include_ramp:
54
+ current_size += self._ramp_size
55
+ current_size = max(0, min(self._max_size, current_size))
56
+ return current_size
57
+
58
+ def clamp_param(self):
59
+ """this need to be called after each update"""
60
+ self.current_val.data.clamp_(0, 1)
61
+
62
+
63
+ class AdaptiveSpan(nn.Module):
64
+ """Adaptive attention span for Transformerself.
65
+ This module learns an attention span length from data for each
66
+ self-attention head.
67
+ Args:
68
+ attn_span: maximum attention span
69
+ adapt_span_loss: loss coefficient for the span length
70
+ adapt_span_ramp: length of the masking ramp
71
+ adapt_span_init: initial size ratio
72
+ adapt_span_cache: adapt cache size to reduce memory usage
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ attn_span,
78
+ adapt_span_ramp,
79
+ adapt_span_init,
80
+ n_head,
81
+ adapt_span_layer,
82
+ **kargs
83
+ ):
84
+ nn.Module.__init__(self)
85
+ self._max_span = attn_span
86
+ self._n_head = n_head
87
+ self._adapt_span_layer = adapt_span_layer
88
+ if self._adapt_span_layer:
89
+ self._mask = AdaptiveMask(
90
+ max_size=self._max_span,
91
+ ramp_size=adapt_span_ramp,
92
+ init_val=adapt_span_init,
93
+ )
94
+ else:
95
+ self._mask = AdaptiveMask(
96
+ max_size=self._max_span,
97
+ ramp_size=adapt_span_ramp,
98
+ init_val=adapt_span_init,
99
+ shape=(n_head, 1, 1),
100
+ )
101
+
102
+ def forward(self, attn, normalize=True):
103
+ """mask attention with the right span"""
104
+ # batch and head dimensions are merged together, so separate them first
105
+ self.clamp_param()
106
+ if self._adapt_span_layer:
107
+ attn = self._mask(attn)
108
+ else:
109
+ B = attn.size(0) # batch size
110
+ M = attn.size(1) # block size
111
+ attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112
+ attn = self._mask(attn)
113
+ attn = attn.view(B, M, -1)
114
+ return attn
115
+
116
+ def get_trim_len(self):
117
+ """how much of memory can be trimmed to reduce computation"""
118
+ L = self._max_span
119
+ trim_len = min(L - 1, L - self._mask.get_current_max_size())
120
+ # too fine granularity might be bad for the memory management
121
+ trim_len = math.floor(trim_len / 64) * 64
122
+ return trim_len
123
+
124
+ def trim_memory(self, query, key, value, key_pe):
125
+ """trim out unnecessary memory beforehand to reduce computation"""
126
+ trim_len = self.get_trim_len()
127
+ cache_size = key.size(1) - query.size(1)
128
+ trim_len_cache = trim_len - (self._max_span - cache_size)
129
+ if trim_len_cache > 0:
130
+ key = key[:, trim_len_cache:, :]
131
+ value = value[:, trim_len_cache:, :]
132
+ elif trim_len_cache < 0:
133
+ # cache is too short! this happens when validation resumes
134
+ # after a lot of updates.
135
+ key = F.pad(key, [0, 0, -trim_len_cache, 0])
136
+ value = F.pad(value, [0, 0, -trim_len_cache, 0])
137
+ if trim_len > 0:
138
+ if key_pe is not None:
139
+ key_pe = key_pe[:, :, trim_len:]
140
+ return key, value, key_pe
141
+
142
+ def get_cache_size(self):
143
+ """determine how long the cache should be"""
144
+ trim_len = self.get_trim_len()
145
+ # give a buffer of 64 steps since a span might increase
146
+ # in future updates
147
+ return min(self._max_span, self._max_span - trim_len + 64)
148
+
149
+ def get_loss(self):
150
+ """a loss term for regularizing the span length"""
151
+ return self._max_span * self._mask.current_val.float().mean()
152
+
153
+ def get_current_max_span(self):
154
+ return self._mask.get_current_max_size()
155
+
156
+ def get_current_avg_span(self):
157
+ return self._mask.get_current_avg_size()
158
+
159
+ def clamp_param(self):
160
+ self._mask.clamp_param()
fairseq/examples/adaptive_span/adaptive_span_loss.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import metrics, utils
11
+ from fairseq.criterions import register_criterion
12
+ from fairseq.criterions.cross_entropy import CrossEntropyCriterion
13
+ from fairseq.dataclass import FairseqDataclass
14
+ from omegaconf import II
15
+
16
+
17
+ @dataclass
18
+ class AdaptiveSpanCriterionConfig(FairseqDataclass):
19
+ sentence_avg: bool = II("optimization.sentence_avg")
20
+
21
+
22
+ @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
23
+ class AdaptiveSpanCriterion(CrossEntropyCriterion):
24
+ def __init__(self, task, sentence_avg):
25
+ super().__init__(task, sentence_avg)
26
+
27
+ def forward(self, model, sample, reduce=True):
28
+ """Compute the loss for the given sample.
29
+
30
+ Returns a tuple with three elements:
31
+ 1) the loss here is summed, different from the adaptive span code
32
+ 2) the sample size, which is used as the denominator for the gradient
33
+ 3) logging outputs to display while training
34
+ """
35
+ net_output = model(**sample["net_input"])
36
+ loss, aux_loss, avg_span, max_span = self.compute_loss(
37
+ model, net_output, sample, reduce=reduce
38
+ )
39
+ sample_size = (
40
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
41
+ )
42
+ loss /= sample_size
43
+ total_loss = loss + aux_loss
44
+ sample_size = 1
45
+
46
+ logging_output = {
47
+ "loss": loss.data,
48
+ "ntokens": sample["ntokens"],
49
+ "nsentences": sample["target"].size(0),
50
+ "sample_size": sample_size,
51
+ "total_loss": total_loss.data,
52
+ "avg_span": avg_span * sample_size,
53
+ "max_span": max_span * sample_size,
54
+ }
55
+ return total_loss, sample_size, logging_output
56
+
57
+ def compute_loss(self, model, net_output, sample, reduce=True):
58
+ loss, _ = super().compute_loss(model, net_output, sample, reduce)
59
+ aux_loss = model.get_aux_loss()
60
+ avg_span = model.get_current_avg_span()
61
+ max_span = model.get_current_max_span()
62
+ return loss, aux_loss, avg_span, max_span
63
+
64
+ @staticmethod
65
+ def reduce_metrics(logging_outputs) -> None:
66
+ """Aggregate logging outputs from data parallel training."""
67
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
68
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
69
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
70
+ total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
71
+ avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
72
+ max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
73
+
74
+ # we divide by log(2) to convert the loss from base e to base 2
75
+ metrics.log_scalar(
76
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
77
+ )
78
+ metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
79
+ metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
80
+ # total loss contains the L1 norm on adaptive-span
81
+ metrics.log_scalar(
82
+ "total_loss",
83
+ total_loss_sum / sample_size / math.log(2),
84
+ sample_size,
85
+ round=3,
86
+ )
87
+ if sample_size != ntokens:
88
+ metrics.log_scalar(
89
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
90
+ )
91
+ metrics.log_derived(
92
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
93
+ )
94
+ else:
95
+ metrics.log_derived(
96
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
97
+ )
98
+
99
+ @staticmethod
100
+ def logging_outputs_can_be_summed() -> bool:
101
+ """
102
+ Whether the logging outputs returned by `forward` can be summed
103
+ across workers prior to calling `reduce_metrics`. Setting this
104
+ to True will improves distributed training speed.
105
+ """
106
+ return True
fairseq/examples/adaptive_span/adaptive_span_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from fairseq.modules.layer_norm import LayerNorm
14
+
15
+ from .adaptive_span_attention import AdaptiveSpan
16
+
17
+ # Size notations:
18
+ # B = batch_size, H = d_model, M = block_size, L = attn_span
19
+
20
+
21
+ def _skew(X, pad_value):
22
+ """shift every row 1 step to right"""
23
+ # X = B x M x L
24
+ B, M, L = X.size()
25
+ X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26
+ X = X.view(B, -1) # B x ML+MM+M
27
+ X = X[:, :-M] # B x ML+MM
28
+ X = X.view(B, M, M + L) # B x M x L+M
29
+ return X
30
+
31
+
32
+ def _unskew(X):
33
+ """reverse _skew operation"""
34
+ # X = B x M x L+M
35
+ B, M, L = X.size()
36
+ L -= M
37
+ X = X.view(B, -1) # B x ML+MM
38
+ X = F.pad(X, (0, M)) # B x ML+MM+M
39
+ X = X.view(B, M, M + L + 1) # B x M x L+M+1
40
+ X = X[:, :, :L] # B x M x L
41
+ return X
42
+
43
+
44
+ class SeqAttention(nn.Module):
45
+ """Sequential self-attention layer.
46
+ Each token will attend to its previous fixed number of steps.
47
+ Note that attention doesn't include the current step itself.
48
+ """
49
+
50
+ def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51
+ nn.Module.__init__(self)
52
+ self.dropout = nn.Dropout(dropout)
53
+ self.d_model = d_model # size of a single head
54
+ self.attn_span = attn_span
55
+ self.adaptive_span = AdaptiveSpan(
56
+ attn_span=attn_span,
57
+ n_head=n_head,
58
+ adapt_span_layer=adapt_span_layer,
59
+ **kargs
60
+ )
61
+
62
+ def forward(self, query, key, value, key_pe):
63
+ # query size = B x M x H
64
+ # key, value sizes = B x (M+L) x H
65
+
66
+ key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67
+
68
+ # compute attention from context
69
+ # B x M (dest) x (M+L) (src)
70
+ attn_cont = torch.matmul(query, key.transpose(-1, -2))
71
+ attn_cont = _unskew(attn_cont) # B x M x L
72
+
73
+ # compute the effect of position embedding
74
+ attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75
+ attn = attn_cont + attn_pos
76
+
77
+ attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78
+
79
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
+
81
+ # trim attention lengths according to the learned span
82
+ attn = self.adaptive_span(attn)
83
+
84
+ attn = self.dropout(attn) # B x M X L_pos
85
+
86
+ attn_cont = _skew(attn, 0) # B x M X (L+M)
87
+ out = torch.matmul(attn_cont, value) # B x M x H
88
+ return out
89
+
90
+ def get_cache_size(self):
91
+ return self.adaptive_span.get_cache_size()
92
+
93
+
94
+ class MultiHeadSeqAttention(nn.Module):
95
+ def __init__(self, d_model, n_head, **kargs):
96
+ nn.Module.__init__(self)
97
+ assert d_model % n_head == 0
98
+ self.n_head = n_head
99
+ self.head_dim = d_model // n_head
100
+ self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101
+ self.proj_query = nn.Linear(d_model, d_model, bias=False)
102
+ nn.init.xavier_normal_(self.proj_query.weight)
103
+ self.proj_out = nn.Linear(d_model, d_model, bias=False)
104
+ nn.init.xavier_normal_(self.proj_out.weight)
105
+ self.proj_val = nn.Linear(d_model, d_model, bias=False)
106
+ nn.init.xavier_normal_(self.proj_val.weight)
107
+ self.proj_key = nn.Linear(d_model, d_model, bias=False)
108
+ nn.init.xavier_normal_(self.proj_key.weight)
109
+
110
+ def head_reshape(self, x):
111
+ K = self.n_head
112
+ D = self.head_dim
113
+ x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114
+ x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115
+ x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116
+ return x
117
+
118
+ def forward(self, query, key, value, key_pe):
119
+ B = query.size(0)
120
+ K = self.n_head
121
+ D = self.head_dim
122
+ M = query.size(1)
123
+
124
+ query = self.proj_query(query)
125
+ query = self.head_reshape(query)
126
+ value = self.proj_val(value)
127
+ value = self.head_reshape(value)
128
+ key = self.proj_key(key)
129
+ key = self.head_reshape(key)
130
+
131
+ out = self.attn(query, key, value, key_pe) # B_K x M x D
132
+ out = out.view(B, K, M, D) # B x K x M x D
133
+ out = out.transpose(1, 2).contiguous() # B x M x K x D
134
+ out = out.view(B, M, -1) # B x M x K_D
135
+ out = self.proj_out(out)
136
+ return out
137
+
138
+
139
+ class FeedForwardLayer(nn.Module):
140
+ def __init__(self, d_model, d_inner, dropout, **kargs):
141
+ nn.Module.__init__(self)
142
+ self.fc1 = nn.Linear(d_model, d_inner)
143
+ self.fc2 = nn.Linear(d_inner, d_model)
144
+ nn.init.xavier_uniform_(self.fc1.weight)
145
+ nn.init.xavier_uniform_(self.fc2.weight)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, h):
149
+ h1 = F.relu(self.fc1(h))
150
+ h1 = self.dropout(h1)
151
+ h2 = self.fc2(h1)
152
+ return h2
153
+
154
+
155
+ class TransformerSeqLayer(nn.Module):
156
+ def __init__(self, d_model, **kargs):
157
+ nn.Module.__init__(self)
158
+ self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159
+ self.norm1 = LayerNorm(d_model)
160
+ self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161
+ self.norm2 = LayerNorm(d_model)
162
+
163
+ def forward(self, h, h_cache, key_pe):
164
+ # h = B x M x H
165
+ # h_cache = B x L x H
166
+ h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
167
+ attn_out = self.attn(h, h_all, h_all, key_pe)
168
+ h = self.norm1(h + attn_out) # B x M x H
169
+ if self.ff is not None:
170
+ ff_out = self.ff(h)
171
+ out = self.norm2(h + ff_out) # B x M x H
172
+ else:
173
+ out = h
174
+ return out
175
+
176
+ def get_cache_size(self):
177
+ return self.attn.attn.get_cache_size()
178
+
179
+
180
+ class TransformerSeq(nn.Module):
181
+ def __init__(
182
+ self,
183
+ vocab_size,
184
+ d_model,
185
+ n_head,
186
+ n_layer,
187
+ attn_span,
188
+ emb_dropout,
189
+ aux_loss_scaler,
190
+ adapt_span_layer,
191
+ **kargs
192
+ ):
193
+ nn.Module.__init__(self)
194
+ # token embeddings
195
+ self.in_emb = nn.Embedding(vocab_size, d_model)
196
+ nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197
+ self.out_emb = nn.Linear(d_model, vocab_size)
198
+ self.aux_loss_scaler = aux_loss_scaler
199
+ if emb_dropout > 0:
200
+ self.emb_dropout = nn.Dropout(emb_dropout)
201
+ else:
202
+ self.emb_dropout = None
203
+ # position embeddings
204
+ self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205
+
206
+ self.layers = nn.ModuleList()
207
+ self.layers.extend(
208
+ TransformerSeqLayer(
209
+ d_model=d_model,
210
+ n_head=n_head,
211
+ attn_span=attn_span,
212
+ adapt_span_layer=adapt_span_layer,
213
+ **kargs
214
+ )
215
+ for _ in range(n_layer)
216
+ )
217
+
218
+ def forward(self, x, h_cache, target=None):
219
+ # x size = B x M
220
+ block_size = x.size(1)
221
+ h = self.in_emb(x) # B x M x H
222
+ if self.emb_dropout is not None:
223
+ h = self.emb_dropout(h)
224
+
225
+ h_cache_next = []
226
+ for l, layer in enumerate(self.layers):
227
+ cache_size = layer.attn.attn.get_cache_size()
228
+ if cache_size > block_size:
229
+ h_cache_next_l = torch.cat(
230
+ [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231
+ ).detach()
232
+ else:
233
+ h_cache_next_l = h[:, -cache_size:, :].detach()
234
+ h_cache_next.append(h_cache_next_l)
235
+ h = layer(h, h_cache[l], self.key_pe) # B x M x H
236
+
237
+ if self.emb_dropout is not None:
238
+ h = self.emb_dropout(h)
239
+
240
+ out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241
+ dummy_loss = None
242
+
243
+ return out, h_cache_next, dummy_loss
244
+
245
+ def get_aux_loss(self):
246
+ loss = 0.0
247
+ for layer in self.layers:
248
+ loss += layer.attn.attn.adaptive_span.get_loss()
249
+ return self.aux_loss_scaler * loss
250
+
251
+ def get_current_max_span(self):
252
+ max_span = 0.0
253
+ for layer in self.layers:
254
+ max_span = max(
255
+ max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256
+ )
257
+ return max_span
258
+
259
+ def get_current_avg_span(self):
260
+ avg_span = 0.0
261
+ for layer in self.layers:
262
+ avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263
+ return avg_span / len(self.layers)
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ )
17
+ from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class AdaptiveSpanSmallConfig(FairseqDataclass):
25
+ # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
26
+ vocab_size: int = 50
27
+ d_model: int = 256
28
+ n_head: int = 4
29
+ d_inner: int = 1024
30
+ n_layer: int = 8
31
+ attn_span: int = 1024
32
+ dropout: float = 0.0
33
+ emb_dropout: float = 0.0
34
+ adapt_span_ramp: int = 32
35
+ adapt_span_init: float = 0.0
36
+ aux_loss_scaler: float = 0.000002
37
+ adapt_span_layer: bool = False
38
+
39
+
40
+ @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41
+ class AdaptiveSpanTransformer(FairseqLanguageModel):
42
+ @classmethod
43
+ def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44
+ return cls(AdaptiveSpanDecoder(cfg, task))
45
+
46
+ def get_aux_loss(self):
47
+ return self.decoder.get_aux_loss()
48
+
49
+ def get_current_max_span(self):
50
+ return self.decoder.get_current_max_span()
51
+
52
+ def get_current_avg_span(self):
53
+ return self.decoder.get_current_avg_span()
54
+
55
+
56
+ class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57
+ def __init__(self, cfg, task):
58
+
59
+ super().__init__(task.target_dictionary)
60
+
61
+ self.config = cfg
62
+ config = AdaptiveSpanSmallConfig(
63
+ vocab_size=len(task.target_dictionary),
64
+ d_model=cfg.d_model,
65
+ n_head=cfg.n_head,
66
+ d_inner=cfg.d_inner,
67
+ n_layer=cfg.n_layer,
68
+ attn_span=cfg.attn_span,
69
+ dropout=cfg.dropout,
70
+ emb_dropout=cfg.emb_dropout,
71
+ adapt_span_ramp=cfg.adapt_span_ramp,
72
+ adapt_span_init=cfg.adapt_span_init,
73
+ aux_loss_scaler=cfg.aux_loss_scaler,
74
+ adapt_span_layer=cfg.adapt_span_layer,
75
+ )
76
+ logger.info(config)
77
+ self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78
+
79
+ self._mems = None
80
+
81
+ def forward(
82
+ self,
83
+ src_tokens,
84
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85
+ encoder_out=None,
86
+ ):
87
+ bsz = src_tokens.size(0)
88
+ if incremental_state is not None: # used during inference
89
+ mems = self.get_incremental_state("mems")
90
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
91
+ else:
92
+ mems = self._mems
93
+
94
+ if mems is None:
95
+ # first time init
96
+ mems = self.init_hid_cache(bsz)
97
+ output = self.model(x=src_tokens, h_cache=mems,)
98
+ if incremental_state is not None:
99
+ self.set_incremental_state(incremental_state, "mems", output[1])
100
+ else:
101
+ self._mems = output[1]
102
+ return (output[0],)
103
+
104
+ def max_positions(self):
105
+ return self.config.attn_span
106
+
107
+ def init_hid_cache(self, batch_sz):
108
+ hid = []
109
+ for layer in self.model.layers:
110
+ param = next(self.model.parameters())
111
+ h = torch.zeros(
112
+ batch_sz,
113
+ layer.get_cache_size(),
114
+ self.config.d_model,
115
+ dtype=param.dtype,
116
+ device=param.device,
117
+ )
118
+ hid.append(h)
119
+ return hid
120
+
121
+ def get_aux_loss(self):
122
+ return self.model.get_aux_loss()
123
+
124
+ def get_current_max_span(self):
125
+ return self.model.get_current_max_span()
126
+
127
+ def get_current_avg_span(self):
128
+ return self.model.get_current_avg_span()
129
+
130
+ def reorder_incremental_state(
131
+ self,
132
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133
+ new_order: torch.Tensor,
134
+ ):
135
+ """Reorder incremental state.
136
+
137
+ This will be called when the order of the input has changed from the
138
+ previous time step. A typical use case is beam search, where the input
139
+ order changes between time steps based on the selection of beams.
140
+ """
141
+ raise NotImplementedError("This is required for generation/beam search")
142
+ # mems = self.get_incremental_state(incremental_state, "mems")
143
+ # if mems is not None:
144
+ # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145
+ # self.set_incremental_state(incremental_state, "mems", new_mems)
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ from fairseq import utils
13
+ from fairseq.data import (
14
+ Dictionary,
15
+ TokenBlockDataset,
16
+ data_utils,
17
+ iterators,
18
+ )
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.distributed import utils as dist_utils
21
+ from fairseq.tasks import FairseqTask, register_task
22
+ from omegaconf import II
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class TruncatedBPTTLMConfig(FairseqDataclass):
30
+ data: str = field(default="???", metadata={"help": "path to data directory"})
31
+ tokens_per_sample: int = field(
32
+ default=1024,
33
+ metadata={"help": "max number of tokens per sequence"},
34
+ )
35
+ batch_size: int = II("dataset.batch_size")
36
+ # Some models use *max_target_positions* to know how many positional
37
+ # embeddings to learn. We use II(...) to make it default to
38
+ # *tokens_per_sample*, but in principle there could be more positional
39
+ # embeddings than tokens in a single batch. This may also be irrelevant for
40
+ # custom model implementations.
41
+ max_target_positions: int = II("task.tokens_per_sample")
42
+ # these will be populated automatically if not provided
43
+ data_parallel_rank: Optional[int] = None
44
+ data_parallel_size: Optional[int] = None
45
+
46
+
47
+ @register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
48
+ class TruncatedBPTTLMTask(FairseqTask):
49
+ def __init__(self, cfg: TruncatedBPTTLMConfig):
50
+ super().__init__(cfg)
51
+
52
+ if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
53
+ if torch.distributed.is_initialized():
54
+ cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
55
+ cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
56
+ else:
57
+ cfg.data_parallel_rank = 0
58
+ cfg.data_parallel_size = 1
59
+
60
+ # load the dictionary
61
+ paths = utils.split_paths(cfg.data)
62
+ assert len(paths) > 0
63
+ self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
64
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
65
+
66
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
67
+ """Load a given dataset split (e.g., train, valid, test)"""
68
+
69
+ # support sharded datasets
70
+ paths = utils.split_paths(self.cfg.data)
71
+ assert len(paths) > 0
72
+ data_path = paths[(epoch - 1) % len(paths)]
73
+ split_path = os.path.join(data_path, split)
74
+
75
+ # each element of *data* will be a tensorized line from the original
76
+ # text dataset, similar to ``open(split_path).readlines()``
77
+ data = data_utils.load_indexed_dataset(
78
+ split_path, self.dictionary, combine=combine
79
+ )
80
+ if data is None:
81
+ raise FileNotFoundError(
82
+ "Dataset not found: {} ({})".format(split, split_path)
83
+ )
84
+
85
+ # this is similar to ``data.view(-1).split(tokens_per_sample)``
86
+ data = TokenBlockDataset(
87
+ data,
88
+ data.sizes,
89
+ block_size=self.cfg.tokens_per_sample,
90
+ pad=None, # unused
91
+ eos=None, # unused
92
+ break_mode="none",
93
+ )
94
+
95
+ self.datasets[split] = TruncatedBPTTDataset(
96
+ data=data,
97
+ bsz_per_shard=self.cfg.batch_size,
98
+ shard_id=self.cfg.data_parallel_rank,
99
+ num_shards=self.cfg.data_parallel_size,
100
+ )
101
+
102
+ def dataset(self, split):
103
+ return self.datasets[split]
104
+
105
+ def get_batch_iterator(
106
+ self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
107
+ ):
108
+ return iterators.EpochBatchIterator(
109
+ dataset=dataset,
110
+ collate_fn=self._collate_fn,
111
+ num_workers=num_workers,
112
+ epoch=epoch,
113
+ buffer_size=data_buffer_size,
114
+ # we don't use the batching functionality from EpochBatchIterator;
115
+ # instead every item in *dataset* is a whole batch
116
+ batch_sampler=[[i] for i in range(len(dataset))],
117
+ disable_shuffling=True,
118
+ )
119
+
120
+ def _collate_fn(self, items: List[List[torch.Tensor]]):
121
+ # we don't use fairseq's batching functionality, so we expect a single
122
+ # Tensor of type List[torch.Tensor]
123
+ assert len(items) == 1
124
+
125
+ # item will have shape B x T (the last batch may have length < T)
126
+ id, item = items[0]
127
+ item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
128
+ B, T = item.size()
129
+
130
+ # shift item one position over and append a padding token for the target
131
+ target = torch.nn.functional.pad(
132
+ item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
133
+ )
134
+
135
+ # fairseq expects batches to have the following structure
136
+ return {
137
+ "id": torch.tensor([id]*item.size(0)),
138
+ "net_input": {
139
+ "src_tokens": item,
140
+ },
141
+ "target": target,
142
+ "nsentences": item.size(0),
143
+ "ntokens": item.numel(),
144
+ }
145
+
146
+ def build_dataset_for_inference(
147
+ self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
148
+ ) -> torch.utils.data.Dataset:
149
+ eos = self.source_dictionary.eos()
150
+ dataset = TokenBlockDataset(
151
+ src_tokens,
152
+ src_lengths,
153
+ block_size=None, # ignored for "eos" break mode
154
+ pad=self.source_dictionary.pad(),
155
+ eos=eos,
156
+ break_mode="eos",
157
+ )
158
+
159
+ class Dataset(torch.utils.data.Dataset):
160
+ def __getitem__(self, i):
161
+ item = dataset[i]
162
+ if item[-1] == eos:
163
+ # remove eos to support generating with a prefix
164
+ item = item[:-1]
165
+ return (i, [item])
166
+
167
+ def __len__(self):
168
+ return len(dataset)
169
+
170
+ return Dataset()
171
+
172
+ def inference_step(
173
+ self, generator, models, sample, prefix_tokens=None, constraints=None
174
+ ):
175
+ with torch.no_grad():
176
+ if constraints is not None:
177
+ raise NotImplementedError
178
+
179
+ # SequenceGenerator doesn't use *src_tokens* directly, we need to
180
+ # pass the *prefix_tokens* argument instead.
181
+ if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
182
+ prefix_tokens = sample["net_input"]["src_tokens"]
183
+
184
+ # begin generation with the end-of-sentence token
185
+ bos_token = self.source_dictionary.eos()
186
+
187
+ return generator.generate(
188
+ models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
189
+ )
190
+
191
+ def eval_lm_dataloader(
192
+ self,
193
+ dataset,
194
+ max_tokens: Optional[int] = 36000,
195
+ batch_size: Optional[int] = None,
196
+ max_positions: Optional[int] = None,
197
+ num_shards: int = 1,
198
+ shard_id: int = 0,
199
+ num_workers: int = 1,
200
+ data_buffer_size: int = 10,
201
+ context_window: int = 0,
202
+ ):
203
+ if context_window > 0:
204
+ raise NotImplementedError(
205
+ "Transformer-XL doesn't need --context-window, try "
206
+ "--model-overrides '{\"mem_len\":42}' instead "
207
+ )
208
+ return self.get_batch_iterator(
209
+ dataset=dataset,
210
+ max_tokens=max_tokens,
211
+ max_sentences=batch_size,
212
+ max_positions=max_positions,
213
+ ignore_invalid_inputs=True,
214
+ num_shards=num_shards,
215
+ shard_id=shard_id,
216
+ num_workers=num_workers,
217
+ data_buffer_size=data_buffer_size,
218
+ ).next_epoch_itr(shuffle=False)
219
+
220
+ @property
221
+ def source_dictionary(self):
222
+ return self.dictionary
223
+
224
+ @property
225
+ def target_dictionary(self):
226
+ return self.dictionary
227
+
228
+
229
+ class TruncatedBPTTDataset(torch.utils.data.Dataset):
230
+ def __init__(
231
+ self,
232
+ data: List[torch.Tensor], # ordered list of items
233
+ bsz_per_shard, # number of items processed per GPUs per forward
234
+ shard_id, # current GPU ID
235
+ num_shards, # number of GPUs
236
+ ):
237
+ super().__init__()
238
+ self.data = data
239
+
240
+ def batchify(data, bsz):
241
+ # Work out how cleanly we can divide the dataset into bsz parts.
242
+ nbatch = data.size(0) // bsz
243
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
244
+ data = data.narrow(0, 0, nbatch * bsz)
245
+ # Evenly divide the data across the bsz batches.
246
+ data = data.view(bsz, -1).contiguous()
247
+ return data
248
+
249
+ # total number of sequences processed by all GPUs in each forward pass
250
+ global_batch_size = bsz_per_shard * num_shards
251
+
252
+ """
253
+ With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
254
+ *indices* might look like:
255
+
256
+ indices = [[0, 1],
257
+ [2, 3],
258
+ [4, 5],
259
+ [6, 7],
260
+ [8, 9],
261
+ [10, 11]]
262
+
263
+ The size of the TruncatedBPTTDataset instance will be 2,
264
+ and shard 1 will see items:
265
+
266
+ [(0, [data[4], data[6]]),
267
+ (1, [data[5], data[7]])]
268
+ """
269
+ indices = batchify(torch.arange(len(data)), global_batch_size)
270
+ assert indices.size(0) == global_batch_size
271
+
272
+ self.my_indices = indices[
273
+ shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
274
+ ]
275
+ assert self.my_indices.size(0) == bsz_per_shard
276
+
277
+ def __len__(self):
278
+ return self.my_indices.size(1)
279
+
280
+ def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
281
+ return (i, [self.data[idx] for idx in self.my_indices[:, i]])
fairseq/examples/backtranslation/README.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding Back-Translation at Scale (Edunov et al., 2018)
2
+
3
+ This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
4
+
5
+ ## Pre-trained models
6
+
7
+ Model | Description | Dataset | Download
8
+ ---|---|---|---
9
+ `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
10
+
11
+ ## Example usage (torch.hub)
12
+
13
+ We require a few additional Python dependencies for preprocessing:
14
+ ```bash
15
+ pip install subword_nmt sacremoses
16
+ ```
17
+
18
+ Then to generate translations from the full model ensemble:
19
+ ```python
20
+ import torch
21
+
22
+ # List available models
23
+ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24
+
25
+ # Load the WMT'18 En-De ensemble
26
+ en2de_ensemble = torch.hub.load(
27
+ 'pytorch/fairseq', 'transformer.wmt18.en-de',
28
+ checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
29
+ tokenizer='moses', bpe='subword_nmt')
30
+
31
+ # The ensemble contains 5 models
32
+ len(en2de_ensemble.models)
33
+ # 5
34
+
35
+ # Translate
36
+ en2de_ensemble.translate('Hello world!')
37
+ # 'Hallo Welt!'
38
+ ```
39
+
40
+ ## Training your own model (WMT'18 English-German)
41
+
42
+ The following instructions can be adapted to reproduce the models from the paper.
43
+
44
+
45
+ #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46
+
47
+ First download and preprocess the data:
48
+ ```bash
49
+ # Download and prepare the data
50
+ cd examples/backtranslation/
51
+ bash prepare-wmt18en2de.sh
52
+ cd ../..
53
+
54
+ # Binarize the data
55
+ TEXT=examples/backtranslation/wmt18_en_de
56
+ fairseq-preprocess \
57
+ --joined-dictionary \
58
+ --source-lang en --target-lang de \
59
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60
+ --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61
+ --workers 20
62
+
63
+ # Copy the BPE code into the data-bin directory for future use
64
+ cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65
+ ```
66
+
67
+ (Optionally) Train a baseline model (English-German) using just the parallel data:
68
+ ```bash
69
+ CHECKPOINT_DIR=checkpoints_en_de_parallel
70
+ fairseq-train --fp16 \
71
+ data-bin/wmt18_en_de \
72
+ --source-lang en --target-lang de \
73
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
74
+ --dropout 0.3 --weight-decay 0.0 \
75
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78
+ --max-tokens 3584 --update-freq 16 \
79
+ --max-update 30000 \
80
+ --save-dir $CHECKPOINT_DIR
81
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82
+ # different number of GPUs.
83
+ ```
84
+
85
+ Average the last 10 checkpoints:
86
+ ```bash
87
+ python scripts/average_checkpoints.py \
88
+ --inputs $CHECKPOINT_DIR \
89
+ --num-epoch-checkpoints 10 \
90
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
91
+ ```
92
+
93
+ Evaluate BLEU:
94
+ ```bash
95
+ # tokenized BLEU on newstest2017:
96
+ bash examples/backtranslation/tokenized_bleu.sh \
97
+ wmt17 \
98
+ en-de \
99
+ data-bin/wmt18_en_de \
100
+ data-bin/wmt18_en_de/code \
101
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
102
+ # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103
+ # compare to 29.46 in Table 1, which is also for tokenized BLEU
104
+
105
+ # generally it's better to report (detokenized) sacrebleu though:
106
+ bash examples/backtranslation/sacrebleu.sh \
107
+ wmt17 \
108
+ en-de \
109
+ data-bin/wmt18_en_de \
110
+ data-bin/wmt18_en_de/code \
111
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
112
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113
+ ```
114
+
115
+
116
+ #### Step 2. Back-translate monolingual German data
117
+
118
+ Train a reverse model (German-English) to do the back-translation:
119
+ ```bash
120
+ CHECKPOINT_DIR=checkpoints_de_en_parallel
121
+ fairseq-train --fp16 \
122
+ data-bin/wmt18_en_de \
123
+ --source-lang de --target-lang en \
124
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
125
+ --dropout 0.3 --weight-decay 0.0 \
126
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129
+ --max-tokens 3584 --update-freq 16 \
130
+ --max-update 30000 \
131
+ --save-dir $CHECKPOINT_DIR
132
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133
+ # different number of GPUs.
134
+ ```
135
+
136
+ Let's evaluate the back-translation (BT) model to make sure it is well trained:
137
+ ```bash
138
+ bash examples/backtranslation/sacrebleu.sh \
139
+ wmt17 \
140
+ de-en \
141
+ data-bin/wmt18_en_de \
142
+ data-bin/wmt18_en_de/code \
143
+ $CHECKPOINT_DIR/checkpoint_best.py
144
+ # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145
+ # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
146
+ ```
147
+
148
+ Next prepare the monolingual data:
149
+ ```bash
150
+ # Download and prepare the monolingual data
151
+ # By default the script samples 25M monolingual sentences, which after
152
+ # deduplication should be just over 24M sentences. These are split into 25
153
+ # shards, each with 1M sentences (except for the last shard).
154
+ cd examples/backtranslation/
155
+ bash prepare-de-monolingual.sh
156
+ cd ../..
157
+
158
+ # Binarize each shard of the monolingual data
159
+ TEXT=examples/backtranslation/wmt18_de_mono
160
+ for SHARD in $(seq -f "%02g" 0 24); do \
161
+ fairseq-preprocess \
162
+ --only-source \
163
+ --source-lang de --target-lang en \
164
+ --joined-dictionary \
165
+ --srcdict data-bin/wmt18_en_de/dict.de.txt \
166
+ --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167
+ --destdir data-bin/wmt18_de_mono/shard${SHARD} \
168
+ --workers 20; \
169
+ cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170
+ done
171
+ ```
172
+
173
+ Now we're ready to perform back-translation over the monolingual data. The
174
+ following command generates via sampling, but it's possible to use greedy
175
+ decoding (`--beam 1`), beam search (`--beam 5`),
176
+ top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177
+ ```bash
178
+ mkdir backtranslation_output
179
+ for SHARD in $(seq -f "%02g" 0 24); do \
180
+ fairseq-generate --fp16 \
181
+ data-bin/wmt18_de_mono/shard${SHARD} \
182
+ --path $CHECKPOINT_DIR/checkpoint_best.pt \
183
+ --skip-invalid-size-inputs-valid-test \
184
+ --max-tokens 4096 \
185
+ --sampling --beam 1 \
186
+ > backtranslation_output/sampling.shard${SHARD}.out; \
187
+ done
188
+ ```
189
+
190
+ After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
191
+ the back-translations and apply length ratio filters:
192
+ ```bash
193
+ python examples/backtranslation/extract_bt_data.py \
194
+ --minlen 1 --maxlen 250 --ratio 1.5 \
195
+ --output backtranslation_output/bt_data --srclang en --tgtlang de \
196
+ backtranslation_output/sampling.shard*.out
197
+
198
+ # Ensure lengths are the same:
199
+ # wc -l backtranslation_output/bt_data.{en,de}
200
+ # 21795614 backtranslation_output/bt_data.en
201
+ # 21795614 backtranslation_output/bt_data.de
202
+ # 43591228 total
203
+ ```
204
+
205
+ Binarize the filtered BT data and combine it with the parallel data:
206
+ ```bash
207
+ TEXT=backtranslation_output
208
+ fairseq-preprocess \
209
+ --source-lang en --target-lang de \
210
+ --joined-dictionary \
211
+ --srcdict data-bin/wmt18_en_de/dict.en.txt \
212
+ --trainpref $TEXT/bt_data \
213
+ --destdir data-bin/wmt18_en_de_bt \
214
+ --workers 20
215
+
216
+ # We want to train on the combined data, so we'll symlink the parallel + BT data
217
+ # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218
+ # and the BT data as "train1", so that fairseq will combine them automatically
219
+ # and so that we can use the `--upsample-primary` option to upsample the
220
+ # parallel data (if desired).
221
+ PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222
+ BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223
+ COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
224
+ mkdir -p $COMB_DATA
225
+ for LANG in en de; do \
226
+ ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227
+ for EXT in bin idx; do \
228
+ ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229
+ ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230
+ ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231
+ ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232
+ done; \
233
+ done
234
+ ```
235
+
236
+
237
+ #### 3. Train an English-German model over the combined parallel + BT data
238
+
239
+ Finally we can train a model over the parallel + BT data:
240
+ ```bash
241
+ CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
242
+ fairseq-train --fp16 \
243
+ data-bin/wmt18_en_de_para_plus_bt \
244
+ --upsample-primary 16 \
245
+ --source-lang en --target-lang de \
246
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
247
+ --dropout 0.3 --weight-decay 0.0 \
248
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250
+ --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251
+ --max-tokens 3584 --update-freq 16 \
252
+ --max-update 100000 \
253
+ --save-dir $CHECKPOINT_DIR
254
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255
+ # different number of GPUs.
256
+ ```
257
+
258
+ Average the last 10 checkpoints:
259
+ ```bash
260
+ python scripts/average_checkpoints.py \
261
+ --inputs $CHECKPOINT_DIR \
262
+ --num-epoch-checkpoints 10 \
263
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
264
+ ```
265
+
266
+ Evaluate BLEU:
267
+ ```bash
268
+ # tokenized BLEU on newstest2017:
269
+ bash examples/backtranslation/tokenized_bleu.sh \
270
+ wmt17 \
271
+ en-de \
272
+ data-bin/wmt18_en_de \
273
+ data-bin/wmt18_en_de/code \
274
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
275
+ # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276
+ # compare to 32.35 in Table 1, which is also for tokenized BLEU
277
+
278
+ # generally it's better to report (detokenized) sacrebleu:
279
+ bash examples/backtranslation/sacrebleu.sh \
280
+ wmt17 \
281
+ en-de \
282
+ data-bin/wmt18_en_de \
283
+ data-bin/wmt18_en_de/code \
284
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
285
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286
+ ```
287
+
288
+
289
+ ## Citation
290
+ ```bibtex
291
+ @inproceedings{edunov2018backtranslation,
292
+ title = {Understanding Back-Translation at Scale},
293
+ author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294
+ booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295
+ year = 2018,
296
+ }
297
+ ```
fairseq/examples/backtranslation/deduplicate_lines.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+ import hashlib
10
+ import sys
11
+ from multiprocessing import Pool
12
+
13
+
14
+ def get_hashes_and_lines(raw_line):
15
+ hash = hashlib.md5(raw_line).hexdigest()
16
+ return hash, raw_line
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--workers", type=int, default=10)
22
+ parser.add_argument("files", nargs="*", help="input files")
23
+ args = parser.parse_args()
24
+
25
+ seen = set()
26
+ with fileinput.input(args.files, mode="rb") as h:
27
+ pool = Pool(args.workers)
28
+ results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29
+ for i, (hash, raw_line) in enumerate(results):
30
+ if hash not in seen:
31
+ seen.add(hash)
32
+ sys.stdout.buffer.write(raw_line)
33
+ if i % 1000000 == 0:
34
+ print(i, file=sys.stderr, end="", flush=True)
35
+ elif i % 100000 == 0:
36
+ print(".", file=sys.stderr, end="", flush=True)
37
+ print(file=sys.stderr, flush=True)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
fairseq/examples/backtranslation/extract_bt_data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(
15
+ description=(
16
+ "Extract back-translations from the stdout of fairseq-generate. "
17
+ "If there are multiply hypotheses for a source, we only keep the first one. "
18
+ )
19
+ )
20
+ parser.add_argument("--output", required=True, help="output prefix")
21
+ parser.add_argument(
22
+ "--srclang", required=True, help="source language (extracted from H-* lines)"
23
+ )
24
+ parser.add_argument(
25
+ "--tgtlang", required=True, help="target language (extracted from S-* lines)"
26
+ )
27
+ parser.add_argument("--minlen", type=int, help="min length filter")
28
+ parser.add_argument("--maxlen", type=int, help="max length filter")
29
+ parser.add_argument("--ratio", type=float, help="ratio filter")
30
+ parser.add_argument("files", nargs="*", help="input files")
31
+ args = parser.parse_args()
32
+
33
+ def validate(src, tgt):
34
+ srclen = len(src.split(" ")) if src != "" else 0
35
+ tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36
+ if (
37
+ (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38
+ or (
39
+ args.maxlen is not None
40
+ and (srclen > args.maxlen or tgtlen > args.maxlen)
41
+ )
42
+ or (
43
+ args.ratio is not None
44
+ and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45
+ )
46
+ ):
47
+ return False
48
+ return True
49
+
50
+ def safe_index(toks, index, default):
51
+ try:
52
+ return toks[index]
53
+ except IndexError:
54
+ return default
55
+
56
+ with open(args.output + "." + args.srclang, "w") as src_h, open(
57
+ args.output + "." + args.tgtlang, "w"
58
+ ) as tgt_h:
59
+ for line in tqdm(fileinput.input(args.files)):
60
+ if line.startswith("S-"):
61
+ tgt = safe_index(line.rstrip().split("\t"), 1, "")
62
+ elif line.startswith("H-"):
63
+ if tgt is not None:
64
+ src = safe_index(line.rstrip().split("\t"), 2, "")
65
+ if validate(src, tgt):
66
+ print(src, file=src_h)
67
+ print(tgt, file=tgt_h)
68
+ tgt = None
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
fairseq/examples/backtranslation/prepare-de-monolingual.sh ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPTS=mosesdecoder/scripts
4
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
5
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
6
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
7
+ BPEROOT=subword-nmt/subword_nmt
8
+
9
+
10
+ BPE_CODE=wmt18_en_de/code
11
+ SUBSAMPLE_SIZE=25000000
12
+ LANG=de
13
+
14
+
15
+ OUTDIR=wmt18_${LANG}_mono
16
+ orig=orig
17
+ tmp=$OUTDIR/tmp
18
+ mkdir -p $OUTDIR $tmp
19
+
20
+
21
+ URLS=(
22
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
23
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
24
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
25
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
26
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
27
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
28
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
29
+ "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
30
+ "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
31
+ "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
32
+ "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
33
+ )
34
+ FILES=(
35
+ "news.2007.de.shuffled.gz"
36
+ "news.2008.de.shuffled.gz"
37
+ "news.2009.de.shuffled.gz"
38
+ "news.2010.de.shuffled.gz"
39
+ "news.2011.de.shuffled.gz"
40
+ "news.2012.de.shuffled.gz"
41
+ "news.2013.de.shuffled.gz"
42
+ "news.2014.de.shuffled.v2.gz"
43
+ "news.2015.de.shuffled.gz"
44
+ "news.2016.de.shuffled.gz"
45
+ "news.2017.de.shuffled.deduped.gz"
46
+ )
47
+
48
+
49
+ cd $orig
50
+ for ((i=0;i<${#URLS[@]};++i)); do
51
+ file=${FILES[i]}
52
+ if [ -f $file ]; then
53
+ echo "$file already exists, skipping download"
54
+ else
55
+ url=${URLS[i]}
56
+ wget "$url"
57
+ fi
58
+ done
59
+ cd ..
60
+
61
+
62
+ if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63
+ echo "found monolingual sample, skipping shuffle/sample/tokenize"
64
+ else
65
+ gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66
+ | shuf -n $SUBSAMPLE_SIZE \
67
+ | perl $NORM_PUNC $LANG \
68
+ | perl $REM_NON_PRINT_CHAR \
69
+ | perl $TOKENIZER -threads 8 -a -l $LANG \
70
+ > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71
+ fi
72
+
73
+
74
+ if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75
+ echo "found BPE monolingual sample, skipping BPE step"
76
+ else
77
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE \
78
+ < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79
+ > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80
+ fi
81
+
82
+
83
+ if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84
+ echo "found deduplicated monolingual sample, skipping deduplication step"
85
+ else
86
+ python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87
+ > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88
+ fi
89
+
90
+
91
+ if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
92
+ echo "found sharded data, skipping sharding step"
93
+ else
94
+ split --lines 1000000 --numeric-suffixes \
95
+ --additional-suffix .${LANG} \
96
+ $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97
+ $OUTDIR/bpe.monolingual.dedup.
98
+ fi
fairseq/examples/backtranslation/prepare-wmt18en2de.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
3
+
4
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
5
+ git clone https://github.com/moses-smt/mosesdecoder.git
6
+
7
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
8
+ git clone https://github.com/rsennrich/subword-nmt.git
9
+
10
+ SCRIPTS=mosesdecoder/scripts
11
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
12
+ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
13
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
14
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15
+ BPEROOT=subword-nmt/subword_nmt
16
+ BPE_TOKENS=32000
17
+
18
+ URLS=(
19
+ "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
20
+ "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
21
+ "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
22
+ "http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
23
+ "http://data.statmt.org/wmt17/translation-task/dev.tgz"
24
+ "http://statmt.org/wmt14/test-full.tgz"
25
+ )
26
+ FILES=(
27
+ "training-parallel-europarl-v7.tgz"
28
+ "training-parallel-commoncrawl.tgz"
29
+ "training-parallel-nc-v13.tgz"
30
+ "rapid2016.tgz"
31
+ "dev.tgz"
32
+ "test-full.tgz"
33
+ )
34
+ CORPORA=(
35
+ "training/europarl-v7.de-en"
36
+ "commoncrawl.de-en"
37
+ "training-parallel-nc-v13/news-commentary-v13.de-en"
38
+ "rapid2016.de-en"
39
+ )
40
+
41
+ if [ ! -d "$SCRIPTS" ]; then
42
+ echo "Please set SCRIPTS variable correctly to point to Moses scripts."
43
+ exit 1
44
+ fi
45
+
46
+ OUTDIR=wmt18_en_de
47
+
48
+ src=en
49
+ tgt=de
50
+ lang=en-de
51
+ prep=$OUTDIR
52
+ tmp=$prep/tmp
53
+ orig=orig
54
+
55
+ mkdir -p $orig $tmp $prep
56
+
57
+ cd $orig
58
+
59
+ for ((i=0;i<${#URLS[@]};++i)); do
60
+ file=${FILES[i]}
61
+ if [ -f $file ]; then
62
+ echo "$file already exists, skipping download"
63
+ else
64
+ url=${URLS[i]}
65
+ wget "$url"
66
+ if [ -f $file ]; then
67
+ echo "$url successfully downloaded."
68
+ else
69
+ echo "$url not successfully downloaded."
70
+ exit 1
71
+ fi
72
+ if [ ${file: -4} == ".tgz" ]; then
73
+ tar zxvf $file
74
+ elif [ ${file: -4} == ".tar" ]; then
75
+ tar xvf $file
76
+ fi
77
+ fi
78
+ done
79
+ cd ..
80
+
81
+ echo "pre-processing train data..."
82
+ for l in $src $tgt; do
83
+ rm $tmp/train.tags.$lang.tok.$l
84
+ for f in "${CORPORA[@]}"; do
85
+ cat $orig/$f.$l | \
86
+ perl $NORM_PUNC $l | \
87
+ perl $REM_NON_PRINT_CHAR | \
88
+ perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l
89
+ done
90
+ done
91
+
92
+ echo "pre-processing test data..."
93
+ for l in $src $tgt; do
94
+ if [ "$l" == "$src" ]; then
95
+ t="src"
96
+ else
97
+ t="ref"
98
+ fi
99
+ grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
100
+ sed -e 's/<seg id="[0-9]*">\s*//g' | \
101
+ sed -e 's/\s*<\/seg>\s*//g' | \
102
+ sed -e "s/\’/\'/g" | \
103
+ perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l
104
+ echo ""
105
+ done
106
+
107
+ echo "splitting train and valid..."
108
+ for l in $src $tgt; do
109
+ awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
110
+ awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
111
+ done
112
+
113
+ TRAIN=$tmp/train.de-en
114
+ BPE_CODE=$prep/code
115
+ rm -f $TRAIN
116
+ for l in $src $tgt; do
117
+ cat $tmp/train.$l >> $TRAIN
118
+ done
119
+
120
+ echo "learn_bpe.py on ${TRAIN}..."
121
+ python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE
122
+
123
+ for L in $src $tgt; do
124
+ for f in train.$L valid.$L test.$L; do
125
+ echo "apply_bpe.py to ${f}..."
126
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f
127
+ done
128
+ done
129
+
130
+ perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
131
+ perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
132
+
133
+ for L in $src $tgt; do
134
+ cp $tmp/bpe.test.$L $prep/test.$L
135
+ done
fairseq/examples/backtranslation/sacrebleu.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30
+ | sacremoses tokenize -a -l $SRCLANG -q \
31
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
32
+ | fairseq-interactive $DATABIN --path $MODEL \
33
+ -s $SRCLANG -t $TGTLANG \
34
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35
+ | grep ^H- | cut -f 3- \
36
+ | sacremoses detokenize -l $TGTLANG -q \
37
+ | sacrebleu -t $DATASET -l $LANGPAIR
fairseq/examples/backtranslation/tokenized_bleu.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ TMP_REF=$(mktemp)
30
+
31
+ sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32
+ | sacremoses normalize -l $TGTLANG -q \
33
+ | sacremoses tokenize -a -l $TGTLANG -q \
34
+ > $TMP_REF
35
+
36
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37
+ | sacremoses normalize -l $SRCLANG -q \
38
+ | sacremoses tokenize -a -l $SRCLANG -q \
39
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
40
+ | fairseq-interactive $DATABIN --path $MODEL \
41
+ -s $SRCLANG -t $TGTLANG \
42
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43
+ | grep ^H- | cut -f 3- \
44
+ | fairseq-score --ref $TMP_REF
45
+
46
+ rm -f $TMP_REF
fairseq/examples/bart/README.glue.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on GLUE tasks
2
+
3
+ ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
+ ```bash
5
+ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
+ python download_glue_data.py --data_dir glue_data --tasks all
7
+ ```
8
+
9
+ ### 2) Preprocess GLUE task data (same as RoBERTa):
10
+ ```bash
11
+ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
+ ```
13
+ `glue_task_name` is one of the following:
14
+ `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
+ Use `ALL` for preprocessing all the glue tasks.
16
+
17
+ ### 3) Fine-tuning on GLUE task:
18
+ Example fine-tuning cmd for `RTE` task
19
+ ```bash
20
+ TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
+ WARMUP_UPDATES=61 # 6 percent of the number of updates
22
+ LR=1e-05 # Peak LR for polynomial LR scheduler.
23
+ NUM_CLASSES=2
24
+ MAX_SENTENCES=16 # Batch size.
25
+ BART_PATH=/path/to/bart/model.pt
26
+
27
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28
+ --restore-file $BART_PATH \
29
+ --batch-size $MAX_SENTENCES \
30
+ --max-tokens 4400 \
31
+ --task sentence_prediction \
32
+ --add-prev-output-tokens \
33
+ --layernorm-embedding \
34
+ --share-all-embeddings \
35
+ --share-decoder-input-output-embed \
36
+ --reset-optimizer --reset-dataloader --reset-meters \
37
+ --required-batch-size-multiple 1 \
38
+ --init-token 0 \
39
+ --arch bart_large \
40
+ --criterion sentence_prediction \
41
+ --num-classes $NUM_CLASSES \
42
+ --dropout 0.1 --attention-dropout 0.1 \
43
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44
+ --clip-norm 0.0 \
45
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47
+ --max-epoch 10 \
48
+ --find-unused-parameters \
49
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50
+ ```
51
+
52
+ For each of the GLUE task, you will need to use following cmd-line arguments:
53
+
54
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55
+ ---|---|---|---|---|---|---|---|---
56
+ `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57
+ `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58
+ `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59
+ `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60
+ `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61
+
62
+ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63
+
64
+ **Note:**
65
+
66
+ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67
+
68
+ b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69
+
70
+ ### Inference on GLUE task
71
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72
+
73
+ ```python
74
+ from fairseq.models.bart import BARTModel
75
+
76
+ bart = BARTModel.from_pretrained(
77
+ 'checkpoints/',
78
+ checkpoint_file='checkpoint_best.pt',
79
+ data_name_or_path='RTE-bin'
80
+ )
81
+
82
+ label_fn = lambda label: bart.task.label_dictionary.string(
83
+ [label + bart.task.label_dictionary.nspecial]
84
+ )
85
+ ncorrect, nsamples = 0, 0
86
+ bart.cuda()
87
+ bart.eval()
88
+ with open('glue_data/RTE/dev.tsv') as fin:
89
+ fin.readline()
90
+ for index, line in enumerate(fin):
91
+ tokens = line.strip().split('\t')
92
+ sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93
+ tokens = bart.encode(sent1, sent2)
94
+ prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95
+ prediction_label = label_fn(prediction)
96
+ ncorrect += int(prediction_label == target)
97
+ nsamples += 1
98
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
99
+ ```
fairseq/examples/bart/README.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2
+
3
+ [https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461)
4
+
5
+ ## Introduction
6
+
7
+ BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8
+
9
+ ## Pre-trained models
10
+
11
+ Model | Description | # params | Download
12
+ ---|---|---|---
13
+ `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
14
+ `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
15
+ `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
16
+ `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
17
+ `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
18
+
19
+ ## Results
20
+
21
+ **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
22
+ _(dev set, single model, single-task finetuning)_
23
+
24
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25
+ ---|---|---|---|---|---|---|---|---
26
+ `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27
+ `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28
+
29
+ **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
30
+ _(dev set, no additional data used)_
31
+
32
+ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33
+ ---|---|---
34
+ `roberta.large` | 88.9/94.6 | 86.5/89.4
35
+ `bart.large` | 88.8/94.6 | 86.1/89.2
36
+
37
+ **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
38
+ _(test set, no additional data used)_
39
+
40
+ Model | R1 | R2 | RL
41
+ ---|---|---|---
42
+ `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43
+ `bart.large` | 44.16 | 21.28 | 40.90
44
+
45
+ ## Example usage
46
+
47
+ ##### Load BART from torch.hub (PyTorch >= 1.1):
48
+ ```python
49
+ import torch
50
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51
+ bart.eval() # disable dropout (or leave in train mode to finetune)
52
+ ```
53
+
54
+ ##### Load BART (for PyTorch 1.0 or custom models):
55
+ ```python
56
+ # Download bart.large model
57
+ wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
58
+ tar -xzvf bart.large.tar.gz
59
+
60
+ # Load the model in fairseq
61
+ from fairseq.models.bart import BARTModel
62
+ bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
63
+ bart.eval() # disable dropout (or leave in train mode to finetune)
64
+ ```
65
+
66
+ ##### Apply Byte-Pair Encoding (BPE) to input text:
67
+ ```python
68
+ tokens = bart.encode('Hello world!')
69
+ assert tokens.tolist() == [0, 31414, 232, 328, 2]
70
+ bart.decode(tokens) # 'Hello world!'
71
+ ```
72
+
73
+ ##### Extract features from BART:
74
+ ```python
75
+ # Extract the last layer's features
76
+ last_layer_features = bart.extract_features(tokens)
77
+ assert last_layer_features.size() == torch.Size([1, 5, 1024])
78
+
79
+ # Extract all layer's features from decoder (layer 0 is the embedding layer)
80
+ all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81
+ assert len(all_layers) == 13
82
+ assert torch.all(all_layers[-1] == last_layer_features)
83
+ ```
84
+
85
+ ##### Use BART for sentence-pair classification tasks:
86
+ ```python
87
+ # Download BART already finetuned for MNLI
88
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89
+ bart.eval() # disable dropout for evaluation
90
+
91
+ # Encode a pair of sentences and make a prediction
92
+ tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93
+ bart.predict('mnli', tokens).argmax() # 0: contradiction
94
+
95
+ # Encode another pair of sentences
96
+ tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97
+ bart.predict('mnli', tokens).argmax() # 2: entailment
98
+ ```
99
+
100
+ ##### Register a new (randomly initialized) classification head:
101
+ ```python
102
+ bart.register_classification_head('new_task', num_classes=3)
103
+ logprobs = bart.predict('new_task', tokens)
104
+ ```
105
+
106
+ ##### Batched prediction:
107
+ ```python
108
+ import torch
109
+ from fairseq.data.data_utils import collate_tokens
110
+
111
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112
+ bart.eval()
113
+
114
+ batch_of_pairs = [
115
+ ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116
+ ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117
+ ]
118
+
119
+ batch = collate_tokens(
120
+ [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121
+ )
122
+
123
+ logprobs = bart.predict('mnli', batch)
124
+ print(logprobs.argmax(dim=1))
125
+ # tensor([0, 2])
126
+ ```
127
+
128
+ ##### Using the GPU:
129
+ ```python
130
+ bart.cuda()
131
+ bart.predict('new_task', tokens)
132
+ ```
133
+
134
+ #### Filling masks:
135
+
136
+ BART can be used to fill multiple `<mask>` tokens in the input.
137
+ ```python
138
+ bart = torch.hub.load('pytorch/fairseq', 'bart.base')
139
+ bart.eval()
140
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10)
141
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]]
142
+ ```
143
+
144
+ Note that by default we enforce the output length to match the input length.
145
+ This can be disabled by setting ``match_source_len=False``:
146
+ ```
147
+ bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False)
148
+ # [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]]
149
+ ```
150
+
151
+ Example code to fill masks for a batch of sentences using GPU
152
+ ```
153
+ bart.cuda()
154
+ bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10)
155
+ # [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)),
156
+ ('The dog was asleep on the couch', tensor(-0.6796))]]
157
+ ```
158
+
159
+ #### Evaluating the `bart.large.mnli` model:
160
+
161
+ Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
162
+ ```python
163
+ label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
164
+ ncorrect, nsamples = 0, 0
165
+ bart.cuda()
166
+ bart.eval()
167
+ with open('glue_data/MNLI/dev_matched.tsv') as fin:
168
+ fin.readline()
169
+ for index, line in enumerate(fin):
170
+ tokens = line.strip().split('\t')
171
+ sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
172
+ tokens = bart.encode(sent1, sent2)
173
+ prediction = bart.predict('mnli', tokens).argmax().item()
174
+ prediction_label = label_map[prediction]
175
+ ncorrect += int(prediction_label == target)
176
+ nsamples += 1
177
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
178
+ # Expected output: 0.9010
179
+ ```
180
+
181
+ #### Evaluating the `bart.large.cnn` model:
182
+ - Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
183
+ - For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores
184
+ - `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search.
185
+ In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`.
186
+
187
+ In `fairseq`, summaries can be generated using:
188
+
189
+ ```bash
190
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
191
+ python examples/bart/summarize.py \
192
+ --model-dir pytorch/fairseq \
193
+ --model-file bart.large.cnn \
194
+ --src cnn_dm/test.source \
195
+ --out cnn_dm/test.hypo
196
+ ```
197
+
198
+ For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
199
+
200
+ ```bash
201
+ export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
202
+
203
+ # Tokenize hypothesis and target files.
204
+ cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
205
+ cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
206
+ files2rouge test.hypo.tokenized test.hypo.target
207
+ # Expected output: (ROUGE-2 Average_F: 0.21238)
208
+ ```
209
+
210
+
211
+ ## Finetuning
212
+
213
+ - [Finetuning on GLUE](README.glue.md)
214
+ - [Finetuning on CNN-DM](README.summarization.md)
215
+
216
+ ## Citation
217
+
218
+ ```bibtex
219
+ @article{lewis2019bart,
220
+ title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
221
+ Language Generation, Translation, and Comprehension},
222
+ author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
223
+ Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
224
+ and Luke Zettlemoyer },
225
+ journal={arXiv preprint arXiv:1910.13461},
226
+ year = {2019},
227
+ }
228
+ ```
fairseq/examples/bart/README.summarization.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on CNN-Dailymail summarization task
2
+
3
+ ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4
+
5
+ Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
6
+
7
+ Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8
+
9
+ ### 2) BPE preprocess:
10
+
11
+ ```bash
12
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
15
+
16
+ TASK=cnn_dm
17
+ for SPLIT in train val
18
+ do
19
+ for LANG in source target
20
+ do
21
+ python -m examples.roberta.multiprocessing_bpe_encoder \
22
+ --encoder-json encoder.json \
23
+ --vocab-bpe vocab.bpe \
24
+ --inputs "$TASK/$SPLIT.$LANG" \
25
+ --outputs "$TASK/$SPLIT.bpe.$LANG" \
26
+ --workers 60 \
27
+ --keep-empty;
28
+ done
29
+ done
30
+ ```
31
+
32
+ ### 3) Binarize dataset:
33
+ ```bash
34
+ fairseq-preprocess \
35
+ --source-lang "source" \
36
+ --target-lang "target" \
37
+ --trainpref "${TASK}/train.bpe" \
38
+ --validpref "${TASK}/val.bpe" \
39
+ --destdir "${TASK}-bin/" \
40
+ --workers 60 \
41
+ --srcdict dict.txt \
42
+ --tgtdict dict.txt;
43
+ ```
44
+
45
+ ### 4) Fine-tuning on CNN-DM summarization task:
46
+ Example fine-tuning CNN-DM
47
+ ```bash
48
+ TOTAL_NUM_UPDATES=20000
49
+ WARMUP_UPDATES=500
50
+ LR=3e-05
51
+ MAX_TOKENS=2048
52
+ UPDATE_FREQ=4
53
+ BART_PATH=/path/to/bart/model.pt
54
+
55
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56
+ --restore-file $BART_PATH \
57
+ --max-tokens $MAX_TOKENS \
58
+ --task translation \
59
+ --source-lang source --target-lang target \
60
+ --truncate-source \
61
+ --layernorm-embedding \
62
+ --share-all-embeddings \
63
+ --share-decoder-input-output-embed \
64
+ --reset-optimizer --reset-dataloader --reset-meters \
65
+ --required-batch-size-multiple 1 \
66
+ --arch bart_large \
67
+ --criterion label_smoothed_cross_entropy \
68
+ --label-smoothing 0.1 \
69
+ --dropout 0.1 --attention-dropout 0.1 \
70
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71
+ --clip-norm 0.1 \
72
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73
+ --fp16 --update-freq $UPDATE_FREQ \
74
+ --skip-invalid-size-inputs-valid-test \
75
+ --find-unused-parameters;
76
+ ```
77
+ Above is expected to run on `1` node with `8 32gb-V100`.
78
+ Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79
+
80
+ Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81
+
82
+ ### Inference for CNN-DM test data using above trained checkpoint.
83
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using `eval_cnn.py`, for example
84
+
85
+ ```bash
86
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
87
+ python examples/bart/summarize.py \
88
+ --model-dir checkpoints \
89
+ --model-file checkpoint_best.pt \
90
+ --src cnn_dm/test.source \
91
+ --out cnn_dm/test.hypo
92
+ ```
93
+ For XSUM, which uses beam=6, lenpen=1.0, max_len_b=60, min_len=10:
94
+ ```bash
95
+ cp data-bin/cnn_dm/dict.source.txt checkpoints/
96
+ python examples/bart/summarize.py \
97
+ --model-dir checkpoints \
98
+ --model-file checkpoint_best.pt \
99
+ --src cnn_dm/test.source \
100
+ --out cnn_dm/test.hypo \
101
+ --xsum-kwargs
102
+ ```
fairseq/examples/bart/summarize.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from fairseq.models.bart import BARTModel
8
+ import argparse
9
+
10
+ XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
11
+ CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
12
+
13
+
14
+ @torch.no_grad()
15
+ def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs):
16
+ count = 1
17
+
18
+ # if n_obs is not None: bsz = min(bsz, n_obs)
19
+
20
+ with open(infile) as source, open(outfile, "w") as fout:
21
+ sline = source.readline().strip()
22
+ slines = [sline]
23
+ for sline in source:
24
+ if n_obs is not None and count > n_obs:
25
+ break
26
+ if count % bsz == 0:
27
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
28
+ for hypothesis in hypotheses_batch:
29
+ fout.write(hypothesis + "\n")
30
+ fout.flush()
31
+ slines = []
32
+
33
+ slines.append(sline.strip())
34
+ count += 1
35
+
36
+ if slines != []:
37
+ hypotheses_batch = bart.sample(slines, **eval_kwargs)
38
+ for hypothesis in hypotheses_batch:
39
+ fout.write(hypothesis + "\n")
40
+ fout.flush()
41
+
42
+
43
+ def main():
44
+ """
45
+ Usage::
46
+
47
+ python examples/bart/summarize.py \
48
+ --model-dir $HOME/bart.large.cnn \
49
+ --model-file model.pt \
50
+ --src $HOME/data-bin/cnn_dm/test.source
51
+ """
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument(
54
+ "--model-dir",
55
+ required=True,
56
+ type=str,
57
+ default="bart.large.cnn/",
58
+ help="path containing model file and src_dict.txt",
59
+ )
60
+ parser.add_argument(
61
+ "--model-file",
62
+ default="checkpoint_best.pt",
63
+ help="where in model_dir are weights saved",
64
+ )
65
+ parser.add_argument(
66
+ "--src", default="test.source", help="text to summarize", type=str
67
+ )
68
+ parser.add_argument(
69
+ "--out", default="test.hypo", help="where to save summaries", type=str
70
+ )
71
+ parser.add_argument("--bsz", default=32, help="where to save summaries", type=int)
72
+ parser.add_argument(
73
+ "--n", default=None, help="how many examples to summarize", type=int
74
+ )
75
+ parser.add_argument(
76
+ "--xsum-kwargs",
77
+ action="store_true",
78
+ default=False,
79
+ help="if true use XSUM_KWARGS else CNN_KWARGS",
80
+ )
81
+ args = parser.parse_args()
82
+ eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS
83
+ if args.model_dir == "pytorch/fairseq":
84
+ bart = torch.hub.load("pytorch/fairseq", args.model_file)
85
+ else:
86
+ bart = BARTModel.from_pretrained(
87
+ args.model_dir,
88
+ checkpoint_file=args.model_file,
89
+ data_name_or_path=args.model_dir,
90
+ )
91
+ bart = bart.eval()
92
+ if torch.cuda.is_available():
93
+ bart = bart.cuda().half()
94
+ generate(
95
+ bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs
96
+ )
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()