Spaces:
Running
Running
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- README.md +4 -4
- airship.jpg +0 -0
- app.py +153 -0
- checkpoints.md +13 -0
- colab.md +8 -0
- criterions/__init__.py +2 -0
- criterions/label_smoothed_cross_entropy.py +343 -0
- criterions/scst_loss.py +280 -0
- data/__init__.py +0 -0
- data/data_utils.py +601 -0
- data/file_dataset.py +102 -0
- data/mm_data/__init__.py +0 -0
- data/mm_data/caption_dataset.py +154 -0
- data/mm_data/refcoco_dataset.py +168 -0
- data/mm_data/vqa_gen_dataset.py +211 -0
- data/ofa_dataset.py +74 -0
- datasets.md +10 -0
- evaluate.py +156 -0
- fairseq/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
- fairseq/.github/stale.yml +30 -0
- fairseq/.github/workflows/build.yml +55 -0
- fairseq/.github/workflows/build_wheels.yml +41 -0
- fairseq/.gitignore +136 -0
- fairseq/.gitmodules +4 -0
- fairseq/CODE_OF_CONDUCT.md +77 -0
- fairseq/CONTRIBUTING.md +28 -0
- fairseq/LICENSE +21 -0
- fairseq/README.md +229 -0
- fairseq/docs/Makefile +20 -0
- fairseq/docs/_static/theme_overrides.css +9 -0
- fairseq/docs/command_line_tools.rst +85 -0
- fairseq/docs/conf.py +134 -0
- fairseq/docs/criterions.rst +31 -0
- fairseq/docs/data.rst +58 -0
- fairseq/docs/docutils.conf +2 -0
- fairseq/docs/fairseq_logo.png +0 -0
- fairseq/docs/getting_started.rst +216 -0
- fairseq/docs/hydra_integration.md +284 -0
- fairseq/docs/index.rst +49 -0
- fairseq/docs/lr_scheduler.rst +34 -0
- fairseq/docs/make.bat +36 -0
- fairseq/docs/models.rst +104 -0
- fairseq/docs/modules.rst +9 -0
- fairseq/docs/optim.rst +38 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 1999-2022 Alibaba Group Holding Ltd.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: OFA
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
|
|
1 |
---
|
2 |
+
title: OFA-Open_Domain_VQA
|
3 |
+
emoji: 💩
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
airship.jpg
ADDED
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;'
|
4 |
+
'pip install --use-feature=in-tree-build ./; cd ..')
|
5 |
+
os.system('ls -l')
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import re
|
10 |
+
from fairseq import utils,tasks
|
11 |
+
from fairseq import checkpoint_utils
|
12 |
+
from fairseq import distributed_utils, options, tasks, utils
|
13 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
14 |
+
from utils.zero_shot_utils import zero_shot_step
|
15 |
+
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
16 |
+
from models.ofa import OFAModel
|
17 |
+
from PIL import Image
|
18 |
+
from torchvision import transforms
|
19 |
+
import gradio as gr
|
20 |
+
|
21 |
+
# Register VQA task
|
22 |
+
tasks.register_task('vqa_gen',VqaGenTask)
|
23 |
+
# turn on cuda if GPU is available
|
24 |
+
use_cuda = torch.cuda.is_available()
|
25 |
+
# use fp16 only when GPU is available
|
26 |
+
use_fp16 = False
|
27 |
+
|
28 |
+
os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/ofa_large_384.pt; '
|
29 |
+
'mkdir -p checkpoints; mv ofa_large_384.pt checkpoints/ofa_large_384.pt')
|
30 |
+
|
31 |
+
# specify some options for evaluation
|
32 |
+
parser = options.get_generation_parser()
|
33 |
+
input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path=checkpoints/ofa_large_384.pt", "--bpe-dir=utils/BPE"]
|
34 |
+
args = options.parse_args_and_arch(parser, input_args)
|
35 |
+
cfg = convert_namespace_to_omegaconf(args)
|
36 |
+
|
37 |
+
# Load pretrained ckpt & config
|
38 |
+
task = tasks.setup_task(cfg.task)
|
39 |
+
models, cfg = checkpoint_utils.load_model_ensemble(
|
40 |
+
utils.split_paths(cfg.common_eval.path),
|
41 |
+
task=task
|
42 |
+
)
|
43 |
+
|
44 |
+
# Move models to GPU
|
45 |
+
for model in models:
|
46 |
+
model.eval()
|
47 |
+
if use_fp16:
|
48 |
+
model.half()
|
49 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
50 |
+
model.cuda()
|
51 |
+
model.prepare_for_inference_(cfg)
|
52 |
+
|
53 |
+
# Initialize generator
|
54 |
+
generator = task.build_generator(models, cfg.generation)
|
55 |
+
|
56 |
+
# Image transform
|
57 |
+
from torchvision import transforms
|
58 |
+
mean = [0.5, 0.5, 0.5]
|
59 |
+
std = [0.5, 0.5, 0.5]
|
60 |
+
|
61 |
+
patch_resize_transform = transforms.Compose([
|
62 |
+
lambda image: image.convert("RGB"),
|
63 |
+
transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize(mean=mean, std=std),
|
66 |
+
])
|
67 |
+
|
68 |
+
# Text preprocess
|
69 |
+
bos_item = torch.LongTensor([task.src_dict.bos()])
|
70 |
+
eos_item = torch.LongTensor([task.src_dict.eos()])
|
71 |
+
pad_idx = task.src_dict.pad()
|
72 |
+
|
73 |
+
# Normalize the question
|
74 |
+
def pre_question(question, max_ques_words):
|
75 |
+
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
|
76 |
+
question = re.sub(
|
77 |
+
r"\s{2,}",
|
78 |
+
' ',
|
79 |
+
question,
|
80 |
+
)
|
81 |
+
question = question.rstrip('\n')
|
82 |
+
question = question.strip(' ')
|
83 |
+
# truncate question
|
84 |
+
question_words = question.split(' ')
|
85 |
+
if len(question_words) > max_ques_words:
|
86 |
+
question = ' '.join(question_words[:max_ques_words])
|
87 |
+
return question
|
88 |
+
|
89 |
+
def encode_text(text, length=None, append_bos=False, append_eos=False):
|
90 |
+
s = task.tgt_dict.encode_line(
|
91 |
+
line=task.bpe.encode(text),
|
92 |
+
add_if_not_exist=False,
|
93 |
+
append_eos=False
|
94 |
+
).long()
|
95 |
+
if length is not None:
|
96 |
+
s = s[:length]
|
97 |
+
if append_bos:
|
98 |
+
s = torch.cat([bos_item, s])
|
99 |
+
if append_eos:
|
100 |
+
s = torch.cat([s, eos_item])
|
101 |
+
return s
|
102 |
+
|
103 |
+
# Construct input for open-domain VQA task
|
104 |
+
def construct_sample(image: Image, question: str):
|
105 |
+
patch_image = patch_resize_transform(image).unsqueeze(0)
|
106 |
+
patch_mask = torch.tensor([True])
|
107 |
+
|
108 |
+
question = pre_question(question, task.cfg.max_src_length)
|
109 |
+
question = question + '?' if not question.endswith('?') else question
|
110 |
+
src_text = encode_text(' {}'.format(question), append_bos=True, append_eos=True).unsqueeze(0)
|
111 |
+
|
112 |
+
src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
|
113 |
+
ref_dict = np.array([{'yes': 1.0}]) # just placeholder
|
114 |
+
sample = {
|
115 |
+
"id":np.array(['42']),
|
116 |
+
"net_input": {
|
117 |
+
"src_tokens": src_text,
|
118 |
+
"src_lengths": src_length,
|
119 |
+
"patch_images": patch_image,
|
120 |
+
"patch_masks": patch_mask,
|
121 |
+
},
|
122 |
+
"ref_dict": ref_dict,
|
123 |
+
}
|
124 |
+
return sample
|
125 |
+
|
126 |
+
# Function to turn FP32 to FP16
|
127 |
+
def apply_half(t):
|
128 |
+
if t.dtype is torch.float32:
|
129 |
+
return t.to(dtype=torch.half)
|
130 |
+
return t
|
131 |
+
|
132 |
+
|
133 |
+
# Function for image captioning
|
134 |
+
def open_domain_vqa(Image, Question):
|
135 |
+
sample = construct_sample(Image, Question)
|
136 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
137 |
+
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
|
138 |
+
# Run eval step for open-domain VQA
|
139 |
+
with torch.no_grad():
|
140 |
+
result, scores = zero_shot_step(task, generator, models, sample)
|
141 |
+
return result[0]['answer']
|
142 |
+
|
143 |
+
|
144 |
+
title = "OFA-Open_Domain_VQA"
|
145 |
+
description = "Gradio Demo for OFA-Open_Domain_VQA. Upload your own image or click any one of the examples, and click " \
|
146 |
+
"\"Submit\" and then wait for OFA's answer. "
|
147 |
+
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
148 |
+
"Repo</a></p> "
|
149 |
+
examples = [['money_tree.png', 'what is grown on the plant?'], ['airship.jpg', 'what does the red-roofed building right to the big airship look like?'], ['sitting_man.png', 'what is the man sitting on?']]
|
150 |
+
io = gr.Interface(fn=open_domain_vqa, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"),
|
151 |
+
title=title, description=description, article=article, examples=examples,
|
152 |
+
allow_flagging=False, allow_screenshot=False)
|
153 |
+
io.launch(cache_examples=True)
|
checkpoints.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Checkpoints
|
2 |
+
|
3 |
+
We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks.
|
4 |
+
|
5 |
+
## Pretraining
|
6 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a>
|
7 |
+
|
8 |
+
## Finetuning
|
9 |
+
|
10 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
|
11 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
|
12 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
|
13 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
|
colab.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Colab Notebooks
|
2 |
+
|
3 |
+
We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below.
|
4 |
+
|
5 |
+
* Image Captioning: [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
|
6 |
+
* Referring Expression Comprehension: [![][colab]](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
|
7 |
+
|
8 |
+
[colab]: <https://colab.research.google.com/assets/colab-badge.svg>
|
criterions/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .scst_loss import ScstRewardCriterion
|
2 |
+
from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
|
criterions/label_smoothed_cross_entropy.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
from fairseq import metrics, utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from omegaconf import II
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
21 |
+
label_smoothing: float = field(
|
22 |
+
default=0.0,
|
23 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
24 |
+
)
|
25 |
+
report_accuracy: bool = field(
|
26 |
+
default=False,
|
27 |
+
metadata={"help": "report accuracy metric"},
|
28 |
+
)
|
29 |
+
ignore_prefix_size: int = field(
|
30 |
+
default=0,
|
31 |
+
metadata={"help": "Ignore first N tokens"},
|
32 |
+
)
|
33 |
+
ignore_eos: bool = field(
|
34 |
+
default=False,
|
35 |
+
metadata={"help": "Ignore eos token"},
|
36 |
+
)
|
37 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
38 |
+
drop_worst_ratio: float = field(
|
39 |
+
default=0.0,
|
40 |
+
metadata={"help": "ratio for discarding bad samples"},
|
41 |
+
)
|
42 |
+
drop_worst_after: int = field(
|
43 |
+
default=0,
|
44 |
+
metadata={"help": "steps for discarding bad samples"},
|
45 |
+
)
|
46 |
+
use_rdrop: bool = field(
|
47 |
+
default=False, metadata={"help": "use R-Drop"}
|
48 |
+
)
|
49 |
+
reg_alpha: float = field(
|
50 |
+
default=1.0, metadata={"help": "weight for R-Drop"}
|
51 |
+
)
|
52 |
+
sample_patch_num: int = field(
|
53 |
+
default=196, metadata={"help": "sample patchs for v1"}
|
54 |
+
)
|
55 |
+
constraint_range: Optional[str] = field(
|
56 |
+
default=None,
|
57 |
+
metadata={"help": "constraint range"}
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def construct_rdrop_sample(x):
|
62 |
+
if isinstance(x, dict):
|
63 |
+
for key in x:
|
64 |
+
x[key] = construct_rdrop_sample(x[key])
|
65 |
+
return x
|
66 |
+
elif isinstance(x, torch.Tensor):
|
67 |
+
return x.repeat(2, *([1] * (x.dim()-1)))
|
68 |
+
elif isinstance(x, int):
|
69 |
+
return x * 2
|
70 |
+
elif isinstance(x, np.ndarray):
|
71 |
+
return x.repeat(2)
|
72 |
+
else:
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
|
76 |
+
def kl_loss(p, q):
|
77 |
+
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
78 |
+
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
79 |
+
loss = (p_loss + q_loss) / 2
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def label_smoothed_nll_loss(
|
84 |
+
lprobs, target, epsilon, update_num, reduce=True,
|
85 |
+
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
86 |
+
constraint_masks=None, constraint_start=None, constraint_end=None
|
87 |
+
):
|
88 |
+
if target.dim() == lprobs.dim() - 1:
|
89 |
+
target = target.unsqueeze(-1)
|
90 |
+
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
91 |
+
if constraint_masks is not None:
|
92 |
+
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
93 |
+
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
94 |
+
elif constraint_start is not None and constraint_end is not None:
|
95 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
96 |
+
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
97 |
+
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
98 |
+
else:
|
99 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
100 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
101 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
102 |
+
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
103 |
+
if use_rdrop:
|
104 |
+
true_batch_size = loss.size(0) // 2
|
105 |
+
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
106 |
+
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
107 |
+
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
108 |
+
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
109 |
+
else:
|
110 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
111 |
+
nll_loss = nll_loss[indices]
|
112 |
+
lprobs = lprobs[indices]
|
113 |
+
|
114 |
+
ntokens = loss.numel()
|
115 |
+
nll_loss = nll_loss.sum()
|
116 |
+
loss = loss.sum()
|
117 |
+
if use_rdrop:
|
118 |
+
true_batch_size = lprobs.size(0) // 2
|
119 |
+
p = lprobs[:true_batch_size]
|
120 |
+
q = lprobs[true_batch_size:]
|
121 |
+
if constraint_start is not None and constraint_end is not None:
|
122 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
123 |
+
p = p[:, constraint_range]
|
124 |
+
q = q[:, constraint_range]
|
125 |
+
loss += kl_loss(p, q) * reg_alpha
|
126 |
+
|
127 |
+
return loss, nll_loss, ntokens
|
128 |
+
|
129 |
+
|
130 |
+
@register_criterion(
|
131 |
+
"ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
|
132 |
+
)
|
133 |
+
class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
task,
|
137 |
+
sentence_avg,
|
138 |
+
label_smoothing,
|
139 |
+
ignore_prefix_size=0,
|
140 |
+
ignore_eos=False,
|
141 |
+
report_accuracy=False,
|
142 |
+
drop_worst_ratio=0,
|
143 |
+
drop_worst_after=0,
|
144 |
+
use_rdrop=False,
|
145 |
+
reg_alpha=1.0,
|
146 |
+
sample_patch_num=196,
|
147 |
+
constraint_range=None
|
148 |
+
):
|
149 |
+
super().__init__(task)
|
150 |
+
self.sentence_avg = sentence_avg
|
151 |
+
self.eps = label_smoothing
|
152 |
+
self.ignore_prefix_size = ignore_prefix_size
|
153 |
+
self.ignore_eos = ignore_eos
|
154 |
+
self.report_accuracy = report_accuracy
|
155 |
+
self.drop_worst_ratio = drop_worst_ratio
|
156 |
+
self.drop_worst_after = drop_worst_after
|
157 |
+
self.use_rdrop = use_rdrop
|
158 |
+
self.reg_alpha = reg_alpha
|
159 |
+
self.sample_patch_num = sample_patch_num
|
160 |
+
|
161 |
+
self.constraint_start = None
|
162 |
+
self.constraint_end = None
|
163 |
+
if constraint_range is not None:
|
164 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
165 |
+
self.constraint_start = int(constraint_start)
|
166 |
+
self.constraint_end = int(constraint_end)
|
167 |
+
|
168 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
169 |
+
"""Compute the loss for the given sample.
|
170 |
+
|
171 |
+
Returns a tuple with three elements:
|
172 |
+
1) the loss
|
173 |
+
2) the sample size, which is used as the denominator for the gradient
|
174 |
+
3) logging outputs to display while training
|
175 |
+
"""
|
176 |
+
if isinstance(sample, list):
|
177 |
+
if self.sample_patch_num > 0:
|
178 |
+
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
179 |
+
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
180 |
+
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
181 |
+
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
182 |
+
sample_size = 1
|
183 |
+
logging_output = {
|
184 |
+
"loss": loss.data,
|
185 |
+
"loss_v1": loss_v1.data,
|
186 |
+
"loss_v2": loss_v2.data,
|
187 |
+
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
188 |
+
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
189 |
+
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
190 |
+
"sample_size": 1,
|
191 |
+
"sample_size_v1": sample_size_v1,
|
192 |
+
"sample_size_v2": sample_size_v2,
|
193 |
+
}
|
194 |
+
return loss, sample_size, logging_output
|
195 |
+
|
196 |
+
if self.use_rdrop:
|
197 |
+
construct_rdrop_sample(sample)
|
198 |
+
|
199 |
+
net_output = model(**sample["net_input"])
|
200 |
+
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
201 |
+
sample_size = (
|
202 |
+
sample["target"].size(0) if self.sentence_avg else ntokens
|
203 |
+
)
|
204 |
+
logging_output = {
|
205 |
+
"loss": loss.data,
|
206 |
+
"nll_loss": nll_loss.data,
|
207 |
+
"ntokens": sample["ntokens"],
|
208 |
+
"nsentences": sample["nsentences"],
|
209 |
+
"sample_size": sample_size,
|
210 |
+
}
|
211 |
+
if self.report_accuracy:
|
212 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
213 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
214 |
+
logging_output["total"] = utils.item(total.data)
|
215 |
+
return loss, sample_size, logging_output
|
216 |
+
|
217 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
218 |
+
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
219 |
+
constraint_masks = None
|
220 |
+
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
221 |
+
constraint_masks = sample["constraint_masks"]
|
222 |
+
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
223 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
224 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
225 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
226 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
227 |
+
target = model.get_targets(sample, net_output)
|
228 |
+
if self.ignore_prefix_size > 0:
|
229 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
230 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
231 |
+
if constraint_masks is not None:
|
232 |
+
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
233 |
+
if self.ignore_eos:
|
234 |
+
bsz, seq_len, embed_dim = lprobs.size()
|
235 |
+
eos_indices = target.eq(self.task.tgt_dict.eos())
|
236 |
+
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
237 |
+
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
238 |
+
if constraint_masks is not None:
|
239 |
+
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
240 |
+
if constraint_masks is not None:
|
241 |
+
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
242 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
243 |
+
|
244 |
+
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
245 |
+
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
246 |
+
if constraint_masks is not None:
|
247 |
+
constraint_masks = constraint_masks[target != self.padding_idx]
|
248 |
+
lprobs = lprobs[target != self.padding_idx]
|
249 |
+
target = target[target != self.padding_idx]
|
250 |
+
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
251 |
+
lprobs,
|
252 |
+
target,
|
253 |
+
self.eps,
|
254 |
+
update_num,
|
255 |
+
reduce=reduce,
|
256 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
257 |
+
drop_worst_after=self.drop_worst_after,
|
258 |
+
use_rdrop=self.use_rdrop,
|
259 |
+
reg_alpha=self.reg_alpha,
|
260 |
+
constraint_masks=constraint_masks,
|
261 |
+
constraint_start=self.constraint_start,
|
262 |
+
constraint_end=self.constraint_end
|
263 |
+
)
|
264 |
+
return loss, nll_loss, ntokens
|
265 |
+
|
266 |
+
def compute_accuracy(self, model, net_output, sample):
|
267 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
268 |
+
mask = target.ne(self.padding_idx)
|
269 |
+
n_correct = torch.sum(
|
270 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
271 |
+
)
|
272 |
+
total = torch.sum(mask)
|
273 |
+
return n_correct, total
|
274 |
+
|
275 |
+
@classmethod
|
276 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
277 |
+
"""Aggregate logging outputs from data parallel training."""
|
278 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
279 |
+
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
280 |
+
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
281 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
282 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
283 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
284 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
285 |
+
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
286 |
+
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
287 |
+
|
288 |
+
metrics.log_scalar(
|
289 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
290 |
+
)
|
291 |
+
metrics.log_scalar(
|
292 |
+
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
293 |
+
)
|
294 |
+
metrics.log_scalar(
|
295 |
+
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
296 |
+
)
|
297 |
+
metrics.log_scalar(
|
298 |
+
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
299 |
+
)
|
300 |
+
metrics.log_derived(
|
301 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
302 |
+
)
|
303 |
+
|
304 |
+
metrics.log_scalar(
|
305 |
+
"ntokens", ntokens, 1, round=3
|
306 |
+
)
|
307 |
+
metrics.log_scalar(
|
308 |
+
"nsentences", nsentences, 1, round=3
|
309 |
+
)
|
310 |
+
metrics.log_scalar(
|
311 |
+
"sample_size", sample_size, 1, round=3
|
312 |
+
)
|
313 |
+
metrics.log_scalar(
|
314 |
+
"sample_size_v1", sample_size_v1, 1, round=3
|
315 |
+
)
|
316 |
+
metrics.log_scalar(
|
317 |
+
"sample_size_v2", sample_size_v2, 1, round=3
|
318 |
+
)
|
319 |
+
|
320 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
321 |
+
if total > 0:
|
322 |
+
metrics.log_scalar("total", total)
|
323 |
+
n_correct = utils.item(
|
324 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
325 |
+
)
|
326 |
+
metrics.log_scalar("n_correct", n_correct)
|
327 |
+
metrics.log_derived(
|
328 |
+
"accuracy",
|
329 |
+
lambda meters: round(
|
330 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
331 |
+
)
|
332 |
+
if meters["total"].sum > 0
|
333 |
+
else float("nan"),
|
334 |
+
)
|
335 |
+
|
336 |
+
@staticmethod
|
337 |
+
def logging_outputs_can_be_summed() -> bool:
|
338 |
+
"""
|
339 |
+
Whether the logging outputs returned by `forward` can be summed
|
340 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
341 |
+
to True will improves distributed training speed.
|
342 |
+
"""
|
343 |
+
return True
|
criterions/scst_loss.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import string
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from collections import OrderedDict
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from fairseq import metrics, utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from omegaconf import II
|
17 |
+
|
18 |
+
from data import data_utils
|
19 |
+
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
|
20 |
+
|
21 |
+
|
22 |
+
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
|
23 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
24 |
+
if ignore_index is not None:
|
25 |
+
pad_mask = target.eq(ignore_index)
|
26 |
+
loss.masked_fill_(pad_mask, 0.0)
|
27 |
+
ntokens = (~pad_mask).sum()
|
28 |
+
else:
|
29 |
+
loss = loss.squeeze(-1)
|
30 |
+
ntokens = target.numel()
|
31 |
+
if reduce:
|
32 |
+
loss = loss.sum()
|
33 |
+
return loss, ntokens
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class ScstRewardCriterionConfig(FairseqDataclass):
|
37 |
+
scst_cider_cached_tokens: str = field(
|
38 |
+
default="coco-train-words.p",
|
39 |
+
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
|
40 |
+
)
|
41 |
+
ignore_prefix_size: int = field(
|
42 |
+
default=0,
|
43 |
+
metadata={"help": "Ignore first N tokens"},
|
44 |
+
)
|
45 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
46 |
+
constraint_range: Optional[str] = field(
|
47 |
+
default=None,
|
48 |
+
metadata={"help": "constraint range"}
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
@register_criterion(
|
53 |
+
"scst_reward_criterion", dataclass=ScstRewardCriterionConfig
|
54 |
+
)
|
55 |
+
class ScstRewardCriterion(FairseqCriterion):
|
56 |
+
CIDER_REWARD_WEIGHT = 1
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
task,
|
61 |
+
scst_cider_cached_tokens,
|
62 |
+
sentence_avg,
|
63 |
+
ignore_prefix_size=0,
|
64 |
+
constraint_range=None
|
65 |
+
):
|
66 |
+
super().__init__(task)
|
67 |
+
self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
|
68 |
+
self.sentence_avg = sentence_avg
|
69 |
+
self.ignore_prefix_size = ignore_prefix_size
|
70 |
+
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
71 |
+
|
72 |
+
self.constraint_start = None
|
73 |
+
self.constraint_end = None
|
74 |
+
if constraint_range is not None:
|
75 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
76 |
+
self.constraint_start = int(constraint_start)
|
77 |
+
self.constraint_end = int(constraint_end)
|
78 |
+
|
79 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
80 |
+
"""Compute the loss for the given sample.
|
81 |
+
|
82 |
+
Returns a tuple with three elements:
|
83 |
+
1) the loss
|
84 |
+
2) the sample size, which is used as the denominator for the gradient
|
85 |
+
3) logging outputs to display while training
|
86 |
+
"""
|
87 |
+
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
88 |
+
|
89 |
+
sample_size = (
|
90 |
+
nsentences if self.sentence_avg else ntokens
|
91 |
+
)
|
92 |
+
logging_output = {
|
93 |
+
"loss": loss.data,
|
94 |
+
"score": score,
|
95 |
+
"ntokens": ntokens,
|
96 |
+
"nsentences": nsentences,
|
97 |
+
"sample_size": sample_size,
|
98 |
+
}
|
99 |
+
return loss, sample_size, logging_output
|
100 |
+
|
101 |
+
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
|
102 |
+
'''
|
103 |
+
gen_res: generated captions, list of str
|
104 |
+
gt_idx: list of int, of the same length as gen_res
|
105 |
+
gt_res: ground truth captions, list of list of str.
|
106 |
+
gen_res[i] corresponds to gt_res[gt_idx[i]]
|
107 |
+
Each image can have multiple ground truth captions
|
108 |
+
'''
|
109 |
+
gen_res_size = len(gen_res)
|
110 |
+
|
111 |
+
res = OrderedDict()
|
112 |
+
for i in range(gen_res_size):
|
113 |
+
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
|
114 |
+
|
115 |
+
gts = OrderedDict()
|
116 |
+
gt_res_ = [
|
117 |
+
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
|
118 |
+
for i in range(len(gt_res))
|
119 |
+
]
|
120 |
+
for i in range(gen_res_size):
|
121 |
+
gts[i] = gt_res_[gt_idx[i]]
|
122 |
+
|
123 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
124 |
+
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
|
125 |
+
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
|
126 |
+
return scores
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def _wrap_sentence(self, s):
|
130 |
+
# ensure the sentence ends with <eos> token
|
131 |
+
# in order to keep consisitent with cider_cached_tokens
|
132 |
+
r = s.strip()
|
133 |
+
if r.endswith('.'):
|
134 |
+
r = r[:-1]
|
135 |
+
r += ' <eos>'
|
136 |
+
return r
|
137 |
+
|
138 |
+
def get_generator_out(self, model, sample):
|
139 |
+
def decode(toks):
|
140 |
+
hypo = toks.int().cpu()
|
141 |
+
hypo_str = self.task.tgt_dict.string(hypo)
|
142 |
+
hypo_str = self.task.bpe.decode(hypo_str).strip()
|
143 |
+
return hypo, hypo_str
|
144 |
+
|
145 |
+
model.eval()
|
146 |
+
with torch.no_grad():
|
147 |
+
self.task.scst_generator.model.eval()
|
148 |
+
gen_out = self.task.scst_generator.generate([model], sample)
|
149 |
+
|
150 |
+
gen_target = []
|
151 |
+
gen_res = []
|
152 |
+
gt_res = []
|
153 |
+
for i in range(len(gen_out)):
|
154 |
+
for j in range(len(gen_out[i])):
|
155 |
+
hypo, hypo_str = decode(gen_out[i][j]["tokens"])
|
156 |
+
gen_target.append(hypo)
|
157 |
+
gen_res.append(hypo_str)
|
158 |
+
gt_res.append(
|
159 |
+
decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
|
160 |
+
)
|
161 |
+
|
162 |
+
return gen_target, gen_res, gt_res
|
163 |
+
|
164 |
+
def get_reward_and_scores(self, gen_res, gt_res, device):
|
165 |
+
batch_size = len(gt_res)
|
166 |
+
gen_res_size = len(gen_res)
|
167 |
+
seq_per_img = gen_res_size // batch_size
|
168 |
+
|
169 |
+
gt_idx = [i // seq_per_img for i in range(gen_res_size)]
|
170 |
+
scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
|
171 |
+
sc_ = scores.reshape(batch_size, seq_per_img)
|
172 |
+
baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
|
173 |
+
# sample - baseline
|
174 |
+
reward = scores.reshape(batch_size, seq_per_img)
|
175 |
+
reward = reward - baseline
|
176 |
+
reward = reward.reshape(gen_res_size)
|
177 |
+
reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
|
178 |
+
|
179 |
+
return reward, scores
|
180 |
+
|
181 |
+
def get_net_output(self, model, sample, gen_target):
|
182 |
+
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
183 |
+
return data_utils.collate_tokens(
|
184 |
+
sample_list,
|
185 |
+
pad_idx=self.padding_idx,
|
186 |
+
eos_idx=eos,
|
187 |
+
left_pad=False,
|
188 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
189 |
+
)
|
190 |
+
|
191 |
+
batch_size = len(sample["target"])
|
192 |
+
gen_target_size = len(gen_target)
|
193 |
+
seq_per_img = gen_target_size // batch_size
|
194 |
+
|
195 |
+
model.train()
|
196 |
+
sample_src_tokens = torch.repeat_interleave(
|
197 |
+
sample['net_input']['src_tokens'], seq_per_img, dim=0
|
198 |
+
)
|
199 |
+
sample_src_lengths = torch.repeat_interleave(
|
200 |
+
sample['net_input']['src_lengths'], seq_per_img, dim=0
|
201 |
+
)
|
202 |
+
sample_patch_images = torch.repeat_interleave(
|
203 |
+
sample['net_input']['patch_images'], seq_per_img, dim=0
|
204 |
+
)
|
205 |
+
sample_patch_masks = torch.repeat_interleave(
|
206 |
+
sample['net_input']['patch_masks'], seq_per_img, dim=0
|
207 |
+
)
|
208 |
+
gen_prev_output_tokens = torch.as_tensor(
|
209 |
+
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
210 |
+
device=sample["target"].device, dtype=torch.int64
|
211 |
+
)
|
212 |
+
gen_target_tokens = torch.as_tensor(
|
213 |
+
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
214 |
+
)
|
215 |
+
net_output = model(
|
216 |
+
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
217 |
+
patch_images=sample_patch_images, patch_masks=sample_patch_masks,
|
218 |
+
prev_output_tokens=gen_prev_output_tokens
|
219 |
+
)
|
220 |
+
|
221 |
+
return net_output, gen_target_tokens
|
222 |
+
|
223 |
+
def get_lprobs_and_target(self, model, net_output, gen_target):
|
224 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
225 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
226 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
227 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
228 |
+
if self.ignore_prefix_size > 0:
|
229 |
+
if getattr(lprobs, "batch_first", False):
|
230 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
231 |
+
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
232 |
+
else:
|
233 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
234 |
+
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
235 |
+
return lprobs, gen_target
|
236 |
+
|
237 |
+
def compute_loss(self, model, sample, reduce=True):
|
238 |
+
gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
|
239 |
+
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
|
240 |
+
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
241 |
+
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
242 |
+
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
243 |
+
nsentences = gen_target_tokens.size(0)
|
244 |
+
|
245 |
+
return loss, scores.sum(), ntokens, nsentences
|
246 |
+
|
247 |
+
@classmethod
|
248 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
249 |
+
"""Aggregate logging outputs from data parallel training."""
|
250 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
251 |
+
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
252 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
253 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
254 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
255 |
+
|
256 |
+
metrics.log_scalar(
|
257 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
258 |
+
)
|
259 |
+
metrics.log_scalar(
|
260 |
+
"score", score_sum / nsentences, nsentences, round=3
|
261 |
+
)
|
262 |
+
|
263 |
+
metrics.log_scalar(
|
264 |
+
"ntokens", ntokens, 1, round=3
|
265 |
+
)
|
266 |
+
metrics.log_scalar(
|
267 |
+
"nsentences", nsentences, 1, round=3
|
268 |
+
)
|
269 |
+
metrics.log_scalar(
|
270 |
+
"sample_size", sample_size, 1, round=3
|
271 |
+
)
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def logging_outputs_can_be_summed() -> bool:
|
275 |
+
"""
|
276 |
+
Whether the logging outputs returned by `forward` can be summed
|
277 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
278 |
+
to True will improves distributed training speed.
|
279 |
+
"""
|
280 |
+
return True
|
data/__init__.py
ADDED
File without changes
|
data/data_utils.py
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
try:
|
7 |
+
from collections.abc import Iterable
|
8 |
+
except ImportError:
|
9 |
+
from collections import Iterable
|
10 |
+
import contextlib
|
11 |
+
import itertools
|
12 |
+
import logging
|
13 |
+
import re
|
14 |
+
import warnings
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from fairseq.file_io import PathManager
|
21 |
+
from fairseq import utils
|
22 |
+
import os
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def infer_language_pair(path):
|
28 |
+
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
29 |
+
src, dst = None, None
|
30 |
+
for filename in PathManager.ls(path):
|
31 |
+
parts = filename.split(".")
|
32 |
+
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
33 |
+
return parts[1].split("-")
|
34 |
+
return src, dst
|
35 |
+
|
36 |
+
|
37 |
+
def collate_tokens(
|
38 |
+
values,
|
39 |
+
pad_idx,
|
40 |
+
eos_idx=None,
|
41 |
+
left_pad=False,
|
42 |
+
move_eos_to_beginning=False,
|
43 |
+
pad_to_length=None,
|
44 |
+
pad_to_multiple=1,
|
45 |
+
pad_to_bsz=None,
|
46 |
+
):
|
47 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
48 |
+
size = max(v.size(0) for v in values)
|
49 |
+
size = size if pad_to_length is None else max(size, pad_to_length)
|
50 |
+
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
51 |
+
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
52 |
+
|
53 |
+
def copy_tensor(src, dst):
|
54 |
+
assert dst.numel() == src.numel()
|
55 |
+
if move_eos_to_beginning:
|
56 |
+
if eos_idx is None:
|
57 |
+
# if no eos_idx is specified, then use the last token in src
|
58 |
+
dst[0] = src[-1]
|
59 |
+
else:
|
60 |
+
dst[0] = eos_idx
|
61 |
+
dst[1:] = src[:-1]
|
62 |
+
else:
|
63 |
+
dst.copy_(src)
|
64 |
+
|
65 |
+
if values[0].dim() == 1:
|
66 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
67 |
+
elif values[0].dim() == 2:
|
68 |
+
assert move_eos_to_beginning is False
|
69 |
+
res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
|
70 |
+
else:
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
for i, v in enumerate(values):
|
74 |
+
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
75 |
+
return res
|
76 |
+
|
77 |
+
|
78 |
+
def load_indexed_dataset(
|
79 |
+
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
80 |
+
):
|
81 |
+
"""A helper function for loading indexed datasets.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
85 |
+
dictionary (~fairseq.data.Dictionary): data dictionary
|
86 |
+
dataset_impl (str, optional): which dataset implementation to use. If
|
87 |
+
not provided, it will be inferred automatically. For legacy indexed
|
88 |
+
data we use the 'cached' implementation by default.
|
89 |
+
combine (bool, optional): automatically load and combine multiple
|
90 |
+
datasets. For example, if *path* is 'data-bin/train', then we will
|
91 |
+
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
92 |
+
single ConcatDataset instance.
|
93 |
+
"""
|
94 |
+
import fairseq.data.indexed_dataset as indexed_dataset
|
95 |
+
from fairseq.data.concat_dataset import ConcatDataset
|
96 |
+
|
97 |
+
datasets = []
|
98 |
+
for k in itertools.count():
|
99 |
+
path_k = path + (str(k) if k > 0 else "")
|
100 |
+
try:
|
101 |
+
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
102 |
+
except Exception as e:
|
103 |
+
if "StorageException: [404] Path not found" in str(e):
|
104 |
+
logger.warning(f"path_k: {e} not found")
|
105 |
+
else:
|
106 |
+
raise e
|
107 |
+
|
108 |
+
dataset_impl_k = dataset_impl
|
109 |
+
if dataset_impl_k is None:
|
110 |
+
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
111 |
+
dataset = indexed_dataset.make_dataset(
|
112 |
+
path_k,
|
113 |
+
impl=dataset_impl_k or default,
|
114 |
+
fix_lua_indexing=True,
|
115 |
+
dictionary=dictionary,
|
116 |
+
)
|
117 |
+
if dataset is None:
|
118 |
+
break
|
119 |
+
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
120 |
+
datasets.append(dataset)
|
121 |
+
if not combine:
|
122 |
+
break
|
123 |
+
if len(datasets) == 0:
|
124 |
+
return None
|
125 |
+
elif len(datasets) == 1:
|
126 |
+
return datasets[0]
|
127 |
+
else:
|
128 |
+
return ConcatDataset(datasets)
|
129 |
+
|
130 |
+
|
131 |
+
@contextlib.contextmanager
|
132 |
+
def numpy_seed(seed, *addl_seeds):
|
133 |
+
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
134 |
+
restores the state afterward"""
|
135 |
+
if seed is None:
|
136 |
+
yield
|
137 |
+
return
|
138 |
+
if len(addl_seeds) > 0:
|
139 |
+
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
140 |
+
state = np.random.get_state()
|
141 |
+
np.random.seed(seed)
|
142 |
+
try:
|
143 |
+
yield
|
144 |
+
finally:
|
145 |
+
np.random.set_state(state)
|
146 |
+
|
147 |
+
|
148 |
+
def collect_filtered(function, iterable, filtered):
|
149 |
+
"""
|
150 |
+
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
function (callable): function that returns ``False`` for elements that
|
154 |
+
should be filtered
|
155 |
+
iterable (iterable): iterable to filter
|
156 |
+
filtered (list): list to store filtered elements
|
157 |
+
"""
|
158 |
+
for el in iterable:
|
159 |
+
if function(el):
|
160 |
+
yield el
|
161 |
+
else:
|
162 |
+
filtered.append(el)
|
163 |
+
|
164 |
+
|
165 |
+
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
166 |
+
def compare_leq(a, b):
|
167 |
+
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
168 |
+
|
169 |
+
def check_size(idx):
|
170 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
171 |
+
return size_fn(idx) <= max_positions
|
172 |
+
elif isinstance(max_positions, dict):
|
173 |
+
idx_size = size_fn(idx)
|
174 |
+
assert isinstance(idx_size, dict)
|
175 |
+
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
176 |
+
return all(
|
177 |
+
all(
|
178 |
+
a is None or b is None or a <= b
|
179 |
+
for a, b in zip(idx_size[key], max_positions[key])
|
180 |
+
)
|
181 |
+
for key in intersect_keys
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
# For MultiCorpusSampledDataset, will generalize it later
|
185 |
+
if not isinstance(size_fn(idx), Iterable):
|
186 |
+
return all(size_fn(idx) <= b for b in max_positions)
|
187 |
+
return all(
|
188 |
+
a is None or b is None or a <= b
|
189 |
+
for a, b in zip(size_fn(idx), max_positions)
|
190 |
+
)
|
191 |
+
|
192 |
+
ignored = []
|
193 |
+
itr = collect_filtered(check_size, indices, ignored)
|
194 |
+
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
195 |
+
return indices, ignored
|
196 |
+
|
197 |
+
|
198 |
+
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
199 |
+
"""
|
200 |
+
[deprecated] Filter indices based on their size.
|
201 |
+
Use `FairseqDataset::filter_indices_by_size` instead.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
indices (List[int]): ordered list of dataset indices
|
205 |
+
dataset (FairseqDataset): fairseq dataset instance
|
206 |
+
max_positions (tuple): filter elements larger than this size.
|
207 |
+
Comparisons are done component-wise.
|
208 |
+
raise_exception (bool, optional): if ``True``, raise an exception if
|
209 |
+
any elements are filtered (default: False).
|
210 |
+
"""
|
211 |
+
warnings.warn(
|
212 |
+
"data_utils.filter_by_size is deprecated. "
|
213 |
+
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
214 |
+
stacklevel=2,
|
215 |
+
)
|
216 |
+
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
217 |
+
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
218 |
+
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
219 |
+
indices = indices[dataset.sizes[indices] <= max_positions]
|
220 |
+
elif (
|
221 |
+
hasattr(dataset, "sizes")
|
222 |
+
and isinstance(dataset.sizes, list)
|
223 |
+
and len(dataset.sizes) == 1
|
224 |
+
):
|
225 |
+
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
226 |
+
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
227 |
+
else:
|
228 |
+
indices, ignored = _filter_by_size_dynamic(
|
229 |
+
indices, dataset.size, max_positions
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
233 |
+
|
234 |
+
if len(ignored) > 0 and raise_exception:
|
235 |
+
raise Exception(
|
236 |
+
(
|
237 |
+
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
238 |
+
"skip this example with --skip-invalid-size-inputs-valid-test"
|
239 |
+
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
240 |
+
)
|
241 |
+
if len(ignored) > 0:
|
242 |
+
logger.warning(
|
243 |
+
(
|
244 |
+
"{} samples have invalid sizes and will be skipped, "
|
245 |
+
"max_positions={}, first few sample ids={}"
|
246 |
+
).format(len(ignored), max_positions, ignored[:10])
|
247 |
+
)
|
248 |
+
return indices
|
249 |
+
|
250 |
+
|
251 |
+
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
252 |
+
"""Filter a list of sample indices. Remove those that are longer
|
253 |
+
than specified in max_sizes.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
indices (np.array): original array of sample indices
|
257 |
+
max_sizes (int or list[int] or tuple[int]): max sample size,
|
258 |
+
can be defined separately for src and tgt (then list or tuple)
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
np.array: filtered sample array
|
262 |
+
list: list of removed indices
|
263 |
+
"""
|
264 |
+
if max_sizes is None:
|
265 |
+
return indices, []
|
266 |
+
if type(max_sizes) in (int, float):
|
267 |
+
max_src_size, max_tgt_size = max_sizes, max_sizes
|
268 |
+
else:
|
269 |
+
max_src_size, max_tgt_size = max_sizes
|
270 |
+
if tgt_sizes is None:
|
271 |
+
ignored = indices[src_sizes[indices] > max_src_size]
|
272 |
+
else:
|
273 |
+
ignored = indices[
|
274 |
+
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
275 |
+
]
|
276 |
+
if len(ignored) > 0:
|
277 |
+
if tgt_sizes is None:
|
278 |
+
indices = indices[src_sizes[indices] <= max_src_size]
|
279 |
+
else:
|
280 |
+
indices = indices[
|
281 |
+
(src_sizes[indices] <= max_src_size)
|
282 |
+
& (tgt_sizes[indices] <= max_tgt_size)
|
283 |
+
]
|
284 |
+
return indices, ignored.tolist()
|
285 |
+
|
286 |
+
|
287 |
+
def batch_by_size(
|
288 |
+
indices,
|
289 |
+
num_tokens_fn,
|
290 |
+
num_tokens_vec=None,
|
291 |
+
max_tokens=None,
|
292 |
+
max_sentences=None,
|
293 |
+
required_batch_size_multiple=1,
|
294 |
+
fixed_shapes=None,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
Yield mini-batches of indices bucketed by size. Batches may contain
|
298 |
+
sequences of different lengths.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
indices (List[int]): ordered list of dataset indices
|
302 |
+
num_tokens_fn (callable): function that returns the number of tokens at
|
303 |
+
a given index
|
304 |
+
num_tokens_vec (List[int], optional): precomputed vector of the number
|
305 |
+
of tokens for each index in indices (to enable faster batch generation)
|
306 |
+
max_tokens (int, optional): max number of tokens in each batch
|
307 |
+
(default: None).
|
308 |
+
max_sentences (int, optional): max number of sentences in each
|
309 |
+
batch (default: None).
|
310 |
+
required_batch_size_multiple (int, optional): require batch size to
|
311 |
+
be less than N or a multiple of N (default: 1).
|
312 |
+
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
313 |
+
only be created with the given shapes. *max_sentences* and
|
314 |
+
*required_batch_size_multiple* will be ignored (default: None).
|
315 |
+
"""
|
316 |
+
try:
|
317 |
+
from fairseq.data.data_utils_fast import (
|
318 |
+
batch_by_size_fn,
|
319 |
+
batch_by_size_vec,
|
320 |
+
batch_fixed_shapes_fast,
|
321 |
+
)
|
322 |
+
except ImportError:
|
323 |
+
raise ImportError(
|
324 |
+
"Please build Cython components with: "
|
325 |
+
"`python setup.py build_ext --inplace`"
|
326 |
+
)
|
327 |
+
except ValueError:
|
328 |
+
raise ValueError(
|
329 |
+
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
330 |
+
)
|
331 |
+
|
332 |
+
# added int() to avoid TypeError: an integer is required
|
333 |
+
max_tokens = (
|
334 |
+
int(max_tokens) if max_tokens is not None else -1
|
335 |
+
)
|
336 |
+
max_sentences = max_sentences if max_sentences is not None else -1
|
337 |
+
bsz_mult = required_batch_size_multiple
|
338 |
+
|
339 |
+
if not isinstance(indices, np.ndarray):
|
340 |
+
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
341 |
+
|
342 |
+
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
343 |
+
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
344 |
+
|
345 |
+
if fixed_shapes is None:
|
346 |
+
if num_tokens_vec is None:
|
347 |
+
return batch_by_size_fn(
|
348 |
+
indices,
|
349 |
+
num_tokens_fn,
|
350 |
+
max_tokens,
|
351 |
+
max_sentences,
|
352 |
+
bsz_mult,
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
return batch_by_size_vec(
|
356 |
+
indices,
|
357 |
+
num_tokens_vec,
|
358 |
+
max_tokens,
|
359 |
+
max_sentences,
|
360 |
+
bsz_mult,
|
361 |
+
)
|
362 |
+
|
363 |
+
else:
|
364 |
+
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
365 |
+
sort_order = np.lexsort(
|
366 |
+
[
|
367 |
+
fixed_shapes[:, 1].argsort(), # length
|
368 |
+
fixed_shapes[:, 0].argsort(), # bsz
|
369 |
+
]
|
370 |
+
)
|
371 |
+
fixed_shapes_sorted = fixed_shapes[sort_order]
|
372 |
+
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
373 |
+
|
374 |
+
|
375 |
+
def post_process(sentence: str, symbol: str):
|
376 |
+
if symbol == "sentencepiece":
|
377 |
+
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
378 |
+
elif symbol == "wordpiece":
|
379 |
+
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
380 |
+
elif symbol == "letter":
|
381 |
+
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
382 |
+
elif symbol == "silence":
|
383 |
+
import re
|
384 |
+
sentence = sentence.replace("<SIL>", "")
|
385 |
+
sentence = re.sub(' +', ' ', sentence).strip()
|
386 |
+
elif symbol == "_EOW":
|
387 |
+
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
388 |
+
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
389 |
+
if symbol == "subword_nmt":
|
390 |
+
symbol = "@@ "
|
391 |
+
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
392 |
+
elif symbol == "none":
|
393 |
+
pass
|
394 |
+
elif symbol is not None:
|
395 |
+
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
396 |
+
return sentence
|
397 |
+
|
398 |
+
|
399 |
+
def compute_mask_indices(
|
400 |
+
shape: Tuple[int, int],
|
401 |
+
padding_mask: Optional[torch.Tensor],
|
402 |
+
mask_prob: float,
|
403 |
+
mask_length: int,
|
404 |
+
mask_type: str = "static",
|
405 |
+
mask_other: float = 0.0,
|
406 |
+
min_masks: int = 0,
|
407 |
+
no_overlap: bool = False,
|
408 |
+
min_space: int = 0,
|
409 |
+
) -> np.ndarray:
|
410 |
+
"""
|
411 |
+
Computes random mask spans for a given shape
|
412 |
+
|
413 |
+
Args:
|
414 |
+
shape: the the shape for which to compute masks.
|
415 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
416 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
417 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
418 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
419 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
420 |
+
mask_type: how to compute mask lengths
|
421 |
+
static = fixed size
|
422 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
423 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
424 |
+
poisson = sample from possion distribution with lambda = mask length
|
425 |
+
min_masks: minimum number of masked spans
|
426 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
427 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
428 |
+
"""
|
429 |
+
|
430 |
+
bsz, all_sz = shape
|
431 |
+
mask = np.full((bsz, all_sz), False)
|
432 |
+
|
433 |
+
all_num_mask = int(
|
434 |
+
# add a random number for probabilistic rounding
|
435 |
+
mask_prob * all_sz / float(mask_length)
|
436 |
+
+ np.random.rand()
|
437 |
+
)
|
438 |
+
|
439 |
+
all_num_mask = max(min_masks, all_num_mask)
|
440 |
+
|
441 |
+
mask_idcs = []
|
442 |
+
for i in range(bsz):
|
443 |
+
if padding_mask is not None:
|
444 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
445 |
+
num_mask = int(
|
446 |
+
# add a random number for probabilistic rounding
|
447 |
+
mask_prob * sz / float(mask_length)
|
448 |
+
+ np.random.rand()
|
449 |
+
)
|
450 |
+
num_mask = max(min_masks, num_mask)
|
451 |
+
else:
|
452 |
+
sz = all_sz
|
453 |
+
num_mask = all_num_mask
|
454 |
+
|
455 |
+
if mask_type == "static":
|
456 |
+
lengths = np.full(num_mask, mask_length)
|
457 |
+
elif mask_type == "uniform":
|
458 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
459 |
+
elif mask_type == "normal":
|
460 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
461 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
462 |
+
elif mask_type == "poisson":
|
463 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
464 |
+
lengths = [int(round(x)) for x in lengths]
|
465 |
+
else:
|
466 |
+
raise Exception("unknown mask selection " + mask_type)
|
467 |
+
|
468 |
+
if sum(lengths) == 0:
|
469 |
+
lengths[0] = min(mask_length, sz - 1)
|
470 |
+
|
471 |
+
if no_overlap:
|
472 |
+
mask_idc = []
|
473 |
+
|
474 |
+
def arrange(s, e, length, keep_length):
|
475 |
+
span_start = np.random.randint(s, e - length)
|
476 |
+
mask_idc.extend(span_start + i for i in range(length))
|
477 |
+
|
478 |
+
new_parts = []
|
479 |
+
if span_start - s - min_space >= keep_length:
|
480 |
+
new_parts.append((s, span_start - min_space + 1))
|
481 |
+
if e - span_start - keep_length - min_space > keep_length:
|
482 |
+
new_parts.append((span_start + length + min_space, e))
|
483 |
+
return new_parts
|
484 |
+
|
485 |
+
parts = [(0, sz)]
|
486 |
+
min_length = min(lengths)
|
487 |
+
for length in sorted(lengths, reverse=True):
|
488 |
+
lens = np.fromiter(
|
489 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
490 |
+
np.int,
|
491 |
+
)
|
492 |
+
l_sum = np.sum(lens)
|
493 |
+
if l_sum == 0:
|
494 |
+
break
|
495 |
+
probs = lens / np.sum(lens)
|
496 |
+
c = np.random.choice(len(parts), p=probs)
|
497 |
+
s, e = parts.pop(c)
|
498 |
+
parts.extend(arrange(s, e, length, min_length))
|
499 |
+
mask_idc = np.asarray(mask_idc)
|
500 |
+
else:
|
501 |
+
min_len = min(lengths)
|
502 |
+
if sz - min_len <= num_mask:
|
503 |
+
min_len = sz - num_mask - 1
|
504 |
+
|
505 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
506 |
+
|
507 |
+
mask_idc = np.asarray(
|
508 |
+
[
|
509 |
+
mask_idc[j] + offset
|
510 |
+
for j in range(len(mask_idc))
|
511 |
+
for offset in range(lengths[j])
|
512 |
+
]
|
513 |
+
)
|
514 |
+
|
515 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
516 |
+
|
517 |
+
min_len = min([len(m) for m in mask_idcs])
|
518 |
+
for i, mask_idc in enumerate(mask_idcs):
|
519 |
+
if len(mask_idc) > min_len:
|
520 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
521 |
+
mask[i, mask_idc] = True
|
522 |
+
|
523 |
+
return mask
|
524 |
+
|
525 |
+
|
526 |
+
def get_mem_usage():
|
527 |
+
try:
|
528 |
+
import psutil
|
529 |
+
|
530 |
+
mb = 1024 * 1024
|
531 |
+
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
532 |
+
except ImportError:
|
533 |
+
return "N/A"
|
534 |
+
|
535 |
+
|
536 |
+
# lens: torch.LongTensor
|
537 |
+
# returns: torch.BoolTensor
|
538 |
+
def lengths_to_padding_mask(lens):
|
539 |
+
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
540 |
+
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
541 |
+
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
542 |
+
return mask
|
543 |
+
|
544 |
+
|
545 |
+
# lens: torch.LongTensor
|
546 |
+
# returns: torch.BoolTensor
|
547 |
+
def lengths_to_mask(lens):
|
548 |
+
return ~lengths_to_padding_mask(lens)
|
549 |
+
|
550 |
+
|
551 |
+
def get_buckets(sizes, num_buckets):
|
552 |
+
buckets = np.unique(
|
553 |
+
np.percentile(
|
554 |
+
sizes,
|
555 |
+
np.linspace(0, 100, num_buckets + 1),
|
556 |
+
interpolation='lower',
|
557 |
+
)[1:]
|
558 |
+
)
|
559 |
+
return buckets
|
560 |
+
|
561 |
+
|
562 |
+
def get_bucketed_sizes(orig_sizes, buckets):
|
563 |
+
sizes = np.copy(orig_sizes)
|
564 |
+
assert np.min(sizes) >= 0
|
565 |
+
start_val = -1
|
566 |
+
for end_val in buckets:
|
567 |
+
mask = (sizes > start_val) & (sizes <= end_val)
|
568 |
+
sizes[mask] = end_val
|
569 |
+
start_val = end_val
|
570 |
+
return sizes
|
571 |
+
|
572 |
+
|
573 |
+
|
574 |
+
def _find_extra_valid_paths(dataset_path: str) -> set:
|
575 |
+
paths = utils.split_paths(dataset_path)
|
576 |
+
all_valid_paths = set()
|
577 |
+
for sub_dir in paths:
|
578 |
+
contents = PathManager.ls(sub_dir)
|
579 |
+
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
580 |
+
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
581 |
+
# Remove .bin, .idx etc
|
582 |
+
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
583 |
+
return roots
|
584 |
+
|
585 |
+
|
586 |
+
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
587 |
+
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
588 |
+
if (
|
589 |
+
train_cfg.dataset.ignore_unused_valid_subsets
|
590 |
+
or train_cfg.dataset.combine_valid_subsets
|
591 |
+
or train_cfg.dataset.disable_validation
|
592 |
+
or not hasattr(train_cfg.task, "data")
|
593 |
+
):
|
594 |
+
return
|
595 |
+
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
596 |
+
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
597 |
+
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
598 |
+
if ignored_paths:
|
599 |
+
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
600 |
+
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
601 |
+
raise ValueError(msg)
|
data/file_dataset.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
|
6 |
+
class FileDataset:
|
7 |
+
def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
|
8 |
+
self.file_path = file_path
|
9 |
+
assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
|
10 |
+
|
11 |
+
self.separator = separator
|
12 |
+
if selected_col_ids is None:
|
13 |
+
# default to all fields
|
14 |
+
self.selected_col_ids = list(
|
15 |
+
range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
|
16 |
+
else:
|
17 |
+
self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
|
18 |
+
if dtypes is None:
|
19 |
+
# default to str
|
20 |
+
self.dtypes = [str for col_id in self.selected_col_ids]
|
21 |
+
else:
|
22 |
+
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
|
23 |
+
assert len(self.dtypes) == len(self.selected_col_ids)
|
24 |
+
|
25 |
+
self.data_cnt = 0
|
26 |
+
try:
|
27 |
+
self.slice_id = torch.distributed.get_rank()
|
28 |
+
self.slice_count = torch.distributed.get_world_size()
|
29 |
+
except Exception:
|
30 |
+
self.slice_id = 0
|
31 |
+
self.slice_count = 1
|
32 |
+
self.cached_index = cached_index
|
33 |
+
self._init_seek_index()
|
34 |
+
self._reader = self._get_reader()
|
35 |
+
print("file {} slice_id {} row count {} total row count {}".format(
|
36 |
+
self.file_path, self.slice_id, self.row_count, self.total_row_count)
|
37 |
+
)
|
38 |
+
|
39 |
+
def _init_seek_index(self):
|
40 |
+
if self.cached_index:
|
41 |
+
cache_path = "{}.index".format(self.file_path)
|
42 |
+
assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
|
43 |
+
self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
|
44 |
+
print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
|
45 |
+
self.file_path, self.slice_id))
|
46 |
+
else:
|
47 |
+
# make an iteration over the file to get row_count and line_idx-to-offset mapping
|
48 |
+
fp = open(self.file_path, "r")
|
49 |
+
print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
|
50 |
+
self.file_path, self.slice_id))
|
51 |
+
self.total_row_count = 0
|
52 |
+
offset = 0
|
53 |
+
self.lineid_to_offset = []
|
54 |
+
for line in fp:
|
55 |
+
self.lineid_to_offset.append(offset)
|
56 |
+
self.total_row_count += 1
|
57 |
+
offset += len(line.encode('utf-8'))
|
58 |
+
self._compute_start_pos_and_row_count()
|
59 |
+
print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
|
60 |
+
self.file_path, self.slice_id))
|
61 |
+
|
62 |
+
def _compute_start_pos_and_row_count(self):
|
63 |
+
self.row_count = self.total_row_count // self.slice_count
|
64 |
+
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
|
65 |
+
self.row_count += 1
|
66 |
+
self.start_pos = self.row_count * self.slice_id
|
67 |
+
else:
|
68 |
+
self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
|
69 |
+
|
70 |
+
def _get_reader(self):
|
71 |
+
fp = open(self.file_path, "r")
|
72 |
+
fp.seek(self.lineid_to_offset[self.start_pos])
|
73 |
+
return fp
|
74 |
+
|
75 |
+
def _seek(self, offset=0):
|
76 |
+
try:
|
77 |
+
print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
|
78 |
+
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
|
79 |
+
self.data_cnt = offset
|
80 |
+
except Exception:
|
81 |
+
print("slice_id {} seek offset {}".format(self.slice_id, offset))
|
82 |
+
self._reader.seek(self.lineid_to_offset[offset])
|
83 |
+
self.data_cnt = offset
|
84 |
+
|
85 |
+
def __del__(self):
|
86 |
+
self._reader.close()
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return self.row_count
|
90 |
+
|
91 |
+
def get_total_row_count(self):
|
92 |
+
return self.total_row_count
|
93 |
+
|
94 |
+
def __getitem__(self, index):
|
95 |
+
if self.data_cnt == self.row_count:
|
96 |
+
print("reach the end of datafile, start a new reader")
|
97 |
+
self.data_cnt = 0
|
98 |
+
self._reader = self._get_reader()
|
99 |
+
column_l = self._reader.readline().rstrip("\n").split(self.separator)
|
100 |
+
self.data_cnt += 1
|
101 |
+
column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
|
102 |
+
return column_l
|
data/mm_data/__init__.py
ADDED
File without changes
|
data/mm_data/caption_dataset.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import warnings
|
9 |
+
import string
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import base64
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
from PIL import Image, ImageFile
|
17 |
+
|
18 |
+
from data import data_utils
|
19 |
+
from data.ofa_dataset import OFADataset
|
20 |
+
|
21 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
22 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
23 |
+
Image.MAX_IMAGE_PIXELS = None
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
27 |
+
|
28 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
+
|
31 |
+
|
32 |
+
def collate(samples, pad_idx, eos_idx):
|
33 |
+
if len(samples) == 0:
|
34 |
+
return {}
|
35 |
+
|
36 |
+
def merge(key):
|
37 |
+
return data_utils.collate_tokens(
|
38 |
+
[s[key] for s in samples],
|
39 |
+
pad_idx,
|
40 |
+
eos_idx=eos_idx,
|
41 |
+
)
|
42 |
+
|
43 |
+
id = np.array([s["id"] for s in samples])
|
44 |
+
src_tokens = merge("source")
|
45 |
+
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
46 |
+
|
47 |
+
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
48 |
+
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
49 |
+
|
50 |
+
prev_output_tokens = None
|
51 |
+
target = None
|
52 |
+
if samples[0].get("target", None) is not None:
|
53 |
+
target = merge("target")
|
54 |
+
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
55 |
+
ntokens = tgt_lengths.sum().item()
|
56 |
+
|
57 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
58 |
+
prev_output_tokens = merge("prev_output_tokens")
|
59 |
+
else:
|
60 |
+
ntokens = src_lengths.sum().item()
|
61 |
+
|
62 |
+
batch = {
|
63 |
+
"id": id,
|
64 |
+
"nsentences": len(samples),
|
65 |
+
"ntokens": ntokens,
|
66 |
+
"net_input": {
|
67 |
+
"src_tokens": src_tokens,
|
68 |
+
"src_lengths": src_lengths,
|
69 |
+
"patch_images": patch_images,
|
70 |
+
"patch_masks": patch_masks,
|
71 |
+
"prev_output_tokens": prev_output_tokens
|
72 |
+
},
|
73 |
+
"target": target,
|
74 |
+
}
|
75 |
+
|
76 |
+
return batch
|
77 |
+
|
78 |
+
|
79 |
+
class CaptionDataset(OFADataset):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
split,
|
83 |
+
dataset,
|
84 |
+
bpe,
|
85 |
+
src_dict,
|
86 |
+
tgt_dict=None,
|
87 |
+
max_src_length=128,
|
88 |
+
max_tgt_length=30,
|
89 |
+
patch_image_size=224,
|
90 |
+
imagenet_default_mean_and_std=False,
|
91 |
+
scst=False
|
92 |
+
):
|
93 |
+
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
94 |
+
self.max_src_length = max_src_length
|
95 |
+
self.max_tgt_length = max_tgt_length
|
96 |
+
self.patch_image_size = patch_image_size
|
97 |
+
self.scst = scst
|
98 |
+
|
99 |
+
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
100 |
+
|
101 |
+
if imagenet_default_mean_and_std:
|
102 |
+
mean = IMAGENET_DEFAULT_MEAN
|
103 |
+
std = IMAGENET_DEFAULT_STD
|
104 |
+
else:
|
105 |
+
mean = [0.5, 0.5, 0.5]
|
106 |
+
std = [0.5, 0.5, 0.5]
|
107 |
+
|
108 |
+
self.patch_resize_transform = transforms.Compose([
|
109 |
+
lambda image: image.convert("RGB"),
|
110 |
+
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
111 |
+
transforms.ToTensor(),
|
112 |
+
transforms.Normalize(mean=mean, std=std),
|
113 |
+
])
|
114 |
+
|
115 |
+
def __getitem__(self, index):
|
116 |
+
uniq_id, image, caption = self.dataset[index]
|
117 |
+
|
118 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
119 |
+
patch_image = self.patch_resize_transform(image)
|
120 |
+
patch_mask = torch.tensor([True])
|
121 |
+
|
122 |
+
if self.split == 'train' and not self.scst:
|
123 |
+
caption = caption.translate(self.transtab).strip()
|
124 |
+
caption_token_list = caption.strip().split()
|
125 |
+
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
|
126 |
+
else:
|
127 |
+
caption = ' '.join(caption.strip().split())
|
128 |
+
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
|
129 |
+
tgt_caption = '&&'.join(caption_list)
|
130 |
+
src_item = self.encode_text(" what does the image describe?")
|
131 |
+
tgt_item = self.encode_text(" {}".format(tgt_caption))
|
132 |
+
|
133 |
+
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
134 |
+
target_item = torch.cat([tgt_item, self.eos_item])
|
135 |
+
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
136 |
+
|
137 |
+
example = {
|
138 |
+
"id": uniq_id,
|
139 |
+
"source": src_item,
|
140 |
+
"patch_image": patch_image,
|
141 |
+
"patch_mask": patch_mask,
|
142 |
+
"target": target_item,
|
143 |
+
"prev_output_tokens": prev_output_item
|
144 |
+
}
|
145 |
+
return example
|
146 |
+
|
147 |
+
def collater(self, samples, pad_to_length=None):
|
148 |
+
"""Merge a list of samples to form a mini-batch.
|
149 |
+
Args:
|
150 |
+
samples (List[dict]): samples to collate
|
151 |
+
Returns:
|
152 |
+
dict: a mini-batch with the following keys:
|
153 |
+
"""
|
154 |
+
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
data/mm_data/refcoco_dataset.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import base64
|
13 |
+
import utils.transforms as T
|
14 |
+
|
15 |
+
from PIL import Image, ImageFile
|
16 |
+
|
17 |
+
from data import data_utils
|
18 |
+
from data.ofa_dataset import OFADataset
|
19 |
+
|
20 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
21 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
22 |
+
Image.MAX_IMAGE_PIXELS = None
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
26 |
+
|
27 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
28 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
29 |
+
|
30 |
+
|
31 |
+
def collate(samples, pad_idx, eos_idx):
|
32 |
+
if len(samples) == 0:
|
33 |
+
return {}
|
34 |
+
|
35 |
+
def merge(key):
|
36 |
+
return data_utils.collate_tokens(
|
37 |
+
[s[key] for s in samples],
|
38 |
+
pad_idx,
|
39 |
+
eos_idx=eos_idx,
|
40 |
+
)
|
41 |
+
|
42 |
+
id = np.array([s["id"] for s in samples])
|
43 |
+
src_tokens = merge("source")
|
44 |
+
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
45 |
+
|
46 |
+
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
47 |
+
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
48 |
+
|
49 |
+
w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
|
50 |
+
h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
|
51 |
+
region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
|
52 |
+
|
53 |
+
prev_output_tokens = None
|
54 |
+
target = None
|
55 |
+
if samples[0].get("target", None) is not None:
|
56 |
+
target = merge("target")
|
57 |
+
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
58 |
+
ntokens = tgt_lengths.sum().item()
|
59 |
+
|
60 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
61 |
+
prev_output_tokens = merge("prev_output_tokens")
|
62 |
+
else:
|
63 |
+
ntokens = src_lengths.sum().item()
|
64 |
+
|
65 |
+
batch = {
|
66 |
+
"id": id,
|
67 |
+
"nsentences": len(samples),
|
68 |
+
"ntokens": ntokens,
|
69 |
+
"net_input": {
|
70 |
+
"src_tokens": src_tokens,
|
71 |
+
"src_lengths": src_lengths,
|
72 |
+
"patch_images": patch_images,
|
73 |
+
"patch_masks": patch_masks,
|
74 |
+
"prev_output_tokens": prev_output_tokens
|
75 |
+
},
|
76 |
+
"target": target,
|
77 |
+
"w_resize_ratios": w_resize_ratios,
|
78 |
+
"h_resize_ratios": h_resize_ratios,
|
79 |
+
"region_coords": region_coords
|
80 |
+
}
|
81 |
+
|
82 |
+
return batch
|
83 |
+
|
84 |
+
|
85 |
+
class RefcocoDataset(OFADataset):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
split,
|
89 |
+
dataset,
|
90 |
+
bpe,
|
91 |
+
src_dict,
|
92 |
+
tgt_dict=None,
|
93 |
+
max_src_length=80,
|
94 |
+
max_tgt_length=30,
|
95 |
+
patch_image_size=512,
|
96 |
+
imagenet_default_mean_and_std=False,
|
97 |
+
num_bins=1000,
|
98 |
+
max_image_size=512
|
99 |
+
):
|
100 |
+
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
101 |
+
self.max_src_length = max_src_length
|
102 |
+
self.max_tgt_length = max_tgt_length
|
103 |
+
self.patch_image_size = patch_image_size
|
104 |
+
self.num_bins = num_bins
|
105 |
+
|
106 |
+
if imagenet_default_mean_and_std:
|
107 |
+
mean = IMAGENET_DEFAULT_MEAN
|
108 |
+
std = IMAGENET_DEFAULT_STD
|
109 |
+
else:
|
110 |
+
mean = [0.5, 0.5, 0.5]
|
111 |
+
std = [0.5, 0.5, 0.5]
|
112 |
+
|
113 |
+
# for positioning
|
114 |
+
self.positioning_transform = T.Compose([
|
115 |
+
T.RandomResize([patch_image_size], max_size=patch_image_size),
|
116 |
+
T.ToTensor(),
|
117 |
+
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
|
118 |
+
])
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
uniq_id, base64_str, text, region_coord = self.dataset[index]
|
122 |
+
|
123 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
|
124 |
+
w, h = image.size
|
125 |
+
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
|
126 |
+
x0, y0, x1, y1 = region_coord.strip().split(',')
|
127 |
+
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
|
128 |
+
boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
|
129 |
+
boxes_target["labels"] = np.array([0])
|
130 |
+
boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
|
131 |
+
|
132 |
+
patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
|
133 |
+
resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
|
134 |
+
patch_mask = torch.tensor([True])
|
135 |
+
quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
|
136 |
+
quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
|
137 |
+
quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
|
138 |
+
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
|
139 |
+
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
|
140 |
+
src_caption = self.pre_caption(text, self.max_src_length)
|
141 |
+
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
|
142 |
+
tgt_item = self.encode_text(region_coord, use_bpe=False)
|
143 |
+
|
144 |
+
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
145 |
+
target_item = torch.cat([tgt_item, self.eos_item])
|
146 |
+
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
147 |
+
|
148 |
+
example = {
|
149 |
+
"id": uniq_id,
|
150 |
+
"source": src_item,
|
151 |
+
"patch_image": patch_image,
|
152 |
+
"patch_mask": patch_mask,
|
153 |
+
"target": target_item,
|
154 |
+
"prev_output_tokens": prev_output_item,
|
155 |
+
"w_resize_ratio": resize_w / w,
|
156 |
+
"h_resize_ratio": resize_h / h,
|
157 |
+
"region_coord": region
|
158 |
+
}
|
159 |
+
return example
|
160 |
+
|
161 |
+
def collater(self, samples, pad_to_length=None):
|
162 |
+
"""Merge a list of samples to form a mini-batch.
|
163 |
+
Args:
|
164 |
+
samples (List[dict]): samples to collate
|
165 |
+
Returns:
|
166 |
+
dict: a mini-batch with the following keys:
|
167 |
+
"""
|
168 |
+
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
data/mm_data/vqa_gen_dataset.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import base64
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
from PIL import Image, ImageFile
|
16 |
+
|
17 |
+
from data import data_utils
|
18 |
+
from data.ofa_dataset import OFADataset
|
19 |
+
|
20 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
21 |
+
ImageFile.MAX_IMAGE_PIXELS = None
|
22 |
+
Image.MAX_IMAGE_PIXELS = None
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
26 |
+
|
27 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
28 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
29 |
+
|
30 |
+
|
31 |
+
def collate(samples, pad_idx, eos_idx):
|
32 |
+
if len(samples) == 0:
|
33 |
+
return {}
|
34 |
+
|
35 |
+
def merge(key):
|
36 |
+
return data_utils.collate_tokens(
|
37 |
+
[s[key] for s in samples],
|
38 |
+
pad_idx,
|
39 |
+
eos_idx=eos_idx,
|
40 |
+
)
|
41 |
+
|
42 |
+
id = np.array([s["id"] for s in samples])
|
43 |
+
src_tokens = merge("source")
|
44 |
+
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
45 |
+
|
46 |
+
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
47 |
+
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
48 |
+
|
49 |
+
conf = None
|
50 |
+
if samples[0].get("conf", None) is not None:
|
51 |
+
conf = torch.cat([s['conf'] for s in samples], dim=0)
|
52 |
+
|
53 |
+
ref_dict = None
|
54 |
+
if samples[0].get("ref_dict", None) is not None:
|
55 |
+
ref_dict = np.array([s['ref_dict'] for s in samples])
|
56 |
+
|
57 |
+
constraint_masks = None
|
58 |
+
if samples[0].get("constraint_mask", None) is not None:
|
59 |
+
constraint_masks = merge("constraint_mask")
|
60 |
+
|
61 |
+
decoder_prompts = None
|
62 |
+
if samples[0].get("decoder_prompt", None) is not None:
|
63 |
+
decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
|
64 |
+
|
65 |
+
prev_output_tokens = None
|
66 |
+
target = None
|
67 |
+
if samples[0].get("target", None) is not None:
|
68 |
+
target = merge("target")
|
69 |
+
tgt_lengths = torch.LongTensor(
|
70 |
+
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
71 |
+
)
|
72 |
+
ntokens = tgt_lengths.sum().item()
|
73 |
+
|
74 |
+
if samples[0].get("prev_output_tokens", None) is not None:
|
75 |
+
prev_output_tokens = merge("prev_output_tokens")
|
76 |
+
else:
|
77 |
+
ntokens = src_lengths.sum().item()
|
78 |
+
|
79 |
+
batch = {
|
80 |
+
"id": id,
|
81 |
+
"nsentences": len(samples),
|
82 |
+
"ntokens": ntokens,
|
83 |
+
"net_input": {
|
84 |
+
"src_tokens": src_tokens,
|
85 |
+
"src_lengths": src_lengths,
|
86 |
+
"patch_images": patch_images,
|
87 |
+
"patch_masks": patch_masks,
|
88 |
+
"prev_output_tokens": prev_output_tokens
|
89 |
+
},
|
90 |
+
"conf": conf,
|
91 |
+
"ref_dict": ref_dict,
|
92 |
+
"constraint_masks": constraint_masks,
|
93 |
+
"decoder_prompts": decoder_prompts,
|
94 |
+
"target": target
|
95 |
+
}
|
96 |
+
|
97 |
+
return batch
|
98 |
+
|
99 |
+
|
100 |
+
class VqaGenDataset(OFADataset):
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
split,
|
104 |
+
dataset,
|
105 |
+
bpe,
|
106 |
+
src_dict,
|
107 |
+
tgt_dict=None,
|
108 |
+
max_src_length=128,
|
109 |
+
max_object_length=30,
|
110 |
+
max_tgt_length=30,
|
111 |
+
patch_image_size=224,
|
112 |
+
add_object=False,
|
113 |
+
constraint_trie=None,
|
114 |
+
imagenet_default_mean_and_std=False,
|
115 |
+
prompt_type="none"
|
116 |
+
):
|
117 |
+
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
118 |
+
self.max_src_length = max_src_length
|
119 |
+
self.max_object_length = max_object_length
|
120 |
+
self.max_tgt_length = max_tgt_length
|
121 |
+
self.patch_image_size = patch_image_size
|
122 |
+
|
123 |
+
self.add_object = add_object
|
124 |
+
self.constraint_trie = constraint_trie
|
125 |
+
self.prompt_type = prompt_type
|
126 |
+
|
127 |
+
if imagenet_default_mean_and_std:
|
128 |
+
mean = IMAGENET_DEFAULT_MEAN
|
129 |
+
std = IMAGENET_DEFAULT_STD
|
130 |
+
else:
|
131 |
+
mean = [0.5, 0.5, 0.5]
|
132 |
+
std = [0.5, 0.5, 0.5]
|
133 |
+
|
134 |
+
self.patch_resize_transform = transforms.Compose([
|
135 |
+
lambda image: image.convert("RGB"),
|
136 |
+
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
137 |
+
transforms.ToTensor(),
|
138 |
+
transforms.Normalize(mean=mean, std=std),
|
139 |
+
])
|
140 |
+
|
141 |
+
def __getitem__(self, index):
|
142 |
+
item = self.dataset[index]
|
143 |
+
if len(item) == 5:
|
144 |
+
uniq_id, image, question, ref, predict_objects = item
|
145 |
+
else:
|
146 |
+
uniq_id, image, question, ref, predict_objects, caption = item
|
147 |
+
|
148 |
+
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
149 |
+
patch_image = self.patch_resize_transform(image)
|
150 |
+
patch_mask = torch.tensor([True])
|
151 |
+
|
152 |
+
question = self.pre_question(question, self.max_src_length)
|
153 |
+
question = question + '?' if not question.endswith('?') else question
|
154 |
+
src_item = self.encode_text(' {}'.format(question))
|
155 |
+
|
156 |
+
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
|
157 |
+
answer = max(ref_dict, key=ref_dict.get)
|
158 |
+
conf = torch.tensor([ref_dict[answer]])
|
159 |
+
tgt_item = self.encode_text(" {}".format(answer))
|
160 |
+
|
161 |
+
if self.add_object and predict_objects is not None:
|
162 |
+
predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
|
163 |
+
predict_object_item = self.encode_text(" object: {}".format(predict_object_seq))
|
164 |
+
src_item = torch.cat([src_item, predict_object_item])
|
165 |
+
|
166 |
+
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
167 |
+
if self.prompt_type == 'none':
|
168 |
+
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
169 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
170 |
+
decoder_prompt = self.bos_item
|
171 |
+
elif self.prompt_type == 'src':
|
172 |
+
prev_output_item = torch.cat([src_item, tgt_item])
|
173 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
174 |
+
decoder_prompt = src_item
|
175 |
+
elif self.prompt_type == 'prev_output':
|
176 |
+
prev_output_item = torch.cat([src_item[:-1], tgt_item])
|
177 |
+
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
178 |
+
decoder_prompt = src_item[:-1]
|
179 |
+
else:
|
180 |
+
raise NotImplementedError
|
181 |
+
target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
|
182 |
+
|
183 |
+
example = {
|
184 |
+
"id": uniq_id,
|
185 |
+
"source": src_item,
|
186 |
+
"patch_image": patch_image,
|
187 |
+
"patch_mask": patch_mask,
|
188 |
+
"target": target_item,
|
189 |
+
"prev_output_tokens": prev_output_item,
|
190 |
+
"decoder_prompt": decoder_prompt,
|
191 |
+
"ref_dict": ref_dict,
|
192 |
+
"conf": conf,
|
193 |
+
}
|
194 |
+
if self.constraint_trie is not None:
|
195 |
+
constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
|
196 |
+
start_idx = len(target_item) - len(tgt_item) - 1
|
197 |
+
for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
|
198 |
+
constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
|
199 |
+
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
|
200 |
+
constraint_mask[i][constraint_nodes] = True
|
201 |
+
example["constraint_mask"] = constraint_mask
|
202 |
+
return example
|
203 |
+
|
204 |
+
def collater(self, samples, pad_to_length=None):
|
205 |
+
"""Merge a list of samples to form a mini-batch.
|
206 |
+
Args:
|
207 |
+
samples (List[dict]): samples to collate
|
208 |
+
Returns:
|
209 |
+
dict: a mini-batch with the following keys:
|
210 |
+
"""
|
211 |
+
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
data/ofa_dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import torch.utils.data
|
4 |
+
from fairseq.data import FairseqDataset
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class OFADataset(FairseqDataset):
|
10 |
+
def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
|
11 |
+
self.split = split
|
12 |
+
self.dataset = dataset
|
13 |
+
self.bpe = bpe
|
14 |
+
self.src_dict = src_dict
|
15 |
+
self.tgt_dict = tgt_dict
|
16 |
+
|
17 |
+
self.bos = src_dict.bos()
|
18 |
+
self.eos = src_dict.eos()
|
19 |
+
self.pad = src_dict.pad()
|
20 |
+
self.bos_item = torch.LongTensor([self.bos])
|
21 |
+
self.eos_item = torch.LongTensor([self.eos])
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.dataset)
|
25 |
+
|
26 |
+
def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
|
27 |
+
s = self.tgt_dict.encode_line(
|
28 |
+
line=self.bpe.encode(text) if use_bpe else text,
|
29 |
+
add_if_not_exist=False,
|
30 |
+
append_eos=False
|
31 |
+
).long()
|
32 |
+
if length is not None:
|
33 |
+
s = s[:length]
|
34 |
+
if append_bos:
|
35 |
+
s = torch.cat([self.bos_item, s])
|
36 |
+
if append_eos:
|
37 |
+
s = torch.cat([s, self.eos_item])
|
38 |
+
return s
|
39 |
+
|
40 |
+
def pre_question(self, question, max_ques_words):
|
41 |
+
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
|
42 |
+
|
43 |
+
question = re.sub(
|
44 |
+
r"\s{2,}",
|
45 |
+
' ',
|
46 |
+
question,
|
47 |
+
)
|
48 |
+
question = question.rstrip('\n')
|
49 |
+
question = question.strip(' ')
|
50 |
+
|
51 |
+
# truncate question
|
52 |
+
question_words = question.split(' ')
|
53 |
+
if len(question_words) > max_ques_words:
|
54 |
+
question = ' '.join(question_words[:max_ques_words])
|
55 |
+
|
56 |
+
return question
|
57 |
+
|
58 |
+
def pre_caption(self, caption, max_words):
|
59 |
+
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
60 |
+
|
61 |
+
caption = re.sub(
|
62 |
+
r"\s{2,}",
|
63 |
+
' ',
|
64 |
+
caption,
|
65 |
+
)
|
66 |
+
caption = caption.rstrip('\n')
|
67 |
+
caption = caption.strip(' ')
|
68 |
+
|
69 |
+
# truncate caption
|
70 |
+
caption_words = caption.split(' ')
|
71 |
+
if len(caption_words) > max_words:
|
72 |
+
caption = ' '.join(caption_words[:max_words])
|
73 |
+
|
74 |
+
return caption
|
datasets.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Datasets
|
2 |
+
|
3 |
+
We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
|
4 |
+
|
5 |
+
## Finetuning
|
6 |
+
|
7 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
|
8 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
|
9 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
|
10 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
|
evaluate.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import json
|
11 |
+
from itertools import chain
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
from fairseq import distributed_utils, options, tasks, utils
|
17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
18 |
+
from fairseq.logging import progress_bar
|
19 |
+
from fairseq.utils import reset_logging
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
|
22 |
+
from utils import checkpoint_utils
|
23 |
+
from utils.eval_utils import eval_step
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
28 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
29 |
+
stream=sys.stdout,
|
30 |
+
)
|
31 |
+
logger = logging.getLogger("ofa.evaluate")
|
32 |
+
|
33 |
+
|
34 |
+
def apply_half(t):
|
35 |
+
if t.dtype is torch.float32:
|
36 |
+
return t.to(dtype=torch.half)
|
37 |
+
return t
|
38 |
+
|
39 |
+
|
40 |
+
def main(cfg: DictConfig, **kwargs):
|
41 |
+
utils.import_user_module(cfg.common)
|
42 |
+
|
43 |
+
reset_logging()
|
44 |
+
logger.info(cfg)
|
45 |
+
|
46 |
+
assert (
|
47 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
48 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
49 |
+
|
50 |
+
# Fix seed for stochastic decoding
|
51 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
52 |
+
np.random.seed(cfg.common.seed)
|
53 |
+
utils.set_torch_seed(cfg.common.seed)
|
54 |
+
|
55 |
+
use_fp16 = cfg.common.fp16
|
56 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
57 |
+
|
58 |
+
if use_cuda:
|
59 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
60 |
+
|
61 |
+
# Load ensemble
|
62 |
+
overrides = eval(cfg.common_eval.model_overrides)
|
63 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
64 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
65 |
+
utils.split_paths(cfg.common_eval.path),
|
66 |
+
arg_overrides=overrides,
|
67 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
68 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
69 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
70 |
+
)
|
71 |
+
|
72 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
73 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
74 |
+
|
75 |
+
# Move models to GPU
|
76 |
+
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
|
77 |
+
if kwargs['ema_eval']:
|
78 |
+
logger.info("loading EMA weights from {}".format(ckpt_path))
|
79 |
+
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
|
80 |
+
model.eval()
|
81 |
+
if use_fp16:
|
82 |
+
model.half()
|
83 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
84 |
+
model.cuda()
|
85 |
+
model.prepare_for_inference_(cfg)
|
86 |
+
|
87 |
+
# Load dataset (possibly sharded)
|
88 |
+
itr = task.get_batch_iterator(
|
89 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
90 |
+
max_tokens=cfg.dataset.max_tokens,
|
91 |
+
max_sentences=cfg.dataset.batch_size,
|
92 |
+
max_positions=utils.resolve_max_positions(
|
93 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
94 |
+
),
|
95 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
96 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
97 |
+
seed=cfg.common.seed,
|
98 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
99 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
100 |
+
num_workers=cfg.dataset.num_workers,
|
101 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
102 |
+
).next_epoch_itr(shuffle=False)
|
103 |
+
progress = progress_bar.progress_bar(
|
104 |
+
itr,
|
105 |
+
log_format=cfg.common.log_format,
|
106 |
+
log_interval=cfg.common.log_interval,
|
107 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
108 |
+
)
|
109 |
+
|
110 |
+
# Initialize generator
|
111 |
+
generator = task.build_generator(models, cfg.generation)
|
112 |
+
|
113 |
+
results = []
|
114 |
+
score_sum = torch.FloatTensor([0]).cuda()
|
115 |
+
score_cnt = torch.FloatTensor([0]).cuda()
|
116 |
+
for sample in progress:
|
117 |
+
if "net_input" not in sample:
|
118 |
+
continue
|
119 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
120 |
+
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
121 |
+
with torch.no_grad():
|
122 |
+
result, scores = eval_step(task, generator, models, sample)
|
123 |
+
results += result
|
124 |
+
score_sum += sum(scores) if scores is not None else 0
|
125 |
+
score_cnt += len(scores) if scores is not None else 0
|
126 |
+
progress.log({"sentences": sample["nsentences"]})
|
127 |
+
|
128 |
+
gather_results = None
|
129 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
130 |
+
gather_results = [None for _ in range(dist.get_world_size())]
|
131 |
+
dist.all_gather_object(gather_results, results)
|
132 |
+
dist.all_reduce(score_sum.data)
|
133 |
+
dist.all_reduce(score_cnt.data)
|
134 |
+
if score_cnt.item() > 0:
|
135 |
+
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
|
136 |
+
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
|
137 |
+
))
|
138 |
+
|
139 |
+
if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
|
140 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
141 |
+
output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
|
142 |
+
gather_results = list(chain(*gather_results)) if gather_results is not None else results
|
143 |
+
with open(output_path, 'w') as fw:
|
144 |
+
json.dump(gather_results, fw)
|
145 |
+
|
146 |
+
|
147 |
+
def cli_main():
|
148 |
+
parser = options.get_generation_parser()
|
149 |
+
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
|
150 |
+
args = options.parse_args_and_arch(parser)
|
151 |
+
cfg = convert_namespace_to_omegaconf(args)
|
152 |
+
distributed_utils.call_main(cfg, main, ema_eval=args.ema_eval)
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
cli_main()
|
fairseq/.github/ISSUE_TEMPLATE.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
|
2 |
+
|
3 |
+
Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
|
fairseq/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🐛 Bug Report
|
3 |
+
about: Submit a bug report to help us improve
|
4 |
+
labels: 'bug, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🐛 Bug
|
8 |
+
|
9 |
+
<!-- A clear and concise description of what the bug is. -->
|
10 |
+
|
11 |
+
### To Reproduce
|
12 |
+
|
13 |
+
Steps to reproduce the behavior (**always include the command you ran**):
|
14 |
+
|
15 |
+
1. Run cmd '....'
|
16 |
+
2. See error
|
17 |
+
|
18 |
+
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
|
19 |
+
|
20 |
+
|
21 |
+
#### Code sample
|
22 |
+
<!-- Ideally attach a minimal code sample to reproduce the decried issue.
|
23 |
+
Minimal means having the shortest code but still preserving the bug. -->
|
24 |
+
|
25 |
+
### Expected behavior
|
26 |
+
|
27 |
+
<!-- A clear and concise description of what you expected to happen. -->
|
28 |
+
|
29 |
+
### Environment
|
30 |
+
|
31 |
+
- fairseq Version (e.g., 1.0 or main):
|
32 |
+
- PyTorch Version (e.g., 1.0)
|
33 |
+
- OS (e.g., Linux):
|
34 |
+
- How you installed fairseq (`pip`, source):
|
35 |
+
- Build command you used (if compiling from source):
|
36 |
+
- Python version:
|
37 |
+
- CUDA/cuDNN version:
|
38 |
+
- GPU models and configuration:
|
39 |
+
- Any other relevant information:
|
40 |
+
|
41 |
+
### Additional context
|
42 |
+
|
43 |
+
<!-- Add any other context about the problem here. -->
|
fairseq/.github/ISSUE_TEMPLATE/documentation.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 📚 Documentation/Typos
|
3 |
+
about: Report an issue related to documentation or a typo
|
4 |
+
labels: 'documentation, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 📚 Documentation
|
8 |
+
|
9 |
+
For typos and doc fixes, please go ahead and:
|
10 |
+
|
11 |
+
1. Create an issue.
|
12 |
+
2. Fix the typo.
|
13 |
+
3. Submit a PR.
|
14 |
+
|
15 |
+
Thanks!
|
fairseq/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🚀 Feature Request
|
3 |
+
about: Submit a proposal/request for a new feature
|
4 |
+
labels: 'enhancement, help wanted, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## 🚀 Feature Request
|
8 |
+
<!-- A clear and concise description of the feature proposal -->
|
9 |
+
|
10 |
+
### Motivation
|
11 |
+
|
12 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
13 |
+
|
14 |
+
### Pitch
|
15 |
+
|
16 |
+
<!-- A clear and concise description of what you want to happen. -->
|
17 |
+
|
18 |
+
### Alternatives
|
19 |
+
|
20 |
+
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
|
21 |
+
|
22 |
+
### Additional context
|
23 |
+
|
24 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
fairseq/.github/ISSUE_TEMPLATE/how-to-question.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: ❓ Questions/Help
|
3 |
+
about: If you have questions, please first search existing issues and docs
|
4 |
+
labels: 'question, needs triage'
|
5 |
+
---
|
6 |
+
|
7 |
+
## ❓ Questions and Help
|
8 |
+
|
9 |
+
### Before asking:
|
10 |
+
1. search the issues.
|
11 |
+
2. search the docs.
|
12 |
+
|
13 |
+
<!-- If you still can't find what you need: -->
|
14 |
+
|
15 |
+
#### What is your question?
|
16 |
+
|
17 |
+
#### Code
|
18 |
+
|
19 |
+
<!-- Please paste a code snippet if your question requires it! -->
|
20 |
+
|
21 |
+
#### What have you tried?
|
22 |
+
|
23 |
+
#### What's your environment?
|
24 |
+
|
25 |
+
- fairseq Version (e.g., 1.0 or main):
|
26 |
+
- PyTorch Version (e.g., 1.0)
|
27 |
+
- OS (e.g., Linux):
|
28 |
+
- How you installed fairseq (`pip`, source):
|
29 |
+
- Build command you used (if compiling from source):
|
30 |
+
- Python version:
|
31 |
+
- CUDA/cuDNN version:
|
32 |
+
- GPU models and configuration:
|
33 |
+
- Any other relevant information:
|
fairseq/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Before submitting
|
2 |
+
|
3 |
+
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
|
4 |
+
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
|
5 |
+
- [ ] Did you make sure to update the docs?
|
6 |
+
- [ ] Did you write any new necessary tests?
|
7 |
+
|
8 |
+
## What does this PR do?
|
9 |
+
Fixes # (issue).
|
10 |
+
|
11 |
+
## PR review
|
12 |
+
Anyone in the community is free to review the PR once the tests have passed.
|
13 |
+
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
|
14 |
+
|
15 |
+
## Did you have fun?
|
16 |
+
Make sure you had fun coding 🙃
|
fairseq/.github/stale.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for probot-stale - https://github.com/probot/stale
|
2 |
+
# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
|
3 |
+
# Number of days of inactivity before an issue becomes stale
|
4 |
+
daysUntilStale: 90
|
5 |
+
# Number of days of inactivity before a stale issue is closed
|
6 |
+
daysUntilClose: 7
|
7 |
+
# Issues with these labels will never be considered stale
|
8 |
+
exemptLabels:
|
9 |
+
- bug
|
10 |
+
# Label to use when marking an issue as stale
|
11 |
+
staleLabel: stale
|
12 |
+
issues:
|
13 |
+
# Comment to post when marking an issue as stale.
|
14 |
+
markComment: >
|
15 |
+
This issue has been automatically marked as stale.
|
16 |
+
**If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
|
17 |
+
We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
|
18 |
+
# Comment to post when closing a stale issue.
|
19 |
+
closeComment: >
|
20 |
+
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
|
21 |
+
pulls:
|
22 |
+
# Comment to post when marking a pull request as stale.
|
23 |
+
markComment: >
|
24 |
+
This pull request has been automatically marked as stale.
|
25 |
+
**If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
|
26 |
+
We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
|
27 |
+
# Comment to post when closing a stale pull request.
|
28 |
+
closeComment: >
|
29 |
+
Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
|
30 |
+
|
fairseq/.github/workflows/build.yml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: build
|
2 |
+
|
3 |
+
on:
|
4 |
+
# Trigger the workflow on push to main or any pull request
|
5 |
+
push:
|
6 |
+
branches:
|
7 |
+
- main
|
8 |
+
pull_request:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
|
13 |
+
strategy:
|
14 |
+
max-parallel: 4
|
15 |
+
matrix:
|
16 |
+
platform: [ubuntu-latest, macos-latest]
|
17 |
+
python-version: [3.6, 3.7]
|
18 |
+
|
19 |
+
runs-on: ${{ matrix.platform }}
|
20 |
+
|
21 |
+
steps:
|
22 |
+
- uses: actions/checkout@v2
|
23 |
+
|
24 |
+
- name: Set up Python ${{ matrix.python-version }}
|
25 |
+
uses: actions/setup-python@v2
|
26 |
+
with:
|
27 |
+
python-version: ${{ matrix.python-version }}
|
28 |
+
|
29 |
+
- name: Conditionally install pytorch
|
30 |
+
if: matrix.platform == 'windows-latest'
|
31 |
+
run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
|
32 |
+
|
33 |
+
- name: Install locally
|
34 |
+
run: |
|
35 |
+
python -m pip install --upgrade pip
|
36 |
+
git submodule update --init --recursive
|
37 |
+
python setup.py build_ext --inplace
|
38 |
+
python -m pip install --editable .
|
39 |
+
|
40 |
+
- name: Install optional test requirements
|
41 |
+
run: |
|
42 |
+
python -m pip install iopath transformers pyarrow
|
43 |
+
python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
|
44 |
+
|
45 |
+
- name: Lint with flake8
|
46 |
+
run: |
|
47 |
+
pip install flake8
|
48 |
+
# stop the build if there are Python syntax errors or undefined names
|
49 |
+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
|
50 |
+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
51 |
+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
|
52 |
+
|
53 |
+
- name: Run tests
|
54 |
+
run: |
|
55 |
+
python setup.py test
|
fairseq/.github/workflows/build_wheels.yml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: build_wheels
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- v[0-9]+.[0-9]+.[x0-9]+
|
7 |
+
tags:
|
8 |
+
- v*
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build_wheels:
|
12 |
+
name: Build wheels on ${{ matrix.os }}
|
13 |
+
runs-on: ${{ matrix.os }}
|
14 |
+
strategy:
|
15 |
+
matrix:
|
16 |
+
os: [ubuntu-latest, macos-latest]
|
17 |
+
|
18 |
+
steps:
|
19 |
+
- uses: actions/checkout@v2
|
20 |
+
|
21 |
+
- name: Install Python
|
22 |
+
uses: actions/setup-python@v2
|
23 |
+
with:
|
24 |
+
python-version: '3.7'
|
25 |
+
|
26 |
+
- name: Install cibuildwheel
|
27 |
+
run: |
|
28 |
+
python -m pip install cibuildwheel
|
29 |
+
|
30 |
+
- name: Build wheels for CPython
|
31 |
+
run: |
|
32 |
+
python -m cibuildwheel --output-dir dist
|
33 |
+
env:
|
34 |
+
CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
|
35 |
+
CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
|
36 |
+
CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
|
37 |
+
|
38 |
+
- uses: actions/upload-artifact@v2
|
39 |
+
with:
|
40 |
+
name: wheels
|
41 |
+
path: ./dist/*.whl
|
fairseq/.gitignore
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# JetBrains PyCharm IDE
|
2 |
+
.idea/
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# macOS dir files
|
13 |
+
.DS_Store
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
env/
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
|
34 |
+
# Checkpoints
|
35 |
+
checkpoints
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
.hypothesis/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# celery beat schedule file
|
86 |
+
celerybeat-schedule
|
87 |
+
|
88 |
+
# SageMath parsed files
|
89 |
+
*.sage.py
|
90 |
+
|
91 |
+
# dotenv
|
92 |
+
.env
|
93 |
+
|
94 |
+
# virtualenv
|
95 |
+
.venv
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
|
99 |
+
# Spyder project settings
|
100 |
+
.spyderproject
|
101 |
+
.spyproject
|
102 |
+
|
103 |
+
# Rope project settings
|
104 |
+
.ropeproject
|
105 |
+
|
106 |
+
# mkdocs documentation
|
107 |
+
/site
|
108 |
+
|
109 |
+
# mypy
|
110 |
+
.mypy_cache/
|
111 |
+
|
112 |
+
# Generated files
|
113 |
+
/fairseq/temporal_convolution_tbc
|
114 |
+
/fairseq/modules/*_layer/*_forward.cu
|
115 |
+
/fairseq/modules/*_layer/*_backward.cu
|
116 |
+
/fairseq/version.py
|
117 |
+
|
118 |
+
# data
|
119 |
+
data-bin/
|
120 |
+
|
121 |
+
# reranking
|
122 |
+
/examples/reranking/rerank_data
|
123 |
+
|
124 |
+
# Cython-generated C++ source files
|
125 |
+
/fairseq/data/data_utils_fast.cpp
|
126 |
+
/fairseq/data/token_block_utils_fast.cpp
|
127 |
+
|
128 |
+
# VSCODE
|
129 |
+
.vscode/ftp-sync.json
|
130 |
+
.vscode/settings.json
|
131 |
+
|
132 |
+
# Experimental Folder
|
133 |
+
experimental/*
|
134 |
+
|
135 |
+
# Weights and Biases logs
|
136 |
+
wandb/
|
fairseq/.gitmodules
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "fairseq/model_parallel/megatron"]
|
2 |
+
path = fairseq/model_parallel/megatron
|
3 |
+
url = https://github.com/ngoyal2707/Megatron-LM
|
4 |
+
branch = fairseq
|
fairseq/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
## Enforcement
|
56 |
+
|
57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
58 |
+
reported by contacting the project team at <conduct@pytorch.org>. All
|
59 |
+
complaints will be reviewed and investigated and will result in a response that
|
60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
62 |
+
Further details of specific enforcement policies may be posted separately.
|
63 |
+
|
64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
65 |
+
faith may face temporary or permanent repercussions as determined by other
|
66 |
+
members of the project's leadership.
|
67 |
+
|
68 |
+
## Attribution
|
69 |
+
|
70 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
71 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
72 |
+
|
73 |
+
[homepage]: https://www.contributor-covenant.org
|
74 |
+
|
75 |
+
For answers to common questions about this code of conduct, see
|
76 |
+
https://www.contributor-covenant.org/faq
|
77 |
+
|
fairseq/CONTRIBUTING.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
## License
|
26 |
+
By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
|
27 |
+
you agree that your contributions will be licensed under the LICENSE file in
|
28 |
+
the root directory of this source tree.
|
fairseq/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
fairseq/README.md
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="docs/fairseq_logo.png" width="150">
|
3 |
+
<br />
|
4 |
+
<br />
|
5 |
+
<a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
6 |
+
<a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
|
7 |
+
<a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
|
8 |
+
<a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
--------------------------------------------------------------------------------
|
12 |
+
|
13 |
+
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
|
14 |
+
developers to train custom models for translation, summarization, language
|
15 |
+
modeling and other text generation tasks.
|
16 |
+
|
17 |
+
We provide reference implementations of various sequence modeling papers:
|
18 |
+
|
19 |
+
<details><summary>List of implemented papers</summary><p>
|
20 |
+
|
21 |
+
* **Convolutional Neural Networks (CNN)**
|
22 |
+
+ [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
|
23 |
+
+ [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
24 |
+
+ [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
25 |
+
+ [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
26 |
+
+ [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
27 |
+
* **LightConv and DynamicConv models**
|
28 |
+
+ [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
29 |
+
* **Long Short-Term Memory (LSTM) networks**
|
30 |
+
+ Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
|
31 |
+
* **Transformer (self-attention) networks**
|
32 |
+
+ Attention Is All You Need (Vaswani et al., 2017)
|
33 |
+
+ [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
34 |
+
+ [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
35 |
+
+ [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
|
36 |
+
+ [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
|
37 |
+
+ [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
|
38 |
+
+ [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
|
39 |
+
+ [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
40 |
+
+ [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
41 |
+
+ [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
42 |
+
+ [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
|
43 |
+
+ [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
44 |
+
+ [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
45 |
+
+ [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
46 |
+
+ [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
47 |
+
+ [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
|
48 |
+
+ [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
|
49 |
+
+ [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
50 |
+
+ [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
|
51 |
+
+ [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
|
52 |
+
+ [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
|
53 |
+
+ [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
|
54 |
+
* **Non-autoregressive Transformers**
|
55 |
+
+ Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
|
56 |
+
+ Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
|
57 |
+
+ Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
|
58 |
+
+ Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
|
59 |
+
+ [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
60 |
+
* **Finetuning**
|
61 |
+
+ [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
|
62 |
+
|
63 |
+
</p></details>
|
64 |
+
|
65 |
+
### What's New:
|
66 |
+
|
67 |
+
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
|
68 |
+
* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
|
69 |
+
* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
|
70 |
+
* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
|
71 |
+
* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
|
72 |
+
* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
|
73 |
+
* February 2021 [Added LASER training code](examples/laser/README.md)
|
74 |
+
* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
|
75 |
+
* December 2020: [GottBERT model and code released](examples/gottbert/README.md)
|
76 |
+
* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
|
77 |
+
* [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
|
78 |
+
* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
|
79 |
+
* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
|
80 |
+
* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
|
81 |
+
* October 2020: [Added CRISS models and code](examples/criss/README.md)
|
82 |
+
|
83 |
+
<details><summary>Previous updates</summary><p>
|
84 |
+
|
85 |
+
* September 2020: [Added Linformer code](examples/linformer/README.md)
|
86 |
+
* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
|
87 |
+
* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
|
88 |
+
* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
|
89 |
+
* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
|
90 |
+
* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
|
91 |
+
* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
|
92 |
+
* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
|
93 |
+
* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
|
94 |
+
* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
|
95 |
+
* February 2020: [mBART model and code released](examples/mbart/README.md)
|
96 |
+
* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
|
97 |
+
* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
|
98 |
+
* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
|
99 |
+
* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
|
100 |
+
* November 2019: [BART model and code released](examples/bart/README.md)
|
101 |
+
* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
|
102 |
+
* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
|
103 |
+
* August 2019: [WMT'19 models released](examples/wmt19/README.md)
|
104 |
+
* July 2019: fairseq relicensed under MIT license
|
105 |
+
* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
|
106 |
+
* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
|
107 |
+
|
108 |
+
</p></details>
|
109 |
+
|
110 |
+
### Features:
|
111 |
+
|
112 |
+
* multi-GPU training on one machine or across multiple machines (data and model parallel)
|
113 |
+
* fast generation on both CPU and GPU with multiple search algorithms implemented:
|
114 |
+
+ beam search
|
115 |
+
+ Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
|
116 |
+
+ sampling (unconstrained, top-k and top-p/nucleus)
|
117 |
+
+ [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
|
118 |
+
* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
|
119 |
+
* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
120 |
+
* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
121 |
+
* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
|
122 |
+
* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
|
123 |
+
* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
|
124 |
+
|
125 |
+
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
126 |
+
with a convenient `torch.hub` interface:
|
127 |
+
|
128 |
+
``` python
|
129 |
+
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
130 |
+
en2de.translate('Hello world', beam=5)
|
131 |
+
# 'Hallo Welt'
|
132 |
+
```
|
133 |
+
|
134 |
+
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
135 |
+
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
136 |
+
|
137 |
+
# Requirements and Installation
|
138 |
+
|
139 |
+
* [PyTorch](http://pytorch.org/) version >= 1.5.0
|
140 |
+
* Python version >= 3.6
|
141 |
+
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
|
142 |
+
* **To install fairseq** and develop locally:
|
143 |
+
|
144 |
+
``` bash
|
145 |
+
git clone https://github.com/pytorch/fairseq
|
146 |
+
cd fairseq
|
147 |
+
pip install --editable ./
|
148 |
+
|
149 |
+
# on MacOS:
|
150 |
+
# CFLAGS="-stdlib=libc++" pip install --editable ./
|
151 |
+
|
152 |
+
# to install the latest stable release (0.10.x)
|
153 |
+
# pip install fairseq
|
154 |
+
```
|
155 |
+
|
156 |
+
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
|
157 |
+
|
158 |
+
``` bash
|
159 |
+
git clone https://github.com/NVIDIA/apex
|
160 |
+
cd apex
|
161 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
|
162 |
+
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
|
163 |
+
--global-option="--fast_multihead_attn" ./
|
164 |
+
```
|
165 |
+
|
166 |
+
* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
|
167 |
+
* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
|
168 |
+
as command line options to `nvidia-docker run` .
|
169 |
+
|
170 |
+
# Getting Started
|
171 |
+
|
172 |
+
The [full documentation](https://fairseq.readthedocs.io/) contains instructions
|
173 |
+
for getting started, training new models and extending fairseq with new model
|
174 |
+
types and tasks.
|
175 |
+
|
176 |
+
# Pre-trained models and examples
|
177 |
+
|
178 |
+
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
|
179 |
+
as well as example training and evaluation commands.
|
180 |
+
|
181 |
+
* [Translation](examples/translation/README.md): convolutional and transformer models are available
|
182 |
+
* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
|
183 |
+
|
184 |
+
We also have more detailed READMEs to reproduce results from specific papers:
|
185 |
+
|
186 |
+
* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
|
187 |
+
* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
|
188 |
+
* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
|
189 |
+
* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
|
190 |
+
* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
|
191 |
+
* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
|
192 |
+
* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
|
193 |
+
* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
|
194 |
+
* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
|
195 |
+
* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
|
196 |
+
* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
|
197 |
+
* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
|
198 |
+
* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
|
199 |
+
* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
|
200 |
+
* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
|
201 |
+
* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
|
202 |
+
* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
|
203 |
+
* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
|
204 |
+
* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
|
205 |
+
* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
|
206 |
+
|
207 |
+
# Join the fairseq community
|
208 |
+
|
209 |
+
* Twitter: https://twitter.com/fairseq
|
210 |
+
* Facebook page: https://www.facebook.com/groups/fairseq.users
|
211 |
+
* Google group: https://groups.google.com/forum/#!forum/fairseq-users
|
212 |
+
|
213 |
+
# License
|
214 |
+
|
215 |
+
fairseq(-py) is MIT-licensed.
|
216 |
+
The license applies to the pre-trained models as well.
|
217 |
+
|
218 |
+
# Citation
|
219 |
+
|
220 |
+
Please cite as:
|
221 |
+
|
222 |
+
``` bibtex
|
223 |
+
@inproceedings{ott2019fairseq,
|
224 |
+
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
225 |
+
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
226 |
+
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
227 |
+
year = {2019},
|
228 |
+
}
|
229 |
+
```
|
fairseq/docs/Makefile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimal makefile for Sphinx documentation
|
2 |
+
#
|
3 |
+
|
4 |
+
# You can set these variables from the command line.
|
5 |
+
SPHINXOPTS =
|
6 |
+
SPHINXBUILD = python -msphinx
|
7 |
+
SPHINXPROJ = fairseq
|
8 |
+
SOURCEDIR = .
|
9 |
+
BUILDDIR = _build
|
10 |
+
|
11 |
+
# Put it first so that "make" without argument is like "make help".
|
12 |
+
help:
|
13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
14 |
+
|
15 |
+
.PHONY: help Makefile
|
16 |
+
|
17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
19 |
+
%: Makefile
|
20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
fairseq/docs/_static/theme_overrides.css
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.wy-table-responsive table td kbd {
|
2 |
+
white-space: nowrap;
|
3 |
+
}
|
4 |
+
.wy-table-responsive table td {
|
5 |
+
white-space: normal !important;
|
6 |
+
}
|
7 |
+
.wy-table-responsive {
|
8 |
+
overflow: visible !important;
|
9 |
+
}
|
fairseq/docs/command_line_tools.rst
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. _Command-line Tools:
|
2 |
+
|
3 |
+
Command-line Tools
|
4 |
+
==================
|
5 |
+
|
6 |
+
Fairseq provides several command-line tools for training and evaluating models:
|
7 |
+
|
8 |
+
- :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
|
9 |
+
- :ref:`fairseq-train`: Train a new model on one or multiple GPUs
|
10 |
+
- :ref:`fairseq-generate`: Translate pre-processed data with a trained model
|
11 |
+
- :ref:`fairseq-interactive`: Translate raw text with a trained model
|
12 |
+
- :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
|
13 |
+
- :ref:`fairseq-eval-lm`: Language model evaluation
|
14 |
+
|
15 |
+
|
16 |
+
.. _fairseq-preprocess:
|
17 |
+
|
18 |
+
fairseq-preprocess
|
19 |
+
~~~~~~~~~~~~~~~~~~
|
20 |
+
.. automodule:: fairseq_cli.preprocess
|
21 |
+
|
22 |
+
.. argparse::
|
23 |
+
:module: fairseq.options
|
24 |
+
:func: get_preprocessing_parser
|
25 |
+
:prog: fairseq-preprocess
|
26 |
+
|
27 |
+
|
28 |
+
.. _fairseq-train:
|
29 |
+
|
30 |
+
fairseq-train
|
31 |
+
~~~~~~~~~~~~~
|
32 |
+
.. automodule:: fairseq_cli.train
|
33 |
+
|
34 |
+
.. argparse::
|
35 |
+
:module: fairseq.options
|
36 |
+
:func: get_training_parser
|
37 |
+
:prog: fairseq-train
|
38 |
+
|
39 |
+
|
40 |
+
.. _fairseq-generate:
|
41 |
+
|
42 |
+
fairseq-generate
|
43 |
+
~~~~~~~~~~~~~~~~
|
44 |
+
.. automodule:: fairseq_cli.generate
|
45 |
+
|
46 |
+
.. argparse::
|
47 |
+
:module: fairseq.options
|
48 |
+
:func: get_generation_parser
|
49 |
+
:prog: fairseq-generate
|
50 |
+
|
51 |
+
|
52 |
+
.. _fairseq-interactive:
|
53 |
+
|
54 |
+
fairseq-interactive
|
55 |
+
~~~~~~~~~~~~~~~~~~~
|
56 |
+
.. automodule:: fairseq_cli.interactive
|
57 |
+
|
58 |
+
.. argparse::
|
59 |
+
:module: fairseq.options
|
60 |
+
:func: get_interactive_generation_parser
|
61 |
+
:prog: fairseq-interactive
|
62 |
+
|
63 |
+
|
64 |
+
.. _fairseq-score:
|
65 |
+
|
66 |
+
fairseq-score
|
67 |
+
~~~~~~~~~~~~~
|
68 |
+
.. automodule:: fairseq_cli.score
|
69 |
+
|
70 |
+
.. argparse::
|
71 |
+
:module: fairseq_cli.score
|
72 |
+
:func: get_parser
|
73 |
+
:prog: fairseq-score
|
74 |
+
|
75 |
+
|
76 |
+
.. _fairseq-eval-lm:
|
77 |
+
|
78 |
+
fairseq-eval-lm
|
79 |
+
~~~~~~~~~~~~~~~
|
80 |
+
.. automodule:: fairseq_cli.eval_lm
|
81 |
+
|
82 |
+
.. argparse::
|
83 |
+
:module: fairseq.options
|
84 |
+
:func: get_eval_lm_parser
|
85 |
+
:prog: fairseq-eval-lm
|
fairseq/docs/conf.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
#
|
4 |
+
# fairseq documentation build configuration file, created by
|
5 |
+
# sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
6 |
+
#
|
7 |
+
# This file is execfile()d with the current directory set to its
|
8 |
+
# containing dir.
|
9 |
+
#
|
10 |
+
# Note that not all possible configuration values are present in this
|
11 |
+
# autogenerated file.
|
12 |
+
#
|
13 |
+
# All configuration values have a default; values that are commented out
|
14 |
+
# serve to show the default.
|
15 |
+
|
16 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
17 |
+
# add these directories to sys.path here. If the directory is relative to the
|
18 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
19 |
+
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
from fairseq import __version__
|
23 |
+
|
24 |
+
|
25 |
+
# source code directory, relative to this file, for sphinx-autobuild
|
26 |
+
sys.path.insert(0, os.path.abspath(".."))
|
27 |
+
|
28 |
+
source_suffix = [".rst"]
|
29 |
+
|
30 |
+
# -- General configuration ------------------------------------------------
|
31 |
+
|
32 |
+
# If your documentation needs a minimal Sphinx version, state it here.
|
33 |
+
#
|
34 |
+
# needs_sphinx = '1.0'
|
35 |
+
|
36 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
37 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
38 |
+
# ones.
|
39 |
+
extensions = [
|
40 |
+
"sphinx.ext.autodoc",
|
41 |
+
"sphinx.ext.intersphinx",
|
42 |
+
"sphinx.ext.viewcode",
|
43 |
+
"sphinx.ext.napoleon",
|
44 |
+
"sphinxarg.ext",
|
45 |
+
]
|
46 |
+
|
47 |
+
# Add any paths that contain templates here, relative to this directory.
|
48 |
+
templates_path = ["_templates"]
|
49 |
+
|
50 |
+
# The master toctree document.
|
51 |
+
master_doc = "index"
|
52 |
+
|
53 |
+
# General information about the project.
|
54 |
+
project = "fairseq"
|
55 |
+
copyright = "Facebook AI Research (FAIR)"
|
56 |
+
author = "Facebook AI Research (FAIR)"
|
57 |
+
|
58 |
+
github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
|
59 |
+
|
60 |
+
# The version info for the project you're documenting, acts as replacement for
|
61 |
+
# |version| and |release|, also used in various other places throughout the
|
62 |
+
# built documents.
|
63 |
+
#
|
64 |
+
# The short X.Y version.
|
65 |
+
version = __version__
|
66 |
+
# The full version, including alpha/beta/rc tags.
|
67 |
+
release = __version__
|
68 |
+
|
69 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
70 |
+
# for a list of supported languages.
|
71 |
+
#
|
72 |
+
# This is also used if you do content translation via gettext catalogs.
|
73 |
+
# Usually you set "language" from the command line for these cases.
|
74 |
+
language = None
|
75 |
+
|
76 |
+
# List of patterns, relative to source directory, that match files and
|
77 |
+
# directories to ignore when looking for source files.
|
78 |
+
# This patterns also effect to html_static_path and html_extra_path
|
79 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
80 |
+
|
81 |
+
# The name of the Pygments (syntax highlighting) style to use.
|
82 |
+
pygments_style = "sphinx"
|
83 |
+
highlight_language = "python"
|
84 |
+
|
85 |
+
# If true, `todo` and `todoList` produce output, else they produce nothing.
|
86 |
+
todo_include_todos = False
|
87 |
+
|
88 |
+
|
89 |
+
# -- Options for HTML output ----------------------------------------------
|
90 |
+
|
91 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
92 |
+
# a list of builtin themes.
|
93 |
+
#
|
94 |
+
html_theme = "sphinx_rtd_theme"
|
95 |
+
|
96 |
+
# Theme options are theme-specific and customize the look and feel of a theme
|
97 |
+
# further. For a list of options available for each theme, see the
|
98 |
+
# documentation.
|
99 |
+
#
|
100 |
+
# html_theme_options = {}
|
101 |
+
|
102 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
103 |
+
# relative to this directory. They are copied after the builtin static files,
|
104 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
105 |
+
html_static_path = ["_static"]
|
106 |
+
|
107 |
+
html_context = {
|
108 |
+
"css_files": [
|
109 |
+
"_static/theme_overrides.css", # override wide tables in RTD theme
|
110 |
+
],
|
111 |
+
}
|
112 |
+
|
113 |
+
# Custom sidebar templates, must be a dictionary that maps document names
|
114 |
+
# to template names.
|
115 |
+
#
|
116 |
+
# This is required for the alabaster theme
|
117 |
+
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
|
118 |
+
# html_sidebars = {
|
119 |
+
# '**': [
|
120 |
+
# 'about.html',
|
121 |
+
# 'navigation.html',
|
122 |
+
# 'relations.html', # needs 'show_related': True theme option to display
|
123 |
+
# 'searchbox.html',
|
124 |
+
# 'donate.html',
|
125 |
+
# ]
|
126 |
+
# }
|
127 |
+
|
128 |
+
|
129 |
+
# Example configuration for intersphinx: refer to the Python standard library.
|
130 |
+
intersphinx_mapping = {
|
131 |
+
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
|
132 |
+
"python": ("https://docs.python.org/", None),
|
133 |
+
"torch": ("https://pytorch.org/docs/master/", None),
|
134 |
+
}
|
fairseq/docs/criterions.rst
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _Criterions:
|
5 |
+
|
6 |
+
Criterions
|
7 |
+
==========
|
8 |
+
|
9 |
+
Criterions compute the loss function given the model and batch, roughly::
|
10 |
+
|
11 |
+
loss = criterion(model, batch)
|
12 |
+
|
13 |
+
.. automodule:: fairseq.criterions
|
14 |
+
:members:
|
15 |
+
|
16 |
+
.. autoclass:: fairseq.criterions.FairseqCriterion
|
17 |
+
:members:
|
18 |
+
:undoc-members:
|
19 |
+
|
20 |
+
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
|
21 |
+
:members:
|
22 |
+
:undoc-members:
|
23 |
+
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
|
24 |
+
:members:
|
25 |
+
:undoc-members:
|
26 |
+
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
|
27 |
+
:members:
|
28 |
+
:undoc-members:
|
29 |
+
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
|
30 |
+
:members:
|
31 |
+
:undoc-members:
|
fairseq/docs/data.rst
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. module:: fairseq.data
|
5 |
+
|
6 |
+
Data Loading and Utilities
|
7 |
+
==========================
|
8 |
+
|
9 |
+
.. _datasets:
|
10 |
+
|
11 |
+
Datasets
|
12 |
+
--------
|
13 |
+
|
14 |
+
**Datasets** define the data format and provide helpers for creating
|
15 |
+
mini-batches.
|
16 |
+
|
17 |
+
.. autoclass:: fairseq.data.FairseqDataset
|
18 |
+
:members:
|
19 |
+
.. autoclass:: fairseq.data.LanguagePairDataset
|
20 |
+
:members:
|
21 |
+
.. autoclass:: fairseq.data.MonolingualDataset
|
22 |
+
:members:
|
23 |
+
|
24 |
+
**Helper Datasets**
|
25 |
+
|
26 |
+
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
|
27 |
+
provide additional functionality:
|
28 |
+
|
29 |
+
.. autoclass:: fairseq.data.BacktranslationDataset
|
30 |
+
:members:
|
31 |
+
.. autoclass:: fairseq.data.ConcatDataset
|
32 |
+
:members:
|
33 |
+
.. autoclass:: fairseq.data.ResamplingDataset
|
34 |
+
:members:
|
35 |
+
.. autoclass:: fairseq.data.RoundRobinZipDatasets
|
36 |
+
:members:
|
37 |
+
.. autoclass:: fairseq.data.TransformEosDataset
|
38 |
+
:members:
|
39 |
+
|
40 |
+
|
41 |
+
Dictionary
|
42 |
+
----------
|
43 |
+
|
44 |
+
.. autoclass:: fairseq.data.Dictionary
|
45 |
+
:members:
|
46 |
+
|
47 |
+
|
48 |
+
Iterators
|
49 |
+
---------
|
50 |
+
|
51 |
+
.. autoclass:: fairseq.data.CountingIterator
|
52 |
+
:members:
|
53 |
+
.. autoclass:: fairseq.data.EpochBatchIterator
|
54 |
+
:members:
|
55 |
+
.. autoclass:: fairseq.data.GroupedIterator
|
56 |
+
:members:
|
57 |
+
.. autoclass:: fairseq.data.ShardedIterator
|
58 |
+
:members:
|
fairseq/docs/docutils.conf
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[writers]
|
2 |
+
option-limit=0
|
fairseq/docs/fairseq_logo.png
ADDED
fairseq/docs/getting_started.rst
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Evaluating Pre-trained Models
|
2 |
+
=============================
|
3 |
+
|
4 |
+
First, download a pre-trained model along with its vocabularies:
|
5 |
+
|
6 |
+
.. code-block:: console
|
7 |
+
|
8 |
+
> curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
|
9 |
+
|
10 |
+
This model uses a `Byte Pair Encoding (BPE)
|
11 |
+
vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
|
12 |
+
the encoding to the source text before it can be translated. This can be
|
13 |
+
done with the
|
14 |
+
`apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
|
15 |
+
script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
|
16 |
+
used as a continuation marker and the original text can be easily
|
17 |
+
recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
|
18 |
+
flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
|
19 |
+
using ``tokenizer.perl`` from
|
20 |
+
`mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
|
21 |
+
|
22 |
+
Let's use :ref:`fairseq-interactive` to generate translations interactively.
|
23 |
+
Here, we use a beam size of 5 and preprocess the input with the Moses
|
24 |
+
tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
|
25 |
+
remove the BPE continuation markers and detokenize the output.
|
26 |
+
|
27 |
+
.. code-block:: console
|
28 |
+
|
29 |
+
> MODEL_DIR=wmt14.en-fr.fconv-py
|
30 |
+
> fairseq-interactive \
|
31 |
+
--path $MODEL_DIR/model.pt $MODEL_DIR \
|
32 |
+
--beam 5 --source-lang en --target-lang fr \
|
33 |
+
--tokenizer moses \
|
34 |
+
--bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
|
35 |
+
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
|
36 |
+
| [en] dictionary: 44206 types
|
37 |
+
| [fr] dictionary: 44463 types
|
38 |
+
| Type the input sentence and press return:
|
39 |
+
Why is it rare to discover new marine mammal species?
|
40 |
+
S-0 Why is it rare to discover new marine mam@@ mal species ?
|
41 |
+
H-0 -0.0643349438905716 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
|
42 |
+
P-0 -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
|
43 |
+
|
44 |
+
This generation script produces three types of outputs: a line prefixed
|
45 |
+
with *O* is a copy of the original source sentence; *H* is the
|
46 |
+
hypothesis along with an average log-likelihood; and *P* is the
|
47 |
+
positional score per token position, including the
|
48 |
+
end-of-sentence marker which is omitted from the text.
|
49 |
+
|
50 |
+
Other types of output lines you might see are *D*, the detokenized hypothesis,
|
51 |
+
*T*, the reference target, *A*, alignment info, *E* the history of generation steps.
|
52 |
+
|
53 |
+
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
|
54 |
+
full list of pre-trained models available.
|
55 |
+
|
56 |
+
Training a New Model
|
57 |
+
====================
|
58 |
+
|
59 |
+
The following tutorial is for machine translation. For an example of how
|
60 |
+
to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
|
61 |
+
``examples/`` directory.
|
62 |
+
|
63 |
+
Data Pre-processing
|
64 |
+
-------------------
|
65 |
+
|
66 |
+
Fairseq contains example pre-processing scripts for several translation
|
67 |
+
datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
|
68 |
+
2014 (English-German). To pre-process and binarize the IWSLT dataset:
|
69 |
+
|
70 |
+
.. code-block:: console
|
71 |
+
|
72 |
+
> cd examples/translation/
|
73 |
+
> bash prepare-iwslt14.sh
|
74 |
+
> cd ../..
|
75 |
+
> TEXT=examples/translation/iwslt14.tokenized.de-en
|
76 |
+
> fairseq-preprocess --source-lang de --target-lang en \
|
77 |
+
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
|
78 |
+
--destdir data-bin/iwslt14.tokenized.de-en
|
79 |
+
|
80 |
+
This will write binarized data that can be used for model training to
|
81 |
+
``data-bin/iwslt14.tokenized.de-en``.
|
82 |
+
|
83 |
+
Training
|
84 |
+
--------
|
85 |
+
|
86 |
+
Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
|
87 |
+
well for the IWSLT 2014 dataset:
|
88 |
+
|
89 |
+
.. code-block:: console
|
90 |
+
|
91 |
+
> mkdir -p checkpoints/fconv
|
92 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
|
93 |
+
--optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
|
94 |
+
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
|
95 |
+
|
96 |
+
By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
|
97 |
+
``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
|
98 |
+
change the number of GPU devices that will be used.
|
99 |
+
|
100 |
+
Also note that the batch size is specified in terms of the maximum
|
101 |
+
number of tokens per batch (``--max-tokens``). You may need to use a
|
102 |
+
smaller value depending on the available GPU memory on your system.
|
103 |
+
|
104 |
+
Generation
|
105 |
+
----------
|
106 |
+
|
107 |
+
Once your model is trained, you can generate translations using
|
108 |
+
:ref:`fairseq-generate` **(for binarized data)** or
|
109 |
+
:ref:`fairseq-interactive` **(for raw text)**:
|
110 |
+
|
111 |
+
.. code-block:: console
|
112 |
+
|
113 |
+
> fairseq-generate data-bin/iwslt14.tokenized.de-en \
|
114 |
+
--path checkpoints/fconv/checkpoint_best.pt \
|
115 |
+
--batch-size 128 --beam 5
|
116 |
+
| [de] dictionary: 35475 types
|
117 |
+
| [en] dictionary: 24739 types
|
118 |
+
| data-bin/iwslt14.tokenized.de-en test 6750 examples
|
119 |
+
| model fconv
|
120 |
+
| loaded checkpoint trainings/fconv/checkpoint_best.pt
|
121 |
+
S-721 danke .
|
122 |
+
T-721 thank you .
|
123 |
+
...
|
124 |
+
|
125 |
+
To generate translations with only a CPU, use the ``--cpu`` flag. BPE
|
126 |
+
continuation markers can be removed with the ``--remove-bpe`` flag.
|
127 |
+
|
128 |
+
Advanced Training Options
|
129 |
+
=========================
|
130 |
+
|
131 |
+
Large mini-batch training with delayed updates
|
132 |
+
----------------------------------------------
|
133 |
+
|
134 |
+
The ``--update-freq`` option can be used to accumulate gradients from
|
135 |
+
multiple mini-batches and delay updating, creating a larger effective
|
136 |
+
batch size. Delayed updates can also improve training speed by reducing
|
137 |
+
inter-GPU communication costs and by saving idle time caused by variance
|
138 |
+
in workload across GPUs. See `Ott et al.
|
139 |
+
(2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
|
140 |
+
|
141 |
+
To train on a single GPU with an effective batch size that is equivalent
|
142 |
+
to training on 8 GPUs:
|
143 |
+
|
144 |
+
.. code-block:: console
|
145 |
+
|
146 |
+
> CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
|
147 |
+
|
148 |
+
Training with half precision floating point (FP16)
|
149 |
+
--------------------------------------------------
|
150 |
+
|
151 |
+
.. note::
|
152 |
+
|
153 |
+
FP16 training requires a Volta GPU and CUDA 9.1 or greater
|
154 |
+
|
155 |
+
Recent GPUs enable efficient half precision floating point computation,
|
156 |
+
e.g., using `Nvidia Tensor Cores
|
157 |
+
<https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
|
158 |
+
Fairseq supports FP16 training with the ``--fp16`` flag:
|
159 |
+
|
160 |
+
.. code-block:: console
|
161 |
+
|
162 |
+
> fairseq-train --fp16 (...)
|
163 |
+
|
164 |
+
Distributed training
|
165 |
+
--------------------
|
166 |
+
|
167 |
+
Distributed training in fairseq is implemented on top of ``torch.distributed``.
|
168 |
+
The easiest way to launch jobs is with the `torch.distributed.launch
|
169 |
+
<https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
|
170 |
+
|
171 |
+
For example, to train a large English-German Transformer model on 2 nodes each
|
172 |
+
with 8 GPUs (in total 16 GPUs), run the following command on each node,
|
173 |
+
replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
|
174 |
+
sure to update ``--master_addr`` to the IP address of the first node:
|
175 |
+
|
176 |
+
.. code-block:: console
|
177 |
+
|
178 |
+
> python -m torch.distributed.launch --nproc_per_node=8 \
|
179 |
+
--nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
|
180 |
+
--master_port=12345 \
|
181 |
+
$(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
|
182 |
+
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
|
183 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
184 |
+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
|
185 |
+
--lr 0.0005 \
|
186 |
+
--dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
187 |
+
--max-tokens 3584 \
|
188 |
+
--max-epoch 70 \
|
189 |
+
--fp16
|
190 |
+
|
191 |
+
On SLURM clusters, fairseq will automatically detect the number of nodes and
|
192 |
+
GPUs, but a port number must be provided:
|
193 |
+
|
194 |
+
.. code-block:: console
|
195 |
+
|
196 |
+
> salloc --gpus=16 --nodes 2 (...)
|
197 |
+
> srun fairseq-train --distributed-port 12345 (...).
|
198 |
+
|
199 |
+
Sharding very large datasets
|
200 |
+
----------------------------
|
201 |
+
|
202 |
+
It can be challenging to train over very large datasets, particularly if your
|
203 |
+
machine does not have much system RAM. Most tasks in fairseq support training
|
204 |
+
over "sharded" datasets, in which the original dataset has been preprocessed
|
205 |
+
into non-overlapping chunks (or "shards").
|
206 |
+
|
207 |
+
For example, instead of preprocessing all your data into a single "data-bin"
|
208 |
+
directory, you can split the data and create "data-bin1", "data-bin2", etc.
|
209 |
+
Then you can adapt your training command like so:
|
210 |
+
|
211 |
+
.. code-block:: console
|
212 |
+
|
213 |
+
> fairseq-train data-bin1:data-bin2:data-bin3 (...)
|
214 |
+
|
215 |
+
Training will now iterate over each shard, one by one, with each shard
|
216 |
+
corresponding to an "epoch", thus reducing system memory usage.
|
fairseq/docs/hydra_integration.md
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Hydra
|
2 |
+
|
3 |
+
[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
|
4 |
+
framework that simplifies the development of research and other complex
|
5 |
+
applications. The key feature is the ability to dynamically create a
|
6 |
+
hierarchical configuration by composition and override it through config files
|
7 |
+
and the command line. The name Hydra comes from its ability to run multiple
|
8 |
+
similar jobs - much like a Hydra with multiple heads.
|
9 |
+
|
10 |
+
## Motivation
|
11 |
+
|
12 |
+
Until recently, all components in fairseq were configured through a shared
|
13 |
+
`args` namespace that was created at application startup. Components declared
|
14 |
+
their own `add_args` method to update the argparse parser, hoping that the names
|
15 |
+
would not clash with arguments from other components. While this model works for
|
16 |
+
smaller applications, as fairseq grew and became integrated into other
|
17 |
+
applications, this became problematic. In order to determine how to configure
|
18 |
+
each component, one needed to a) examine what args were added by this component,
|
19 |
+
and b) read the code to figure out what shared arguments it is using that were
|
20 |
+
added in other places. Reproducing models involved sharing commands that often
|
21 |
+
contained dozens of command line switches.
|
22 |
+
|
23 |
+
The model described above is still supported by fairseq for backward
|
24 |
+
compatibility, but will be deprecated some time in the future.
|
25 |
+
|
26 |
+
New components in fairseq should now create a dataclass that encapsulates all
|
27 |
+
parameters required to configure this component. The dataclass is registered
|
28 |
+
along with the component, and fairseq takes care of constructing and providing
|
29 |
+
this configuration object to the component's constructor. Note that sharing
|
30 |
+
parameters can optionally still work, but one has to explicitly point to the
|
31 |
+
"source of truth" (see inheritance example below). These changes make components
|
32 |
+
in fairseq more independent and re-usable by other applications: all that is
|
33 |
+
needed to create a component is to initialize its dataclass and overwrite some
|
34 |
+
of the defaults.
|
35 |
+
|
36 |
+
While configuring fairseq through command line (using either the legacy argparse
|
37 |
+
based or the new Hydra based entry points) is still fully supported, you can now
|
38 |
+
take advantage of configuring fairseq completely or piece-by-piece through
|
39 |
+
hierarchical YAML configuration files. These files can also be shipped as
|
40 |
+
examples that others can use to run an identically configured job.
|
41 |
+
|
42 |
+
Additionally, Hydra has a rich and growing [library of
|
43 |
+
plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
|
44 |
+
provide functionality such as hyperparameter sweeping (including using bayesian
|
45 |
+
optimization through the [Ax](https://github.com/facebook/Ax) library), job
|
46 |
+
launching across various platforms, and more.
|
47 |
+
|
48 |
+
## Creating or migrating components
|
49 |
+
|
50 |
+
In general, each new (or updated) component should provide a companion
|
51 |
+
[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
|
52 |
+
typically located in the same file as the component and are passed as arguments
|
53 |
+
to the `register_*()` functions. Top-level configs that should be present in
|
54 |
+
every fairseq application are placed in the
|
55 |
+
[global](fairseq/dataclass/configs.py) config file and added to the
|
56 |
+
`FairseqConfig` object.
|
57 |
+
|
58 |
+
Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
|
59 |
+
classes are decorated with a `@dataclass` decorator, and typically inherit from
|
60 |
+
`FairseqDataclass` (which adds some functionality for backward compatibility).
|
61 |
+
Each field must have a type, and generally has metadata (such as a help string)
|
62 |
+
and a default value. Only primitive types or other config objects are allowed as
|
63 |
+
data types for each field.
|
64 |
+
|
65 |
+
#### Example:
|
66 |
+
|
67 |
+
```python
|
68 |
+
from dataclasses import dataclass, field
|
69 |
+
from fairseq.dataclass import FairseqDataclass
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class InteractiveConfig(FairseqDataclass):
|
73 |
+
buffer_size: int = field(
|
74 |
+
default=0,
|
75 |
+
metadata={
|
76 |
+
"help": "read this many sentences into a buffer before processing them"
|
77 |
+
},
|
78 |
+
)
|
79 |
+
input: str = field(
|
80 |
+
default="-",
|
81 |
+
metadata={"help": "file to read from; use - for stdin"},
|
82 |
+
)
|
83 |
+
```
|
84 |
+
|
85 |
+
### Inherting values
|
86 |
+
|
87 |
+
Some components require sharing a value. For example, a learning rate scheduler
|
88 |
+
and an optimizer may both need to know the initial learning rate value. One can
|
89 |
+
declare a field that, by default, will inherit its value from another config
|
90 |
+
node in the same hierarchy:
|
91 |
+
|
92 |
+
```python
|
93 |
+
@dataclass
|
94 |
+
FairseqAdamConfig(FairseqDataclass):
|
95 |
+
...
|
96 |
+
lr: List[float] = II("optimization.lr")
|
97 |
+
...
|
98 |
+
```
|
99 |
+
|
100 |
+
`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
|
101 |
+
the value one can use in a YAML config file or through command line to achieve
|
102 |
+
the same effect. Note that this assumes that there is an "optimization" config
|
103 |
+
object in the root config and it has a field called "lr".
|
104 |
+
|
105 |
+
### Tasks and Models
|
106 |
+
|
107 |
+
Creating Tasks and Models works same as before, except that legacy
|
108 |
+
implementations now inherit from `LegacyFairseq*` base classes, while new
|
109 |
+
components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
|
110 |
+
to the `register_*()` functions.
|
111 |
+
|
112 |
+
#### Task example:
|
113 |
+
|
114 |
+
```python
|
115 |
+
@dataclass
|
116 |
+
class LanguageModelingConfig(FairseqDataclass):
|
117 |
+
data: Optional[str] = field(
|
118 |
+
default=None, metadata={"help": "path to data directory"}
|
119 |
+
)
|
120 |
+
...
|
121 |
+
|
122 |
+
@register_task("language_modeling", dataclass=LanguageModelingConfig)
|
123 |
+
class LanguageModelingTask(FairseqTask):
|
124 |
+
...
|
125 |
+
@classmethod
|
126 |
+
def setup_task(cls, cfg: LanguageModelingConfig):
|
127 |
+
...
|
128 |
+
```
|
129 |
+
|
130 |
+
#### Model example:
|
131 |
+
|
132 |
+
```python
|
133 |
+
@dataclass
|
134 |
+
class TransformerLanguageModelConfig(FairseqDataclass):
|
135 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
136 |
+
default="relu", metadata={"help": "activation function to use"}
|
137 |
+
)
|
138 |
+
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
139 |
+
...
|
140 |
+
|
141 |
+
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
|
142 |
+
class TransformerLanguageModel(FairseqLanguageModel):
|
143 |
+
...
|
144 |
+
@classmethod
|
145 |
+
def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
|
146 |
+
...
|
147 |
+
```
|
148 |
+
|
149 |
+
### Other components
|
150 |
+
|
151 |
+
Other components work as before, but they now take their configuration dataclass
|
152 |
+
as the only constructor argument:
|
153 |
+
|
154 |
+
```python
|
155 |
+
@dataclass
|
156 |
+
class MosesTokenizerConfig(FairseqDataclass):
|
157 |
+
source_lang: str = field(default="en", metadata={"help": "source language"})
|
158 |
+
...
|
159 |
+
|
160 |
+
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
161 |
+
class MosesTokenizer(object):
|
162 |
+
def __init__(self, cfg: MosesTokenizerConfig):
|
163 |
+
...
|
164 |
+
```
|
165 |
+
|
166 |
+
Note that if you are adding a new registry for a new set of components, you need
|
167 |
+
to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
|
168 |
+
|
169 |
+
```python
|
170 |
+
@dataclass
|
171 |
+
class FairseqConfig(object):
|
172 |
+
...
|
173 |
+
my_new_registry: Any = None
|
174 |
+
```
|
175 |
+
|
176 |
+
## Training with `fairseq-hydra-train`
|
177 |
+
|
178 |
+
To fully take advantage of configuration flexibility offered by Hydra, you may
|
179 |
+
want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
|
180 |
+
tools such as `fairseq-train` will remain supported for the foreseeable future
|
181 |
+
but will be deprecated eventually.
|
182 |
+
|
183 |
+
On startup, Hydra will create a configuration object that contains a hierarchy
|
184 |
+
of all the necessary dataclasses populated with their default values in the
|
185 |
+
code. The default values are overwritten by values found in YAML files in
|
186 |
+
`fairseq/config` directory (which currently sets minimal defaults) and then
|
187 |
+
further overwritten by values provided through command line arguments.
|
188 |
+
|
189 |
+
Some of the most common use cases are shown below:
|
190 |
+
|
191 |
+
### 1. Override default values through command line:
|
192 |
+
|
193 |
+
```shell script
|
194 |
+
$ fairseq-hydra-train \
|
195 |
+
distributed_training.distributed_world_size=1 \
|
196 |
+
dataset.batch_size=2 \
|
197 |
+
task.data=data-bin \
|
198 |
+
model=transformer_lm/transformer_lm_gpt \
|
199 |
+
task=language_modeling \
|
200 |
+
optimization.max_update=5000
|
201 |
+
```
|
202 |
+
|
203 |
+
Note that along with explicitly providing values for parameters such as
|
204 |
+
`dataset.batch_size`, this also tells Hydra to overlay configuration found in
|
205 |
+
`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
|
206 |
+
values in the dataclass. If you want to train a model without specifying a
|
207 |
+
particular architecture you can simply specify `model=transformer_lm`. This only
|
208 |
+
works for migrated tasks and models.
|
209 |
+
|
210 |
+
### 2. Replace bundled configs with an external config:
|
211 |
+
|
212 |
+
```shell script
|
213 |
+
$ fairseq-hydra-train \
|
214 |
+
--config-dir /path/to/external/configs \
|
215 |
+
--config-name wiki103
|
216 |
+
```
|
217 |
+
|
218 |
+
where `/path/to/external/configs/wiki103.yaml` contains:
|
219 |
+
|
220 |
+
```yaml
|
221 |
+
# @package _group_
|
222 |
+
|
223 |
+
model:
|
224 |
+
_name: transformer_lm
|
225 |
+
distributed_training:
|
226 |
+
distributed_world_size: 1
|
227 |
+
dataset:
|
228 |
+
batch_size: 2
|
229 |
+
task:
|
230 |
+
_name: language_modeling
|
231 |
+
data: /path/to/data
|
232 |
+
add_bos_token: false
|
233 |
+
max_target_positions: 1024
|
234 |
+
optimization:
|
235 |
+
max_update: 50000
|
236 |
+
lr: [ 0.25 ]
|
237 |
+
criterion: cross_entropy
|
238 |
+
optimizer: adam
|
239 |
+
lr_scheduler:
|
240 |
+
_name: cosine
|
241 |
+
```
|
242 |
+
|
243 |
+
Note that here bundled configs from `fairseq/config` directory are not used,
|
244 |
+
however the defaults from each dataclass will still be used (unless overwritten
|
245 |
+
by your external config).
|
246 |
+
|
247 |
+
Additionally you can choose to break up your configs by creating a directory
|
248 |
+
structure in the same location as your main config file, with the names of the
|
249 |
+
top-level fields (such as "model", "dataset", etc), and placing config files
|
250 |
+
with meaningful names that would populate that specific section of your
|
251 |
+
top-level config file (for example, you might have
|
252 |
+
`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
|
253 |
+
can then specify the correct configuration via command line, defaults in the
|
254 |
+
main config, or even launch all of them as a sweep (see Hydra documentation on
|
255 |
+
how to do this).
|
256 |
+
|
257 |
+
### 3. Add an external config directory to Hydra search path:
|
258 |
+
|
259 |
+
This allows combining default configuration (including using any bundled config
|
260 |
+
files), while specifying your own config files for some parts of the
|
261 |
+
configuration.
|
262 |
+
|
263 |
+
```shell script
|
264 |
+
$ fairseq-hydra-train \
|
265 |
+
distributed_training.distributed_world_size=1 \
|
266 |
+
dataset.batch_size=2 \
|
267 |
+
task.data=/path/to/data/ \
|
268 |
+
model=transformer_lm/2_layers \
|
269 |
+
task=language_modeling \
|
270 |
+
optimization.max_update=5000 \
|
271 |
+
--config-dir /path/to/external/configs
|
272 |
+
```
|
273 |
+
|
274 |
+
where `/path/to/external/configs` has the following structure:
|
275 |
+
```
|
276 |
+
.
|
277 |
+
+-- model
|
278 |
+
| +-- transformer_lm
|
279 |
+
| | +-- 2_layers.yaml
|
280 |
+
```
|
281 |
+
|
282 |
+
and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
|
283 |
+
`decoder_layers` set to 2. You can add other configs to configure other
|
284 |
+
components as well.
|
fairseq/docs/index.rst
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. fairseq documentation master file, created by
|
2 |
+
sphinx-quickstart on Fri Aug 17 21:45:30 2018.
|
3 |
+
You can adapt this file completely to your liking, but it should at least
|
4 |
+
contain the root `toctree` directive.
|
5 |
+
|
6 |
+
:github_url: https://github.com/pytorch/fairseq
|
7 |
+
|
8 |
+
|
9 |
+
fairseq documentation
|
10 |
+
=====================
|
11 |
+
|
12 |
+
Fairseq is a sequence modeling toolkit written in `PyTorch
|
13 |
+
<http://pytorch.org/>`_ that allows researchers and developers to
|
14 |
+
train custom models for translation, summarization, language modeling and other
|
15 |
+
text generation tasks.
|
16 |
+
|
17 |
+
.. toctree::
|
18 |
+
:maxdepth: 1
|
19 |
+
:caption: Getting Started
|
20 |
+
|
21 |
+
getting_started
|
22 |
+
command_line_tools
|
23 |
+
|
24 |
+
.. toctree::
|
25 |
+
:maxdepth: 1
|
26 |
+
:caption: Extending Fairseq
|
27 |
+
|
28 |
+
overview
|
29 |
+
tutorial_simple_lstm
|
30 |
+
tutorial_classifying_names
|
31 |
+
|
32 |
+
.. toctree::
|
33 |
+
:maxdepth: 2
|
34 |
+
:caption: Library Reference
|
35 |
+
|
36 |
+
tasks
|
37 |
+
models
|
38 |
+
criterions
|
39 |
+
optim
|
40 |
+
lr_scheduler
|
41 |
+
data
|
42 |
+
modules
|
43 |
+
|
44 |
+
|
45 |
+
Indices and tables
|
46 |
+
==================
|
47 |
+
|
48 |
+
* :ref:`genindex`
|
49 |
+
* :ref:`search`
|
fairseq/docs/lr_scheduler.rst
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _Learning Rate Schedulers:
|
5 |
+
|
6 |
+
Learning Rate Schedulers
|
7 |
+
========================
|
8 |
+
|
9 |
+
Learning Rate Schedulers update the learning rate over the course of training.
|
10 |
+
Learning rates can be updated after each update via :func:`step_update` or at
|
11 |
+
epoch boundaries via :func:`step`.
|
12 |
+
|
13 |
+
.. automodule:: fairseq.optim.lr_scheduler
|
14 |
+
:members:
|
15 |
+
|
16 |
+
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
|
17 |
+
:members:
|
18 |
+
:undoc-members:
|
19 |
+
|
20 |
+
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
|
21 |
+
:members:
|
22 |
+
:undoc-members:
|
23 |
+
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
|
24 |
+
:members:
|
25 |
+
:undoc-members:
|
26 |
+
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
|
27 |
+
:members:
|
28 |
+
:undoc-members:
|
29 |
+
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
|
30 |
+
:members:
|
31 |
+
:undoc-members:
|
32 |
+
.. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
|
33 |
+
:members:
|
34 |
+
:undoc-members:
|
fairseq/docs/make.bat
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@ECHO OFF
|
2 |
+
|
3 |
+
pushd %~dp0
|
4 |
+
|
5 |
+
REM Command file for Sphinx documentation
|
6 |
+
|
7 |
+
if "%SPHINXBUILD%" == "" (
|
8 |
+
set SPHINXBUILD=python -msphinx
|
9 |
+
)
|
10 |
+
set SOURCEDIR=.
|
11 |
+
set BUILDDIR=_build
|
12 |
+
set SPHINXPROJ=fairseq
|
13 |
+
|
14 |
+
if "%1" == "" goto help
|
15 |
+
|
16 |
+
%SPHINXBUILD% >NUL 2>NUL
|
17 |
+
if errorlevel 9009 (
|
18 |
+
echo.
|
19 |
+
echo.The Sphinx module was not found. Make sure you have Sphinx installed,
|
20 |
+
echo.then set the SPHINXBUILD environment variable to point to the full
|
21 |
+
echo.path of the 'sphinx-build' executable. Alternatively you may add the
|
22 |
+
echo.Sphinx directory to PATH.
|
23 |
+
echo.
|
24 |
+
echo.If you don't have Sphinx installed, grab it from
|
25 |
+
echo.http://sphinx-doc.org/
|
26 |
+
exit /b 1
|
27 |
+
)
|
28 |
+
|
29 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
30 |
+
goto end
|
31 |
+
|
32 |
+
:help
|
33 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
34 |
+
|
35 |
+
:end
|
36 |
+
popd
|
fairseq/docs/models.rst
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. module:: fairseq.models
|
5 |
+
|
6 |
+
.. _Models:
|
7 |
+
|
8 |
+
Models
|
9 |
+
======
|
10 |
+
|
11 |
+
A Model defines the neural network's ``forward()`` method and encapsulates all
|
12 |
+
of the learnable parameters in the network. Each model also provides a set of
|
13 |
+
named *architectures* that define the precise network configuration (e.g.,
|
14 |
+
embedding dimension, number of layers, etc.).
|
15 |
+
|
16 |
+
Both the model type and architecture are selected via the ``--arch``
|
17 |
+
command-line argument. Once selected, a model may expose additional command-line
|
18 |
+
arguments for further configuration.
|
19 |
+
|
20 |
+
.. note::
|
21 |
+
|
22 |
+
All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
|
23 |
+
:class:`torch.nn.Module`. Thus any fairseq Model can be used as a
|
24 |
+
stand-alone Module in other PyTorch code.
|
25 |
+
|
26 |
+
|
27 |
+
Convolutional Neural Networks (CNN)
|
28 |
+
-----------------------------------
|
29 |
+
|
30 |
+
.. module:: fairseq.models.fconv
|
31 |
+
.. autoclass:: fairseq.models.fconv.FConvModel
|
32 |
+
:members:
|
33 |
+
.. autoclass:: fairseq.models.fconv.FConvEncoder
|
34 |
+
:members:
|
35 |
+
:undoc-members:
|
36 |
+
.. autoclass:: fairseq.models.fconv.FConvDecoder
|
37 |
+
:members:
|
38 |
+
|
39 |
+
|
40 |
+
Long Short-Term Memory (LSTM) networks
|
41 |
+
--------------------------------------
|
42 |
+
|
43 |
+
.. module:: fairseq.models.lstm
|
44 |
+
.. autoclass:: fairseq.models.lstm.LSTMModel
|
45 |
+
:members:
|
46 |
+
.. autoclass:: fairseq.models.lstm.LSTMEncoder
|
47 |
+
:members:
|
48 |
+
.. autoclass:: fairseq.models.lstm.LSTMDecoder
|
49 |
+
:members:
|
50 |
+
|
51 |
+
|
52 |
+
Transformer (self-attention) networks
|
53 |
+
-------------------------------------
|
54 |
+
|
55 |
+
.. module:: fairseq.models.transformer
|
56 |
+
.. autoclass:: fairseq.models.transformer.TransformerModel
|
57 |
+
:members:
|
58 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoder
|
59 |
+
:members:
|
60 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
|
61 |
+
:members:
|
62 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoder
|
63 |
+
:members:
|
64 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
|
65 |
+
:members:
|
66 |
+
|
67 |
+
|
68 |
+
Adding new models
|
69 |
+
-----------------
|
70 |
+
|
71 |
+
.. currentmodule:: fairseq.models
|
72 |
+
.. autofunction:: fairseq.models.register_model
|
73 |
+
.. autofunction:: fairseq.models.register_model_architecture
|
74 |
+
.. autoclass:: fairseq.models.BaseFairseqModel
|
75 |
+
:members:
|
76 |
+
:undoc-members:
|
77 |
+
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
|
78 |
+
:members:
|
79 |
+
:undoc-members:
|
80 |
+
.. autoclass:: fairseq.models.FairseqEncoderModel
|
81 |
+
:members:
|
82 |
+
:undoc-members:
|
83 |
+
.. autoclass:: fairseq.models.FairseqLanguageModel
|
84 |
+
:members:
|
85 |
+
:undoc-members:
|
86 |
+
.. autoclass:: fairseq.models.FairseqMultiModel
|
87 |
+
:members:
|
88 |
+
:undoc-members:
|
89 |
+
.. autoclass:: fairseq.models.FairseqEncoder
|
90 |
+
:members:
|
91 |
+
.. autoclass:: fairseq.models.CompositeEncoder
|
92 |
+
:members:
|
93 |
+
.. autoclass:: fairseq.models.FairseqDecoder
|
94 |
+
:members:
|
95 |
+
|
96 |
+
|
97 |
+
.. _Incremental decoding:
|
98 |
+
|
99 |
+
Incremental decoding
|
100 |
+
--------------------
|
101 |
+
|
102 |
+
.. autoclass:: fairseq.models.FairseqIncrementalDecoder
|
103 |
+
:members:
|
104 |
+
:undoc-members:
|
fairseq/docs/modules.rst
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Modules
|
2 |
+
=======
|
3 |
+
|
4 |
+
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
|
5 |
+
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
|
6 |
+
|
7 |
+
.. automodule:: fairseq.modules
|
8 |
+
:members:
|
9 |
+
:undoc-members:
|
fairseq/docs/optim.rst
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. _optimizers:
|
5 |
+
|
6 |
+
Optimizers
|
7 |
+
==========
|
8 |
+
|
9 |
+
Optimizers update the Model parameters based on the gradients.
|
10 |
+
|
11 |
+
.. automodule:: fairseq.optim
|
12 |
+
:members:
|
13 |
+
|
14 |
+
.. autoclass:: fairseq.optim.FairseqOptimizer
|
15 |
+
:members:
|
16 |
+
:undoc-members:
|
17 |
+
|
18 |
+
.. autoclass:: fairseq.optim.adadelta.Adadelta
|
19 |
+
:members:
|
20 |
+
:undoc-members:
|
21 |
+
.. autoclass:: fairseq.optim.adagrad.Adagrad
|
22 |
+
:members:
|
23 |
+
:undoc-members:
|
24 |
+
.. autoclass:: fairseq.optim.adafactor.FairseqAdafactor
|
25 |
+
:members:
|
26 |
+
:undoc-members:
|
27 |
+
.. autoclass:: fairseq.optim.adam.FairseqAdam
|
28 |
+
:members:
|
29 |
+
:undoc-members:
|
30 |
+
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
|
31 |
+
:members:
|
32 |
+
:undoc-members:
|
33 |
+
.. autoclass:: fairseq.optim.nag.FairseqNAG
|
34 |
+
:members:
|
35 |
+
:undoc-members:
|
36 |
+
.. autoclass:: fairseq.optim.sgd.SGD
|
37 |
+
:members:
|
38 |
+
:undoc-members:
|