Spaces:
Runtime error
Runtime error
JustinLin610
commited on
Commit
•
085ecd3
1
Parent(s):
38764eb
change the service
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +0 -201
- README.md +0 -27
- README_EncouragingLoss.md +0 -34
- app.py +1 -2
- benpao.jpeg +0 -0
- checkpoints.md +0 -37
- checkpoints_cn.md +0 -82
- colab.md +0 -9
- criterions/__init__.py +0 -4
- criterions/clip_scst_loss.py +0 -277
- criterions/label_smoothed_cross_entropy.py +0 -343
- criterions/label_smoothed_encouraging_loss.py +0 -395
- criterions/scst_loss.py +0 -281
- data/__init__.py +0 -0
- data/cv_data/image_classify_dataset.py +0 -196
- data/data_utils.py +0 -601
- data/file_dataset.py +0 -107
- data/mm_data/__init__.py +0 -0
- data/mm_data/caption_dataset.py +0 -160
- data/mm_data/ocr_dataset.py +0 -210
- data/mm_data/refcoco_dataset.py +0 -174
- data/mm_data/snli_ve_dataset.py +0 -203
- data/mm_data/vqa_gen_dataset.py +0 -218
- data/nlg_data/summary_dataset.py +0 -131
- data/nlu_data/cola_dataset.py +0 -138
- data/nlu_data/mnli_dataset.py +0 -143
- data/nlu_data/mrpc_dataset.py +0 -141
- data/nlu_data/qnli_dataset.py +0 -141
- data/nlu_data/qqp_dataset.py +0 -141
- data/nlu_data/rte_dataset.py +0 -141
- data/nlu_data/sst2_dataset.py +0 -138
- data/ofa_dataset.py +0 -79
- data/pretrain_data/unify_dataset.py +0 -636
- datasets.md +0 -44
- evaluate.py +0 -160
- ezocr/LICENSE +0 -201
- ezocr/README.md +0 -49
- ezocr/build/lib/easyocrlite/__init__.py +0 -1
- ezocr/build/lib/easyocrlite/reader.py +0 -272
- ezocr/build/lib/easyocrlite/types.py +0 -5
- ezocr/easyocrlite.egg-info/PKG-INFO +0 -8
- ezocr/easyocrlite.egg-info/SOURCES.txt +0 -11
- ezocr/easyocrlite.egg-info/dependency_links.txt +0 -1
- ezocr/easyocrlite.egg-info/requires.txt +0 -5
- ezocr/easyocrlite.egg-info/top_level.txt +0 -1
- ezocr/easyocrlite/__init__.py +0 -1
- ezocr/easyocrlite/model/__init__.py +0 -1
- ezocr/easyocrlite/model/craft.py +0 -174
- ezocr/easyocrlite/reader.py +0 -271
- ezocr/easyocrlite/types.py +0 -5
LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright 1999-2022 Alibaba Group Holding Ltd.
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Chinese OCR
|
3 |
-
emoji: 📖
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.9.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
# Configuration
|
12 |
-
`title`: _string_
|
13 |
-
OFA Image Caption
|
14 |
-
`emoji`: _string_
|
15 |
-
🖼
|
16 |
-
`colorFrom`: _string_
|
17 |
-
red
|
18 |
-
`colorTo`: _string_
|
19 |
-
indigo
|
20 |
-
`sdk`: _string_
|
21 |
-
gradio
|
22 |
-
`app_file`: _string_
|
23 |
-
app.py
|
24 |
-
|
25 |
-
|
26 |
-
`pinned`: _boolean_
|
27 |
-
false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_EncouragingLoss.md
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
# Finetuning with Encouraging Loss (EL)
|
2 |
-
Below we provide methods for finetuning with label smoothed encouraging loss proposed in [_Well-classified Examples are Underestimated in Classification with Deep Neural Networks_](https://arxiv.org/pdf/2110.06537.pdf) on different downstream tasks.
|
3 |
-
The implementation is in [label_smoothed_encouraging_loss.py](criterions/label_smoothed_encouraging_loss.py).
|
4 |
-
You can set the `--criterion` to `adjust_label_smoothed_encouraging_loss` to use it. This criterion has a hyper-parameter `--log-end`.
|
5 |
-
`--log-end < 1` results in a approximated and conservative version of the full encouraging loss.
|
6 |
-
A high log_end will more strongly weaken the gradient vanishing, enhance the modeling of the data, and increase the growth rate of the margin, but it will also bring a larger gradient norm, which will bring challenges to the existing optimization system.
|
7 |
-
We recommend higher log_end for cases with higher performance, and 0.75 or 0.5 as your first try.
|
8 |
-
## Image Captioning
|
9 |
-
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
10 |
-
|
11 |
-
<details>
|
12 |
-
<summary><b>Finetuning</b></summary>
|
13 |
-
<p>
|
14 |
-
We propose two scripts for stage1. </b>
|
15 |
-
</p>
|
16 |
-
<pre>
|
17 |
-
cd run_scripts/caption
|
18 |
-
nohup sh train_caption_stage1_el.sh > train_stage1_el.out & # stage 1, train with encouraging loss, expected cider 1.403
|
19 |
-
nohup sh train_caption_stage1_el_db.sh > train_stage1_el.out & # stage 1, train with encouraging loss, and drop best examples, expected cider 1.404
|
20 |
-
</pre>
|
21 |
-
</details>
|
22 |
-
|
23 |
-
## Referring Expression Comprehension
|
24 |
-
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
25 |
-
<details>
|
26 |
-
<summary><b>Finetuning</b></summary>
|
27 |
-
<pre>
|
28 |
-
cd run_scripts/refcoco
|
29 |
-
nohup sh train_refcoco_el.sh > train_refcoco_el.out & # finetune for refcoco
|
30 |
-
nohup sh train_refcocoplus_el.sh > train_refcocoplus_el.out & # finetune for refcoco+
|
31 |
-
nohup sh train_refcocog_el.sh > train_refcocog_el.out & # finetune for refcocog
|
32 |
-
</pre>
|
33 |
-
</details>
|
34 |
-
Evaluation is also the same as the default setting.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -193,8 +193,7 @@ description = "Gradio Demo for Chinese OCR based on OFA-Base. "\
|
|
193 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
194 |
"Repo</a></p> "
|
195 |
examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
|
196 |
-
['qiaodaima.png'], ['
|
197 |
-
['xsd.jpg', 'General']]
|
198 |
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
199 |
outputs=[gr.outputs.Image(type='pil', label='Image'),
|
200 |
gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
|
|
|
193 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
194 |
"Repo</a></p> "
|
195 |
examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
|
196 |
+
['qiaodaima.png'], ['xsd.jpg']]
|
|
|
197 |
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
198 |
outputs=[gr.outputs.Image(type='pil', label='Image'),
|
199 |
gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
|
benpao.jpeg
DELETED
Binary file (6.38 kB)
|
|
checkpoints.md
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
# Checkpoints
|
2 |
-
|
3 |
-
We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.
|
4 |
-
|
5 |
-
## Pretraining
|
6 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
|
7 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a> (~470M parameters)
|
8 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt"> Pre-trained checkpoint (OFA-Base) </a> (~180M parameters)
|
9 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt"> Pre-trained checkpoint (OFA-Medium) </a> (~93M parameters)
|
10 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt"> Pre-trained checkpoint (OFA-Tiny) </a> (~33M parameters)
|
11 |
-
|
12 |
-
## Finetuning (OFA-Huge)
|
13 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_huge_best.pt"> Finetuned checkpoint for Caption on COCO </a>
|
14 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_huge_best.pt"> Finetuned checkpoint for VQAv2 </a>
|
15 |
-
|
16 |
-
## Finetuning (OFA-Large)
|
17 |
-
|
18 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
|
19 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_stage1_best.pt"> Finetuned checkpoint for Caption on COCO During Stage1 Finetuning </a>
|
20 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
|
21 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
|
22 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
|
23 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_large_best.pt"> Finetuned checkpoint for VQAv2 </a>
|
24 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_large_best.pt"> Finetuned checkpoint for SNLI-VE </a>
|
25 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_large_best.zip"> Finetuned checkpoint for Text-to-Image Generation on COCO && CLIP checkpoint && VQGAN checkpoint </a>
|
26 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/imagenet_1k_large_best.pt"> Finetuned checkpoint for ImageNet-1K </a>
|
27 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/gigaword_large_best.pt"> Finetuned checkpoint for Gigaword </a>
|
28 |
-
|
29 |
-
|
30 |
-
## Finetuning (OFA-Base)
|
31 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_base_best.pt"> Finetuned base checkpoint for Caption on COCO </a>
|
32 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_base_best.pt"> Finetuned base checkpoint for RefCOCO </a>
|
33 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_base_best.pt"> Finetuned base checkpoint for RefCOCO+ </a>
|
34 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_base_best.pt"> Finetuned base checkpoint for RefCOCOg </a>
|
35 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_base_best.pt"> Finetuned base checkpoint for VQAv2 </a>
|
36 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_base_best.pt"> Finetuned base checkpoint for SNLI-VE </a>
|
37 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_base_best.pt"> Finetuned base checkpoint for Text-to-Image Generation on COCO </a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoints_cn.md
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
# Checkpoints (OFA-CN)
|
2 |
-
|
3 |
-
We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
|
4 |
-
<br>
|
5 |
-
|
6 |
-
## Checkpoints
|
7 |
-
Below we provide the links for downloading the Chinese OFA checkpoints.
|
8 |
-
|
9 |
-
### Pretraining
|
10 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
|
11 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)
|
12 |
-
|
13 |
-
### Finetuning (OFA-Large)
|
14 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
15 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
16 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
17 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
18 |
-
|
19 |
-
### Finetuning (OFA-Base)
|
20 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
21 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
22 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
23 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
24 |
-
<br>
|
25 |
-
|
26 |
-
## Model Card
|
27 |
-
Below we provide the basic information of the base-size and large-size OFA-CN.
|
28 |
-
|
29 |
-
<table border="1" width="100%">
|
30 |
-
<tr align="center">
|
31 |
-
<th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
|
32 |
-
</tr>
|
33 |
-
<tr align="center">
|
34 |
-
<td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
|
35 |
-
</tr>
|
36 |
-
<tr align="center">
|
37 |
-
<td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
|
38 |
-
</tr>
|
39 |
-
</tr>
|
40 |
-
</table>
|
41 |
-
<br>
|
42 |
-
|
43 |
-
## Results
|
44 |
-
Below we provide the results of OFA-CN and the baselines for comparison.
|
45 |
-
|
46 |
-
### [MUGE Caption]("https://tianchi.aliyun.com/muge")
|
47 |
-
<table border="1" width="100%">
|
48 |
-
<tr align="center">
|
49 |
-
<td>Model</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
|
50 |
-
</tr>
|
51 |
-
<tr align="center">
|
52 |
-
<td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
|
53 |
-
</tr>
|
54 |
-
<tr align="center">
|
55 |
-
<td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
|
56 |
-
</tr>
|
57 |
-
<tr align="center">
|
58 |
-
<td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
|
59 |
-
</tr>
|
60 |
-
<tr align="center">
|
61 |
-
<td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
|
62 |
-
</tr>
|
63 |
-
</table>
|
64 |
-
|
65 |
-
### RefCOCO-CN Series
|
66 |
-
<table border="1" width="100%">
|
67 |
-
<tr align="center">
|
68 |
-
<td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
|
69 |
-
</tr>
|
70 |
-
<tr align="center">
|
71 |
-
<td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
|
72 |
-
</tr>
|
73 |
-
<tr align="center">
|
74 |
-
<td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
|
75 |
-
</tr>
|
76 |
-
<tr align="center">
|
77 |
-
<td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
|
78 |
-
</tr>
|
79 |
-
</table>
|
80 |
-
<br>
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colab.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Colab Notebooks
|
2 |
-
|
3 |
-
We provide Colab notebooks of different downstream tasks for you guys to enjoy OFA. See below.
|
4 |
-
|
5 |
-
* [Image Captioning in Huggingface Transformers](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
|
6 |
-
* [Generic Interface](https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m?usp=sharing#scrollTo=s9Vni6YUZOpC) (using different instructions to perform various tasks with just one model.)
|
7 |
-
* [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
|
8 |
-
* [Referring Expression Comprehension](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
|
9 |
-
* [Open-Domain Visual Question Answering](https://colab.research.google.com/drive/1lsMsF-Vum3MVyXwSVF5E-Y23rHFvj_3y?usp=sharing)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterions/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from .scst_loss import ScstRewardCriterion
|
2 |
-
from .label_smoothed_cross_entropy import AdjustLabelSmoothedCrossEntropyCriterion
|
3 |
-
from .clip_scst_loss import ClipScstRewardCriterion
|
4 |
-
from .label_smoothed_encouraging_loss import AdjustLabelSmoothedEncouragingLossCriterion
|
|
|
|
|
|
|
|
|
|
criterions/clip_scst_loss.py
DELETED
@@ -1,277 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import math
|
7 |
-
from dataclasses import dataclass, field
|
8 |
-
from typing import Optional
|
9 |
-
from PIL import Image
|
10 |
-
from torchvision import transforms
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import numpy as np
|
14 |
-
from fairseq import metrics
|
15 |
-
from fairseq.data import data_utils
|
16 |
-
from fairseq.criterions import FairseqCriterion, register_criterion
|
17 |
-
from fairseq.dataclass import FairseqDataclass
|
18 |
-
from fairseq import utils
|
19 |
-
from omegaconf import II
|
20 |
-
|
21 |
-
from models import clip
|
22 |
-
|
23 |
-
|
24 |
-
def custom_to_pil(x):
|
25 |
-
x = x.detach().cpu()
|
26 |
-
x = torch.clamp(x, -1., 1.)
|
27 |
-
x = (x + 1.) / 2.
|
28 |
-
x = x.permute(1, 2, 0).numpy()
|
29 |
-
x = (255 * x).astype(np.uint8)
|
30 |
-
x = Image.fromarray(x)
|
31 |
-
if not x.mode == "RGB":
|
32 |
-
x = x.convert("RGB")
|
33 |
-
return x
|
34 |
-
|
35 |
-
|
36 |
-
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
|
37 |
-
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
38 |
-
if ignore_index is not None:
|
39 |
-
pad_mask = target.eq(ignore_index)
|
40 |
-
loss.masked_fill_(pad_mask, 0.0)
|
41 |
-
ntokens = (~pad_mask).sum()
|
42 |
-
else:
|
43 |
-
loss = loss.squeeze(-1)
|
44 |
-
ntokens = target.numel()
|
45 |
-
if reduce:
|
46 |
-
loss = loss.sum()
|
47 |
-
return loss, ntokens
|
48 |
-
|
49 |
-
|
50 |
-
@dataclass
|
51 |
-
class ClipScstRewardCriterionConfig(FairseqDataclass):
|
52 |
-
ignore_prefix_size: int = field(
|
53 |
-
default=0,
|
54 |
-
metadata={"help": "Ignore first N tokens"},
|
55 |
-
)
|
56 |
-
sentence_avg: bool = II("optimization.sentence_avg")
|
57 |
-
constraint_range: Optional[str] = field(
|
58 |
-
default=None,
|
59 |
-
metadata={"help": "constraint range"}
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
@register_criterion(
|
64 |
-
"clip_scst_reward_criterion", dataclass=ClipScstRewardCriterionConfig
|
65 |
-
)
|
66 |
-
class ClipScstRewardCriterion(FairseqCriterion):
|
67 |
-
CLIP_REWARD_WEIGHT = 2.5
|
68 |
-
|
69 |
-
def __init__(
|
70 |
-
self,
|
71 |
-
task,
|
72 |
-
sentence_avg,
|
73 |
-
ignore_prefix_size=0,
|
74 |
-
constraint_range=None
|
75 |
-
):
|
76 |
-
super().__init__(task)
|
77 |
-
self.sentence_avg = sentence_avg
|
78 |
-
self.ignore_prefix_size = ignore_prefix_size
|
79 |
-
|
80 |
-
self.constraint_start = None
|
81 |
-
self.constraint_end = None
|
82 |
-
if constraint_range is not None:
|
83 |
-
constraint_start, constraint_end = constraint_range.split(',')
|
84 |
-
self.constraint_start = int(constraint_start)
|
85 |
-
self.constraint_end = int(constraint_end)
|
86 |
-
|
87 |
-
def forward(self, model, sample, update_num=0, reduce=True):
|
88 |
-
"""Compute the loss for the given sample.
|
89 |
-
|
90 |
-
Returns a tuple with three elements:
|
91 |
-
1) the loss
|
92 |
-
2) the sample size, which is used as the denominator for the gradient
|
93 |
-
3) logging outputs to display while training
|
94 |
-
"""
|
95 |
-
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
96 |
-
|
97 |
-
sample_size = (
|
98 |
-
nsentences if self.sentence_avg else ntokens
|
99 |
-
)
|
100 |
-
logging_output = {
|
101 |
-
"loss": loss.data,
|
102 |
-
"score": score,
|
103 |
-
"ntokens": ntokens,
|
104 |
-
"nsentences": nsentences,
|
105 |
-
"sample_size": sample_size,
|
106 |
-
}
|
107 |
-
return loss, sample_size, logging_output
|
108 |
-
|
109 |
-
def _calculate_clip_scores(self, gen_res, gt_text, device):
|
110 |
-
'''
|
111 |
-
gen_res: generated images, list of Image
|
112 |
-
gt_text: input captions.
|
113 |
-
device: device for clip model
|
114 |
-
'''
|
115 |
-
batch_size = len(gt_text)
|
116 |
-
gen_res_size = len(gen_res)
|
117 |
-
img_per_seq = gen_res_size // batch_size
|
118 |
-
|
119 |
-
hyp_images = torch.stack(
|
120 |
-
[self.task.clip_preprocess(gen_image) for gen_image in gen_res], dim=0
|
121 |
-
).to(device)
|
122 |
-
|
123 |
-
clip_input = clip.tokenize([text for text in gt_text]).to(device)
|
124 |
-
with torch.no_grad():
|
125 |
-
image_features = self.task.clip_model.encode_image(hyp_images)
|
126 |
-
text_features = self.task.clip_model.encode_text(clip_input)
|
127 |
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
128 |
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
129 |
-
image_features = image_features.view(batch_size, img_per_seq, -1)
|
130 |
-
text_features = text_features.view(batch_size, 1, -1)
|
131 |
-
ti_similarity = image_features @ text_features.transpose(1, 2)
|
132 |
-
ti_similarity = ti_similarity.view(-1)
|
133 |
-
|
134 |
-
scores = self.CLIP_REWARD_WEIGHT * ti_similarity
|
135 |
-
return scores
|
136 |
-
|
137 |
-
def get_generator_out(self, model, sample):
|
138 |
-
model.eval()
|
139 |
-
with torch.no_grad():
|
140 |
-
self.task.scst_generator.model.eval()
|
141 |
-
gen_out = self.task.scst_generator.generate([model], sample)
|
142 |
-
|
143 |
-
gen_target = []
|
144 |
-
gen_res = []
|
145 |
-
gt_text = []
|
146 |
-
for i in range(len(gen_out)):
|
147 |
-
with torch.no_grad():
|
148 |
-
tokens = torch.stack([item['tokens'][:-1] for item in gen_out[i]], dim=0)
|
149 |
-
tokens += -len(self.task.src_dict) + self.task.cfg.code_dict_size + self.task.cfg.num_bins
|
150 |
-
images = self.task.image_tokenizer.decode_code(
|
151 |
-
tokens.view(-1, self.task.cfg.code_image_size // 8, self.task.cfg.code_image_size // 8)
|
152 |
-
)
|
153 |
-
images = [custom_to_pil(image) for image in images]
|
154 |
-
|
155 |
-
gen_target += [item['tokens'] for item in gen_out[i]]
|
156 |
-
gen_res += images
|
157 |
-
gt_text.append(
|
158 |
-
self.task.bpe.decode(
|
159 |
-
self.task.tgt_dict.string(
|
160 |
-
utils.strip_pad(sample['net_input']['src_tokens'][i], self.padding_idx).cpu().int()
|
161 |
-
)
|
162 |
-
)[38:] # remove task instruction.
|
163 |
-
)
|
164 |
-
|
165 |
-
return gen_target, gen_res, gt_text
|
166 |
-
|
167 |
-
def get_reward_and_scores(self, gen_res, gt_text, device):
|
168 |
-
batch_size = len(gt_text)
|
169 |
-
gen_res_size = len(gen_res)
|
170 |
-
img_per_sample = gen_res_size // batch_size
|
171 |
-
|
172 |
-
scores = self._calculate_clip_scores(gen_res, gt_text, device)
|
173 |
-
sc_ = scores.reshape(batch_size, img_per_sample)
|
174 |
-
baseline = (sc_.sum(1, keepdim=True) - sc_) / (sc_.shape[1] - 1)
|
175 |
-
# sample - baseline
|
176 |
-
reward = scores.reshape(batch_size, img_per_sample)
|
177 |
-
reward = reward - baseline
|
178 |
-
reward = reward.view(-1)
|
179 |
-
|
180 |
-
return reward, scores
|
181 |
-
|
182 |
-
def get_net_output(self, model, sample, gen_target):
|
183 |
-
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
184 |
-
return data_utils.collate_tokens(
|
185 |
-
sample_list,
|
186 |
-
pad_idx=self.padding_idx,
|
187 |
-
eos_idx=eos,
|
188 |
-
left_pad=False,
|
189 |
-
move_eos_to_beginning=move_eos_to_beginning,
|
190 |
-
)
|
191 |
-
|
192 |
-
batch_size = len(sample["target"])
|
193 |
-
gen_target_size = len(gen_target)
|
194 |
-
img_per_sample = gen_target_size // batch_size
|
195 |
-
|
196 |
-
model.train()
|
197 |
-
sample_src_tokens = torch.repeat_interleave(
|
198 |
-
sample['net_input']['src_tokens'], img_per_sample, dim=0
|
199 |
-
)
|
200 |
-
sample_src_lengths = torch.repeat_interleave(
|
201 |
-
sample['net_input']['src_lengths'], img_per_sample, dim=0
|
202 |
-
)
|
203 |
-
sample_code_masks = torch.repeat_interleave(
|
204 |
-
sample['net_input']['code_masks'], img_per_sample, dim=0
|
205 |
-
)
|
206 |
-
gen_prev_output_tokens = torch.as_tensor(
|
207 |
-
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
208 |
-
device=sample["target"].device, dtype=torch.int64
|
209 |
-
)
|
210 |
-
gen_target_tokens = torch.as_tensor(
|
211 |
-
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
212 |
-
)
|
213 |
-
net_output = model(
|
214 |
-
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
215 |
-
code_masks=sample_code_masks, prev_output_tokens=gen_prev_output_tokens
|
216 |
-
)
|
217 |
-
|
218 |
-
return net_output, gen_target_tokens
|
219 |
-
|
220 |
-
def get_lprobs_and_target(self, model, net_output, gen_target):
|
221 |
-
if self.constraint_start is not None and self.constraint_end is not None:
|
222 |
-
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
223 |
-
net_output[0][:, :, self.constraint_end:] = -math.inf
|
224 |
-
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
225 |
-
if self.ignore_prefix_size > 0:
|
226 |
-
if getattr(lprobs, "batch_first", False):
|
227 |
-
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
228 |
-
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
229 |
-
else:
|
230 |
-
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
231 |
-
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
232 |
-
return lprobs, gen_target
|
233 |
-
|
234 |
-
def compute_loss(self, model, sample, reduce=True):
|
235 |
-
gen_target, gen_res, gt_text = self.get_generator_out(model, sample)
|
236 |
-
reward, scores = self.get_reward_and_scores(gen_res, gt_text, device=sample["target"].device)
|
237 |
-
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
238 |
-
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
239 |
-
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
240 |
-
nsentences = gen_target_tokens.size(0)
|
241 |
-
|
242 |
-
return loss, scores.sum(), ntokens, nsentences
|
243 |
-
|
244 |
-
@classmethod
|
245 |
-
def reduce_metrics(cls, logging_outputs) -> None:
|
246 |
-
"""Aggregate logging outputs from data parallel training."""
|
247 |
-
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
248 |
-
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
249 |
-
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
250 |
-
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
251 |
-
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
252 |
-
|
253 |
-
metrics.log_scalar(
|
254 |
-
"loss", loss_sum / sample_size, sample_size, round=3
|
255 |
-
)
|
256 |
-
metrics.log_scalar(
|
257 |
-
"score", score_sum / nsentences, nsentences, round=3
|
258 |
-
)
|
259 |
-
|
260 |
-
metrics.log_scalar(
|
261 |
-
"ntokens", ntokens, 1, round=3
|
262 |
-
)
|
263 |
-
metrics.log_scalar(
|
264 |
-
"nsentences", nsentences, 1, round=3
|
265 |
-
)
|
266 |
-
metrics.log_scalar(
|
267 |
-
"sample_size", sample_size, 1, round=3
|
268 |
-
)
|
269 |
-
|
270 |
-
@staticmethod
|
271 |
-
def logging_outputs_can_be_summed() -> bool:
|
272 |
-
"""
|
273 |
-
Whether the logging outputs returned by `forward` can be summed
|
274 |
-
across workers prior to calling `reduce_metrics`. Setting this
|
275 |
-
to True will improves distributed training speed.
|
276 |
-
"""
|
277 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterions/label_smoothed_cross_entropy.py
DELETED
@@ -1,343 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import math
|
7 |
-
from dataclasses import dataclass, field
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import numpy as np
|
13 |
-
from fairseq import metrics, utils
|
14 |
-
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
-
from fairseq.dataclass import FairseqDataclass
|
16 |
-
from omegaconf import II
|
17 |
-
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class AdjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
21 |
-
label_smoothing: float = field(
|
22 |
-
default=0.0,
|
23 |
-
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
24 |
-
)
|
25 |
-
report_accuracy: bool = field(
|
26 |
-
default=False,
|
27 |
-
metadata={"help": "report accuracy metric"},
|
28 |
-
)
|
29 |
-
ignore_prefix_size: int = field(
|
30 |
-
default=0,
|
31 |
-
metadata={"help": "Ignore first N tokens"},
|
32 |
-
)
|
33 |
-
ignore_eos: bool = field(
|
34 |
-
default=False,
|
35 |
-
metadata={"help": "Ignore eos token"},
|
36 |
-
)
|
37 |
-
sentence_avg: bool = II("optimization.sentence_avg")
|
38 |
-
drop_worst_ratio: float = field(
|
39 |
-
default=0.0,
|
40 |
-
metadata={"help": "ratio for discarding bad samples"},
|
41 |
-
)
|
42 |
-
drop_worst_after: int = field(
|
43 |
-
default=0,
|
44 |
-
metadata={"help": "steps for discarding bad samples"},
|
45 |
-
)
|
46 |
-
use_rdrop: bool = field(
|
47 |
-
default=False, metadata={"help": "use R-Drop"}
|
48 |
-
)
|
49 |
-
reg_alpha: float = field(
|
50 |
-
default=1.0, metadata={"help": "weight for R-Drop"}
|
51 |
-
)
|
52 |
-
sample_patch_num: int = field(
|
53 |
-
default=196, metadata={"help": "sample patches for v1"}
|
54 |
-
)
|
55 |
-
constraint_range: Optional[str] = field(
|
56 |
-
default=None,
|
57 |
-
metadata={"help": "constraint range"}
|
58 |
-
)
|
59 |
-
|
60 |
-
|
61 |
-
def construct_rdrop_sample(x):
|
62 |
-
if isinstance(x, dict):
|
63 |
-
for key in x:
|
64 |
-
x[key] = construct_rdrop_sample(x[key])
|
65 |
-
return x
|
66 |
-
elif isinstance(x, torch.Tensor):
|
67 |
-
return x.repeat(2, *([1] * (x.dim()-1)))
|
68 |
-
elif isinstance(x, int):
|
69 |
-
return x * 2
|
70 |
-
elif isinstance(x, np.ndarray):
|
71 |
-
return x.repeat(2)
|
72 |
-
else:
|
73 |
-
raise NotImplementedError
|
74 |
-
|
75 |
-
|
76 |
-
def kl_loss(p, q):
|
77 |
-
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
78 |
-
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
79 |
-
loss = (p_loss + q_loss) / 2
|
80 |
-
return loss
|
81 |
-
|
82 |
-
|
83 |
-
def label_smoothed_nll_loss(
|
84 |
-
lprobs, target, epsilon, update_num, reduce=True,
|
85 |
-
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
86 |
-
constraint_masks=None, constraint_start=None, constraint_end=None
|
87 |
-
):
|
88 |
-
if target.dim() == lprobs.dim() - 1:
|
89 |
-
target = target.unsqueeze(-1)
|
90 |
-
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
91 |
-
if constraint_masks is not None:
|
92 |
-
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
93 |
-
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
94 |
-
elif constraint_start is not None and constraint_end is not None:
|
95 |
-
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
96 |
-
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
97 |
-
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
98 |
-
else:
|
99 |
-
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
100 |
-
eps_i = epsilon / (lprobs.size(-1) - 1)
|
101 |
-
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
102 |
-
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
103 |
-
if use_rdrop:
|
104 |
-
true_batch_size = loss.size(0) // 2
|
105 |
-
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
106 |
-
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
107 |
-
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
108 |
-
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
109 |
-
else:
|
110 |
-
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
111 |
-
nll_loss = nll_loss[indices]
|
112 |
-
lprobs = lprobs[indices]
|
113 |
-
|
114 |
-
ntokens = loss.numel()
|
115 |
-
nll_loss = nll_loss.sum()
|
116 |
-
loss = loss.sum()
|
117 |
-
if use_rdrop:
|
118 |
-
true_batch_size = lprobs.size(0) // 2
|
119 |
-
p = lprobs[:true_batch_size]
|
120 |
-
q = lprobs[true_batch_size:]
|
121 |
-
if constraint_start is not None and constraint_end is not None:
|
122 |
-
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
123 |
-
p = p[:, constraint_range]
|
124 |
-
q = q[:, constraint_range]
|
125 |
-
loss += kl_loss(p, q) * reg_alpha
|
126 |
-
|
127 |
-
return loss, nll_loss, ntokens
|
128 |
-
|
129 |
-
|
130 |
-
@register_criterion(
|
131 |
-
"adjust_label_smoothed_cross_entropy", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
|
132 |
-
)
|
133 |
-
class AdjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
134 |
-
def __init__(
|
135 |
-
self,
|
136 |
-
task,
|
137 |
-
sentence_avg,
|
138 |
-
label_smoothing,
|
139 |
-
ignore_prefix_size=0,
|
140 |
-
ignore_eos=False,
|
141 |
-
report_accuracy=False,
|
142 |
-
drop_worst_ratio=0,
|
143 |
-
drop_worst_after=0,
|
144 |
-
use_rdrop=False,
|
145 |
-
reg_alpha=1.0,
|
146 |
-
sample_patch_num=196,
|
147 |
-
constraint_range=None
|
148 |
-
):
|
149 |
-
super().__init__(task)
|
150 |
-
self.sentence_avg = sentence_avg
|
151 |
-
self.eps = label_smoothing
|
152 |
-
self.ignore_prefix_size = ignore_prefix_size
|
153 |
-
self.ignore_eos = ignore_eos
|
154 |
-
self.report_accuracy = report_accuracy
|
155 |
-
self.drop_worst_ratio = drop_worst_ratio
|
156 |
-
self.drop_worst_after = drop_worst_after
|
157 |
-
self.use_rdrop = use_rdrop
|
158 |
-
self.reg_alpha = reg_alpha
|
159 |
-
self.sample_patch_num = sample_patch_num
|
160 |
-
|
161 |
-
self.constraint_start = None
|
162 |
-
self.constraint_end = None
|
163 |
-
if constraint_range is not None:
|
164 |
-
constraint_start, constraint_end = constraint_range.split(',')
|
165 |
-
self.constraint_start = int(constraint_start)
|
166 |
-
self.constraint_end = int(constraint_end)
|
167 |
-
|
168 |
-
def forward(self, model, sample, update_num=0, reduce=True):
|
169 |
-
"""Compute the loss for the given sample.
|
170 |
-
|
171 |
-
Returns a tuple with three elements:
|
172 |
-
1) the loss
|
173 |
-
2) the sample size, which is used as the denominator for the gradient
|
174 |
-
3) logging outputs to display while training
|
175 |
-
"""
|
176 |
-
if isinstance(sample, list):
|
177 |
-
if self.sample_patch_num > 0:
|
178 |
-
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
179 |
-
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
180 |
-
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
181 |
-
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
182 |
-
sample_size = 1
|
183 |
-
logging_output = {
|
184 |
-
"loss": loss.data,
|
185 |
-
"loss_v1": loss_v1.data,
|
186 |
-
"loss_v2": loss_v2.data,
|
187 |
-
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
188 |
-
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
189 |
-
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
190 |
-
"sample_size": 1,
|
191 |
-
"sample_size_v1": sample_size_v1,
|
192 |
-
"sample_size_v2": sample_size_v2,
|
193 |
-
}
|
194 |
-
return loss, sample_size, logging_output
|
195 |
-
|
196 |
-
if self.use_rdrop:
|
197 |
-
construct_rdrop_sample(sample)
|
198 |
-
|
199 |
-
net_output = model(**sample["net_input"])
|
200 |
-
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
201 |
-
sample_size = (
|
202 |
-
sample["target"].size(0) if self.sentence_avg else ntokens
|
203 |
-
)
|
204 |
-
logging_output = {
|
205 |
-
"loss": loss.data,
|
206 |
-
"nll_loss": nll_loss.data,
|
207 |
-
"ntokens": sample["ntokens"],
|
208 |
-
"nsentences": sample["nsentences"],
|
209 |
-
"sample_size": sample_size,
|
210 |
-
}
|
211 |
-
if self.report_accuracy:
|
212 |
-
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
213 |
-
logging_output["n_correct"] = utils.item(n_correct.data)
|
214 |
-
logging_output["total"] = utils.item(total.data)
|
215 |
-
return loss, sample_size, logging_output
|
216 |
-
|
217 |
-
def get_lprobs_and_target(self, model, net_output, sample):
|
218 |
-
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
219 |
-
constraint_masks = None
|
220 |
-
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
221 |
-
constraint_masks = sample["constraint_masks"]
|
222 |
-
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
223 |
-
if self.constraint_start is not None and self.constraint_end is not None:
|
224 |
-
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
225 |
-
net_output[0][:, :, self.constraint_end:] = -math.inf
|
226 |
-
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
227 |
-
target = model.get_targets(sample, net_output)
|
228 |
-
if self.ignore_prefix_size > 0:
|
229 |
-
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
230 |
-
target = target[:, self.ignore_prefix_size :].contiguous()
|
231 |
-
if constraint_masks is not None:
|
232 |
-
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
233 |
-
if self.ignore_eos:
|
234 |
-
bsz, seq_len, embed_dim = lprobs.size()
|
235 |
-
eos_indices = target.eq(self.task.tgt_dict.eos())
|
236 |
-
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
237 |
-
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
238 |
-
if constraint_masks is not None:
|
239 |
-
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
240 |
-
if constraint_masks is not None:
|
241 |
-
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
242 |
-
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
243 |
-
|
244 |
-
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
245 |
-
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
246 |
-
if constraint_masks is not None:
|
247 |
-
constraint_masks = constraint_masks[target != self.padding_idx]
|
248 |
-
lprobs = lprobs[target != self.padding_idx]
|
249 |
-
target = target[target != self.padding_idx]
|
250 |
-
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
251 |
-
lprobs,
|
252 |
-
target,
|
253 |
-
self.eps,
|
254 |
-
update_num,
|
255 |
-
reduce=reduce,
|
256 |
-
drop_worst_ratio=self.drop_worst_ratio,
|
257 |
-
drop_worst_after=self.drop_worst_after,
|
258 |
-
use_rdrop=self.use_rdrop,
|
259 |
-
reg_alpha=self.reg_alpha,
|
260 |
-
constraint_masks=constraint_masks,
|
261 |
-
constraint_start=self.constraint_start,
|
262 |
-
constraint_end=self.constraint_end
|
263 |
-
)
|
264 |
-
return loss, nll_loss, ntokens
|
265 |
-
|
266 |
-
def compute_accuracy(self, model, net_output, sample):
|
267 |
-
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
268 |
-
mask = target.ne(self.padding_idx)
|
269 |
-
n_correct = torch.sum(
|
270 |
-
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
271 |
-
)
|
272 |
-
total = torch.sum(mask)
|
273 |
-
return n_correct, total
|
274 |
-
|
275 |
-
@classmethod
|
276 |
-
def reduce_metrics(cls, logging_outputs) -> None:
|
277 |
-
"""Aggregate logging outputs from data parallel training."""
|
278 |
-
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
279 |
-
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
280 |
-
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
281 |
-
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
282 |
-
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
283 |
-
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
284 |
-
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
285 |
-
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
286 |
-
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
287 |
-
|
288 |
-
metrics.log_scalar(
|
289 |
-
"loss", loss_sum / sample_size, sample_size, round=3
|
290 |
-
)
|
291 |
-
metrics.log_scalar(
|
292 |
-
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
293 |
-
)
|
294 |
-
metrics.log_scalar(
|
295 |
-
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
296 |
-
)
|
297 |
-
metrics.log_scalar(
|
298 |
-
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
299 |
-
)
|
300 |
-
metrics.log_derived(
|
301 |
-
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
302 |
-
)
|
303 |
-
|
304 |
-
metrics.log_scalar(
|
305 |
-
"ntokens", ntokens, 1, round=3
|
306 |
-
)
|
307 |
-
metrics.log_scalar(
|
308 |
-
"nsentences", nsentences, 1, round=3
|
309 |
-
)
|
310 |
-
metrics.log_scalar(
|
311 |
-
"sample_size", sample_size, 1, round=3
|
312 |
-
)
|
313 |
-
metrics.log_scalar(
|
314 |
-
"sample_size_v1", sample_size_v1, 1, round=3
|
315 |
-
)
|
316 |
-
metrics.log_scalar(
|
317 |
-
"sample_size_v2", sample_size_v2, 1, round=3
|
318 |
-
)
|
319 |
-
|
320 |
-
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
321 |
-
if total > 0:
|
322 |
-
metrics.log_scalar("total", total)
|
323 |
-
n_correct = utils.item(
|
324 |
-
sum(log.get("n_correct", 0) for log in logging_outputs)
|
325 |
-
)
|
326 |
-
metrics.log_scalar("n_correct", n_correct)
|
327 |
-
metrics.log_derived(
|
328 |
-
"accuracy",
|
329 |
-
lambda meters: round(
|
330 |
-
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
331 |
-
)
|
332 |
-
if meters["total"].sum > 0
|
333 |
-
else float("nan"),
|
334 |
-
)
|
335 |
-
|
336 |
-
@staticmethod
|
337 |
-
def logging_outputs_can_be_summed() -> bool:
|
338 |
-
"""
|
339 |
-
Whether the logging outputs returned by `forward` can be summed
|
340 |
-
across workers prior to calling `reduce_metrics`. Setting this
|
341 |
-
to True will improves distributed training speed.
|
342 |
-
"""
|
343 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterions/label_smoothed_encouraging_loss.py
DELETED
@@ -1,395 +0,0 @@
|
|
1 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the MIT license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
import math
|
7 |
-
from dataclasses import dataclass, field
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import numpy as np
|
13 |
-
from fairseq import metrics, utils
|
14 |
-
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
-
from fairseq.dataclass import FairseqDataclass
|
16 |
-
from omegaconf import II
|
17 |
-
|
18 |
-
|
19 |
-
@dataclass
|
20 |
-
class AdjustLabelSmoothedEncouragingLossConfig(FairseqDataclass):
|
21 |
-
label_smoothing: float = field(
|
22 |
-
default=0.0,
|
23 |
-
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
24 |
-
)
|
25 |
-
report_accuracy: bool = field(
|
26 |
-
default=False,
|
27 |
-
metadata={"help": "report accuracy metric"},
|
28 |
-
)
|
29 |
-
ignore_prefix_size: int = field(
|
30 |
-
default=0,
|
31 |
-
metadata={"help": "Ignore first N tokens"},
|
32 |
-
)
|
33 |
-
ignore_eos: bool = field(
|
34 |
-
default=False,
|
35 |
-
metadata={"help": "Ignore eos token"},
|
36 |
-
)
|
37 |
-
sentence_avg: bool = II("optimization.sentence_avg")
|
38 |
-
drop_worst_ratio: float = field(
|
39 |
-
default=0.0,
|
40 |
-
metadata={"help": "ratio for discarding bad samples"},
|
41 |
-
)
|
42 |
-
drop_worst_after: int = field(
|
43 |
-
default=0,
|
44 |
-
metadata={"help": "steps for discarding bad samples"},
|
45 |
-
)
|
46 |
-
use_rdrop: bool = field(
|
47 |
-
default=False, metadata={"help": "use R-Drop"}
|
48 |
-
)
|
49 |
-
reg_alpha: float = field(
|
50 |
-
default=1.0, metadata={"help": "weight for R-Drop"}
|
51 |
-
)
|
52 |
-
sample_patch_num: int = field(
|
53 |
-
default=196, metadata={"help": "sample patchs for v1"}
|
54 |
-
)
|
55 |
-
constraint_range: Optional[str] = field(
|
56 |
-
default=None,
|
57 |
-
metadata={"help": "constraint range"}
|
58 |
-
)
|
59 |
-
log_end: float = field(
|
60 |
-
default=0.75,
|
61 |
-
metadata={"help": "higher log_end is for cases with higher performance,"
|
62 |
-
" we recommend 0.75 or 0.5 as your first try."}
|
63 |
-
)
|
64 |
-
drop_best_ratio: float = field(
|
65 |
-
default=0.0,
|
66 |
-
metadata={"help": "ratio for discarding best samples"},
|
67 |
-
)
|
68 |
-
drop_best_after: int = field(
|
69 |
-
default=0,
|
70 |
-
metadata={"help": "steps for discarding best samples"},
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
def construct_rdrop_sample(x):
|
76 |
-
if isinstance(x, dict):
|
77 |
-
for key in x:
|
78 |
-
x[key] = construct_rdrop_sample(x[key])
|
79 |
-
return x
|
80 |
-
elif isinstance(x, torch.Tensor):
|
81 |
-
return x.repeat(2, *([1] * (x.dim()-1)))
|
82 |
-
elif isinstance(x, int):
|
83 |
-
return x * 2
|
84 |
-
elif isinstance(x, np.ndarray):
|
85 |
-
return x.repeat(2)
|
86 |
-
else:
|
87 |
-
raise NotImplementedError
|
88 |
-
|
89 |
-
|
90 |
-
def kl_loss(p, q):
|
91 |
-
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
92 |
-
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
93 |
-
loss = (p_loss + q_loss) / 2
|
94 |
-
return loss
|
95 |
-
|
96 |
-
|
97 |
-
def label_smoothed_nll_loss(
|
98 |
-
lprobs, target, epsilon, update_num, reduce=True,
|
99 |
-
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
100 |
-
constraint_masks=None, constraint_start=None, constraint_end=None, drop_best_ratio=0.0,
|
101 |
-
drop_best_after=0,
|
102 |
-
):
|
103 |
-
if target.dim() == lprobs.dim() - 1:
|
104 |
-
target = target.unsqueeze(-1)
|
105 |
-
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
106 |
-
if constraint_masks is not None:
|
107 |
-
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
108 |
-
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
109 |
-
elif constraint_start is not None and constraint_end is not None:
|
110 |
-
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
111 |
-
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
112 |
-
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
113 |
-
else:
|
114 |
-
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
115 |
-
eps_i = epsilon / (lprobs.size(-1) - 1)
|
116 |
-
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
117 |
-
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
118 |
-
if use_rdrop:
|
119 |
-
true_batch_size = loss.size(0) // 2
|
120 |
-
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
121 |
-
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
122 |
-
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
123 |
-
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
124 |
-
else:
|
125 |
-
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
126 |
-
nll_loss = nll_loss[indices]
|
127 |
-
lprobs = lprobs[indices]
|
128 |
-
target = target[indices]
|
129 |
-
if update_num > drop_best_after:
|
130 |
-
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_best_ratio)), largest=True)
|
131 |
-
nll_loss = nll_loss[indices]
|
132 |
-
lprobs = lprobs[indices]
|
133 |
-
target = target[indices]
|
134 |
-
|
135 |
-
ntokens = loss.numel()
|
136 |
-
nll_loss = nll_loss.sum()
|
137 |
-
loss = loss.sum()
|
138 |
-
if use_rdrop:
|
139 |
-
true_batch_size = lprobs.size(0) // 2
|
140 |
-
p = lprobs[:true_batch_size]
|
141 |
-
q = lprobs[true_batch_size:]
|
142 |
-
if constraint_start is not None and constraint_end is not None:
|
143 |
-
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
144 |
-
p = p[:, constraint_range]
|
145 |
-
q = q[:, constraint_range]
|
146 |
-
loss += kl_loss(p, q) * reg_alpha
|
147 |
-
|
148 |
-
return loss, nll_loss, ntokens,lprobs,target
|
149 |
-
|
150 |
-
|
151 |
-
@register_criterion(
|
152 |
-
"adjust_label_smoothed_encouraging_loss", dataclass=AdjustLabelSmoothedEncouragingLossConfig
|
153 |
-
)
|
154 |
-
class AdjustLabelSmoothedEncouragingLossCriterion(FairseqCriterion):
|
155 |
-
def __init__(
|
156 |
-
self,
|
157 |
-
task,
|
158 |
-
sentence_avg,
|
159 |
-
label_smoothing,
|
160 |
-
ignore_prefix_size=0,
|
161 |
-
ignore_eos=False,
|
162 |
-
report_accuracy=False,
|
163 |
-
drop_worst_ratio=0,
|
164 |
-
drop_worst_after=0,
|
165 |
-
use_rdrop=False,
|
166 |
-
reg_alpha=1.0,
|
167 |
-
sample_patch_num=196,
|
168 |
-
constraint_range=None,
|
169 |
-
log_end=0.75,
|
170 |
-
drop_best_ratio=0.0,
|
171 |
-
drop_best_after=0,
|
172 |
-
):
|
173 |
-
super().__init__(task)
|
174 |
-
self.sentence_avg = sentence_avg
|
175 |
-
self.eps = label_smoothing
|
176 |
-
self.ignore_prefix_size = ignore_prefix_size
|
177 |
-
self.ignore_eos = ignore_eos
|
178 |
-
self.report_accuracy = report_accuracy
|
179 |
-
self.drop_worst_ratio = drop_worst_ratio
|
180 |
-
self.drop_worst_after = drop_worst_after
|
181 |
-
self.use_rdrop = use_rdrop
|
182 |
-
self.reg_alpha = reg_alpha
|
183 |
-
self.sample_patch_num = sample_patch_num
|
184 |
-
|
185 |
-
self.constraint_start = None
|
186 |
-
self.constraint_end = None
|
187 |
-
if constraint_range is not None:
|
188 |
-
constraint_start, constraint_end = constraint_range.split(',')
|
189 |
-
self.constraint_start = int(constraint_start)
|
190 |
-
self.constraint_end = int(constraint_end)
|
191 |
-
self.log_end = log_end
|
192 |
-
self.drop_best_ratio = drop_best_ratio
|
193 |
-
self.drop_best_after = drop_best_after
|
194 |
-
print('el, self.log_end=', self.log_end)
|
195 |
-
# @staticmethod
|
196 |
-
# def add_args(parser):
|
197 |
-
# """Add criterion-specific arguments to the parser."""
|
198 |
-
# # fmt: off
|
199 |
-
# parser.add_argument('--log_end', type=float, default=1.0)
|
200 |
-
|
201 |
-
def forward(self, model, sample, update_num=0, reduce=True):
|
202 |
-
"""Compute the loss for the given sample.
|
203 |
-
|
204 |
-
Returns a tuple with three elements:
|
205 |
-
1) the loss
|
206 |
-
2) the sample size, which is used as the denominator for the gradient
|
207 |
-
3) logging outputs to display while training
|
208 |
-
"""
|
209 |
-
if isinstance(sample, list):
|
210 |
-
if self.sample_patch_num > 0:
|
211 |
-
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
212 |
-
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
213 |
-
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
214 |
-
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
215 |
-
sample_size = 1
|
216 |
-
logging_output = {
|
217 |
-
"loss": loss.data,
|
218 |
-
"loss_v1": loss_v1.data,
|
219 |
-
"loss_v2": loss_v2.data,
|
220 |
-
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
221 |
-
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
222 |
-
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
223 |
-
"sample_size": 1,
|
224 |
-
"sample_size_v1": sample_size_v1,
|
225 |
-
"sample_size_v2": sample_size_v2,
|
226 |
-
}
|
227 |
-
return loss, sample_size, logging_output
|
228 |
-
|
229 |
-
if self.use_rdrop:
|
230 |
-
construct_rdrop_sample(sample)
|
231 |
-
|
232 |
-
net_output = model(**sample["net_input"])
|
233 |
-
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
234 |
-
sample_size = (
|
235 |
-
sample["target"].size(0) if self.sentence_avg else ntokens
|
236 |
-
)
|
237 |
-
logging_output = {
|
238 |
-
"loss": loss.data,
|
239 |
-
"nll_loss": nll_loss.data,
|
240 |
-
"ntokens": sample["ntokens"],
|
241 |
-
"nsentences": sample["nsentences"],
|
242 |
-
"sample_size": sample_size,
|
243 |
-
}
|
244 |
-
if self.report_accuracy:
|
245 |
-
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
246 |
-
logging_output["n_correct"] = utils.item(n_correct.data)
|
247 |
-
logging_output["total"] = utils.item(total.data)
|
248 |
-
return loss, sample_size, logging_output
|
249 |
-
|
250 |
-
def get_lprobs_and_target(self, model, net_output, sample):
|
251 |
-
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
252 |
-
constraint_masks = None
|
253 |
-
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
254 |
-
constraint_masks = sample["constraint_masks"]
|
255 |
-
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
256 |
-
if self.constraint_start is not None and self.constraint_end is not None:
|
257 |
-
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
258 |
-
net_output[0][:, :, self.constraint_end:] = -math.inf
|
259 |
-
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
260 |
-
target = model.get_targets(sample, net_output)
|
261 |
-
if self.ignore_prefix_size > 0:
|
262 |
-
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
263 |
-
target = target[:, self.ignore_prefix_size :].contiguous()
|
264 |
-
if constraint_masks is not None:
|
265 |
-
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
266 |
-
if self.ignore_eos:
|
267 |
-
bsz, seq_len, embed_dim = lprobs.size()
|
268 |
-
eos_indices = target.eq(self.task.tgt_dict.eos())
|
269 |
-
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
270 |
-
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
271 |
-
if constraint_masks is not None:
|
272 |
-
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
273 |
-
if constraint_masks is not None:
|
274 |
-
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
275 |
-
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
276 |
-
|
277 |
-
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
278 |
-
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
279 |
-
if constraint_masks is not None:
|
280 |
-
constraint_masks = constraint_masks[target != self.padding_idx]
|
281 |
-
lprobs = lprobs[target != self.padding_idx]
|
282 |
-
target = target[target != self.padding_idx]
|
283 |
-
loss, nll_loss, ntokens, lprobs, target = label_smoothed_nll_loss(
|
284 |
-
lprobs,
|
285 |
-
target,
|
286 |
-
self.eps,
|
287 |
-
update_num,
|
288 |
-
reduce=reduce,
|
289 |
-
drop_worst_ratio=self.drop_worst_ratio,
|
290 |
-
drop_worst_after=self.drop_worst_after,
|
291 |
-
use_rdrop=self.use_rdrop,
|
292 |
-
reg_alpha=self.reg_alpha,
|
293 |
-
constraint_masks=constraint_masks,
|
294 |
-
constraint_start=self.constraint_start,
|
295 |
-
constraint_end=self.constraint_end
|
296 |
-
)
|
297 |
-
# for encouraging loss
|
298 |
-
probs = torch.exp(lprobs)
|
299 |
-
bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=1e-5)) # likelihood bonus
|
300 |
-
log_end = self.log_end
|
301 |
-
if log_end != 1.0: # e.g. 0.9
|
302 |
-
y_log_end = torch.log(torch.ones_like(probs) - log_end)
|
303 |
-
bonus_after_log_end = 1 / (log_end - torch.ones_like(probs)) * (probs - log_end) + y_log_end
|
304 |
-
# x:log_end, y torch.log(torch.clamp((torch.ones_like(probs) - probs), min=self.cl_eps))
|
305 |
-
bonus = torch.where(probs > log_end, bonus_after_log_end, bonus)
|
306 |
-
c_loss = F.nll_loss(
|
307 |
-
-bonus,
|
308 |
-
target.view(-1),
|
309 |
-
reduction='sum',
|
310 |
-
)
|
311 |
-
smoothing_c_loss = bonus.sum(dim=-1)
|
312 |
-
smoothing_c_loss = smoothing_c_loss.sum()
|
313 |
-
c_loss = c_loss * (1 - self.eps) + (self.eps / lprobs.size(-1)) * smoothing_c_loss
|
314 |
-
loss = loss + c_loss
|
315 |
-
# end for encouraging loss
|
316 |
-
return loss, nll_loss, ntokens
|
317 |
-
|
318 |
-
def compute_accuracy(self, model, net_output, sample):
|
319 |
-
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
320 |
-
mask = target.ne(self.padding_idx)
|
321 |
-
n_correct = torch.sum(
|
322 |
-
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
323 |
-
)
|
324 |
-
total = torch.sum(mask)
|
325 |
-
return n_correct, total
|
326 |
-
|
327 |
-
@classmethod
|
328 |
-
def reduce_metrics(cls, logging_outputs) -> None:
|
329 |
-
"""Aggregate logging outputs from data parallel training."""
|
330 |
-
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
331 |
-
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
332 |
-
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
333 |
-
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
334 |
-
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
335 |
-
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
336 |
-
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
337 |
-
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
338 |
-
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
339 |
-
|
340 |
-
metrics.log_scalar(
|
341 |
-
"loss", loss_sum / sample_size, sample_size, round=3
|
342 |
-
)
|
343 |
-
metrics.log_scalar(
|
344 |
-
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
345 |
-
)
|
346 |
-
metrics.log_scalar(
|
347 |
-
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
348 |
-
)
|
349 |
-
metrics.log_scalar(
|
350 |
-
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
351 |
-
)
|
352 |
-
metrics.log_derived(
|
353 |
-
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
354 |
-
)
|
355 |
-
|
356 |
-
metrics.log_scalar(
|
357 |
-
"ntokens", ntokens, 1, round=3
|
358 |
-
)
|
359 |
-
metrics.log_scalar(
|
360 |
-
"nsentences", nsentences, 1, round=3
|
361 |
-
)
|
362 |
-
metrics.log_scalar(
|
363 |
-
"sample_size", sample_size, 1, round=3
|
364 |
-
)
|
365 |
-
metrics.log_scalar(
|
366 |
-
"sample_size_v1", sample_size_v1, 1, round=3
|
367 |
-
)
|
368 |
-
metrics.log_scalar(
|
369 |
-
"sample_size_v2", sample_size_v2, 1, round=3
|
370 |
-
)
|
371 |
-
|
372 |
-
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
373 |
-
if total > 0:
|
374 |
-
metrics.log_scalar("total", total)
|
375 |
-
n_correct = utils.item(
|
376 |
-
sum(log.get("n_correct", 0) for log in logging_outputs)
|
377 |
-
)
|
378 |
-
metrics.log_scalar("n_correct", n_correct)
|
379 |
-
metrics.log_derived(
|
380 |
-
"accuracy",
|
381 |
-
lambda meters: round(
|
382 |
-
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
383 |
-
)
|
384 |
-
if meters["total"].sum > 0
|
385 |
-
else float("nan"),
|
386 |
-
)
|
387 |
-
|
388 |
-
@staticmethod
|
389 |
-
def logging_outputs_can_be_summed() -> bool:
|
390 |
-
"""
|
391 |
-
Whether the logging outputs returned by `forward` can be summed
|
392 |
-
across workers prior to calling `reduce_metrics`. Setting this
|
393 |
-
to True will improves distributed training speed.
|
394 |
-
"""
|
395 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterions/scst_loss.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import math
|
7 |
-
import string
|
8 |
-
from dataclasses import dataclass, field
|
9 |
-
from collections import OrderedDict
|
10 |
-
from typing import Optional
|
11 |
-
|
12 |
-
import torch
|
13 |
-
from fairseq import metrics, utils
|
14 |
-
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
-
from fairseq.dataclass import FairseqDataclass
|
16 |
-
from omegaconf import II
|
17 |
-
|
18 |
-
from data import data_utils
|
19 |
-
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
|
20 |
-
|
21 |
-
|
22 |
-
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
|
23 |
-
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
24 |
-
if ignore_index is not None:
|
25 |
-
pad_mask = target.eq(ignore_index)
|
26 |
-
loss.masked_fill_(pad_mask, 0.0)
|
27 |
-
ntokens = (~pad_mask).sum()
|
28 |
-
else:
|
29 |
-
loss = loss.squeeze(-1)
|
30 |
-
ntokens = target.numel()
|
31 |
-
if reduce:
|
32 |
-
loss = loss.sum()
|
33 |
-
return loss, ntokens
|
34 |
-
|
35 |
-
|
36 |
-
@dataclass
|
37 |
-
class ScstRewardCriterionConfig(FairseqDataclass):
|
38 |
-
scst_cider_cached_tokens: str = field(
|
39 |
-
default="coco-train-words.p",
|
40 |
-
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
|
41 |
-
)
|
42 |
-
ignore_prefix_size: int = field(
|
43 |
-
default=0,
|
44 |
-
metadata={"help": "Ignore first N tokens"},
|
45 |
-
)
|
46 |
-
sentence_avg: bool = II("optimization.sentence_avg")
|
47 |
-
constraint_range: Optional[str] = field(
|
48 |
-
default=None,
|
49 |
-
metadata={"help": "constraint range"}
|
50 |
-
)
|
51 |
-
|
52 |
-
|
53 |
-
@register_criterion(
|
54 |
-
"scst_reward_criterion", dataclass=ScstRewardCriterionConfig
|
55 |
-
)
|
56 |
-
class ScstRewardCriterion(FairseqCriterion):
|
57 |
-
CIDER_REWARD_WEIGHT = 1
|
58 |
-
|
59 |
-
def __init__(
|
60 |
-
self,
|
61 |
-
task,
|
62 |
-
scst_cider_cached_tokens,
|
63 |
-
sentence_avg,
|
64 |
-
ignore_prefix_size=0,
|
65 |
-
constraint_range=None
|
66 |
-
):
|
67 |
-
super().__init__(task)
|
68 |
-
self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
|
69 |
-
self.sentence_avg = sentence_avg
|
70 |
-
self.ignore_prefix_size = ignore_prefix_size
|
71 |
-
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
72 |
-
|
73 |
-
self.constraint_start = None
|
74 |
-
self.constraint_end = None
|
75 |
-
if constraint_range is not None:
|
76 |
-
constraint_start, constraint_end = constraint_range.split(',')
|
77 |
-
self.constraint_start = int(constraint_start)
|
78 |
-
self.constraint_end = int(constraint_end)
|
79 |
-
|
80 |
-
def forward(self, model, sample, update_num=0, reduce=True):
|
81 |
-
"""Compute the loss for the given sample.
|
82 |
-
|
83 |
-
Returns a tuple with three elements:
|
84 |
-
1) the loss
|
85 |
-
2) the sample size, which is used as the denominator for the gradient
|
86 |
-
3) logging outputs to display while training
|
87 |
-
"""
|
88 |
-
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
89 |
-
|
90 |
-
sample_size = (
|
91 |
-
nsentences if self.sentence_avg else ntokens
|
92 |
-
)
|
93 |
-
logging_output = {
|
94 |
-
"loss": loss.data,
|
95 |
-
"score": score,
|
96 |
-
"ntokens": ntokens,
|
97 |
-
"nsentences": nsentences,
|
98 |
-
"sample_size": sample_size,
|
99 |
-
}
|
100 |
-
return loss, sample_size, logging_output
|
101 |
-
|
102 |
-
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
|
103 |
-
'''
|
104 |
-
gen_res: generated captions, list of str
|
105 |
-
gt_idx: list of int, of the same length as gen_res
|
106 |
-
gt_res: ground truth captions, list of list of str.
|
107 |
-
gen_res[i] corresponds to gt_res[gt_idx[i]]
|
108 |
-
Each image can have multiple ground truth captions
|
109 |
-
'''
|
110 |
-
gen_res_size = len(gen_res)
|
111 |
-
|
112 |
-
res = OrderedDict()
|
113 |
-
for i in range(gen_res_size):
|
114 |
-
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
|
115 |
-
|
116 |
-
gts = OrderedDict()
|
117 |
-
gt_res_ = [
|
118 |
-
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
|
119 |
-
for i in range(len(gt_res))
|
120 |
-
]
|
121 |
-
for i in range(gen_res_size):
|
122 |
-
gts[i] = gt_res_[gt_idx[i]]
|
123 |
-
|
124 |
-
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
125 |
-
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
|
126 |
-
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
|
127 |
-
return scores
|
128 |
-
|
129 |
-
@classmethod
|
130 |
-
def _wrap_sentence(self, s):
|
131 |
-
# ensure the sentence ends with <eos> token
|
132 |
-
# in order to keep consisitent with cider_cached_tokens
|
133 |
-
r = s.strip()
|
134 |
-
if r.endswith('.'):
|
135 |
-
r = r[:-1]
|
136 |
-
r += ' <eos>'
|
137 |
-
return r
|
138 |
-
|
139 |
-
def get_generator_out(self, model, sample):
|
140 |
-
def decode(toks):
|
141 |
-
hypo = toks.int().cpu()
|
142 |
-
hypo_str = self.task.tgt_dict.string(hypo)
|
143 |
-
hypo_str = self.task.bpe.decode(hypo_str).strip()
|
144 |
-
return hypo, hypo_str
|
145 |
-
|
146 |
-
model.eval()
|
147 |
-
with torch.no_grad():
|
148 |
-
self.task.scst_generator.model.eval()
|
149 |
-
gen_out = self.task.scst_generator.generate([model], sample)
|
150 |
-
|
151 |
-
gen_target = []
|
152 |
-
gen_res = []
|
153 |
-
gt_res = []
|
154 |
-
for i in range(len(gen_out)):
|
155 |
-
for j in range(len(gen_out[i])):
|
156 |
-
hypo, hypo_str = decode(gen_out[i][j]["tokens"])
|
157 |
-
gen_target.append(hypo)
|
158 |
-
gen_res.append(hypo_str)
|
159 |
-
gt_res.append(
|
160 |
-
decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
|
161 |
-
)
|
162 |
-
|
163 |
-
return gen_target, gen_res, gt_res
|
164 |
-
|
165 |
-
def get_reward_and_scores(self, gen_res, gt_res, device):
|
166 |
-
batch_size = len(gt_res)
|
167 |
-
gen_res_size = len(gen_res)
|
168 |
-
seq_per_img = gen_res_size // batch_size
|
169 |
-
|
170 |
-
gt_idx = [i // seq_per_img for i in range(gen_res_size)]
|
171 |
-
scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
|
172 |
-
sc_ = scores.reshape(batch_size, seq_per_img)
|
173 |
-
baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
|
174 |
-
# sample - baseline
|
175 |
-
reward = scores.reshape(batch_size, seq_per_img)
|
176 |
-
reward = reward - baseline
|
177 |
-
reward = reward.reshape(gen_res_size)
|
178 |
-
reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
|
179 |
-
|
180 |
-
return reward, scores
|
181 |
-
|
182 |
-
def get_net_output(self, model, sample, gen_target):
|
183 |
-
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
184 |
-
return data_utils.collate_tokens(
|
185 |
-
sample_list,
|
186 |
-
pad_idx=self.padding_idx,
|
187 |
-
eos_idx=eos,
|
188 |
-
left_pad=False,
|
189 |
-
move_eos_to_beginning=move_eos_to_beginning,
|
190 |
-
)
|
191 |
-
|
192 |
-
batch_size = len(sample["target"])
|
193 |
-
gen_target_size = len(gen_target)
|
194 |
-
seq_per_img = gen_target_size // batch_size
|
195 |
-
|
196 |
-
model.train()
|
197 |
-
sample_src_tokens = torch.repeat_interleave(
|
198 |
-
sample['net_input']['src_tokens'], seq_per_img, dim=0
|
199 |
-
)
|
200 |
-
sample_src_lengths = torch.repeat_interleave(
|
201 |
-
sample['net_input']['src_lengths'], seq_per_img, dim=0
|
202 |
-
)
|
203 |
-
sample_patch_images = torch.repeat_interleave(
|
204 |
-
sample['net_input']['patch_images'], seq_per_img, dim=0
|
205 |
-
)
|
206 |
-
sample_patch_masks = torch.repeat_interleave(
|
207 |
-
sample['net_input']['patch_masks'], seq_per_img, dim=0
|
208 |
-
)
|
209 |
-
gen_prev_output_tokens = torch.as_tensor(
|
210 |
-
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
211 |
-
device=sample["target"].device, dtype=torch.int64
|
212 |
-
)
|
213 |
-
gen_target_tokens = torch.as_tensor(
|
214 |
-
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
215 |
-
)
|
216 |
-
net_output = model(
|
217 |
-
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
218 |
-
patch_images=sample_patch_images, patch_masks=sample_patch_masks,
|
219 |
-
prev_output_tokens=gen_prev_output_tokens
|
220 |
-
)
|
221 |
-
|
222 |
-
return net_output, gen_target_tokens
|
223 |
-
|
224 |
-
def get_lprobs_and_target(self, model, net_output, gen_target):
|
225 |
-
if self.constraint_start is not None and self.constraint_end is not None:
|
226 |
-
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
227 |
-
net_output[0][:, :, self.constraint_end:] = -math.inf
|
228 |
-
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
229 |
-
if self.ignore_prefix_size > 0:
|
230 |
-
if getattr(lprobs, "batch_first", False):
|
231 |
-
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
232 |
-
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
233 |
-
else:
|
234 |
-
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
235 |
-
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
236 |
-
return lprobs, gen_target
|
237 |
-
|
238 |
-
def compute_loss(self, model, sample, reduce=True):
|
239 |
-
gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
|
240 |
-
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
|
241 |
-
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
242 |
-
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
243 |
-
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
244 |
-
nsentences = gen_target_tokens.size(0)
|
245 |
-
|
246 |
-
return loss, scores.sum(), ntokens, nsentences
|
247 |
-
|
248 |
-
@classmethod
|
249 |
-
def reduce_metrics(cls, logging_outputs) -> None:
|
250 |
-
"""Aggregate logging outputs from data parallel training."""
|
251 |
-
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
252 |
-
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
253 |
-
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
254 |
-
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
255 |
-
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
256 |
-
|
257 |
-
metrics.log_scalar(
|
258 |
-
"loss", loss_sum / sample_size, sample_size, round=3
|
259 |
-
)
|
260 |
-
metrics.log_scalar(
|
261 |
-
"score", score_sum / nsentences, nsentences, round=3
|
262 |
-
)
|
263 |
-
|
264 |
-
metrics.log_scalar(
|
265 |
-
"ntokens", ntokens, 1, round=3
|
266 |
-
)
|
267 |
-
metrics.log_scalar(
|
268 |
-
"nsentences", nsentences, 1, round=3
|
269 |
-
)
|
270 |
-
metrics.log_scalar(
|
271 |
-
"sample_size", sample_size, 1, round=3
|
272 |
-
)
|
273 |
-
|
274 |
-
@staticmethod
|
275 |
-
def logging_outputs_can_be_summed() -> bool:
|
276 |
-
"""
|
277 |
-
Whether the logging outputs returned by `forward` can be summed
|
278 |
-
across workers prior to calling `reduce_metrics`. Setting this
|
279 |
-
to True will improves distributed training speed.
|
280 |
-
"""
|
281 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/__init__.py
DELETED
File without changes
|
data/cv_data/image_classify_dataset.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
import functools
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import base64
|
15 |
-
from torchvision import transforms
|
16 |
-
from timm.data import create_transform
|
17 |
-
from utils.vision_helper import RandomAugment
|
18 |
-
|
19 |
-
from PIL import Image, ImageFile
|
20 |
-
|
21 |
-
from data import data_utils
|
22 |
-
from data.ofa_dataset import OFADataset
|
23 |
-
|
24 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
25 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
26 |
-
Image.MAX_IMAGE_PIXELS = None
|
27 |
-
|
28 |
-
logger = logging.getLogger(__name__)
|
29 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
30 |
-
|
31 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
32 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
33 |
-
|
34 |
-
def collate(samples, pad_idx, eos_idx):
|
35 |
-
if len(samples) == 0:
|
36 |
-
return {}
|
37 |
-
|
38 |
-
def merge(key):
|
39 |
-
return data_utils.collate_tokens(
|
40 |
-
[s[key] for s in samples],
|
41 |
-
pad_idx,
|
42 |
-
eos_idx=eos_idx,
|
43 |
-
)
|
44 |
-
|
45 |
-
id = np.array([s["id"] for s in samples])
|
46 |
-
src_tokens = merge("source")
|
47 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
48 |
-
|
49 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
50 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
51 |
-
|
52 |
-
conf = None
|
53 |
-
if samples[0].get("conf", None) is not None:
|
54 |
-
conf = torch.cat([s['conf'] for s in samples], dim=0)
|
55 |
-
|
56 |
-
ref_dict = None
|
57 |
-
if samples[0].get("ref_dict", None) is not None:
|
58 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
59 |
-
|
60 |
-
constraint_masks = None
|
61 |
-
if samples[0].get("constraint_mask", None) is not None:
|
62 |
-
constraint_masks = merge("constraint_mask")
|
63 |
-
|
64 |
-
prev_output_tokens = None
|
65 |
-
target = None
|
66 |
-
if samples[0].get("target", None) is not None:
|
67 |
-
target = merge("target")
|
68 |
-
tgt_lengths = torch.LongTensor(
|
69 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
70 |
-
)
|
71 |
-
ntokens = tgt_lengths.sum().item()
|
72 |
-
|
73 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
74 |
-
prev_output_tokens = merge("prev_output_tokens")
|
75 |
-
else:
|
76 |
-
ntokens = src_lengths.sum().item()
|
77 |
-
|
78 |
-
batch = {
|
79 |
-
"id": id,
|
80 |
-
"nsentences": len(samples),
|
81 |
-
"ntokens": ntokens,
|
82 |
-
"net_input": {
|
83 |
-
"src_tokens": src_tokens,
|
84 |
-
"src_lengths": src_lengths,
|
85 |
-
"patch_images": patch_images,
|
86 |
-
"patch_masks": patch_masks,
|
87 |
-
"prev_output_tokens": prev_output_tokens
|
88 |
-
},
|
89 |
-
"conf": conf,
|
90 |
-
"ref_dict": ref_dict,
|
91 |
-
"constraint_masks": constraint_masks,
|
92 |
-
"target": target,
|
93 |
-
}
|
94 |
-
|
95 |
-
return batch
|
96 |
-
|
97 |
-
|
98 |
-
class ImageClassifyDataset(OFADataset):
|
99 |
-
def __init__(
|
100 |
-
self,
|
101 |
-
split,
|
102 |
-
dataset,
|
103 |
-
bpe,
|
104 |
-
src_dict,
|
105 |
-
tgt_dict=None,
|
106 |
-
max_src_length=128,
|
107 |
-
max_tgt_length=30,
|
108 |
-
patch_image_size=224,
|
109 |
-
constraint_trie=None,
|
110 |
-
imagenet_default_mean_and_std=False
|
111 |
-
):
|
112 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
113 |
-
self.max_src_length = max_src_length
|
114 |
-
self.max_tgt_length = max_tgt_length
|
115 |
-
self.patch_image_size = patch_image_size
|
116 |
-
|
117 |
-
self.constraint_trie = constraint_trie
|
118 |
-
|
119 |
-
if imagenet_default_mean_and_std:
|
120 |
-
mean = IMAGENET_DEFAULT_MEAN
|
121 |
-
std = IMAGENET_DEFAULT_STD
|
122 |
-
else:
|
123 |
-
mean = [0.5, 0.5, 0.5]
|
124 |
-
std = [0.5, 0.5, 0.5]
|
125 |
-
|
126 |
-
if self.split != 'train':
|
127 |
-
self.patch_resize_transform = transforms.Compose([
|
128 |
-
lambda image: image.convert("RGB"),
|
129 |
-
transforms.Resize([patch_image_size, patch_image_size], interpolation=Image.BICUBIC),
|
130 |
-
transforms.ToTensor(),
|
131 |
-
transforms.Normalize(mean=mean, std=std),
|
132 |
-
])
|
133 |
-
logger.info("val split, do not use random augmentation.")
|
134 |
-
else:
|
135 |
-
self.patch_resize_transform = create_transform(
|
136 |
-
input_size=patch_image_size,
|
137 |
-
is_training=True,
|
138 |
-
color_jitter=0.4,
|
139 |
-
auto_augment='rand-m9-mstd0.5-inc1',
|
140 |
-
interpolation='bicubic',
|
141 |
-
re_prob=0.25,
|
142 |
-
re_mode='pixel',
|
143 |
-
re_count=1,
|
144 |
-
mean=mean,
|
145 |
-
std=std,
|
146 |
-
)
|
147 |
-
self.patch_resize_transform = transforms.Compose(functools.reduce(lambda x, y:x + y, [
|
148 |
-
[lambda image: image.convert("RGB"),],
|
149 |
-
self.patch_resize_transform.transforms[:2],
|
150 |
-
[self.patch_resize_transform.transforms[2]],
|
151 |
-
[RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), ],
|
152 |
-
self.patch_resize_transform.transforms[3:],
|
153 |
-
]))
|
154 |
-
logger.info("train split, use random augmentation.")
|
155 |
-
|
156 |
-
def __getitem__(self, index):
|
157 |
-
image, label_name = self.dataset[index]
|
158 |
-
|
159 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
160 |
-
patch_image = self.patch_resize_transform(image)
|
161 |
-
patch_mask = torch.tensor([True])
|
162 |
-
|
163 |
-
src_item = self.encode_text(' what does the image describe?')
|
164 |
-
tgt_item = self.encode_text(" {}".format(label_name))
|
165 |
-
ref_dict = {label_name: 1.0}
|
166 |
-
|
167 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
168 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
169 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
170 |
-
|
171 |
-
example = {
|
172 |
-
"id": index,
|
173 |
-
"source": src_item,
|
174 |
-
"patch_image": patch_image,
|
175 |
-
"patch_mask": patch_mask,
|
176 |
-
"target": target_item,
|
177 |
-
"prev_output_tokens": prev_output_item,
|
178 |
-
"ref_dict": ref_dict,
|
179 |
-
}
|
180 |
-
if self.constraint_trie is not None:
|
181 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
182 |
-
for i in range(len(prev_output_item)):
|
183 |
-
constraint_prefix_token = prev_output_item[:i+1].tolist()
|
184 |
-
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
|
185 |
-
constraint_mask[i][constraint_nodes] = True
|
186 |
-
example["constraint_mask"] = constraint_mask
|
187 |
-
return example
|
188 |
-
|
189 |
-
def collater(self, samples, pad_to_length=None):
|
190 |
-
"""Merge a list of samples to form a mini-batch.
|
191 |
-
Args:
|
192 |
-
samples (List[dict]): samples to collate
|
193 |
-
Returns:
|
194 |
-
dict: a mini-batch containing the data of the task
|
195 |
-
"""
|
196 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/data_utils.py
DELETED
@@ -1,601 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
try:
|
7 |
-
from collections.abc import Iterable
|
8 |
-
except ImportError:
|
9 |
-
from collections import Iterable
|
10 |
-
import contextlib
|
11 |
-
import itertools
|
12 |
-
import logging
|
13 |
-
import re
|
14 |
-
import warnings
|
15 |
-
from typing import Optional, Tuple
|
16 |
-
|
17 |
-
import numpy as np
|
18 |
-
import torch
|
19 |
-
|
20 |
-
from fairseq.file_io import PathManager
|
21 |
-
from fairseq import utils
|
22 |
-
import os
|
23 |
-
|
24 |
-
logger = logging.getLogger(__name__)
|
25 |
-
|
26 |
-
|
27 |
-
def infer_language_pair(path):
|
28 |
-
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
29 |
-
src, dst = None, None
|
30 |
-
for filename in PathManager.ls(path):
|
31 |
-
parts = filename.split(".")
|
32 |
-
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
33 |
-
return parts[1].split("-")
|
34 |
-
return src, dst
|
35 |
-
|
36 |
-
|
37 |
-
def collate_tokens(
|
38 |
-
values,
|
39 |
-
pad_idx,
|
40 |
-
eos_idx=None,
|
41 |
-
left_pad=False,
|
42 |
-
move_eos_to_beginning=False,
|
43 |
-
pad_to_length=None,
|
44 |
-
pad_to_multiple=1,
|
45 |
-
pad_to_bsz=None,
|
46 |
-
):
|
47 |
-
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
48 |
-
size = max(v.size(0) for v in values)
|
49 |
-
size = size if pad_to_length is None else max(size, pad_to_length)
|
50 |
-
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
51 |
-
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
52 |
-
|
53 |
-
def copy_tensor(src, dst):
|
54 |
-
assert dst.numel() == src.numel()
|
55 |
-
if move_eos_to_beginning:
|
56 |
-
if eos_idx is None:
|
57 |
-
# if no eos_idx is specified, then use the last token in src
|
58 |
-
dst[0] = src[-1]
|
59 |
-
else:
|
60 |
-
dst[0] = eos_idx
|
61 |
-
dst[1:] = src[:-1]
|
62 |
-
else:
|
63 |
-
dst.copy_(src)
|
64 |
-
|
65 |
-
if values[0].dim() == 1:
|
66 |
-
res = values[0].new(len(values), size).fill_(pad_idx)
|
67 |
-
elif values[0].dim() == 2:
|
68 |
-
assert move_eos_to_beginning is False
|
69 |
-
res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
|
70 |
-
else:
|
71 |
-
raise NotImplementedError
|
72 |
-
|
73 |
-
for i, v in enumerate(values):
|
74 |
-
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
75 |
-
return res
|
76 |
-
|
77 |
-
|
78 |
-
def load_indexed_dataset(
|
79 |
-
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
80 |
-
):
|
81 |
-
"""A helper function for loading indexed datasets.
|
82 |
-
|
83 |
-
Args:
|
84 |
-
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
85 |
-
dictionary (~fairseq.data.Dictionary): data dictionary
|
86 |
-
dataset_impl (str, optional): which dataset implementation to use. If
|
87 |
-
not provided, it will be inferred automatically. For legacy indexed
|
88 |
-
data we use the 'cached' implementation by default.
|
89 |
-
combine (bool, optional): automatically load and combine multiple
|
90 |
-
datasets. For example, if *path* is 'data-bin/train', then we will
|
91 |
-
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
92 |
-
single ConcatDataset instance.
|
93 |
-
"""
|
94 |
-
import fairseq.data.indexed_dataset as indexed_dataset
|
95 |
-
from fairseq.data.concat_dataset import ConcatDataset
|
96 |
-
|
97 |
-
datasets = []
|
98 |
-
for k in itertools.count():
|
99 |
-
path_k = path + (str(k) if k > 0 else "")
|
100 |
-
try:
|
101 |
-
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
102 |
-
except Exception as e:
|
103 |
-
if "StorageException: [404] Path not found" in str(e):
|
104 |
-
logger.warning(f"path_k: {e} not found")
|
105 |
-
else:
|
106 |
-
raise e
|
107 |
-
|
108 |
-
dataset_impl_k = dataset_impl
|
109 |
-
if dataset_impl_k is None:
|
110 |
-
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
111 |
-
dataset = indexed_dataset.make_dataset(
|
112 |
-
path_k,
|
113 |
-
impl=dataset_impl_k or default,
|
114 |
-
fix_lua_indexing=True,
|
115 |
-
dictionary=dictionary,
|
116 |
-
)
|
117 |
-
if dataset is None:
|
118 |
-
break
|
119 |
-
logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
|
120 |
-
datasets.append(dataset)
|
121 |
-
if not combine:
|
122 |
-
break
|
123 |
-
if len(datasets) == 0:
|
124 |
-
return None
|
125 |
-
elif len(datasets) == 1:
|
126 |
-
return datasets[0]
|
127 |
-
else:
|
128 |
-
return ConcatDataset(datasets)
|
129 |
-
|
130 |
-
|
131 |
-
@contextlib.contextmanager
|
132 |
-
def numpy_seed(seed, *addl_seeds):
|
133 |
-
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
134 |
-
restores the state afterward"""
|
135 |
-
if seed is None:
|
136 |
-
yield
|
137 |
-
return
|
138 |
-
if len(addl_seeds) > 0:
|
139 |
-
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
140 |
-
state = np.random.get_state()
|
141 |
-
np.random.seed(seed)
|
142 |
-
try:
|
143 |
-
yield
|
144 |
-
finally:
|
145 |
-
np.random.set_state(state)
|
146 |
-
|
147 |
-
|
148 |
-
def collect_filtered(function, iterable, filtered):
|
149 |
-
"""
|
150 |
-
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
151 |
-
|
152 |
-
Args:
|
153 |
-
function (callable): function that returns ``False`` for elements that
|
154 |
-
should be filtered
|
155 |
-
iterable (iterable): iterable to filter
|
156 |
-
filtered (list): list to store filtered elements
|
157 |
-
"""
|
158 |
-
for el in iterable:
|
159 |
-
if function(el):
|
160 |
-
yield el
|
161 |
-
else:
|
162 |
-
filtered.append(el)
|
163 |
-
|
164 |
-
|
165 |
-
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
166 |
-
def compare_leq(a, b):
|
167 |
-
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
168 |
-
|
169 |
-
def check_size(idx):
|
170 |
-
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
171 |
-
return size_fn(idx) <= max_positions
|
172 |
-
elif isinstance(max_positions, dict):
|
173 |
-
idx_size = size_fn(idx)
|
174 |
-
assert isinstance(idx_size, dict)
|
175 |
-
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
176 |
-
return all(
|
177 |
-
all(
|
178 |
-
a is None or b is None or a <= b
|
179 |
-
for a, b in zip(idx_size[key], max_positions[key])
|
180 |
-
)
|
181 |
-
for key in intersect_keys
|
182 |
-
)
|
183 |
-
else:
|
184 |
-
# For MultiCorpusSampledDataset, will generalize it later
|
185 |
-
if not isinstance(size_fn(idx), Iterable):
|
186 |
-
return all(size_fn(idx) <= b for b in max_positions)
|
187 |
-
return all(
|
188 |
-
a is None or b is None or a <= b
|
189 |
-
for a, b in zip(size_fn(idx), max_positions)
|
190 |
-
)
|
191 |
-
|
192 |
-
ignored = []
|
193 |
-
itr = collect_filtered(check_size, indices, ignored)
|
194 |
-
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
195 |
-
return indices, ignored
|
196 |
-
|
197 |
-
|
198 |
-
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
199 |
-
"""
|
200 |
-
[deprecated] Filter indices based on their size.
|
201 |
-
Use `FairseqDataset::filter_indices_by_size` instead.
|
202 |
-
|
203 |
-
Args:
|
204 |
-
indices (List[int]): ordered list of dataset indices
|
205 |
-
dataset (FairseqDataset): fairseq dataset instance
|
206 |
-
max_positions (tuple): filter elements larger than this size.
|
207 |
-
Comparisons are done component-wise.
|
208 |
-
raise_exception (bool, optional): if ``True``, raise an exception if
|
209 |
-
any elements are filtered (default: False).
|
210 |
-
"""
|
211 |
-
warnings.warn(
|
212 |
-
"data_utils.filter_by_size is deprecated. "
|
213 |
-
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
214 |
-
stacklevel=2,
|
215 |
-
)
|
216 |
-
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
217 |
-
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
218 |
-
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
219 |
-
indices = indices[dataset.sizes[indices] <= max_positions]
|
220 |
-
elif (
|
221 |
-
hasattr(dataset, "sizes")
|
222 |
-
and isinstance(dataset.sizes, list)
|
223 |
-
and len(dataset.sizes) == 1
|
224 |
-
):
|
225 |
-
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
226 |
-
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
227 |
-
else:
|
228 |
-
indices, ignored = _filter_by_size_dynamic(
|
229 |
-
indices, dataset.size, max_positions
|
230 |
-
)
|
231 |
-
else:
|
232 |
-
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
233 |
-
|
234 |
-
if len(ignored) > 0 and raise_exception:
|
235 |
-
raise Exception(
|
236 |
-
(
|
237 |
-
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
238 |
-
"skip this example with --skip-invalid-size-inputs-valid-test"
|
239 |
-
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
240 |
-
)
|
241 |
-
if len(ignored) > 0:
|
242 |
-
logger.warning(
|
243 |
-
(
|
244 |
-
"{} samples have invalid sizes and will be skipped, "
|
245 |
-
"max_positions={}, first few sample ids={}"
|
246 |
-
).format(len(ignored), max_positions, ignored[:10])
|
247 |
-
)
|
248 |
-
return indices
|
249 |
-
|
250 |
-
|
251 |
-
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
252 |
-
"""Filter a list of sample indices. Remove those that are longer
|
253 |
-
than specified in max_sizes.
|
254 |
-
|
255 |
-
Args:
|
256 |
-
indices (np.array): original array of sample indices
|
257 |
-
max_sizes (int or list[int] or tuple[int]): max sample size,
|
258 |
-
can be defined separately for src and tgt (then list or tuple)
|
259 |
-
|
260 |
-
Returns:
|
261 |
-
np.array: filtered sample array
|
262 |
-
list: list of removed indices
|
263 |
-
"""
|
264 |
-
if max_sizes is None:
|
265 |
-
return indices, []
|
266 |
-
if type(max_sizes) in (int, float):
|
267 |
-
max_src_size, max_tgt_size = max_sizes, max_sizes
|
268 |
-
else:
|
269 |
-
max_src_size, max_tgt_size = max_sizes
|
270 |
-
if tgt_sizes is None:
|
271 |
-
ignored = indices[src_sizes[indices] > max_src_size]
|
272 |
-
else:
|
273 |
-
ignored = indices[
|
274 |
-
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
275 |
-
]
|
276 |
-
if len(ignored) > 0:
|
277 |
-
if tgt_sizes is None:
|
278 |
-
indices = indices[src_sizes[indices] <= max_src_size]
|
279 |
-
else:
|
280 |
-
indices = indices[
|
281 |
-
(src_sizes[indices] <= max_src_size)
|
282 |
-
& (tgt_sizes[indices] <= max_tgt_size)
|
283 |
-
]
|
284 |
-
return indices, ignored.tolist()
|
285 |
-
|
286 |
-
|
287 |
-
def batch_by_size(
|
288 |
-
indices,
|
289 |
-
num_tokens_fn,
|
290 |
-
num_tokens_vec=None,
|
291 |
-
max_tokens=None,
|
292 |
-
max_sentences=None,
|
293 |
-
required_batch_size_multiple=1,
|
294 |
-
fixed_shapes=None,
|
295 |
-
):
|
296 |
-
"""
|
297 |
-
Yield mini-batches of indices bucketed by size. Batches may contain
|
298 |
-
sequences of different lengths.
|
299 |
-
|
300 |
-
Args:
|
301 |
-
indices (List[int]): ordered list of dataset indices
|
302 |
-
num_tokens_fn (callable): function that returns the number of tokens at
|
303 |
-
a given index
|
304 |
-
num_tokens_vec (List[int], optional): precomputed vector of the number
|
305 |
-
of tokens for each index in indices (to enable faster batch generation)
|
306 |
-
max_tokens (int, optional): max number of tokens in each batch
|
307 |
-
(default: None).
|
308 |
-
max_sentences (int, optional): max number of sentences in each
|
309 |
-
batch (default: None).
|
310 |
-
required_batch_size_multiple (int, optional): require batch size to
|
311 |
-
be less than N or a multiple of N (default: 1).
|
312 |
-
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
313 |
-
only be created with the given shapes. *max_sentences* and
|
314 |
-
*required_batch_size_multiple* will be ignored (default: None).
|
315 |
-
"""
|
316 |
-
try:
|
317 |
-
from fairseq.data.data_utils_fast import (
|
318 |
-
batch_by_size_fn,
|
319 |
-
batch_by_size_vec,
|
320 |
-
batch_fixed_shapes_fast,
|
321 |
-
)
|
322 |
-
except ImportError:
|
323 |
-
raise ImportError(
|
324 |
-
"Please build Cython components with: "
|
325 |
-
"`python setup.py build_ext --inplace`"
|
326 |
-
)
|
327 |
-
except ValueError:
|
328 |
-
raise ValueError(
|
329 |
-
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
|
330 |
-
)
|
331 |
-
|
332 |
-
# added int() to avoid TypeError: an integer is required
|
333 |
-
max_tokens = (
|
334 |
-
int(max_tokens) if max_tokens is not None else -1
|
335 |
-
)
|
336 |
-
max_sentences = max_sentences if max_sentences is not None else -1
|
337 |
-
bsz_mult = required_batch_size_multiple
|
338 |
-
|
339 |
-
if not isinstance(indices, np.ndarray):
|
340 |
-
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
341 |
-
|
342 |
-
if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
|
343 |
-
num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
|
344 |
-
|
345 |
-
if fixed_shapes is None:
|
346 |
-
if num_tokens_vec is None:
|
347 |
-
return batch_by_size_fn(
|
348 |
-
indices,
|
349 |
-
num_tokens_fn,
|
350 |
-
max_tokens,
|
351 |
-
max_sentences,
|
352 |
-
bsz_mult,
|
353 |
-
)
|
354 |
-
else:
|
355 |
-
return batch_by_size_vec(
|
356 |
-
indices,
|
357 |
-
num_tokens_vec,
|
358 |
-
max_tokens,
|
359 |
-
max_sentences,
|
360 |
-
bsz_mult,
|
361 |
-
)
|
362 |
-
|
363 |
-
else:
|
364 |
-
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
365 |
-
sort_order = np.lexsort(
|
366 |
-
[
|
367 |
-
fixed_shapes[:, 1].argsort(), # length
|
368 |
-
fixed_shapes[:, 0].argsort(), # bsz
|
369 |
-
]
|
370 |
-
)
|
371 |
-
fixed_shapes_sorted = fixed_shapes[sort_order]
|
372 |
-
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
373 |
-
|
374 |
-
|
375 |
-
def post_process(sentence: str, symbol: str):
|
376 |
-
if symbol == "sentencepiece":
|
377 |
-
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
378 |
-
elif symbol == "wordpiece":
|
379 |
-
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
380 |
-
elif symbol == "letter":
|
381 |
-
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
382 |
-
elif symbol == "silence":
|
383 |
-
import re
|
384 |
-
sentence = sentence.replace("<SIL>", "")
|
385 |
-
sentence = re.sub(' +', ' ', sentence).strip()
|
386 |
-
elif symbol == "_EOW":
|
387 |
-
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
388 |
-
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
389 |
-
if symbol == "subword_nmt":
|
390 |
-
symbol = "@@ "
|
391 |
-
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
392 |
-
elif symbol == "none":
|
393 |
-
pass
|
394 |
-
elif symbol is not None:
|
395 |
-
raise NotImplementedError(f"Unknown post_process option: {symbol}")
|
396 |
-
return sentence
|
397 |
-
|
398 |
-
|
399 |
-
def compute_mask_indices(
|
400 |
-
shape: Tuple[int, int],
|
401 |
-
padding_mask: Optional[torch.Tensor],
|
402 |
-
mask_prob: float,
|
403 |
-
mask_length: int,
|
404 |
-
mask_type: str = "static",
|
405 |
-
mask_other: float = 0.0,
|
406 |
-
min_masks: int = 0,
|
407 |
-
no_overlap: bool = False,
|
408 |
-
min_space: int = 0,
|
409 |
-
) -> np.ndarray:
|
410 |
-
"""
|
411 |
-
Computes random mask spans for a given shape
|
412 |
-
|
413 |
-
Args:
|
414 |
-
shape: the the shape for which to compute masks.
|
415 |
-
should be of size 2 where first element is batch size and 2nd is timesteps
|
416 |
-
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
417 |
-
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
418 |
-
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
419 |
-
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
420 |
-
mask_type: how to compute mask lengths
|
421 |
-
static = fixed size
|
422 |
-
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
423 |
-
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
424 |
-
poisson = sample from possion distribution with lambda = mask length
|
425 |
-
min_masks: minimum number of masked spans
|
426 |
-
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
427 |
-
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
428 |
-
"""
|
429 |
-
|
430 |
-
bsz, all_sz = shape
|
431 |
-
mask = np.full((bsz, all_sz), False)
|
432 |
-
|
433 |
-
all_num_mask = int(
|
434 |
-
# add a random number for probabilistic rounding
|
435 |
-
mask_prob * all_sz / float(mask_length)
|
436 |
-
+ np.random.rand()
|
437 |
-
)
|
438 |
-
|
439 |
-
all_num_mask = max(min_masks, all_num_mask)
|
440 |
-
|
441 |
-
mask_idcs = []
|
442 |
-
for i in range(bsz):
|
443 |
-
if padding_mask is not None:
|
444 |
-
sz = all_sz - padding_mask[i].long().sum().item()
|
445 |
-
num_mask = int(
|
446 |
-
# add a random number for probabilistic rounding
|
447 |
-
mask_prob * sz / float(mask_length)
|
448 |
-
+ np.random.rand()
|
449 |
-
)
|
450 |
-
num_mask = max(min_masks, num_mask)
|
451 |
-
else:
|
452 |
-
sz = all_sz
|
453 |
-
num_mask = all_num_mask
|
454 |
-
|
455 |
-
if mask_type == "static":
|
456 |
-
lengths = np.full(num_mask, mask_length)
|
457 |
-
elif mask_type == "uniform":
|
458 |
-
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
459 |
-
elif mask_type == "normal":
|
460 |
-
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
461 |
-
lengths = [max(1, int(round(x))) for x in lengths]
|
462 |
-
elif mask_type == "poisson":
|
463 |
-
lengths = np.random.poisson(mask_length, size=num_mask)
|
464 |
-
lengths = [int(round(x)) for x in lengths]
|
465 |
-
else:
|
466 |
-
raise Exception("unknown mask selection " + mask_type)
|
467 |
-
|
468 |
-
if sum(lengths) == 0:
|
469 |
-
lengths[0] = min(mask_length, sz - 1)
|
470 |
-
|
471 |
-
if no_overlap:
|
472 |
-
mask_idc = []
|
473 |
-
|
474 |
-
def arrange(s, e, length, keep_length):
|
475 |
-
span_start = np.random.randint(s, e - length)
|
476 |
-
mask_idc.extend(span_start + i for i in range(length))
|
477 |
-
|
478 |
-
new_parts = []
|
479 |
-
if span_start - s - min_space >= keep_length:
|
480 |
-
new_parts.append((s, span_start - min_space + 1))
|
481 |
-
if e - span_start - keep_length - min_space > keep_length:
|
482 |
-
new_parts.append((span_start + length + min_space, e))
|
483 |
-
return new_parts
|
484 |
-
|
485 |
-
parts = [(0, sz)]
|
486 |
-
min_length = min(lengths)
|
487 |
-
for length in sorted(lengths, reverse=True):
|
488 |
-
lens = np.fromiter(
|
489 |
-
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
490 |
-
np.int,
|
491 |
-
)
|
492 |
-
l_sum = np.sum(lens)
|
493 |
-
if l_sum == 0:
|
494 |
-
break
|
495 |
-
probs = lens / np.sum(lens)
|
496 |
-
c = np.random.choice(len(parts), p=probs)
|
497 |
-
s, e = parts.pop(c)
|
498 |
-
parts.extend(arrange(s, e, length, min_length))
|
499 |
-
mask_idc = np.asarray(mask_idc)
|
500 |
-
else:
|
501 |
-
min_len = min(lengths)
|
502 |
-
if sz - min_len <= num_mask:
|
503 |
-
min_len = sz - num_mask - 1
|
504 |
-
|
505 |
-
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
506 |
-
|
507 |
-
mask_idc = np.asarray(
|
508 |
-
[
|
509 |
-
mask_idc[j] + offset
|
510 |
-
for j in range(len(mask_idc))
|
511 |
-
for offset in range(lengths[j])
|
512 |
-
]
|
513 |
-
)
|
514 |
-
|
515 |
-
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
516 |
-
|
517 |
-
min_len = min([len(m) for m in mask_idcs])
|
518 |
-
for i, mask_idc in enumerate(mask_idcs):
|
519 |
-
if len(mask_idc) > min_len:
|
520 |
-
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
521 |
-
mask[i, mask_idc] = True
|
522 |
-
|
523 |
-
return mask
|
524 |
-
|
525 |
-
|
526 |
-
def get_mem_usage():
|
527 |
-
try:
|
528 |
-
import psutil
|
529 |
-
|
530 |
-
mb = 1024 * 1024
|
531 |
-
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
532 |
-
except ImportError:
|
533 |
-
return "N/A"
|
534 |
-
|
535 |
-
|
536 |
-
# lens: torch.LongTensor
|
537 |
-
# returns: torch.BoolTensor
|
538 |
-
def lengths_to_padding_mask(lens):
|
539 |
-
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
540 |
-
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
541 |
-
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
542 |
-
return mask
|
543 |
-
|
544 |
-
|
545 |
-
# lens: torch.LongTensor
|
546 |
-
# returns: torch.BoolTensor
|
547 |
-
def lengths_to_mask(lens):
|
548 |
-
return ~lengths_to_padding_mask(lens)
|
549 |
-
|
550 |
-
|
551 |
-
def get_buckets(sizes, num_buckets):
|
552 |
-
buckets = np.unique(
|
553 |
-
np.percentile(
|
554 |
-
sizes,
|
555 |
-
np.linspace(0, 100, num_buckets + 1),
|
556 |
-
interpolation='lower',
|
557 |
-
)[1:]
|
558 |
-
)
|
559 |
-
return buckets
|
560 |
-
|
561 |
-
|
562 |
-
def get_bucketed_sizes(orig_sizes, buckets):
|
563 |
-
sizes = np.copy(orig_sizes)
|
564 |
-
assert np.min(sizes) >= 0
|
565 |
-
start_val = -1
|
566 |
-
for end_val in buckets:
|
567 |
-
mask = (sizes > start_val) & (sizes <= end_val)
|
568 |
-
sizes[mask] = end_val
|
569 |
-
start_val = end_val
|
570 |
-
return sizes
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
def _find_extra_valid_paths(dataset_path: str) -> set:
|
575 |
-
paths = utils.split_paths(dataset_path)
|
576 |
-
all_valid_paths = set()
|
577 |
-
for sub_dir in paths:
|
578 |
-
contents = PathManager.ls(sub_dir)
|
579 |
-
valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
|
580 |
-
all_valid_paths |= {os.path.basename(p) for p in valid_paths}
|
581 |
-
# Remove .bin, .idx etc
|
582 |
-
roots = {os.path.splitext(p)[0] for p in all_valid_paths}
|
583 |
-
return roots
|
584 |
-
|
585 |
-
|
586 |
-
def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
|
587 |
-
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
|
588 |
-
if (
|
589 |
-
train_cfg.dataset.ignore_unused_valid_subsets
|
590 |
-
or train_cfg.dataset.combine_valid_subsets
|
591 |
-
or train_cfg.dataset.disable_validation
|
592 |
-
or not hasattr(train_cfg.task, "data")
|
593 |
-
):
|
594 |
-
return
|
595 |
-
other_paths = _find_extra_valid_paths(train_cfg.task.data)
|
596 |
-
specified_subsets = train_cfg.dataset.valid_subset.split(",")
|
597 |
-
ignored_paths = [p for p in other_paths if p not in specified_subsets]
|
598 |
-
if ignored_paths:
|
599 |
-
advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
|
600 |
-
msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
|
601 |
-
raise ValueError(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/file_dataset.py
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import os
|
7 |
-
import torch
|
8 |
-
import pickle
|
9 |
-
|
10 |
-
|
11 |
-
class FileDataset:
|
12 |
-
def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
|
13 |
-
self.file_path = file_path
|
14 |
-
assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
|
15 |
-
|
16 |
-
self.separator = separator
|
17 |
-
if selected_col_ids is None:
|
18 |
-
# default to all fields
|
19 |
-
self.selected_col_ids = list(
|
20 |
-
range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
|
21 |
-
else:
|
22 |
-
self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
|
23 |
-
if dtypes is None:
|
24 |
-
# default to str
|
25 |
-
self.dtypes = [str for col_id in self.selected_col_ids]
|
26 |
-
else:
|
27 |
-
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
|
28 |
-
assert len(self.dtypes) == len(self.selected_col_ids)
|
29 |
-
|
30 |
-
self.data_cnt = 0
|
31 |
-
try:
|
32 |
-
self.slice_id = torch.distributed.get_rank()
|
33 |
-
self.slice_count = torch.distributed.get_world_size()
|
34 |
-
except Exception:
|
35 |
-
self.slice_id = 0
|
36 |
-
self.slice_count = 1
|
37 |
-
self.cached_index = cached_index
|
38 |
-
self._init_seek_index()
|
39 |
-
self._reader = self._get_reader()
|
40 |
-
print("file {} slice_id {} row count {} total row count {}".format(
|
41 |
-
self.file_path, self.slice_id, self.row_count, self.total_row_count)
|
42 |
-
)
|
43 |
-
|
44 |
-
def _init_seek_index(self):
|
45 |
-
if self.cached_index:
|
46 |
-
cache_path = "{}.index".format(self.file_path)
|
47 |
-
assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
|
48 |
-
self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
|
49 |
-
print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
|
50 |
-
self.file_path, self.slice_id))
|
51 |
-
else:
|
52 |
-
# make an iteration over the file to get row_count and line_idx-to-offset mapping
|
53 |
-
fp = open(self.file_path, "r")
|
54 |
-
print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
|
55 |
-
self.file_path, self.slice_id))
|
56 |
-
self.total_row_count = 0
|
57 |
-
offset = 0
|
58 |
-
self.lineid_to_offset = []
|
59 |
-
for line in fp:
|
60 |
-
self.lineid_to_offset.append(offset)
|
61 |
-
self.total_row_count += 1
|
62 |
-
offset += len(line.encode('utf-8'))
|
63 |
-
self._compute_start_pos_and_row_count()
|
64 |
-
print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
|
65 |
-
self.file_path, self.slice_id))
|
66 |
-
|
67 |
-
def _compute_start_pos_and_row_count(self):
|
68 |
-
self.row_count = self.total_row_count // self.slice_count
|
69 |
-
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
|
70 |
-
self.row_count += 1
|
71 |
-
self.start_pos = self.row_count * self.slice_id
|
72 |
-
else:
|
73 |
-
self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
|
74 |
-
|
75 |
-
def _get_reader(self):
|
76 |
-
fp = open(self.file_path, "r")
|
77 |
-
fp.seek(self.lineid_to_offset[self.start_pos])
|
78 |
-
return fp
|
79 |
-
|
80 |
-
def _seek(self, offset=0):
|
81 |
-
try:
|
82 |
-
print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
|
83 |
-
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
|
84 |
-
self.data_cnt = offset
|
85 |
-
except Exception:
|
86 |
-
print("slice_id {} seek offset {}".format(self.slice_id, offset))
|
87 |
-
self._reader.seek(self.lineid_to_offset[offset])
|
88 |
-
self.data_cnt = offset
|
89 |
-
|
90 |
-
def __del__(self):
|
91 |
-
self._reader.close()
|
92 |
-
|
93 |
-
def __len__(self):
|
94 |
-
return self.row_count
|
95 |
-
|
96 |
-
def get_total_row_count(self):
|
97 |
-
return self.total_row_count
|
98 |
-
|
99 |
-
def __getitem__(self, index):
|
100 |
-
if self.data_cnt == self.row_count:
|
101 |
-
print("reach the end of datafile, start a new reader")
|
102 |
-
self.data_cnt = 0
|
103 |
-
self._reader = self._get_reader()
|
104 |
-
column_l = self._reader.readline().rstrip("\n").split(self.separator)
|
105 |
-
self.data_cnt += 1
|
106 |
-
column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
|
107 |
-
return column_l
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/mm_data/__init__.py
DELETED
File without changes
|
data/mm_data/caption_dataset.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
import string
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import base64
|
15 |
-
from torchvision import transforms
|
16 |
-
|
17 |
-
from PIL import Image, ImageFile
|
18 |
-
|
19 |
-
from data import data_utils
|
20 |
-
from data.ofa_dataset import OFADataset
|
21 |
-
|
22 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
23 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
24 |
-
Image.MAX_IMAGE_PIXELS = None
|
25 |
-
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
28 |
-
|
29 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
30 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
31 |
-
|
32 |
-
|
33 |
-
def collate(samples, pad_idx, eos_idx):
|
34 |
-
if len(samples) == 0:
|
35 |
-
return {}
|
36 |
-
|
37 |
-
def merge(key):
|
38 |
-
return data_utils.collate_tokens(
|
39 |
-
[s[key] for s in samples],
|
40 |
-
pad_idx,
|
41 |
-
eos_idx=eos_idx,
|
42 |
-
)
|
43 |
-
|
44 |
-
id = np.array([s["id"] for s in samples])
|
45 |
-
src_tokens = merge("source")
|
46 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
47 |
-
|
48 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
49 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
50 |
-
|
51 |
-
prev_output_tokens = None
|
52 |
-
target = None
|
53 |
-
if samples[0].get("target", None) is not None:
|
54 |
-
target = merge("target")
|
55 |
-
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
56 |
-
ntokens = tgt_lengths.sum().item()
|
57 |
-
|
58 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
59 |
-
prev_output_tokens = merge("prev_output_tokens")
|
60 |
-
else:
|
61 |
-
ntokens = src_lengths.sum().item()
|
62 |
-
|
63 |
-
batch = {
|
64 |
-
"id": id,
|
65 |
-
"nsentences": len(samples),
|
66 |
-
"ntokens": ntokens,
|
67 |
-
"net_input": {
|
68 |
-
"src_tokens": src_tokens,
|
69 |
-
"src_lengths": src_lengths,
|
70 |
-
"patch_images": patch_images,
|
71 |
-
"patch_masks": patch_masks,
|
72 |
-
"prev_output_tokens": prev_output_tokens
|
73 |
-
},
|
74 |
-
"target": target,
|
75 |
-
}
|
76 |
-
|
77 |
-
return batch
|
78 |
-
|
79 |
-
|
80 |
-
class CaptionDataset(OFADataset):
|
81 |
-
def __init__(
|
82 |
-
self,
|
83 |
-
split,
|
84 |
-
dataset,
|
85 |
-
bpe,
|
86 |
-
src_dict,
|
87 |
-
tgt_dict=None,
|
88 |
-
max_src_length=128,
|
89 |
-
max_tgt_length=30,
|
90 |
-
patch_image_size=224,
|
91 |
-
imagenet_default_mean_and_std=False,
|
92 |
-
scst=False
|
93 |
-
):
|
94 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
95 |
-
self.max_src_length = max_src_length
|
96 |
-
self.max_tgt_length = max_tgt_length
|
97 |
-
self.patch_image_size = patch_image_size
|
98 |
-
self.scst = scst
|
99 |
-
|
100 |
-
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
101 |
-
|
102 |
-
if imagenet_default_mean_and_std:
|
103 |
-
mean = IMAGENET_DEFAULT_MEAN
|
104 |
-
std = IMAGENET_DEFAULT_STD
|
105 |
-
else:
|
106 |
-
mean = [0.5, 0.5, 0.5]
|
107 |
-
std = [0.5, 0.5, 0.5]
|
108 |
-
|
109 |
-
self.patch_resize_transform = transforms.Compose([
|
110 |
-
lambda image: image.convert("RGB"),
|
111 |
-
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
112 |
-
transforms.ToTensor(),
|
113 |
-
transforms.Normalize(mean=mean, std=std),
|
114 |
-
])
|
115 |
-
|
116 |
-
if type(bpe).__name__ == 'GPT2BPE':
|
117 |
-
self.prompt = " what does the image describe?"
|
118 |
-
elif type(bpe).__name__ == 'BertBPE':
|
119 |
-
self.prompt = "图片描述了什么内容?"
|
120 |
-
|
121 |
-
def __getitem__(self, index):
|
122 |
-
uniq_id, image, caption = self.dataset[index]
|
123 |
-
|
124 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
125 |
-
patch_image = self.patch_resize_transform(image)
|
126 |
-
patch_mask = torch.tensor([True])
|
127 |
-
|
128 |
-
if self.split == 'train' and not self.scst:
|
129 |
-
caption = caption.translate(self.transtab).strip()
|
130 |
-
caption_token_list = caption.strip().split()
|
131 |
-
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
|
132 |
-
else:
|
133 |
-
caption = ' '.join(caption.strip().split())
|
134 |
-
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
|
135 |
-
tgt_caption = '&&'.join(caption_list)
|
136 |
-
src_item = self.encode_text(self.prompt)
|
137 |
-
tgt_item = self.encode_text(" {}".format(tgt_caption))
|
138 |
-
|
139 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
140 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
141 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
142 |
-
|
143 |
-
example = {
|
144 |
-
"id": uniq_id,
|
145 |
-
"source": src_item,
|
146 |
-
"patch_image": patch_image,
|
147 |
-
"patch_mask": patch_mask,
|
148 |
-
"target": target_item,
|
149 |
-
"prev_output_tokens": prev_output_item
|
150 |
-
}
|
151 |
-
return example
|
152 |
-
|
153 |
-
def collater(self, samples, pad_to_length=None):
|
154 |
-
"""Merge a list of samples to form a mini-batch.
|
155 |
-
Args:
|
156 |
-
samples (List[dict]): samples to collate
|
157 |
-
Returns:
|
158 |
-
dict: a mini-batch containing the data of the task
|
159 |
-
"""
|
160 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/mm_data/ocr_dataset.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
import random
|
11 |
-
import functools
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import base64
|
15 |
-
from torchvision import transforms
|
16 |
-
from torchvision.transforms import InterpolationMode
|
17 |
-
from torchvision.transforms import functional as F
|
18 |
-
|
19 |
-
from PIL import Image, ImageFile
|
20 |
-
|
21 |
-
from zhconv import convert
|
22 |
-
import unicodedata
|
23 |
-
|
24 |
-
from data import data_utils
|
25 |
-
from data.ofa_dataset import OFADataset
|
26 |
-
|
27 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
28 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
29 |
-
Image.MAX_IMAGE_PIXELS = None
|
30 |
-
|
31 |
-
logger = logging.getLogger(__name__)
|
32 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
33 |
-
|
34 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
35 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
36 |
-
|
37 |
-
|
38 |
-
def collate(samples, pad_idx, eos_idx):
|
39 |
-
if len(samples) == 0:
|
40 |
-
return {}
|
41 |
-
|
42 |
-
def merge(key):
|
43 |
-
return data_utils.collate_tokens(
|
44 |
-
[s[key] for s in samples],
|
45 |
-
pad_idx,
|
46 |
-
eos_idx=eos_idx,
|
47 |
-
)
|
48 |
-
|
49 |
-
id = np.array([s["id"] for s in samples])
|
50 |
-
src_tokens = merge("source")
|
51 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
52 |
-
|
53 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
54 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
55 |
-
|
56 |
-
prev_output_tokens = None
|
57 |
-
target = None
|
58 |
-
if samples[0].get("target", None) is not None:
|
59 |
-
target = merge("target")
|
60 |
-
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
61 |
-
ntokens = tgt_lengths.sum().item()
|
62 |
-
|
63 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
64 |
-
prev_output_tokens = merge("prev_output_tokens")
|
65 |
-
else:
|
66 |
-
ntokens = src_lengths.sum().item()
|
67 |
-
|
68 |
-
batch = {
|
69 |
-
"id": id,
|
70 |
-
"nsentences": len(samples),
|
71 |
-
"ntokens": ntokens,
|
72 |
-
"net_input": {
|
73 |
-
"src_tokens": src_tokens,
|
74 |
-
"src_lengths": src_lengths,
|
75 |
-
"patch_images": patch_images,
|
76 |
-
"patch_masks": patch_masks,
|
77 |
-
"prev_output_tokens": prev_output_tokens
|
78 |
-
},
|
79 |
-
"target": target,
|
80 |
-
}
|
81 |
-
|
82 |
-
return batch
|
83 |
-
|
84 |
-
|
85 |
-
def ocr_resize(img, patch_image_size, is_document=False, split='train'):
|
86 |
-
img = img.convert("RGB")
|
87 |
-
width, height = img.size
|
88 |
-
|
89 |
-
if is_document:
|
90 |
-
new_height, new_width = 64, 1920
|
91 |
-
else:
|
92 |
-
if width >= height:
|
93 |
-
new_width = max(64, patch_image_size)
|
94 |
-
new_height = max(64, int(patch_image_size * (height / width)))
|
95 |
-
if split != 'train':
|
96 |
-
top = int((patch_image_size - new_height) // 2)
|
97 |
-
else:
|
98 |
-
top = random.randint(0, patch_image_size - new_height)
|
99 |
-
bottom = patch_image_size - new_height - top
|
100 |
-
left, right = 0, 0
|
101 |
-
else:
|
102 |
-
new_height = max(64, patch_image_size)
|
103 |
-
new_width = max(64, int(patch_image_size * (width / height)))
|
104 |
-
if split != 'train':
|
105 |
-
left = int((patch_image_size - new_width) // 2)
|
106 |
-
else:
|
107 |
-
left = random.randint(0, patch_image_size - new_width)
|
108 |
-
right = patch_image_size - new_width - left
|
109 |
-
top, bottom = 0, 0
|
110 |
-
|
111 |
-
img_new = F.resize(
|
112 |
-
img,
|
113 |
-
[new_height, new_width],
|
114 |
-
interpolation=InterpolationMode.BICUBIC,
|
115 |
-
)
|
116 |
-
|
117 |
-
if is_document:
|
118 |
-
img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1)
|
119 |
-
img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2))
|
120 |
-
new_width, new_height = img_new.size
|
121 |
-
top = random.randint(0, patch_image_size - new_height)
|
122 |
-
bottom = patch_image_size - new_height - top
|
123 |
-
left, right = 0, 0
|
124 |
-
|
125 |
-
img_new = F.pad(img_new, padding=[left, top, right, bottom], padding_mode="edge")
|
126 |
-
assert img_new.size == (patch_image_size, patch_image_size)
|
127 |
-
|
128 |
-
return img_new
|
129 |
-
|
130 |
-
|
131 |
-
class OcrDataset(OFADataset):
|
132 |
-
def __init__(
|
133 |
-
self,
|
134 |
-
split,
|
135 |
-
dataset,
|
136 |
-
bpe,
|
137 |
-
src_dict,
|
138 |
-
tgt_dict=None,
|
139 |
-
max_src_length=80,
|
140 |
-
max_tgt_length=30,
|
141 |
-
patch_image_size=224,
|
142 |
-
imagenet_default_mean_and_std=False,
|
143 |
-
is_document=False,
|
144 |
-
):
|
145 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
146 |
-
self.max_src_length = max_src_length
|
147 |
-
self.max_tgt_length = max_tgt_length
|
148 |
-
self.patch_image_size = patch_image_size
|
149 |
-
|
150 |
-
if imagenet_default_mean_and_std:
|
151 |
-
mean = IMAGENET_DEFAULT_MEAN
|
152 |
-
std = IMAGENET_DEFAULT_STD
|
153 |
-
else:
|
154 |
-
mean = [0.5, 0.5, 0.5]
|
155 |
-
std = [0.5, 0.5, 0.5]
|
156 |
-
|
157 |
-
self.patch_resize_transform = transforms.Compose(
|
158 |
-
[
|
159 |
-
lambda image: ocr_resize(
|
160 |
-
image, patch_image_size, is_document=is_document, split=split,
|
161 |
-
),
|
162 |
-
transforms.ToTensor(),
|
163 |
-
transforms.Normalize(mean=mean, std=std),
|
164 |
-
]
|
165 |
-
)
|
166 |
-
|
167 |
-
self.bpe = bpe
|
168 |
-
if type(bpe).__name__ == 'GPT2BPE':
|
169 |
-
self.prompt = " what are the texts on the image?"
|
170 |
-
elif type(bpe).__name__ == 'BertBPE':
|
171 |
-
self.prompt = "图片上的文字是什么?"
|
172 |
-
|
173 |
-
def __getitem__(self, index):
|
174 |
-
uniq_id, image, caption = self.dataset[index]
|
175 |
-
|
176 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
177 |
-
patch_image = self.patch_resize_transform(image)
|
178 |
-
patch_mask = torch.tensor([True])
|
179 |
-
|
180 |
-
caption = unicodedata.normalize("NFKC", convert(caption, "zh-hans"))
|
181 |
-
if type(self.bpe).__name__ == 'GPT2BPE':
|
182 |
-
caption_token_list = caption.lower().strip().split()
|
183 |
-
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
|
184 |
-
elif type(self.bpe).__name__ == 'BertBPE':
|
185 |
-
tgt_caption = caption[: self.max_tgt_length].lower()
|
186 |
-
src_item = self.encode_text(self.prompt)
|
187 |
-
tgt_item = self.encode_text(" {}".format(tgt_caption))
|
188 |
-
|
189 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
190 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
191 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
192 |
-
|
193 |
-
example = {
|
194 |
-
"id": uniq_id,
|
195 |
-
"source": src_item,
|
196 |
-
"patch_image": patch_image,
|
197 |
-
"patch_mask": patch_mask,
|
198 |
-
"target": target_item,
|
199 |
-
"prev_output_tokens": prev_output_item,
|
200 |
-
}
|
201 |
-
return example
|
202 |
-
|
203 |
-
def collater(self, samples, pad_to_length=None):
|
204 |
-
"""Merge a list of samples to form a mini-batch.
|
205 |
-
Args:
|
206 |
-
samples (List[dict]): samples to collate
|
207 |
-
Returns:
|
208 |
-
dict: a mini-batch containing the data required for the task
|
209 |
-
"""
|
210 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/mm_data/refcoco_dataset.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import base64
|
14 |
-
import utils.transforms as T
|
15 |
-
|
16 |
-
from PIL import Image, ImageFile
|
17 |
-
|
18 |
-
from data import data_utils
|
19 |
-
from data.ofa_dataset import OFADataset
|
20 |
-
|
21 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
22 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
23 |
-
Image.MAX_IMAGE_PIXELS = None
|
24 |
-
|
25 |
-
logger = logging.getLogger(__name__)
|
26 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
27 |
-
|
28 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
-
|
31 |
-
|
32 |
-
def collate(samples, pad_idx, eos_idx):
|
33 |
-
if len(samples) == 0:
|
34 |
-
return {}
|
35 |
-
|
36 |
-
def merge(key):
|
37 |
-
return data_utils.collate_tokens(
|
38 |
-
[s[key] for s in samples],
|
39 |
-
pad_idx,
|
40 |
-
eos_idx=eos_idx,
|
41 |
-
)
|
42 |
-
|
43 |
-
id = np.array([s["id"] for s in samples])
|
44 |
-
src_tokens = merge("source")
|
45 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
46 |
-
|
47 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
48 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
49 |
-
|
50 |
-
w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
|
51 |
-
h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
|
52 |
-
region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
|
53 |
-
|
54 |
-
prev_output_tokens = None
|
55 |
-
target = None
|
56 |
-
if samples[0].get("target", None) is not None:
|
57 |
-
target = merge("target")
|
58 |
-
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
59 |
-
ntokens = tgt_lengths.sum().item()
|
60 |
-
|
61 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
62 |
-
prev_output_tokens = merge("prev_output_tokens")
|
63 |
-
else:
|
64 |
-
ntokens = src_lengths.sum().item()
|
65 |
-
|
66 |
-
batch = {
|
67 |
-
"id": id,
|
68 |
-
"nsentences": len(samples),
|
69 |
-
"ntokens": ntokens,
|
70 |
-
"net_input": {
|
71 |
-
"src_tokens": src_tokens,
|
72 |
-
"src_lengths": src_lengths,
|
73 |
-
"patch_images": patch_images,
|
74 |
-
"patch_masks": patch_masks,
|
75 |
-
"prev_output_tokens": prev_output_tokens
|
76 |
-
},
|
77 |
-
"target": target,
|
78 |
-
"w_resize_ratios": w_resize_ratios,
|
79 |
-
"h_resize_ratios": h_resize_ratios,
|
80 |
-
"region_coords": region_coords
|
81 |
-
}
|
82 |
-
|
83 |
-
return batch
|
84 |
-
|
85 |
-
|
86 |
-
class RefcocoDataset(OFADataset):
|
87 |
-
def __init__(
|
88 |
-
self,
|
89 |
-
split,
|
90 |
-
dataset,
|
91 |
-
bpe,
|
92 |
-
src_dict,
|
93 |
-
tgt_dict=None,
|
94 |
-
max_src_length=80,
|
95 |
-
max_tgt_length=30,
|
96 |
-
patch_image_size=512,
|
97 |
-
imagenet_default_mean_and_std=False,
|
98 |
-
num_bins=1000,
|
99 |
-
max_image_size=512
|
100 |
-
):
|
101 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
102 |
-
self.max_src_length = max_src_length
|
103 |
-
self.max_tgt_length = max_tgt_length
|
104 |
-
self.patch_image_size = patch_image_size
|
105 |
-
self.num_bins = num_bins
|
106 |
-
|
107 |
-
if imagenet_default_mean_and_std:
|
108 |
-
mean = IMAGENET_DEFAULT_MEAN
|
109 |
-
std = IMAGENET_DEFAULT_STD
|
110 |
-
else:
|
111 |
-
mean = [0.5, 0.5, 0.5]
|
112 |
-
std = [0.5, 0.5, 0.5]
|
113 |
-
|
114 |
-
# for positioning
|
115 |
-
self.positioning_transform = T.Compose([
|
116 |
-
T.RandomResize([patch_image_size], max_size=patch_image_size),
|
117 |
-
T.ToTensor(),
|
118 |
-
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
|
119 |
-
])
|
120 |
-
|
121 |
-
if type(bpe).__name__ == 'GPT2BPE':
|
122 |
-
self.prompt = ' which region does the text " {} " describe?'
|
123 |
-
elif type(bpe).__name__ == 'BertBPE':
|
124 |
-
self.prompt = '这段文字" {} "描述的是哪个区域?'
|
125 |
-
|
126 |
-
def __getitem__(self, index):
|
127 |
-
uniq_id, base64_str, text, region_coord = self.dataset[index]
|
128 |
-
|
129 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
|
130 |
-
w, h = image.size
|
131 |
-
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
|
132 |
-
x0, y0, x1, y1 = region_coord.strip().split(',')
|
133 |
-
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
|
134 |
-
boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
|
135 |
-
boxes_target["labels"] = np.array([0])
|
136 |
-
boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
|
137 |
-
|
138 |
-
patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
|
139 |
-
resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
|
140 |
-
patch_mask = torch.tensor([True])
|
141 |
-
quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
|
142 |
-
quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
|
143 |
-
quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
|
144 |
-
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
|
145 |
-
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
|
146 |
-
src_caption = self.pre_caption(text, self.max_src_length)
|
147 |
-
src_item = self.encode_text(self.prompt.format(src_caption))
|
148 |
-
tgt_item = self.encode_text(region_coord, use_bpe=False)
|
149 |
-
|
150 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
151 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
152 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
153 |
-
|
154 |
-
example = {
|
155 |
-
"id": uniq_id,
|
156 |
-
"source": src_item,
|
157 |
-
"patch_image": patch_image,
|
158 |
-
"patch_mask": patch_mask,
|
159 |
-
"target": target_item,
|
160 |
-
"prev_output_tokens": prev_output_item,
|
161 |
-
"w_resize_ratio": resize_w / w,
|
162 |
-
"h_resize_ratio": resize_h / h,
|
163 |
-
"region_coord": region
|
164 |
-
}
|
165 |
-
return example
|
166 |
-
|
167 |
-
def collater(self, samples, pad_to_length=None):
|
168 |
-
"""Merge a list of samples to form a mini-batch.
|
169 |
-
Args:
|
170 |
-
samples (List[dict]): samples to collate
|
171 |
-
Returns:
|
172 |
-
dict: a mini-batch containing the data of the task
|
173 |
-
"""
|
174 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/mm_data/snli_ve_dataset.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import base64
|
14 |
-
from torchvision import transforms
|
15 |
-
|
16 |
-
from PIL import Image, ImageFile
|
17 |
-
|
18 |
-
from data import data_utils
|
19 |
-
from data.ofa_dataset import OFADataset
|
20 |
-
|
21 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
22 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
23 |
-
Image.MAX_IMAGE_PIXELS = None
|
24 |
-
|
25 |
-
logger = logging.getLogger(__name__)
|
26 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
27 |
-
|
28 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
-
|
31 |
-
|
32 |
-
def collate(samples, pad_idx, eos_idx):
|
33 |
-
if len(samples) == 0:
|
34 |
-
return {}
|
35 |
-
|
36 |
-
def merge(key):
|
37 |
-
return data_utils.collate_tokens(
|
38 |
-
[s[key] for s in samples],
|
39 |
-
pad_idx,
|
40 |
-
eos_idx=eos_idx,
|
41 |
-
)
|
42 |
-
|
43 |
-
id = np.array([s["id"] for s in samples])
|
44 |
-
src_tokens = merge("source")
|
45 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
46 |
-
|
47 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
48 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
49 |
-
|
50 |
-
ref_dict = None
|
51 |
-
if samples[0].get("ref_dict", None) is not None:
|
52 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
53 |
-
|
54 |
-
constraint_masks = None
|
55 |
-
if samples[0].get("constraint_mask", None) is not None:
|
56 |
-
constraint_masks = merge("constraint_mask")
|
57 |
-
|
58 |
-
decoder_prompts = None
|
59 |
-
if samples[0].get("decoder_prompt", None) is not None:
|
60 |
-
decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
|
61 |
-
|
62 |
-
prev_output_tokens = None
|
63 |
-
target = None
|
64 |
-
if samples[0].get("target", None) is not None:
|
65 |
-
target = merge("target")
|
66 |
-
tgt_lengths = torch.LongTensor(
|
67 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
68 |
-
)
|
69 |
-
ntokens = tgt_lengths.sum().item()
|
70 |
-
|
71 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
72 |
-
prev_output_tokens = merge("prev_output_tokens")
|
73 |
-
else:
|
74 |
-
ntokens = src_lengths.sum().item()
|
75 |
-
|
76 |
-
batch = {
|
77 |
-
"id": id,
|
78 |
-
"nsentences": len(samples),
|
79 |
-
"ntokens": ntokens,
|
80 |
-
"net_input": {
|
81 |
-
"src_tokens": src_tokens,
|
82 |
-
"src_lengths": src_lengths,
|
83 |
-
"patch_images": patch_images,
|
84 |
-
"patch_masks": patch_masks,
|
85 |
-
"prev_output_tokens": prev_output_tokens
|
86 |
-
},
|
87 |
-
"ref_dict": ref_dict,
|
88 |
-
"constraint_masks": constraint_masks,
|
89 |
-
"decoder_prompts": decoder_prompts,
|
90 |
-
"target": target
|
91 |
-
}
|
92 |
-
|
93 |
-
return batch
|
94 |
-
|
95 |
-
|
96 |
-
class SnliVeDataset(OFADataset):
|
97 |
-
def __init__(
|
98 |
-
self,
|
99 |
-
split,
|
100 |
-
dataset,
|
101 |
-
bpe,
|
102 |
-
src_dict,
|
103 |
-
tgt_dict=None,
|
104 |
-
max_src_length=80,
|
105 |
-
max_tgt_length=30,
|
106 |
-
patch_image_size=224,
|
107 |
-
add_caption=False,
|
108 |
-
constraint_trie=None,
|
109 |
-
imagenet_default_mean_and_std=False,
|
110 |
-
prompt_type="none"
|
111 |
-
):
|
112 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
113 |
-
self.max_src_length = max_src_length
|
114 |
-
self.max_tgt_length = max_tgt_length
|
115 |
-
self.patch_image_size = patch_image_size
|
116 |
-
|
117 |
-
self.add_caption = add_caption
|
118 |
-
self.constraint_trie = constraint_trie
|
119 |
-
self.prompt_type = prompt_type
|
120 |
-
|
121 |
-
if imagenet_default_mean_and_std:
|
122 |
-
mean = IMAGENET_DEFAULT_MEAN
|
123 |
-
std = IMAGENET_DEFAULT_STD
|
124 |
-
else:
|
125 |
-
mean = [0.5, 0.5, 0.5]
|
126 |
-
std = [0.5, 0.5, 0.5]
|
127 |
-
|
128 |
-
self.patch_resize_transform = transforms.Compose([
|
129 |
-
lambda image: image.convert("RGB"),
|
130 |
-
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
131 |
-
transforms.ToTensor(),
|
132 |
-
transforms.Normalize(mean=mean, std=std),
|
133 |
-
])
|
134 |
-
|
135 |
-
def __getitem__(self, index):
|
136 |
-
uniq_id, image, hypothesis, caption, label = self.dataset[index]
|
137 |
-
if label == 'contradiction':
|
138 |
-
label = 'no'
|
139 |
-
elif label == 'entailment':
|
140 |
-
label = 'yes'
|
141 |
-
elif label == 'neutral':
|
142 |
-
label = 'maybe'
|
143 |
-
else:
|
144 |
-
raise NotImplementedError
|
145 |
-
|
146 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
147 |
-
patch_image = self.patch_resize_transform(image)
|
148 |
-
patch_mask = torch.tensor([True])
|
149 |
-
|
150 |
-
hypothesis = self.pre_caption(hypothesis, self.max_src_length)
|
151 |
-
src_item = self.encode_text(' does the image describe " {} "?'.format(hypothesis))
|
152 |
-
tgt_item = self.encode_text(" {}".format(label))
|
153 |
-
ref_dict = {label: 1.0}
|
154 |
-
|
155 |
-
if self.add_caption:
|
156 |
-
caption = self.pre_caption(caption, self.max_src_length)
|
157 |
-
src_item = self.encode_text(' can image and text1 " {} " imply text2 " {} "?'.format(caption, hypothesis))
|
158 |
-
|
159 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
160 |
-
if self.prompt_type == 'none':
|
161 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
162 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
163 |
-
decoder_prompt = self.bos_item
|
164 |
-
elif self.prompt_type == 'src':
|
165 |
-
prev_output_item = torch.cat([src_item, tgt_item])
|
166 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
167 |
-
decoder_prompt = src_item
|
168 |
-
elif self.prompt_type == 'prev_output':
|
169 |
-
prev_output_item = torch.cat([src_item[:-1], tgt_item])
|
170 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
171 |
-
decoder_prompt = src_item[:-1]
|
172 |
-
else:
|
173 |
-
raise NotImplementedError
|
174 |
-
target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
|
175 |
-
|
176 |
-
example = {
|
177 |
-
"id": uniq_id,
|
178 |
-
"source": src_item,
|
179 |
-
"patch_image": patch_image,
|
180 |
-
"patch_mask": patch_mask,
|
181 |
-
"target": target_item,
|
182 |
-
"prev_output_tokens": prev_output_item,
|
183 |
-
"decoder_prompt": decoder_prompt,
|
184 |
-
"ref_dict": ref_dict,
|
185 |
-
}
|
186 |
-
if self.constraint_trie is not None:
|
187 |
-
constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
|
188 |
-
start_idx = len(target_item) - len(tgt_item) - 1
|
189 |
-
for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
|
190 |
-
constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
|
191 |
-
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
|
192 |
-
constraint_mask[i][constraint_nodes] = True
|
193 |
-
example["constraint_mask"] = constraint_mask
|
194 |
-
return example
|
195 |
-
|
196 |
-
def collater(self, samples, pad_to_length=None):
|
197 |
-
"""Merge a list of samples to form a mini-batch.
|
198 |
-
Args:
|
199 |
-
samples (List[dict]): samples to collate
|
200 |
-
Returns:
|
201 |
-
dict: a mini-batch containing the data of the task
|
202 |
-
"""
|
203 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/mm_data/vqa_gen_dataset.py
DELETED
@@ -1,218 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import warnings
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import base64
|
14 |
-
from torchvision import transforms
|
15 |
-
|
16 |
-
from PIL import Image, ImageFile
|
17 |
-
|
18 |
-
from data import data_utils
|
19 |
-
from data.ofa_dataset import OFADataset
|
20 |
-
|
21 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
22 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
23 |
-
Image.MAX_IMAGE_PIXELS = None
|
24 |
-
|
25 |
-
logger = logging.getLogger(__name__)
|
26 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
27 |
-
|
28 |
-
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
29 |
-
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
30 |
-
|
31 |
-
|
32 |
-
def collate(samples, pad_idx, eos_idx):
|
33 |
-
if len(samples) == 0:
|
34 |
-
return {}
|
35 |
-
|
36 |
-
def merge(key):
|
37 |
-
return data_utils.collate_tokens(
|
38 |
-
[s[key] for s in samples],
|
39 |
-
pad_idx,
|
40 |
-
eos_idx=eos_idx,
|
41 |
-
)
|
42 |
-
|
43 |
-
id = np.array([s["id"] for s in samples])
|
44 |
-
src_tokens = merge("source")
|
45 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
46 |
-
|
47 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
48 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
49 |
-
|
50 |
-
conf = None
|
51 |
-
if samples[0].get("conf", None) is not None:
|
52 |
-
conf = torch.cat([s['conf'] for s in samples], dim=0)
|
53 |
-
|
54 |
-
ref_dict = None
|
55 |
-
if samples[0].get("ref_dict", None) is not None:
|
56 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
57 |
-
|
58 |
-
constraint_masks = None
|
59 |
-
if samples[0].get("constraint_mask", None) is not None:
|
60 |
-
constraint_masks = merge("constraint_mask")
|
61 |
-
|
62 |
-
decoder_prompts = None
|
63 |
-
if samples[0].get("decoder_prompt", None) is not None:
|
64 |
-
decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
|
65 |
-
|
66 |
-
prefix_tokens = None
|
67 |
-
if samples[0].get("decoder_prompt", None) is not None:
|
68 |
-
prefix_tokens = merge("decoder_prompt")
|
69 |
-
prefix_tokens = prefix_tokens[:, 1:]
|
70 |
-
|
71 |
-
prev_output_tokens = None
|
72 |
-
target = None
|
73 |
-
if samples[0].get("target", None) is not None:
|
74 |
-
target = merge("target")
|
75 |
-
tgt_lengths = torch.LongTensor(
|
76 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
77 |
-
)
|
78 |
-
ntokens = tgt_lengths.sum().item()
|
79 |
-
|
80 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
81 |
-
prev_output_tokens = merge("prev_output_tokens")
|
82 |
-
else:
|
83 |
-
ntokens = src_lengths.sum().item()
|
84 |
-
|
85 |
-
batch = {
|
86 |
-
"id": id,
|
87 |
-
"nsentences": len(samples),
|
88 |
-
"ntokens": ntokens,
|
89 |
-
"net_input": {
|
90 |
-
"src_tokens": src_tokens,
|
91 |
-
"src_lengths": src_lengths,
|
92 |
-
"patch_images": patch_images,
|
93 |
-
"patch_masks": patch_masks,
|
94 |
-
"prev_output_tokens": prev_output_tokens
|
95 |
-
},
|
96 |
-
"conf": conf,
|
97 |
-
"ref_dict": ref_dict,
|
98 |
-
"constraint_masks": constraint_masks,
|
99 |
-
"decoder_prompts": decoder_prompts,
|
100 |
-
"target": target,
|
101 |
-
"prefix_tokens": prefix_tokens
|
102 |
-
}
|
103 |
-
|
104 |
-
return batch
|
105 |
-
|
106 |
-
|
107 |
-
class VqaGenDataset(OFADataset):
|
108 |
-
def __init__(
|
109 |
-
self,
|
110 |
-
split,
|
111 |
-
dataset,
|
112 |
-
bpe,
|
113 |
-
src_dict,
|
114 |
-
tgt_dict=None,
|
115 |
-
max_src_length=128,
|
116 |
-
max_object_length=30,
|
117 |
-
max_tgt_length=30,
|
118 |
-
patch_image_size=224,
|
119 |
-
add_object=False,
|
120 |
-
constraint_trie=None,
|
121 |
-
imagenet_default_mean_and_std=False,
|
122 |
-
prompt_type="none"
|
123 |
-
):
|
124 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
125 |
-
self.max_src_length = max_src_length
|
126 |
-
self.max_object_length = max_object_length
|
127 |
-
self.max_tgt_length = max_tgt_length
|
128 |
-
self.patch_image_size = patch_image_size
|
129 |
-
|
130 |
-
self.add_object = add_object
|
131 |
-
self.constraint_trie = constraint_trie
|
132 |
-
self.prompt_type = prompt_type
|
133 |
-
|
134 |
-
if imagenet_default_mean_and_std:
|
135 |
-
mean = IMAGENET_DEFAULT_MEAN
|
136 |
-
std = IMAGENET_DEFAULT_STD
|
137 |
-
else:
|
138 |
-
mean = [0.5, 0.5, 0.5]
|
139 |
-
std = [0.5, 0.5, 0.5]
|
140 |
-
|
141 |
-
self.patch_resize_transform = transforms.Compose([
|
142 |
-
lambda image: image.convert("RGB"),
|
143 |
-
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
144 |
-
transforms.ToTensor(),
|
145 |
-
transforms.Normalize(mean=mean, std=std),
|
146 |
-
])
|
147 |
-
|
148 |
-
def __getitem__(self, index):
|
149 |
-
item = self.dataset[index]
|
150 |
-
if len(item) == 5:
|
151 |
-
uniq_id, image, question, ref, predict_objects = item
|
152 |
-
else:
|
153 |
-
uniq_id, image, question, ref, predict_objects, caption = item
|
154 |
-
|
155 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
|
156 |
-
patch_image = self.patch_resize_transform(image)
|
157 |
-
patch_mask = torch.tensor([True])
|
158 |
-
|
159 |
-
question = self.pre_question(question, self.max_src_length)
|
160 |
-
question = question + '?' if not question.endswith('?') else question
|
161 |
-
src_item = self.encode_text(' {}'.format(question))
|
162 |
-
|
163 |
-
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
|
164 |
-
answer = max(ref_dict, key=ref_dict.get)
|
165 |
-
conf = torch.tensor([ref_dict[answer]])
|
166 |
-
tgt_item = self.encode_text(" {}".format(answer))
|
167 |
-
|
168 |
-
if self.add_object and predict_objects is not None:
|
169 |
-
predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
|
170 |
-
predict_object_item = self.encode_text(" object: {}".format(predict_object_seq))
|
171 |
-
src_item = torch.cat([src_item, predict_object_item])
|
172 |
-
|
173 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
174 |
-
if self.prompt_type == 'none':
|
175 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
176 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
177 |
-
decoder_prompt = self.bos_item
|
178 |
-
elif self.prompt_type == 'src':
|
179 |
-
prev_output_item = torch.cat([src_item, tgt_item])
|
180 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
181 |
-
decoder_prompt = src_item
|
182 |
-
elif self.prompt_type == 'prev_output':
|
183 |
-
prev_output_item = torch.cat([src_item[:-1], tgt_item])
|
184 |
-
target_item = torch.cat([prev_output_item[1:], self.eos_item])
|
185 |
-
decoder_prompt = src_item[:-1]
|
186 |
-
else:
|
187 |
-
raise NotImplementedError
|
188 |
-
target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
|
189 |
-
|
190 |
-
example = {
|
191 |
-
"id": uniq_id,
|
192 |
-
"source": src_item,
|
193 |
-
"patch_image": patch_image,
|
194 |
-
"patch_mask": patch_mask,
|
195 |
-
"target": target_item,
|
196 |
-
"prev_output_tokens": prev_output_item,
|
197 |
-
"decoder_prompt": decoder_prompt,
|
198 |
-
"ref_dict": ref_dict,
|
199 |
-
"conf": conf,
|
200 |
-
}
|
201 |
-
if self.constraint_trie is not None:
|
202 |
-
constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
|
203 |
-
start_idx = len(target_item) - len(tgt_item) - 1
|
204 |
-
for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
|
205 |
-
constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
|
206 |
-
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
|
207 |
-
constraint_mask[i][constraint_nodes] = True
|
208 |
-
example["constraint_mask"] = constraint_mask
|
209 |
-
return example
|
210 |
-
|
211 |
-
def collater(self, samples, pad_to_length=None):
|
212 |
-
"""Merge a list of samples to form a mini-batch.
|
213 |
-
Args:
|
214 |
-
samples (List[dict]): samples to collate
|
215 |
-
Returns:
|
216 |
-
dict: a mini-batch containing the data of the task
|
217 |
-
"""
|
218 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlg_data/summary_dataset.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
prev_output_tokens = None
|
33 |
-
target = None
|
34 |
-
if samples[0].get("target", None) is not None:
|
35 |
-
target = merge("target")
|
36 |
-
tgt_lengths = torch.LongTensor(
|
37 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
38 |
-
)
|
39 |
-
ntokens = tgt_lengths.sum().item()
|
40 |
-
|
41 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
42 |
-
prev_output_tokens = merge("prev_output_tokens")
|
43 |
-
else:
|
44 |
-
ntokens = src_lengths.sum().item()
|
45 |
-
|
46 |
-
target_strs = np.array([s["target_str"] for s in samples])
|
47 |
-
|
48 |
-
batch = {
|
49 |
-
"nsentences": len(samples),
|
50 |
-
"ntokens": ntokens,
|
51 |
-
"net_input": {
|
52 |
-
"src_tokens": src_tokens,
|
53 |
-
"src_lengths": src_lengths,
|
54 |
-
"prev_output_tokens": prev_output_tokens
|
55 |
-
},
|
56 |
-
"target": target,
|
57 |
-
"target_strs": target_strs
|
58 |
-
}
|
59 |
-
|
60 |
-
return batch
|
61 |
-
|
62 |
-
|
63 |
-
class SummaryDataset(OFADataset):
|
64 |
-
def __init__(
|
65 |
-
self,
|
66 |
-
split,
|
67 |
-
dataset,
|
68 |
-
bpe,
|
69 |
-
src_dict,
|
70 |
-
tgt_dict=None,
|
71 |
-
code_dict_size=8192,
|
72 |
-
num_bins=1000,
|
73 |
-
max_src_length=512,
|
74 |
-
max_tgt_length=128,
|
75 |
-
noise_ratio=0.0
|
76 |
-
):
|
77 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
78 |
-
self.max_src_length = max_src_length
|
79 |
-
self.max_tgt_length = max_tgt_length
|
80 |
-
self.code_dict_size = code_dict_size
|
81 |
-
self.num_bins = num_bins
|
82 |
-
self.noise_ratio = noise_ratio
|
83 |
-
|
84 |
-
if type(bpe).__name__ == 'GPT2BPE':
|
85 |
-
self.prompt = ' what is the summary of article " {} "?'
|
86 |
-
elif type(bpe).__name__ == 'BertBPE':
|
87 |
-
self.prompt = "{} 请用一个句子简单总结上文:"
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
source, target = self.dataset[index]
|
91 |
-
target_str = target.lower()
|
92 |
-
|
93 |
-
source = self.pre_caption(source, max_words=self.max_src_length)
|
94 |
-
target = self.pre_caption(target, max_words=self.max_tgt_length)
|
95 |
-
source = source.replace('<unk>', 'unk')
|
96 |
-
target = target.replace('<unk>', 'unk')
|
97 |
-
|
98 |
-
src_item = self.encode_text(
|
99 |
-
self.prompt.format(source),
|
100 |
-
length=self.max_src_length
|
101 |
-
)
|
102 |
-
tgt_item = self.encode_text('{}'.format(target))
|
103 |
-
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio)
|
104 |
-
|
105 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
106 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
107 |
-
prev_output_item = torch.cat([self.bos_item, noise_tgt_item])
|
108 |
-
|
109 |
-
example = {
|
110 |
-
"source": src_item,
|
111 |
-
"target": target_item,
|
112 |
-
"prev_output_tokens": prev_output_item,
|
113 |
-
"target_str": target_str
|
114 |
-
}
|
115 |
-
return example
|
116 |
-
|
117 |
-
def add_noise_to_tgt(self, target, p):
|
118 |
-
noise_indices = torch.FloatTensor(target.size(0)).uniform_() < p
|
119 |
-
target[noise_indices] = torch.randint(
|
120 |
-
4, len(self.src_dict) - self.code_dict_size - self.num_bins, size=(noise_indices.sum(),)
|
121 |
-
)
|
122 |
-
return target
|
123 |
-
|
124 |
-
def collater(self, samples, pad_to_length=None):
|
125 |
-
"""Merge a list of samples to form a mini-batch.
|
126 |
-
Args:
|
127 |
-
samples (List[dict]): samples to collate
|
128 |
-
Returns:
|
129 |
-
dict: a mini-batch containing the data of the task
|
130 |
-
"""
|
131 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/cola_dataset.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class COLADataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
sentence, label = self.dataset[index]
|
91 |
-
if label == '0':
|
92 |
-
label = 'no'
|
93 |
-
elif label == '1':
|
94 |
-
label = 'yes'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
|
99 |
-
src_item = self.encode_text(' is the text " {} " grammatically correct?'.format(sentence))
|
100 |
-
tgt_item = self.encode_text(" {}".format(label))
|
101 |
-
assert tgt_item.size(0) == 1
|
102 |
-
ref_dict = {label: 1.0}
|
103 |
-
|
104 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
105 |
-
if self.prompt_type == 'none':
|
106 |
-
prev_output_item = self.bos_item
|
107 |
-
target_item = tgt_item
|
108 |
-
elif self.prompt_type == 'src':
|
109 |
-
prev_output_item = src_item.clone()
|
110 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
111 |
-
elif self.prompt_type == 'prev_output':
|
112 |
-
prev_output_item = src_item[:-1].clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
else:
|
115 |
-
raise NotImplementedError
|
116 |
-
target_item[:-1] = self.tgt_dict.pad()
|
117 |
-
|
118 |
-
example = {
|
119 |
-
"source": src_item,
|
120 |
-
"target": target_item,
|
121 |
-
"prev_output_tokens": prev_output_item,
|
122 |
-
"ref_dict": ref_dict,
|
123 |
-
}
|
124 |
-
if self.constraint_trie is not None:
|
125 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
126 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
127 |
-
constraint_mask[-1][constraint_nodes] = True
|
128 |
-
example["constraint_mask"] = constraint_mask
|
129 |
-
return example
|
130 |
-
|
131 |
-
def collater(self, samples, pad_to_length=None):
|
132 |
-
"""Merge a list of samples to form a mini-batch.
|
133 |
-
Args:
|
134 |
-
samples (List[dict]): samples to collate
|
135 |
-
Returns:
|
136 |
-
dict: a mini-batch containing the data of the task
|
137 |
-
"""
|
138 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/mnli_dataset.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class MNLIDataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
sentence1, sentence2, label = self.dataset[index]
|
91 |
-
if label == '0':
|
92 |
-
label = 'maybe'
|
93 |
-
elif label == '1':
|
94 |
-
label = 'yes'
|
95 |
-
elif label == '2':
|
96 |
-
label = 'no'
|
97 |
-
else:
|
98 |
-
raise NotImplementedError
|
99 |
-
|
100 |
-
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
|
101 |
-
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
|
102 |
-
src_item = self.encode_text(
|
103 |
-
' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2)
|
104 |
-
)
|
105 |
-
tgt_item = self.encode_text(" {}".format(label))
|
106 |
-
assert tgt_item.size(0) == 1
|
107 |
-
ref_dict = {label: 1.0}
|
108 |
-
|
109 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
110 |
-
if self.prompt_type == 'none':
|
111 |
-
prev_output_item = self.bos_item
|
112 |
-
target_item = tgt_item
|
113 |
-
elif self.prompt_type == 'src':
|
114 |
-
prev_output_item = src_item.clone()
|
115 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
116 |
-
elif self.prompt_type == 'prev_output':
|
117 |
-
prev_output_item = src_item[:-1].clone()
|
118 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
119 |
-
else:
|
120 |
-
raise NotImplementedError
|
121 |
-
target_item[:-1] = self.tgt_dict.pad()
|
122 |
-
|
123 |
-
example = {
|
124 |
-
"source": src_item,
|
125 |
-
"target": target_item,
|
126 |
-
"prev_output_tokens": prev_output_item,
|
127 |
-
"ref_dict": ref_dict,
|
128 |
-
}
|
129 |
-
if self.constraint_trie is not None:
|
130 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
131 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
132 |
-
constraint_mask[-1][constraint_nodes] = True
|
133 |
-
example["constraint_mask"] = constraint_mask
|
134 |
-
return example
|
135 |
-
|
136 |
-
def collater(self, samples, pad_to_length=None):
|
137 |
-
"""Merge a list of samples to form a mini-batch.
|
138 |
-
Args:
|
139 |
-
samples (List[dict]): samples to collate
|
140 |
-
Returns:
|
141 |
-
dict: a mini-batch containing the data of the task
|
142 |
-
"""
|
143 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/mrpc_dataset.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class MRPCDataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
sentence1, sentence2, label = self.dataset[index]
|
91 |
-
if label == '0':
|
92 |
-
label = 'no'
|
93 |
-
elif label == '1':
|
94 |
-
label = 'yes'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
|
99 |
-
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
|
100 |
-
src_item = self.encode_text(
|
101 |
-
' does text1 " {} " and text2 " {} " have the same semantics?'.format(sentence1, sentence2),
|
102 |
-
)
|
103 |
-
tgt_item = self.encode_text(" {}".format(label))
|
104 |
-
assert tgt_item.size(0) == 1
|
105 |
-
ref_dict = {label: 1.0}
|
106 |
-
|
107 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
108 |
-
if self.prompt_type == 'none':
|
109 |
-
prev_output_item = self.bos_item
|
110 |
-
target_item = tgt_item
|
111 |
-
elif self.prompt_type == 'src':
|
112 |
-
prev_output_item = src_item.clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
elif self.prompt_type == 'prev_output':
|
115 |
-
prev_output_item = src_item[:-1].clone()
|
116 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
117 |
-
else:
|
118 |
-
raise NotImplementedError
|
119 |
-
target_item[:-1] = self.tgt_dict.pad()
|
120 |
-
|
121 |
-
example = {
|
122 |
-
"source": src_item,
|
123 |
-
"target": target_item,
|
124 |
-
"prev_output_tokens": prev_output_item,
|
125 |
-
"ref_dict": ref_dict,
|
126 |
-
}
|
127 |
-
if self.constraint_trie is not None:
|
128 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
129 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
130 |
-
constraint_mask[-1][constraint_nodes] = True
|
131 |
-
example["constraint_mask"] = constraint_mask
|
132 |
-
return example
|
133 |
-
|
134 |
-
def collater(self, samples, pad_to_length=None):
|
135 |
-
"""Merge a list of samples to form a mini-batch.
|
136 |
-
Args:
|
137 |
-
samples (List[dict]): samples to collate
|
138 |
-
Returns:
|
139 |
-
dict: a mini-batch containing the data of the task
|
140 |
-
"""
|
141 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/qnli_dataset.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class QNLIDataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
question, sentence, label = self.dataset[index]
|
91 |
-
if label == '0' or label == 'not_entailment':
|
92 |
-
label = 'no'
|
93 |
-
elif label == '1' or label == 'entailment':
|
94 |
-
label = 'yes'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
question = ' '.join(question.lower().strip().split()[:self.max_src_length])
|
99 |
-
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
|
100 |
-
src_item = self.encode_text(
|
101 |
-
' does " {} " contain the answer to question " {} "?'.format(sentence, question)
|
102 |
-
)
|
103 |
-
tgt_item = self.encode_text(" {}".format(label))
|
104 |
-
assert tgt_item.size(0) == 1
|
105 |
-
ref_dict = {label: 1.0}
|
106 |
-
|
107 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
108 |
-
if self.prompt_type == 'none':
|
109 |
-
prev_output_item = self.bos_item
|
110 |
-
target_item = tgt_item
|
111 |
-
elif self.prompt_type == 'src':
|
112 |
-
prev_output_item = src_item.clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
elif self.prompt_type == 'prev_output':
|
115 |
-
prev_output_item = src_item[:-1].clone()
|
116 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
117 |
-
else:
|
118 |
-
raise NotImplementedError
|
119 |
-
target_item[:-1] = self.tgt_dict.pad()
|
120 |
-
|
121 |
-
example = {
|
122 |
-
"source": src_item,
|
123 |
-
"target": target_item,
|
124 |
-
"prev_output_tokens": prev_output_item,
|
125 |
-
"ref_dict": ref_dict,
|
126 |
-
}
|
127 |
-
if self.constraint_trie is not None:
|
128 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
129 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
130 |
-
constraint_mask[-1][constraint_nodes] = True
|
131 |
-
example["constraint_mask"] = constraint_mask
|
132 |
-
return example
|
133 |
-
|
134 |
-
def collater(self, samples, pad_to_length=None):
|
135 |
-
"""Merge a list of samples to form a mini-batch.
|
136 |
-
Args:
|
137 |
-
samples (List[dict]): samples to collate
|
138 |
-
Returns:
|
139 |
-
dict: a mini-batch containing the data of the task
|
140 |
-
"""
|
141 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/qqp_dataset.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class QQPDataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
question1, question2, label = self.dataset[index]
|
91 |
-
if label == '0':
|
92 |
-
label = 'no'
|
93 |
-
elif label == '1':
|
94 |
-
label = 'yes'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
question1 = ' '.join(question1.lower().strip().split()[:self.max_src_length])
|
99 |
-
question2 = ' '.join(question2.lower().strip().split()[:self.max_src_length])
|
100 |
-
src_item = self.encode_text(
|
101 |
-
' is question " {} " and question " {} " equivalent?'.format(question1, question2)
|
102 |
-
)
|
103 |
-
tgt_item = self.encode_text(" {}".format(label))
|
104 |
-
assert tgt_item.size(0) == 1
|
105 |
-
ref_dict = {label: 1.0}
|
106 |
-
|
107 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
108 |
-
if self.prompt_type == 'none':
|
109 |
-
prev_output_item = self.bos_item
|
110 |
-
target_item = tgt_item
|
111 |
-
elif self.prompt_type == 'src':
|
112 |
-
prev_output_item = src_item.clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
elif self.prompt_type == 'prev_output':
|
115 |
-
prev_output_item = src_item[:-1].clone()
|
116 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
117 |
-
else:
|
118 |
-
raise NotImplementedError
|
119 |
-
target_item[:-1] = self.tgt_dict.pad()
|
120 |
-
|
121 |
-
example = {
|
122 |
-
"source": src_item,
|
123 |
-
"target": target_item,
|
124 |
-
"prev_output_tokens": prev_output_item,
|
125 |
-
"ref_dict": ref_dict,
|
126 |
-
}
|
127 |
-
if self.constraint_trie is not None:
|
128 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
129 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
130 |
-
constraint_mask[-1][constraint_nodes] = True
|
131 |
-
example["constraint_mask"] = constraint_mask
|
132 |
-
return example
|
133 |
-
|
134 |
-
def collater(self, samples, pad_to_length=None):
|
135 |
-
"""Merge a list of samples to form a mini-batch.
|
136 |
-
Args:
|
137 |
-
samples (List[dict]): samples to collate
|
138 |
-
Returns:
|
139 |
-
dict: a mini-batch containing the data of the task
|
140 |
-
"""
|
141 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/rte_dataset.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class RTEDataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
sentence1, sentence2, label = self.dataset[index]
|
91 |
-
if label == 'not_entailment':
|
92 |
-
label = 'no'
|
93 |
-
elif label == 'entailment':
|
94 |
-
label = 'yes'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
|
99 |
-
sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
|
100 |
-
src_item = self.encode_text(
|
101 |
-
' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2),
|
102 |
-
)
|
103 |
-
tgt_item = self.encode_text(" {}".format(label))
|
104 |
-
assert tgt_item.size(0) == 1
|
105 |
-
ref_dict = {label: 1.0}
|
106 |
-
|
107 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
108 |
-
if self.prompt_type == 'none':
|
109 |
-
prev_output_item = self.bos_item
|
110 |
-
target_item = tgt_item
|
111 |
-
elif self.prompt_type == 'src':
|
112 |
-
prev_output_item = src_item.clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
elif self.prompt_type == 'prev_output':
|
115 |
-
prev_output_item = src_item[:-1].clone()
|
116 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
117 |
-
else:
|
118 |
-
raise NotImplementedError
|
119 |
-
target_item[:-1] = self.tgt_dict.pad()
|
120 |
-
|
121 |
-
example = {
|
122 |
-
"source": src_item,
|
123 |
-
"target": target_item,
|
124 |
-
"prev_output_tokens": prev_output_item,
|
125 |
-
"ref_dict": ref_dict,
|
126 |
-
}
|
127 |
-
if self.constraint_trie is not None:
|
128 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
129 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
130 |
-
constraint_mask[-1][constraint_nodes] = True
|
131 |
-
example["constraint_mask"] = constraint_mask
|
132 |
-
return example
|
133 |
-
|
134 |
-
def collater(self, samples, pad_to_length=None):
|
135 |
-
"""Merge a list of samples to form a mini-batch.
|
136 |
-
Args:
|
137 |
-
samples (List[dict]): samples to collate
|
138 |
-
Returns:
|
139 |
-
dict: a mini-batch containing the data of the task
|
140 |
-
"""
|
141 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/nlu_data/sst2_dataset.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import warnings
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from data import data_utils
|
12 |
-
from data.ofa_dataset import OFADataset
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
16 |
-
|
17 |
-
|
18 |
-
def collate(samples, pad_idx, eos_idx):
|
19 |
-
if len(samples) == 0:
|
20 |
-
return {}
|
21 |
-
|
22 |
-
def merge(key):
|
23 |
-
return data_utils.collate_tokens(
|
24 |
-
[s[key] for s in samples],
|
25 |
-
pad_idx,
|
26 |
-
eos_idx=eos_idx,
|
27 |
-
)
|
28 |
-
|
29 |
-
src_tokens = merge("source")
|
30 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
31 |
-
|
32 |
-
ref_dict = None
|
33 |
-
if samples[0].get("ref_dict", None) is not None:
|
34 |
-
ref_dict = np.array([s['ref_dict'] for s in samples])
|
35 |
-
|
36 |
-
constraint_masks = None
|
37 |
-
if samples[0].get("constraint_mask", None) is not None:
|
38 |
-
constraint_masks = merge("constraint_mask")
|
39 |
-
|
40 |
-
prev_output_tokens = None
|
41 |
-
target = None
|
42 |
-
if samples[0].get("target", None) is not None:
|
43 |
-
target = merge("target")
|
44 |
-
tgt_lengths = torch.LongTensor(
|
45 |
-
[s["target"].ne(pad_idx).long().sum() for s in samples]
|
46 |
-
)
|
47 |
-
ntokens = tgt_lengths.sum().item()
|
48 |
-
|
49 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
50 |
-
prev_output_tokens = merge("prev_output_tokens")
|
51 |
-
else:
|
52 |
-
ntokens = src_lengths.sum().item()
|
53 |
-
|
54 |
-
batch = {
|
55 |
-
"nsentences": len(samples),
|
56 |
-
"ntokens": ntokens,
|
57 |
-
"net_input": {
|
58 |
-
"src_tokens": src_tokens,
|
59 |
-
"src_lengths": src_lengths,
|
60 |
-
"prev_output_tokens": prev_output_tokens
|
61 |
-
},
|
62 |
-
"ref_dict": ref_dict,
|
63 |
-
"constraint_masks": constraint_masks,
|
64 |
-
"target": target,
|
65 |
-
}
|
66 |
-
|
67 |
-
return batch
|
68 |
-
|
69 |
-
|
70 |
-
class SST2Dataset(OFADataset):
|
71 |
-
def __init__(
|
72 |
-
self,
|
73 |
-
split,
|
74 |
-
dataset,
|
75 |
-
bpe,
|
76 |
-
src_dict,
|
77 |
-
tgt_dict=None,
|
78 |
-
max_src_length=512,
|
79 |
-
max_tgt_length=30,
|
80 |
-
constraint_trie=None,
|
81 |
-
prompt_type="none"
|
82 |
-
):
|
83 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
84 |
-
self.max_src_length = max_src_length
|
85 |
-
self.max_tgt_length = max_tgt_length
|
86 |
-
self.constraint_trie = constraint_trie
|
87 |
-
self.prompt_type = prompt_type
|
88 |
-
|
89 |
-
def __getitem__(self, index):
|
90 |
-
sentence, label = self.dataset[index]
|
91 |
-
if label == '0':
|
92 |
-
label = 'negative'
|
93 |
-
elif label == '1':
|
94 |
-
label = 'positive'
|
95 |
-
else:
|
96 |
-
raise NotImplementedError
|
97 |
-
|
98 |
-
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
|
99 |
-
src_item = self.encode_text(' is the sentiment of text " {} " positive or negative?'.format(sentence))
|
100 |
-
tgt_item = self.encode_text(" {}".format(label))
|
101 |
-
assert tgt_item.size(0) == 1
|
102 |
-
ref_dict = {label: 1.0}
|
103 |
-
|
104 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
105 |
-
if self.prompt_type == 'none':
|
106 |
-
prev_output_item = self.bos_item
|
107 |
-
target_item = tgt_item
|
108 |
-
elif self.prompt_type == 'src':
|
109 |
-
prev_output_item = src_item.clone()
|
110 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
111 |
-
elif self.prompt_type == 'prev_output':
|
112 |
-
prev_output_item = src_item[:-1].clone()
|
113 |
-
target_item = torch.cat([prev_output_item[1:], tgt_item])
|
114 |
-
else:
|
115 |
-
raise NotImplementedError
|
116 |
-
target_item[:-1] = self.tgt_dict.pad()
|
117 |
-
|
118 |
-
example = {
|
119 |
-
"source": src_item,
|
120 |
-
"target": target_item,
|
121 |
-
"prev_output_tokens": prev_output_item,
|
122 |
-
"ref_dict": ref_dict,
|
123 |
-
}
|
124 |
-
if self.constraint_trie is not None:
|
125 |
-
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
|
126 |
-
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
|
127 |
-
constraint_mask[-1][constraint_nodes] = True
|
128 |
-
example["constraint_mask"] = constraint_mask
|
129 |
-
return example
|
130 |
-
|
131 |
-
def collater(self, samples, pad_to_length=None):
|
132 |
-
"""Merge a list of samples to form a mini-batch.
|
133 |
-
Args:
|
134 |
-
samples (List[dict]): samples to collate
|
135 |
-
Returns:
|
136 |
-
dict: a mini-batch containing the data of the task
|
137 |
-
"""
|
138 |
-
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/ofa_dataset.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import re
|
8 |
-
import torch.utils.data
|
9 |
-
from fairseq.data import FairseqDataset
|
10 |
-
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
|
14 |
-
class OFADataset(FairseqDataset):
|
15 |
-
def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
|
16 |
-
self.split = split
|
17 |
-
self.dataset = dataset
|
18 |
-
self.bpe = bpe
|
19 |
-
self.src_dict = src_dict
|
20 |
-
self.tgt_dict = tgt_dict
|
21 |
-
|
22 |
-
self.bos = src_dict.bos()
|
23 |
-
self.eos = src_dict.eos()
|
24 |
-
self.pad = src_dict.pad()
|
25 |
-
self.bos_item = torch.LongTensor([self.bos])
|
26 |
-
self.eos_item = torch.LongTensor([self.eos])
|
27 |
-
|
28 |
-
def __len__(self):
|
29 |
-
return len(self.dataset)
|
30 |
-
|
31 |
-
def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
|
32 |
-
s = self.tgt_dict.encode_line(
|
33 |
-
line=self.bpe.encode(text) if use_bpe else text,
|
34 |
-
add_if_not_exist=False,
|
35 |
-
append_eos=False
|
36 |
-
).long()
|
37 |
-
if length is not None:
|
38 |
-
s = s[:length]
|
39 |
-
if append_bos:
|
40 |
-
s = torch.cat([self.bos_item, s])
|
41 |
-
if append_eos:
|
42 |
-
s = torch.cat([s, self.eos_item])
|
43 |
-
return s
|
44 |
-
|
45 |
-
def pre_question(self, question, max_ques_words=None):
|
46 |
-
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
|
47 |
-
|
48 |
-
question = re.sub(
|
49 |
-
r"\s{2,}",
|
50 |
-
' ',
|
51 |
-
question,
|
52 |
-
)
|
53 |
-
question = question.rstrip('\n')
|
54 |
-
question = question.strip(' ')
|
55 |
-
|
56 |
-
# truncate question
|
57 |
-
question_words = question.split(' ')
|
58 |
-
if max_ques_words is not None and len(question_words) > max_ques_words:
|
59 |
-
question = ' '.join(question_words[:max_ques_words])
|
60 |
-
|
61 |
-
return question
|
62 |
-
|
63 |
-
def pre_caption(self, caption, max_words=None):
|
64 |
-
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
65 |
-
|
66 |
-
caption = re.sub(
|
67 |
-
r"\s{2,}",
|
68 |
-
' ',
|
69 |
-
caption,
|
70 |
-
)
|
71 |
-
caption = caption.rstrip('\n')
|
72 |
-
caption = caption.strip(' ')
|
73 |
-
|
74 |
-
# truncate caption
|
75 |
-
caption_words = caption.split(' ')
|
76 |
-
if max_words is not None and len(caption_words) > max_words:
|
77 |
-
caption = ' '.join(caption_words[:max_words])
|
78 |
-
|
79 |
-
return caption
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/pretrain_data/unify_dataset.py
DELETED
@@ -1,636 +0,0 @@
|
|
1 |
-
# Copyright 2022 The OFA-Sys Team.
|
2 |
-
# All rights reserved.
|
3 |
-
# This source code is licensed under the Apache 2.0 license
|
4 |
-
# found in the LICENSE file in the root directory.
|
5 |
-
|
6 |
-
from io import BytesIO
|
7 |
-
|
8 |
-
import math
|
9 |
-
import logging
|
10 |
-
import random
|
11 |
-
import warnings
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
import base64
|
16 |
-
from torchvision import transforms
|
17 |
-
|
18 |
-
from PIL import Image, ImageFile
|
19 |
-
|
20 |
-
from data import data_utils
|
21 |
-
from data.ofa_dataset import OFADataset
|
22 |
-
from utils.vision_helper import RandomAugment
|
23 |
-
import utils.transforms as T
|
24 |
-
|
25 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
26 |
-
ImageFile.MAX_IMAGE_PIXELS = None
|
27 |
-
Image.MAX_IMAGE_PIXELS = None
|
28 |
-
|
29 |
-
logger = logging.getLogger(__name__)
|
30 |
-
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
|
31 |
-
|
32 |
-
|
33 |
-
def get_whole_word_mask(bpe, dictionary):
|
34 |
-
if bpe is not None:
|
35 |
-
|
36 |
-
def is_beginning_of_word(i):
|
37 |
-
if i < dictionary.nspecial:
|
38 |
-
# special elements are always considered beginnings
|
39 |
-
return True
|
40 |
-
tok = dictionary[i]
|
41 |
-
if tok.startswith("madeupword"):
|
42 |
-
return True
|
43 |
-
try:
|
44 |
-
return bpe.is_beginning_of_word(tok)
|
45 |
-
except ValueError:
|
46 |
-
return True
|
47 |
-
|
48 |
-
mask_whole_words = torch.ByteTensor(
|
49 |
-
list(map(is_beginning_of_word, range(len(dictionary))))
|
50 |
-
)
|
51 |
-
return mask_whole_words
|
52 |
-
return None
|
53 |
-
|
54 |
-
|
55 |
-
def collate(samples, pad_idx, eos_idx):
|
56 |
-
if len(samples) == 0:
|
57 |
-
return {}
|
58 |
-
|
59 |
-
def merge(key):
|
60 |
-
return data_utils.collate_tokens(
|
61 |
-
[s[key] for s in samples],
|
62 |
-
pad_idx,
|
63 |
-
eos_idx=eos_idx,
|
64 |
-
)
|
65 |
-
|
66 |
-
id = np.array([s["id"] for s in samples])
|
67 |
-
src_tokens = merge("source")
|
68 |
-
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
|
69 |
-
|
70 |
-
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
|
71 |
-
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
|
72 |
-
|
73 |
-
code_masks = None
|
74 |
-
if samples[0].get("code_mask", None) is not None:
|
75 |
-
code_masks = torch.cat([sample['code_mask'] for sample in samples])
|
76 |
-
|
77 |
-
conf = torch.cat([s['conf'] for s in samples], dim=0)
|
78 |
-
|
79 |
-
prev_output_tokens = None
|
80 |
-
target = None
|
81 |
-
if samples[0].get("target", None) is not None:
|
82 |
-
target = merge("target")
|
83 |
-
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
|
84 |
-
ntokens = tgt_lengths.sum().item()
|
85 |
-
|
86 |
-
if samples[0].get("prev_output_tokens", None) is not None:
|
87 |
-
prev_output_tokens = merge("prev_output_tokens")
|
88 |
-
else:
|
89 |
-
ntokens = src_lengths.sum().item()
|
90 |
-
|
91 |
-
batch = {
|
92 |
-
"id": id,
|
93 |
-
"nsentences": len(samples),
|
94 |
-
"ntokens": ntokens,
|
95 |
-
"net_input": {
|
96 |
-
"src_tokens": src_tokens,
|
97 |
-
"src_lengths": src_lengths,
|
98 |
-
"patch_images": patch_images,
|
99 |
-
"patch_masks": patch_masks,
|
100 |
-
"code_masks": code_masks,
|
101 |
-
"prev_output_tokens": prev_output_tokens
|
102 |
-
},
|
103 |
-
"target": target,
|
104 |
-
"conf": conf
|
105 |
-
}
|
106 |
-
|
107 |
-
return batch
|
108 |
-
|
109 |
-
|
110 |
-
class UnifyDataset(OFADataset):
|
111 |
-
def __init__(
|
112 |
-
self,
|
113 |
-
split,
|
114 |
-
dataset,
|
115 |
-
bpe,
|
116 |
-
src_dict,
|
117 |
-
tgt_dict=None,
|
118 |
-
max_src_length=128,
|
119 |
-
max_tgt_length=30,
|
120 |
-
seed=7,
|
121 |
-
code_dict_size=8192,
|
122 |
-
num_bins=1000,
|
123 |
-
patch_image_size=384,
|
124 |
-
code_image_size=128,
|
125 |
-
pure_text_dataset=None,
|
126 |
-
pure_image_dataset=None,
|
127 |
-
detection_dataset=None,
|
128 |
-
all_object_list=None,
|
129 |
-
all_caption_list=None,
|
130 |
-
type2ans_dict=None,
|
131 |
-
ans2type_dict=None,
|
132 |
-
max_image_size=512,
|
133 |
-
mask_ratio=0.3,
|
134 |
-
random_ratio=0.0,
|
135 |
-
keep_ratio=0.0,
|
136 |
-
mask_length="span-poisson",
|
137 |
-
poisson_lambda=3.0,
|
138 |
-
replace_length=1
|
139 |
-
):
|
140 |
-
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
|
141 |
-
self.max_src_length = max_src_length
|
142 |
-
self.max_tgt_length = max_tgt_length
|
143 |
-
self.seed = seed
|
144 |
-
self.code_dict_size = code_dict_size
|
145 |
-
self.num_bins = num_bins
|
146 |
-
self.patch_image_size = patch_image_size
|
147 |
-
self.code_image_size = code_image_size
|
148 |
-
|
149 |
-
self.pure_text_dataset = pure_text_dataset
|
150 |
-
self.pure_image_dataset = pure_image_dataset
|
151 |
-
self.detection_dataset = detection_dataset
|
152 |
-
self.epoch = 0
|
153 |
-
|
154 |
-
self.all_object_list = all_object_list
|
155 |
-
self.all_caption_list = all_caption_list
|
156 |
-
self.type2ans_dict = type2ans_dict
|
157 |
-
self.ans2type_dict = ans2type_dict
|
158 |
-
|
159 |
-
self.mask_ratio = mask_ratio
|
160 |
-
self.random_ratio = random_ratio
|
161 |
-
self.keep_ratio = keep_ratio
|
162 |
-
self.mask_length = mask_length
|
163 |
-
self.poisson_lambda = poisson_lambda
|
164 |
-
self.replace_length = replace_length
|
165 |
-
if self.replace_length not in [-1, 0, 1]:
|
166 |
-
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
167 |
-
if self.mask_length not in ["subword", "word", "span-poisson"]:
|
168 |
-
raise ValueError(f"invalid arg: mask-length={self.mask_length}")
|
169 |
-
if self.mask_length == "subword" and self.replace_length not in [0, 1]:
|
170 |
-
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
171 |
-
|
172 |
-
self.mask_idx = src_dict.index("<mask>")
|
173 |
-
self.mask_whole_word = (
|
174 |
-
get_whole_word_mask(self.bpe, self.src_dict)
|
175 |
-
if self.mask_length != "subword"
|
176 |
-
else None
|
177 |
-
)
|
178 |
-
self.mask_span_distribution = None
|
179 |
-
if self.mask_length == "span-poisson":
|
180 |
-
_lambda = self.poisson_lambda
|
181 |
-
lambda_to_the_k = 1
|
182 |
-
e_to_the_minus_lambda = math.exp(-_lambda)
|
183 |
-
k_factorial = 1
|
184 |
-
ps = []
|
185 |
-
for k in range(0, 128):
|
186 |
-
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
187 |
-
lambda_to_the_k *= _lambda
|
188 |
-
k_factorial *= k + 1
|
189 |
-
if ps[-1] < 0.0000001:
|
190 |
-
break
|
191 |
-
ps = torch.FloatTensor(ps)
|
192 |
-
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
193 |
-
|
194 |
-
self.pos_tgt_item = self.encode_text(" yes")
|
195 |
-
self.neg_tgt_item = self.encode_text(" no")
|
196 |
-
|
197 |
-
self.mask_left = self.mask_top = int(0.5 * self.code_image_size)
|
198 |
-
self.mask_right = self.mask_bottom = int(1.5 * self.code_image_size)
|
199 |
-
self.mask_ids = [
|
200 |
-
i*self.code_image_size*2+j
|
201 |
-
for i in range(self.code_image_size*2) for j in range(self.code_image_size*2)
|
202 |
-
if not (self.mask_left <= i < self.mask_right and self.mask_top <= j < self.mask_bottom)
|
203 |
-
]
|
204 |
-
|
205 |
-
scales = np.arange(patch_image_size, 481).tolist()
|
206 |
-
|
207 |
-
# for image-text pair
|
208 |
-
self.patch_resize_transform = transforms.Compose([
|
209 |
-
T.RandomResize(scales, max_size=672),
|
210 |
-
transforms.CenterCrop(patch_image_size),
|
211 |
-
RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
212 |
-
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
213 |
-
transforms.ToTensor(),
|
214 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
215 |
-
])
|
216 |
-
# for pure image
|
217 |
-
self.patch_crop_transform = transforms.Compose([
|
218 |
-
transforms.ToTensor(),
|
219 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
220 |
-
])
|
221 |
-
# for detection
|
222 |
-
self.detection_transform = T.Compose([
|
223 |
-
T.RandomHorizontalFlip(),
|
224 |
-
T.LargeScaleJitter(output_size=self.code_image_size*2, aug_scale_min=1.0, aug_scale_max=1.5),
|
225 |
-
T.ToTensor(),
|
226 |
-
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
|
227 |
-
])
|
228 |
-
# for visual grounding
|
229 |
-
self.visual_grounding_transform = T.Compose([
|
230 |
-
T.RandomResize(scales, max_size=672),
|
231 |
-
T.ObjectCenterCrop((patch_image_size, patch_image_size)),
|
232 |
-
T.ToTensor(),
|
233 |
-
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
|
234 |
-
])
|
235 |
-
|
236 |
-
def set_epoch(self, epoch, **unused):
|
237 |
-
self.epoch = epoch
|
238 |
-
|
239 |
-
def get_negative_caption(self, caption, gt_objects):
|
240 |
-
prob = random.random()
|
241 |
-
if gt_objects is not None and gt_objects != '' and prob > 0.6:
|
242 |
-
gt_object = random.choice(gt_objects.strip().split('&&'))
|
243 |
-
negative_object = random.choice(self.all_object_list[:-1])
|
244 |
-
negative_object = self.all_object_list[-1] if negative_object == gt_object else negative_object
|
245 |
-
negative_caption = caption.replace(gt_object, negative_object)
|
246 |
-
else:
|
247 |
-
negative_caption = random.choice(self.all_caption_list)
|
248 |
-
return negative_caption
|
249 |
-
|
250 |
-
def get_negative_answer(self, answer, conf):
|
251 |
-
prob = random.random()
|
252 |
-
if conf > (prob + 0.1) and answer in self.ans2type_dict:
|
253 |
-
negative_answer_type = self.ans2type_dict[answer]
|
254 |
-
if negative_answer_type == 'how many' and answer.isdigit() and prob > 0.5:
|
255 |
-
negative_answer = int(answer) + random.choice([-1, 1]) if answer != 0 else 1
|
256 |
-
else:
|
257 |
-
negative_answer_list = self.type2ans_dict[negative_answer_type]
|
258 |
-
negative_answer = random.choice(negative_answer_list[:-1])
|
259 |
-
negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
|
260 |
-
return negative_answer
|
261 |
-
|
262 |
-
negative_answer_list = self.type2ans_dict['other']
|
263 |
-
negative_answer = random.choice(negative_answer_list[:-1])
|
264 |
-
negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
|
265 |
-
return negative_answer
|
266 |
-
|
267 |
-
def process_image_text_pair(self, index):
|
268 |
-
uniq_id, image, caption, question, refs, gt_objects, dataset_name, type = self.dataset[index]
|
269 |
-
|
270 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
271 |
-
patch_image = self.patch_resize_transform(image) if type != 'visual_grounding' else None
|
272 |
-
patch_mask = torch.tensor([True])
|
273 |
-
conf = torch.tensor([1.0])
|
274 |
-
if type == 'caption':
|
275 |
-
tgt_caption = self.pre_caption(caption, self.max_tgt_length)
|
276 |
-
pos_src_caption = self.pre_caption(caption, self.max_src_length)
|
277 |
-
neg_src_caption = self.pre_caption(self.get_negative_caption(caption, gt_objects), self.max_src_length)
|
278 |
-
src_item = self.encode_text(" what does the image describe?")
|
279 |
-
tgt_item = self.encode_text(" {}".format(tgt_caption))
|
280 |
-
pos_src_item = self.encode_text(' does the image describe " {} "?'.format(pos_src_caption))
|
281 |
-
neg_src_item = self.encode_text(' does the image describe " {} "?'.format(neg_src_caption))
|
282 |
-
elif type == 'qa':
|
283 |
-
question = self.pre_question(question, self.max_src_length)
|
284 |
-
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in refs.split('&&')}
|
285 |
-
answer = max(ref_dict, key=ref_dict.get)
|
286 |
-
conf = ref_dict[answer]
|
287 |
-
src_item = self.encode_text(" {}".format(question))
|
288 |
-
tgt_item = self.encode_text(" {}".format(answer))
|
289 |
-
conf = torch.tensor([conf])
|
290 |
-
pos_src_item = self.encode_text(' what is the answer to question " {} ". is " {} "?'.format(question, answer))
|
291 |
-
neg_src_item = self.encode_text(
|
292 |
-
' what is the answer to question " {} ". is " {} "?'.format(question, self.get_negative_answer(answer, conf))
|
293 |
-
)
|
294 |
-
elif type == 'visual_grounding':
|
295 |
-
conf = torch.tensor([1.0])
|
296 |
-
w, h = image.size
|
297 |
-
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
|
298 |
-
x0, y0, x1, y1 = refs.strip().split(',')
|
299 |
-
boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
|
300 |
-
boxes_target["labels"] = np.array([0])
|
301 |
-
boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
|
302 |
-
patch_image, boxes_target = self.visual_grounding_transform(image, boxes_target)
|
303 |
-
quant_x0 = "<bin_{}>".format(int((boxes_target["boxes"][0][0] * (self.num_bins - 1)).round()))
|
304 |
-
quant_y0 = "<bin_{}>".format(int((boxes_target["boxes"][0][1] * (self.num_bins - 1)).round()))
|
305 |
-
quant_x1 = "<bin_{}>".format(int((boxes_target["boxes"][0][2] * (self.num_bins - 1)).round()))
|
306 |
-
quant_y1 = "<bin_{}>".format(int((boxes_target["boxes"][0][3] * (self.num_bins - 1)).round()))
|
307 |
-
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
|
308 |
-
src_caption = self.pre_caption(caption, self.max_src_length)
|
309 |
-
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
|
310 |
-
tgt_item = self.encode_text(region_coord, use_bpe=False)
|
311 |
-
else:
|
312 |
-
logger.info('type {} is not implemented'.format(type))
|
313 |
-
raise NotImplementedError
|
314 |
-
|
315 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
316 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
317 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
318 |
-
pos_src_item = torch.cat([self.bos_item, pos_src_item, self.eos_item]) if type != 'visual_grounding' else None
|
319 |
-
neg_src_item = torch.cat([self.bos_item, neg_src_item, self.eos_item]) if type != 'visual_grounding' else None
|
320 |
-
|
321 |
-
if type == 'caption' and dataset_name == 'cc12m':
|
322 |
-
target_item[:2] = self.src_dict.pad()
|
323 |
-
target_item[-1] = self.eos_item
|
324 |
-
|
325 |
-
example = {
|
326 |
-
"id": uniq_id,
|
327 |
-
"source": src_item,
|
328 |
-
"patch_image": patch_image,
|
329 |
-
"patch_mask": patch_mask,
|
330 |
-
"target": target_item,
|
331 |
-
"prev_output_tokens": prev_output_item,
|
332 |
-
"conf": conf,
|
333 |
-
}
|
334 |
-
|
335 |
-
examples = [example]
|
336 |
-
prob = random.random()
|
337 |
-
if type == 'visual_grounding':
|
338 |
-
region_example = example.copy()
|
339 |
-
region_prefix_item = self.encode_text(' what does the region describe? region:')
|
340 |
-
region_coord_item = self.encode_text('{}'.format(region_coord), use_bpe=False)
|
341 |
-
region_src_item = torch.cat([region_prefix_item, region_coord_item])
|
342 |
-
region_tgt_item = self.encode_text(' {}'.format(self.pre_caption(caption, self.max_tgt_length)))
|
343 |
-
region_example["source"] = torch.cat([self.bos_item, region_src_item, self.eos_item])
|
344 |
-
region_example["target"] = torch.cat([region_tgt_item, self.eos_item])
|
345 |
-
region_example["prev_output_tokens"] = torch.cat([self.bos_item, region_tgt_item])
|
346 |
-
region_example["conf"] = torch.tensor([1.0])
|
347 |
-
examples.append(region_example)
|
348 |
-
elif prob >= 0.5 and self.split == 'train':
|
349 |
-
pos_example = example.copy()
|
350 |
-
pos_example["source"] = pos_src_item
|
351 |
-
pos_example["target"] = torch.cat([self.pos_tgt_item, self.eos_item])
|
352 |
-
pos_example["prev_output_tokens"] = torch.cat([self.bos_item, self.pos_tgt_item])
|
353 |
-
examples.append(pos_example)
|
354 |
-
elif self.split == 'train':
|
355 |
-
neg_example = example.copy()
|
356 |
-
neg_example["source"] = neg_src_item
|
357 |
-
neg_example["target"] = torch.cat([self.neg_tgt_item, self.eos_item])
|
358 |
-
neg_example["prev_output_tokens"] = torch.cat([self.bos_item, self.neg_tgt_item])
|
359 |
-
examples.append(neg_example)
|
360 |
-
return examples
|
361 |
-
|
362 |
-
def process_pure_text(self, index):
|
363 |
-
patch_image = torch.zeros((3, self.code_image_size*2, self.code_image_size*2))
|
364 |
-
patch_mask = torch.tensor([False])
|
365 |
-
code_mask = torch.tensor([False])
|
366 |
-
conf = torch.tensor([2.0])
|
367 |
-
|
368 |
-
examples = []
|
369 |
-
for _ in range(2):
|
370 |
-
uniq_id, text = self.pure_text_dataset[index]
|
371 |
-
text = text.strip().lower()
|
372 |
-
text_item = self.encode_text(" {}".format(text), length=512)
|
373 |
-
text_item = text_item[-256:]
|
374 |
-
text_item = torch.cat([self.bos_item, text_item, self.eos_item])
|
375 |
-
mask_text_item = self.add_whole_word_mask(text_item.clone(), self.mask_ratio)
|
376 |
-
prefix_item = self.encode_text(' what is the complete text of " "?')
|
377 |
-
src_item = torch.cat([prefix_item[:-2], mask_text_item[1:-1], prefix_item[-2:]])
|
378 |
-
tgt_item = text_item[1:-1]
|
379 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
380 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
381 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
382 |
-
example = {
|
383 |
-
"id": uniq_id,
|
384 |
-
"source": src_item,
|
385 |
-
"patch_image": patch_image,
|
386 |
-
"patch_mask": patch_mask,
|
387 |
-
"code_mask": code_mask,
|
388 |
-
"target": target_item,
|
389 |
-
"prev_output_tokens": prev_output_item,
|
390 |
-
"conf": conf,
|
391 |
-
}
|
392 |
-
examples.append(example)
|
393 |
-
|
394 |
-
return examples
|
395 |
-
|
396 |
-
def process_pure_image(self, index):
|
397 |
-
image_id, image, code = self.pure_image_dataset[index]
|
398 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
399 |
-
patch_image = self.patch_crop_transform(image)
|
400 |
-
patch_image[:, self.mask_top:self.mask_bottom, self.mask_left:self.mask_right] = 0
|
401 |
-
patch_mask = torch.tensor([True])
|
402 |
-
src_item = self.encode_text(" what is the image in the middle part?")
|
403 |
-
image_code = torch.LongTensor([int(num) for num in code.strip().split()])
|
404 |
-
tgt_item = image_code + len(self.src_dict) - self.code_dict_size - self.num_bins
|
405 |
-
code_mask = torch.tensor([True])
|
406 |
-
conf = torch.tensor([2.0])
|
407 |
-
|
408 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
409 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
410 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
411 |
-
|
412 |
-
example = {
|
413 |
-
"id": image_id,
|
414 |
-
"source": src_item,
|
415 |
-
"patch_image": patch_image,
|
416 |
-
"patch_mask": patch_mask,
|
417 |
-
"code_mask": code_mask,
|
418 |
-
"target": target_item,
|
419 |
-
"prev_output_tokens": prev_output_item,
|
420 |
-
"conf": conf,
|
421 |
-
}
|
422 |
-
return [example]
|
423 |
-
|
424 |
-
def process_detection(self, index):
|
425 |
-
image_id, image, label = self.detection_dataset[index]
|
426 |
-
image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
|
427 |
-
|
428 |
-
w, h = image.size
|
429 |
-
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
|
430 |
-
label_list = label.strip().split('&&')
|
431 |
-
for label in label_list:
|
432 |
-
x0, y0, x1, y1, cat_id, cat = label.strip().split(',', 5)
|
433 |
-
boxes_target["boxes"].append([float(x0), float(y0), float(x1), float(y1)])
|
434 |
-
boxes_target["labels"].append(cat)
|
435 |
-
boxes_target["area"].append((float(x1) - float(x0)) * (float(y1) - float(y0)))
|
436 |
-
boxes_target["boxes"] = torch.tensor(boxes_target["boxes"])
|
437 |
-
boxes_target["labels"] = np.array(boxes_target["labels"])
|
438 |
-
boxes_target["area"] = torch.tensor(boxes_target["area"])
|
439 |
-
|
440 |
-
patch_image, boxes_target = self.detection_transform(image, boxes_target)
|
441 |
-
patch_mask = torch.tensor([True])
|
442 |
-
code_mask = torch.tensor([False])
|
443 |
-
conf = torch.tensor([2.0])
|
444 |
-
|
445 |
-
quant_boxes = []
|
446 |
-
for i, box in enumerate(boxes_target["boxes"]):
|
447 |
-
quant_boxes.extend(["<bin_{}>".format(int((pos * (self.num_bins - 1)).round())) for pos in box[:4]])
|
448 |
-
quant_boxes.append(self.bpe.encode(' {}'.format(boxes_target["labels"][i])))
|
449 |
-
src_item = self.encode_text(' what are the objects in the image?')
|
450 |
-
tgt_item = self.encode_text(' '.join(quant_boxes), use_bpe=False)
|
451 |
-
|
452 |
-
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
|
453 |
-
target_item = torch.cat([tgt_item, self.eos_item])
|
454 |
-
prev_output_item = torch.cat([self.bos_item, tgt_item])
|
455 |
-
|
456 |
-
example = {
|
457 |
-
"id": image_id,
|
458 |
-
"source": src_item,
|
459 |
-
"patch_image": patch_image,
|
460 |
-
"patch_mask": patch_mask,
|
461 |
-
"code_mask": code_mask,
|
462 |
-
"target": target_item,
|
463 |
-
"prev_output_tokens": prev_output_item,
|
464 |
-
"conf": conf,
|
465 |
-
}
|
466 |
-
return [example]
|
467 |
-
|
468 |
-
def __getitem__(self, index):
|
469 |
-
with data_utils.numpy_seed(self.seed, self.epoch):
|
470 |
-
pair_samples = self.process_image_text_pair(index)
|
471 |
-
extra_samples = []
|
472 |
-
if self.split == 'train' and self.dataset.data_cnt % 8 == 0:
|
473 |
-
extra_samples += self.process_pure_text(0) if self.pure_text_dataset else []
|
474 |
-
extra_samples += self.process_pure_image(0) if self.pure_image_dataset else []
|
475 |
-
extra_samples += self.process_detection(0) if self.detection_dataset else []
|
476 |
-
return pair_samples, extra_samples
|
477 |
-
|
478 |
-
def word_starts(self, source):
|
479 |
-
if self.mask_whole_word is not None:
|
480 |
-
is_word_start = self.mask_whole_word.gather(0, source)
|
481 |
-
else:
|
482 |
-
is_word_start = torch.ones(source.size())
|
483 |
-
is_word_start[0] = 0
|
484 |
-
is_word_start[-1] = 0
|
485 |
-
return is_word_start
|
486 |
-
|
487 |
-
def add_whole_word_mask(self, source, p):
|
488 |
-
is_word_start = self.word_starts(source)
|
489 |
-
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
490 |
-
num_inserts = 0
|
491 |
-
if num_to_mask == 0:
|
492 |
-
return source
|
493 |
-
|
494 |
-
if self.mask_span_distribution is not None:
|
495 |
-
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
496 |
-
|
497 |
-
# Make sure we have enough to mask
|
498 |
-
cum_length = torch.cumsum(lengths, 0)
|
499 |
-
while cum_length[-1] < num_to_mask:
|
500 |
-
lengths = torch.cat(
|
501 |
-
[
|
502 |
-
lengths,
|
503 |
-
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
504 |
-
],
|
505 |
-
dim=0,
|
506 |
-
)
|
507 |
-
cum_length = torch.cumsum(lengths, 0)
|
508 |
-
|
509 |
-
# Trim to masking budget
|
510 |
-
i = 0
|
511 |
-
while cum_length[i] < num_to_mask:
|
512 |
-
i += 1
|
513 |
-
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
514 |
-
num_to_mask = i + 1
|
515 |
-
lengths = lengths[:num_to_mask]
|
516 |
-
|
517 |
-
# Handle 0-length mask (inserts) separately
|
518 |
-
lengths = lengths[lengths > 0]
|
519 |
-
num_inserts = num_to_mask - lengths.size(0)
|
520 |
-
num_to_mask -= num_inserts
|
521 |
-
if num_to_mask == 0:
|
522 |
-
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
523 |
-
|
524 |
-
assert (lengths > 0).all()
|
525 |
-
else:
|
526 |
-
lengths = torch.ones((num_to_mask,)).long()
|
527 |
-
assert is_word_start[-1] == 0
|
528 |
-
word_starts = is_word_start.nonzero(as_tuple=False)
|
529 |
-
indices = word_starts[
|
530 |
-
torch.randperm(word_starts.size(0))[:num_to_mask]
|
531 |
-
].squeeze(1)
|
532 |
-
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
533 |
-
|
534 |
-
source_length = source.size(0)
|
535 |
-
assert source_length - 1 not in indices
|
536 |
-
to_keep = torch.ones(source_length, dtype=torch.bool)
|
537 |
-
is_word_start[
|
538 |
-
-1
|
539 |
-
] = 255 # acts as a long length, so spans don't go over the end of doc
|
540 |
-
if self.replace_length == 0:
|
541 |
-
to_keep[indices] = 0
|
542 |
-
else:
|
543 |
-
# keep index, but replace it with [MASK]
|
544 |
-
source[indices] = self.mask_idx
|
545 |
-
source[indices[mask_random]] = torch.randint(
|
546 |
-
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
|
547 |
-
)
|
548 |
-
|
549 |
-
if self.mask_span_distribution is not None:
|
550 |
-
assert len(lengths.size()) == 1
|
551 |
-
assert lengths.size() == indices.size()
|
552 |
-
lengths -= 1
|
553 |
-
while indices.size(0) > 0:
|
554 |
-
assert lengths.size() == indices.size()
|
555 |
-
lengths -= is_word_start[indices + 1].long()
|
556 |
-
uncompleted = lengths >= 0
|
557 |
-
indices = indices[uncompleted] + 1
|
558 |
-
mask_random = mask_random[uncompleted]
|
559 |
-
lengths = lengths[uncompleted]
|
560 |
-
if self.replace_length != -1:
|
561 |
-
# delete token
|
562 |
-
to_keep[indices] = 0
|
563 |
-
else:
|
564 |
-
# keep index, but replace it with [MASK]
|
565 |
-
source[indices] = self.mask_idx
|
566 |
-
source[indices[mask_random]] = torch.randint(
|
567 |
-
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
|
568 |
-
)
|
569 |
-
else:
|
570 |
-
# A bit faster when all lengths are 1
|
571 |
-
while indices.size(0) > 0:
|
572 |
-
uncompleted = is_word_start[indices + 1] == 0
|
573 |
-
indices = indices[uncompleted] + 1
|
574 |
-
mask_random = mask_random[uncompleted]
|
575 |
-
if self.replace_length != -1:
|
576 |
-
# delete token
|
577 |
-
to_keep[indices] = 0
|
578 |
-
else:
|
579 |
-
# keep index, but replace it with [MASK]
|
580 |
-
source[indices] = self.mask_idx
|
581 |
-
source[indices[mask_random]] = torch.randint(
|
582 |
-
4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
|
583 |
-
)
|
584 |
-
|
585 |
-
assert source_length - 1 not in indices
|
586 |
-
|
587 |
-
source = source[to_keep]
|
588 |
-
|
589 |
-
if num_inserts > 0:
|
590 |
-
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
591 |
-
|
592 |
-
return source
|
593 |
-
|
594 |
-
def add_insertion_noise(self, tokens, p):
|
595 |
-
if p == 0.0:
|
596 |
-
return tokens
|
597 |
-
|
598 |
-
num_tokens = len(tokens)
|
599 |
-
n = int(math.ceil(num_tokens * p))
|
600 |
-
|
601 |
-
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
602 |
-
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
603 |
-
noise_mask[noise_indices] = 1
|
604 |
-
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
605 |
-
|
606 |
-
num_random = int(math.ceil(n * self.random_ratio))
|
607 |
-
result[noise_indices[num_random:]] = self.mask_idx
|
608 |
-
result[noise_indices[:num_random]] = torch.randint(
|
609 |
-
low=4, high=len(self.tgt_dict)-self.code_dict_size-self.num_bins, size=(num_random,)
|
610 |
-
)
|
611 |
-
|
612 |
-
result[~noise_mask] = tokens
|
613 |
-
|
614 |
-
assert (result >= 0).all()
|
615 |
-
return result
|
616 |
-
|
617 |
-
def collater(self, samples, pad_to_length=None):
|
618 |
-
"""Merge samples of different tasks to form two mini-batches.
|
619 |
-
Args:
|
620 |
-
samples (List[Tuple]): samples to collate
|
621 |
-
Returns:
|
622 |
-
Tuple[dict]: two mini-batch containing the data of different tasks
|
623 |
-
"""
|
624 |
-
|
625 |
-
samples_v1 = [] # containing image-text pairs
|
626 |
-
samples_v2 = [] # containing detection data, text data and image data
|
627 |
-
for sample_tuple in samples:
|
628 |
-
samples_v1 += sample_tuple[0]
|
629 |
-
samples_v2 += sample_tuple[1]
|
630 |
-
if samples_v2 != []:
|
631 |
-
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
|
632 |
-
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
|
633 |
-
return res_v1, res_v2
|
634 |
-
else:
|
635 |
-
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
|
636 |
-
return res_v1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets.md
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
# Datasets
|
2 |
-
|
3 |
-
We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
|
4 |
-
|
5 |
-
## Pretraining
|
6 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip"> A small subset of the pretraining data </a>
|
7 |
-
|
8 |
-
The pretraining datasets used in OFA are all publicly available. Here we provide the public links to these data, it is recommended that you download the data from the links first, and then process the downloaded dataset into a similar format as the examples we provided.
|
9 |
-
- _CC12M_: https://github.com/google-research-datasets/conceptual-12m
|
10 |
-
- _CC3M_: https://github.com/google-research-datasets/conceptual-captions
|
11 |
-
- _SBU_: https://www.cs.virginia.edu/~vicente/sbucaptions
|
12 |
-
- _COCO_: https://cocodataset.org/#home
|
13 |
-
- _VG_: https://visualgenome.org/
|
14 |
-
- _VQAv2_: https://visualqa.org/
|
15 |
-
- _GQA_: https://cs.stanford.edu/people/dorarad/gqa/about.html
|
16 |
-
- _RefCOCO_/_RefCOCO+_/RefCOCOg: https://github.com/lichengunc/refer
|
17 |
-
- _OpenImages_: https://storage.googleapis.com/openimages/web/index.html
|
18 |
-
- _Object365_: https://www.objects365.org/overview.html
|
19 |
-
- _YFCC100M (subset)_: https://github.com/openai/CLIP/blob/main/data/yfcc100m.md
|
20 |
-
- _ImageNet-21K_: https://image-net.org/index.php
|
21 |
-
- _Pile_: https://pile.eleuther.ai
|
22 |
-
|
23 |
-
## Vision & Language Tasks
|
24 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
|
25 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
|
26 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
|
27 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
|
28 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/vqa_data/vqa_data.zip"> Dataset for VQAv2 </a> (we have also provided chunked parts of the dataset files for more convenient downloading, please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>)
|
29 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/snli_ve_data/snli_ve_data.zip"> Dataset for SNLI-VE </a>
|
30 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen.zip"> Dataset for Text-to-Image Genearion </a>
|
31 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen_origin_id.zip"> Dataset for Text-to-Image Genearion (with original id) </a>
|
32 |
-
|
33 |
-
## Vision Tasks
|
34 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/imagenet_1k_data/imagenet_1k_data.zip"> Dataset for ImageNet-1K </a>
|
35 |
-
|
36 |
-
## Language Tasks
|
37 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/cola_data.zip"> Dataset for COLA </a>
|
38 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mnli_data.zip"> Dataset for MNLI </a>
|
39 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mrpc_data.zip"> Dataset for MRPC </a>
|
40 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qnli_data.zip"> Dataset for QNLI </a>
|
41 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qqp_data.zip"> Dataset for QQP </a>
|
42 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/rte_data.zip"> Dataset for RTE </a>
|
43 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/sst2_data.zip"> Dataset for SST2 </a>
|
44 |
-
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/gigaword_data/gigaword_data.zip"> Dataset for Gigaword </a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluate.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3 -u
|
2 |
-
# Copyright 2022 The OFA-Sys Team.
|
3 |
-
# All rights reserved.
|
4 |
-
# This source code is licensed under the Apache 2.0 license
|
5 |
-
# found in the LICENSE file in the root directory.
|
6 |
-
|
7 |
-
import logging
|
8 |
-
import os
|
9 |
-
import sys
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
from fairseq import distributed_utils, options, tasks, utils
|
14 |
-
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
15 |
-
from fairseq.logging import progress_bar
|
16 |
-
from fairseq.utils import reset_logging
|
17 |
-
from omegaconf import DictConfig
|
18 |
-
|
19 |
-
from utils import checkpoint_utils
|
20 |
-
from utils.eval_utils import eval_step, merge_results
|
21 |
-
from utils.zero_shot_utils import zero_shot_step
|
22 |
-
|
23 |
-
logging.basicConfig(
|
24 |
-
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
25 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
26 |
-
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
27 |
-
stream=sys.stdout,
|
28 |
-
)
|
29 |
-
logger = logging.getLogger("ofa.evaluate")
|
30 |
-
|
31 |
-
|
32 |
-
def apply_half(t):
|
33 |
-
if t.dtype is torch.float32:
|
34 |
-
return t.to(dtype=torch.half)
|
35 |
-
return t
|
36 |
-
|
37 |
-
|
38 |
-
def main(cfg: DictConfig, **kwargs):
|
39 |
-
utils.import_user_module(cfg.common)
|
40 |
-
|
41 |
-
reset_logging()
|
42 |
-
logger.info(cfg)
|
43 |
-
|
44 |
-
assert (
|
45 |
-
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
46 |
-
), "Must specify batch size either with --max-tokens or --batch-size"
|
47 |
-
|
48 |
-
# Fix seed for stochastic decoding
|
49 |
-
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
50 |
-
np.random.seed(cfg.common.seed)
|
51 |
-
utils.set_torch_seed(cfg.common.seed)
|
52 |
-
|
53 |
-
use_fp16 = cfg.common.fp16
|
54 |
-
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
55 |
-
|
56 |
-
if use_cuda:
|
57 |
-
torch.cuda.set_device(cfg.distributed_training.device_id)
|
58 |
-
|
59 |
-
# Load ensemble
|
60 |
-
overrides = eval(cfg.common_eval.model_overrides)
|
61 |
-
# Deal with beam-search / all-candidate VQA eval
|
62 |
-
if cfg.task._name == "vqa_gen":
|
63 |
-
overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
|
64 |
-
|
65 |
-
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
66 |
-
if kwargs["zero_shot"]:
|
67 |
-
task = tasks.setup_task(cfg.task)
|
68 |
-
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
69 |
-
utils.split_paths(cfg.common_eval.path),
|
70 |
-
arg_overrides=overrides,
|
71 |
-
task=task,
|
72 |
-
suffix=cfg.checkpoint.checkpoint_suffix,
|
73 |
-
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
74 |
-
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
75 |
-
)
|
76 |
-
else:
|
77 |
-
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
78 |
-
utils.split_paths(cfg.common_eval.path),
|
79 |
-
arg_overrides=overrides,
|
80 |
-
suffix=cfg.checkpoint.checkpoint_suffix,
|
81 |
-
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
82 |
-
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
83 |
-
)
|
84 |
-
|
85 |
-
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
86 |
-
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
87 |
-
|
88 |
-
# Move models to GPU
|
89 |
-
for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
|
90 |
-
if kwargs['ema_eval']:
|
91 |
-
logger.info("loading EMA weights from {}".format(ckpt_path))
|
92 |
-
model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
|
93 |
-
model.eval()
|
94 |
-
if use_fp16:
|
95 |
-
model.half()
|
96 |
-
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
97 |
-
model.cuda()
|
98 |
-
model.prepare_for_inference_(cfg)
|
99 |
-
|
100 |
-
# Load dataset (possibly sharded)
|
101 |
-
itr = task.get_batch_iterator(
|
102 |
-
dataset=task.dataset(cfg.dataset.gen_subset),
|
103 |
-
max_tokens=cfg.dataset.max_tokens,
|
104 |
-
max_sentences=cfg.dataset.batch_size,
|
105 |
-
max_positions=utils.resolve_max_positions(
|
106 |
-
task.max_positions(), *[m.max_positions() for m in models]
|
107 |
-
),
|
108 |
-
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
109 |
-
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
110 |
-
seed=cfg.common.seed,
|
111 |
-
num_shards=cfg.distributed_training.distributed_world_size,
|
112 |
-
shard_id=cfg.distributed_training.distributed_rank,
|
113 |
-
num_workers=cfg.dataset.num_workers,
|
114 |
-
data_buffer_size=cfg.dataset.data_buffer_size,
|
115 |
-
).next_epoch_itr(shuffle=False)
|
116 |
-
progress = progress_bar.progress_bar(
|
117 |
-
itr,
|
118 |
-
log_format=cfg.common.log_format,
|
119 |
-
log_interval=cfg.common.log_interval,
|
120 |
-
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
121 |
-
)
|
122 |
-
|
123 |
-
# Initialize generator
|
124 |
-
generator = task.build_generator(models, cfg.generation)
|
125 |
-
|
126 |
-
results = []
|
127 |
-
score_sum = torch.FloatTensor([0]).cuda()
|
128 |
-
score_cnt = torch.FloatTensor([0]).cuda()
|
129 |
-
for sample in progress:
|
130 |
-
if "net_input" not in sample:
|
131 |
-
continue
|
132 |
-
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
133 |
-
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
134 |
-
with torch.no_grad():
|
135 |
-
if kwargs["zero_shot"]:
|
136 |
-
result, scores = zero_shot_step(task, generator, models, sample)
|
137 |
-
else:
|
138 |
-
result, scores = eval_step(task, generator, models, sample, **kwargs)
|
139 |
-
results += result
|
140 |
-
score_sum += sum(scores) if scores is not None else 0
|
141 |
-
score_cnt += len(scores) if scores is not None else 0
|
142 |
-
progress.log({"sentences": sample["nsentences"]})
|
143 |
-
|
144 |
-
merge_results(task, cfg, logger, score_cnt, score_sum, results)
|
145 |
-
|
146 |
-
|
147 |
-
def cli_main():
|
148 |
-
parser = options.get_generation_parser()
|
149 |
-
parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
|
150 |
-
parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
|
151 |
-
parser.add_argument("--zero-shot", action='store_true')
|
152 |
-
args = options.parse_args_and_arch(parser)
|
153 |
-
cfg = convert_namespace_to_omegaconf(args)
|
154 |
-
distributed_utils.call_main(
|
155 |
-
cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot
|
156 |
-
)
|
157 |
-
|
158 |
-
|
159 |
-
if __name__ == "__main__":
|
160 |
-
cli_main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [yyyy] [name of copyright owner]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/README.md
DELETED
@@ -1,49 +0,0 @@
|
|
1 |
-
# EasyOCR Lite
|
2 |
-
|
3 |
-
从EasyOCR提取文本定位有关代码,进一步适配中文,修正缺陷
|
4 |
-
|
5 |
-
## 安装
|
6 |
-
|
7 |
-
Python版本至少为3.8。
|
8 |
-
|
9 |
-
|
10 |
-
首先按照PyTorch官方说明安装PyTorch。
|
11 |
-
|
12 |
-
```
|
13 |
-
pip install -e .
|
14 |
-
```
|
15 |
-
|
16 |
-
## 使用
|
17 |
-
|
18 |
-
``` python3
|
19 |
-
from easyocrlite import ReaderLite
|
20 |
-
|
21 |
-
reader = ReaderLite()
|
22 |
-
results = reader.process('my_awesome_handwriting.png')
|
23 |
-
```
|
24 |
-
|
25 |
-
返回的内容为边界框和对应的图像区域的列表。
|
26 |
-
其它说明见[demo](./demo.ipynb)。
|
27 |
-
|
28 |
-
|
29 |
-
## 致谢
|
30 |
-
|
31 |
-
基于[EasyOCR](https://github.com/JaidedAI/EasyOCR)修改实现。以下为EasyOCR致谢:
|
32 |
-
|
33 |
-
This project is based on research and code from several papers and open-source repositories.
|
34 |
-
|
35 |
-
All deep learning execution is based on [Pytorch](https://pytorch.org). :heart:
|
36 |
-
|
37 |
-
Detection execution uses the CRAFT algorithm from this [official repository](https://github.com/clovaai/CRAFT-pytorch) and their [paper](https://arxiv.org/abs/1904.01941) (Thanks @YoungminBaek from [@clovaai](https://github.com/clovaai)). We also use their pretrained model. Training script is provided by [@gmuffiness](https://github.com/gmuffiness).
|
38 |
-
|
39 |
-
The recognition model is a CRNN ([paper](https://arxiv.org/abs/1507.05717)). It is composed of 3 main components: feature extraction (we are currently using [Resnet](https://arxiv.org/abs/1512.03385)) and VGG, sequence labeling ([LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf)) and decoding ([CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf)). The training pipeline for recognition execution is a modified version of the [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) framework. (Thanks [@ku21fan](https://github.com/ku21fan) from [@clovaai](https://github.com/clovaai)) This repository is a gem that deserves more recognition.
|
40 |
-
|
41 |
-
Beam search code is based on this [repository](https://github.com/githubharald/CTCDecoder) and his [blog](https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7). (Thanks [@githubharald](https://github.com/githubharald))
|
42 |
-
|
43 |
-
Data synthesis is based on [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). (Thanks [@Belval](https://github.com/Belval))
|
44 |
-
|
45 |
-
And a good read about CTC from distill.pub [here](https://distill.pub/2017/ctc/).
|
46 |
-
|
47 |
-
|
48 |
-
## 许可证 (注意!)
|
49 |
-
Apache 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/build/lib/easyocrlite/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from easyocrlite.reader import ReaderLite
|
|
|
|
ezocr/build/lib/easyocrlite/reader.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import Tuple
|
7 |
-
|
8 |
-
import cv2
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
from PIL import Image, ImageEnhance
|
12 |
-
|
13 |
-
from easyocrlite.model import CRAFT
|
14 |
-
|
15 |
-
from easyocrlite.utils.download_utils import prepare_model
|
16 |
-
from easyocrlite.utils.image_utils import (
|
17 |
-
adjust_result_coordinates,
|
18 |
-
boxed_transform,
|
19 |
-
normalize_mean_variance,
|
20 |
-
resize_aspect_ratio,
|
21 |
-
)
|
22 |
-
from easyocrlite.utils.detect_utils import (
|
23 |
-
extract_boxes,
|
24 |
-
extract_regions_from_boxes,
|
25 |
-
box_expand,
|
26 |
-
greedy_merge,
|
27 |
-
)
|
28 |
-
from easyocrlite.types import BoxTuple, RegionTuple
|
29 |
-
import easyocrlite.utils.utils as utils
|
30 |
-
|
31 |
-
logger = logging.getLogger(__name__)
|
32 |
-
|
33 |
-
MODULE_PATH = (
|
34 |
-
os.environ.get("EASYOCR_MODULE_PATH")
|
35 |
-
or os.environ.get("MODULE_PATH")
|
36 |
-
or os.path.expanduser("~/.EasyOCR/")
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
class ReaderLite(object):
|
41 |
-
def __init__(
|
42 |
-
self,
|
43 |
-
gpu=True,
|
44 |
-
model_storage_directory=None,
|
45 |
-
download_enabled=True,
|
46 |
-
verbose=True,
|
47 |
-
quantize=True,
|
48 |
-
cudnn_benchmark=False,
|
49 |
-
):
|
50 |
-
|
51 |
-
self.verbose = verbose
|
52 |
-
|
53 |
-
model_storage_directory = Path(
|
54 |
-
model_storage_directory
|
55 |
-
if model_storage_directory
|
56 |
-
else MODULE_PATH + "/model"
|
57 |
-
)
|
58 |
-
self.detector_path = prepare_model(
|
59 |
-
model_storage_directory, download_enabled, verbose
|
60 |
-
)
|
61 |
-
|
62 |
-
self.quantize = quantize
|
63 |
-
self.cudnn_benchmark = cudnn_benchmark
|
64 |
-
if gpu is False:
|
65 |
-
self.device = "cpu"
|
66 |
-
if verbose:
|
67 |
-
logger.warning(
|
68 |
-
"Using CPU. Note: This module is much faster with a GPU."
|
69 |
-
)
|
70 |
-
elif not torch.cuda.is_available():
|
71 |
-
self.device = "cpu"
|
72 |
-
if verbose:
|
73 |
-
logger.warning(
|
74 |
-
"CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU."
|
75 |
-
)
|
76 |
-
elif gpu is True:
|
77 |
-
self.device = "cuda"
|
78 |
-
else:
|
79 |
-
self.device = gpu
|
80 |
-
|
81 |
-
self.detector = CRAFT()
|
82 |
-
|
83 |
-
state_dict = torch.load(self.detector_path, map_location=self.device)
|
84 |
-
if list(state_dict.keys())[0].startswith("module"):
|
85 |
-
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
86 |
-
|
87 |
-
self.detector.load_state_dict(state_dict)
|
88 |
-
|
89 |
-
if self.device == "cpu":
|
90 |
-
if self.quantize:
|
91 |
-
try:
|
92 |
-
torch.quantization.quantize_dynamic(
|
93 |
-
self.detector, dtype=torch.qint8, inplace=True
|
94 |
-
)
|
95 |
-
except:
|
96 |
-
pass
|
97 |
-
else:
|
98 |
-
self.detector = torch.nn.DataParallel(self.detector).to(self.device)
|
99 |
-
import torch.backends.cudnn as cudnn
|
100 |
-
|
101 |
-
cudnn.benchmark = self.cudnn_benchmark
|
102 |
-
|
103 |
-
self.detector.eval()
|
104 |
-
|
105 |
-
def process(
|
106 |
-
self,
|
107 |
-
image_path: str,
|
108 |
-
max_size: int = 960,
|
109 |
-
expand_ratio: float = 1.0,
|
110 |
-
sharp: float = 1.0,
|
111 |
-
contrast: float = 1.0,
|
112 |
-
text_confidence: float = 0.7,
|
113 |
-
text_threshold: float = 0.4,
|
114 |
-
link_threshold: float = 0.4,
|
115 |
-
slope_ths: float = 0.1,
|
116 |
-
ratio_ths: float = 0.5,
|
117 |
-
center_ths: float = 0.5,
|
118 |
-
dim_ths: float = 0.5,
|
119 |
-
space_ths: float = 1.0,
|
120 |
-
add_margin: float = 0.1,
|
121 |
-
min_size: float = 0.01,
|
122 |
-
) -> Tuple[BoxTuple, list[np.ndarray]]:
|
123 |
-
|
124 |
-
image = Image.open(image_path).convert('RGB')
|
125 |
-
|
126 |
-
tensor, inverse_ratio = self.preprocess(
|
127 |
-
image, max_size, expand_ratio, sharp, contrast
|
128 |
-
)
|
129 |
-
|
130 |
-
scores = self.forward_net(tensor)
|
131 |
-
|
132 |
-
boxes = self.detect(scores, text_confidence, text_threshold, link_threshold)
|
133 |
-
|
134 |
-
image = np.array(image)
|
135 |
-
region_list, box_list = self.postprocess(
|
136 |
-
image,
|
137 |
-
boxes,
|
138 |
-
inverse_ratio,
|
139 |
-
slope_ths,
|
140 |
-
ratio_ths,
|
141 |
-
center_ths,
|
142 |
-
dim_ths,
|
143 |
-
space_ths,
|
144 |
-
add_margin,
|
145 |
-
min_size,
|
146 |
-
)
|
147 |
-
|
148 |
-
# get cropped image
|
149 |
-
image_list = []
|
150 |
-
for region in region_list:
|
151 |
-
x_min, x_max, y_min, y_max = region
|
152 |
-
crop_img = image[y_min:y_max, x_min:x_max, :]
|
153 |
-
image_list.append(
|
154 |
-
(
|
155 |
-
((x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)),
|
156 |
-
crop_img,
|
157 |
-
)
|
158 |
-
)
|
159 |
-
|
160 |
-
for box in box_list:
|
161 |
-
transformed_img = boxed_transform(image, np.array(box, dtype="float32"))
|
162 |
-
image_list.append((box, transformed_img))
|
163 |
-
|
164 |
-
# sort by top left point
|
165 |
-
image_list = sorted(image_list, key=lambda x: (x[0][0][1], x[0][0][0]))
|
166 |
-
|
167 |
-
return image_list
|
168 |
-
|
169 |
-
def preprocess(
|
170 |
-
self,
|
171 |
-
image: Image.Image,
|
172 |
-
max_size: int,
|
173 |
-
expand_ratio: float = 1.0,
|
174 |
-
sharp: float = 1.0,
|
175 |
-
contrast: float = 1.0,
|
176 |
-
) -> torch.Tensor:
|
177 |
-
if sharp != 1:
|
178 |
-
enhancer = ImageEnhance.Sharpness(image)
|
179 |
-
image = enhancer.enhance(sharp)
|
180 |
-
if contrast != 1:
|
181 |
-
enhancer = ImageEnhance.Contrast(image)
|
182 |
-
image = enhancer.enhance(contrast)
|
183 |
-
|
184 |
-
image = np.array(image)
|
185 |
-
|
186 |
-
image, target_ratio = resize_aspect_ratio(
|
187 |
-
image, max_size, interpolation=cv2.INTER_LINEAR, expand_ratio=expand_ratio
|
188 |
-
)
|
189 |
-
inverse_ratio = 1 / target_ratio
|
190 |
-
|
191 |
-
x = np.transpose(normalize_mean_variance(image), (2, 0, 1))
|
192 |
-
|
193 |
-
x = torch.tensor(np.array([x]), device=self.device)
|
194 |
-
|
195 |
-
return x, inverse_ratio
|
196 |
-
|
197 |
-
@torch.no_grad()
|
198 |
-
def forward_net(self, tensor: torch.Tensor) -> torch.Tensor:
|
199 |
-
scores, feature = self.detector(tensor)
|
200 |
-
return scores[0]
|
201 |
-
|
202 |
-
def detect(
|
203 |
-
self,
|
204 |
-
scores: torch.Tensor,
|
205 |
-
text_confidence: float = 0.7,
|
206 |
-
text_threshold: float = 0.4,
|
207 |
-
link_threshold: float = 0.4,
|
208 |
-
) -> list[BoxTuple]:
|
209 |
-
# make score and link map
|
210 |
-
score_text = scores[:, :, 0].cpu().data.numpy()
|
211 |
-
score_link = scores[:, :, 1].cpu().data.numpy()
|
212 |
-
# extract box
|
213 |
-
boxes, _ = extract_boxes(
|
214 |
-
score_text, score_link, text_confidence, text_threshold, link_threshold
|
215 |
-
)
|
216 |
-
return boxes
|
217 |
-
|
218 |
-
def postprocess(
|
219 |
-
self,
|
220 |
-
image: np.ndarray,
|
221 |
-
boxes: list[BoxTuple],
|
222 |
-
inverse_ratio: float,
|
223 |
-
slope_ths: float = 0.1,
|
224 |
-
ratio_ths: float = 0.5,
|
225 |
-
center_ths: float = 0.5,
|
226 |
-
dim_ths: float = 0.5,
|
227 |
-
space_ths: float = 1.0,
|
228 |
-
add_margin: float = 0.1,
|
229 |
-
min_size: int = 0,
|
230 |
-
) -> Tuple[list[RegionTuple], list[BoxTuple]]:
|
231 |
-
|
232 |
-
# coordinate adjustment
|
233 |
-
boxes = adjust_result_coordinates(boxes, inverse_ratio)
|
234 |
-
|
235 |
-
max_y, max_x, _ = image.shape
|
236 |
-
|
237 |
-
# extract region and merge
|
238 |
-
region_list, box_list = extract_regions_from_boxes(boxes, slope_ths)
|
239 |
-
|
240 |
-
region_list = greedy_merge(
|
241 |
-
region_list,
|
242 |
-
ratio_ths=ratio_ths,
|
243 |
-
center_ths=center_ths,
|
244 |
-
dim_ths=dim_ths,
|
245 |
-
space_ths=space_ths,
|
246 |
-
verbose=0
|
247 |
-
)
|
248 |
-
|
249 |
-
# add margin
|
250 |
-
region_list = [
|
251 |
-
region.expand(add_margin, (max_x, max_y)).as_tuple()
|
252 |
-
for region in region_list
|
253 |
-
]
|
254 |
-
|
255 |
-
box_list = [box_expand(box, add_margin, (max_x, max_y)) for box in box_list]
|
256 |
-
|
257 |
-
# filter by size
|
258 |
-
if min_size:
|
259 |
-
if min_size < 1:
|
260 |
-
min_size = int(min(max_y, max_x) * min_size)
|
261 |
-
|
262 |
-
region_list = [
|
263 |
-
i for i in region_list if max(i[1] - i[0], i[3] - i[2]) > min_size
|
264 |
-
]
|
265 |
-
box_list = [
|
266 |
-
i
|
267 |
-
for i in box_list
|
268 |
-
if max(utils.diff([c[0] for c in i]), utils.diff([c[1] for c in i]))
|
269 |
-
> min_size
|
270 |
-
]
|
271 |
-
|
272 |
-
return region_list, box_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/build/lib/easyocrlite/types.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from typing import Tuple
|
2 |
-
|
3 |
-
Point = Tuple[int, int]
|
4 |
-
BoxTuple = Tuple[Point, Point, Point, Point]
|
5 |
-
RegionTuple = Tuple[int, int, int, int]
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite.egg-info/PKG-INFO
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
Metadata-Version: 2.1
|
2 |
-
Name: easyocrlite
|
3 |
-
Version: 0.0.1
|
4 |
-
License: Apache License 2.0
|
5 |
-
Keywords: ocr optical character recognition deep learning neural network
|
6 |
-
Classifier: Development Status :: 5 - Production/Stable
|
7 |
-
Requires-Python: >=3.7
|
8 |
-
License-File: LICENSE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite.egg-info/SOURCES.txt
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
LICENSE
|
2 |
-
README.md
|
3 |
-
setup.py
|
4 |
-
easyocrlite/__init__.py
|
5 |
-
easyocrlite/reader.py
|
6 |
-
easyocrlite/types.py
|
7 |
-
easyocrlite.egg-info/PKG-INFO
|
8 |
-
easyocrlite.egg-info/SOURCES.txt
|
9 |
-
easyocrlite.egg-info/dependency_links.txt
|
10 |
-
easyocrlite.egg-info/requires.txt
|
11 |
-
easyocrlite.egg-info/top_level.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite.egg-info/dependency_links.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
|
|
|
|
ezocr/easyocrlite.egg-info/requires.txt
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
torchvision>=0.5
|
3 |
-
opencv-python-headless<=4.5.4.60
|
4 |
-
numpy
|
5 |
-
Pillow
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite.egg-info/top_level.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
easyocrlite
|
|
|
|
ezocr/easyocrlite/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from easyocrlite.reader import ReaderLite
|
|
|
|
ezocr/easyocrlite/model/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .craft import CRAFT
|
|
|
|
ezocr/easyocrlite/model/craft.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2019-present NAVER Corp.
|
3 |
-
MIT License
|
4 |
-
"""
|
5 |
-
from __future__ import annotations
|
6 |
-
|
7 |
-
from collections import namedtuple
|
8 |
-
from typing import Iterable, Tuple
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
import torch.nn.functional as F
|
13 |
-
import torchvision
|
14 |
-
from packaging import version
|
15 |
-
from torchvision import models
|
16 |
-
|
17 |
-
VGGOutputs = namedtuple(
|
18 |
-
"VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
|
19 |
-
)
|
20 |
-
|
21 |
-
def init_weights(modules: Iterable[nn.Module]):
|
22 |
-
for m in modules:
|
23 |
-
if isinstance(m, nn.Conv2d):
|
24 |
-
nn.init.xavier_uniform_(m.weight)
|
25 |
-
if m.bias is not None:
|
26 |
-
nn.init.zeros_(m.bias)
|
27 |
-
elif isinstance(m, nn.BatchNorm2d):
|
28 |
-
nn.init.constant_(m.weight, 1.0)
|
29 |
-
nn.init.zeros_(m.bias)
|
30 |
-
elif isinstance(m, nn.Linear):
|
31 |
-
nn.init.normal_(m.weight, 0, 0.01)
|
32 |
-
nn.init.zeros_(m.bias)
|
33 |
-
|
34 |
-
|
35 |
-
class VGG16_BN(nn.Module):
|
36 |
-
def __init__(self, pretrained: bool=True, freeze: bool=True):
|
37 |
-
super().__init__()
|
38 |
-
if version.parse(torchvision.__version__) >= version.parse("0.13"):
|
39 |
-
vgg_pretrained_features = models.vgg16_bn(
|
40 |
-
weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
|
41 |
-
).features
|
42 |
-
else: # torchvision.__version__ < 0.13
|
43 |
-
models.vgg.model_urls["vgg16_bn"] = models.vgg.model_urls[
|
44 |
-
"vgg16_bn"
|
45 |
-
].replace("https://", "http://")
|
46 |
-
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
|
47 |
-
|
48 |
-
self.slice1 = torch.nn.Sequential()
|
49 |
-
self.slice2 = torch.nn.Sequential()
|
50 |
-
self.slice3 = torch.nn.Sequential()
|
51 |
-
self.slice4 = torch.nn.Sequential()
|
52 |
-
self.slice5 = torch.nn.Sequential()
|
53 |
-
for x in range(12): # conv2_2
|
54 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
55 |
-
for x in range(12, 19): # conv3_3
|
56 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
57 |
-
for x in range(19, 29): # conv4_3
|
58 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
59 |
-
for x in range(29, 39): # conv5_3
|
60 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
61 |
-
|
62 |
-
# fc6, fc7 without atrous conv
|
63 |
-
self.slice5 = torch.nn.Sequential(
|
64 |
-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
65 |
-
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
66 |
-
nn.Conv2d(1024, 1024, kernel_size=1),
|
67 |
-
)
|
68 |
-
|
69 |
-
if not pretrained:
|
70 |
-
init_weights(self.slice1.modules())
|
71 |
-
init_weights(self.slice2.modules())
|
72 |
-
init_weights(self.slice3.modules())
|
73 |
-
init_weights(self.slice4.modules())
|
74 |
-
|
75 |
-
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
76 |
-
|
77 |
-
if freeze:
|
78 |
-
for param in self.slice1.parameters(): # only first conv
|
79 |
-
param.requires_grad = False
|
80 |
-
|
81 |
-
def forward(self, x: torch.Tensor) -> VGGOutputs:
|
82 |
-
h = self.slice1(x)
|
83 |
-
h_relu2_2 = h
|
84 |
-
h = self.slice2(h)
|
85 |
-
h_relu3_2 = h
|
86 |
-
h = self.slice3(h)
|
87 |
-
h_relu4_3 = h
|
88 |
-
h = self.slice4(h)
|
89 |
-
h_relu5_3 = h
|
90 |
-
h = self.slice5(h)
|
91 |
-
h_fc7 = h
|
92 |
-
|
93 |
-
out = VGGOutputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
|
94 |
-
return out
|
95 |
-
|
96 |
-
|
97 |
-
class DoubleConv(nn.Module):
|
98 |
-
def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
|
99 |
-
super().__init__()
|
100 |
-
self.conv = nn.Sequential(
|
101 |
-
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
102 |
-
nn.BatchNorm2d(mid_ch),
|
103 |
-
nn.ReLU(inplace=True),
|
104 |
-
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
105 |
-
nn.BatchNorm2d(out_ch),
|
106 |
-
nn.ReLU(inplace=True),
|
107 |
-
)
|
108 |
-
|
109 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
110 |
-
x = self.conv(x)
|
111 |
-
return x
|
112 |
-
|
113 |
-
|
114 |
-
class CRAFT(nn.Module):
|
115 |
-
def __init__(self, pretrained: bool=False, freeze: bool=False):
|
116 |
-
super(CRAFT, self).__init__()
|
117 |
-
|
118 |
-
""" Base network """
|
119 |
-
self.basenet = VGG16_BN(pretrained, freeze)
|
120 |
-
|
121 |
-
""" U network """
|
122 |
-
self.upconv1 = DoubleConv(1024, 512, 256)
|
123 |
-
self.upconv2 = DoubleConv(512, 256, 128)
|
124 |
-
self.upconv3 = DoubleConv(256, 128, 64)
|
125 |
-
self.upconv4 = DoubleConv(128, 64, 32)
|
126 |
-
|
127 |
-
num_class = 2
|
128 |
-
self.conv_cls = nn.Sequential(
|
129 |
-
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
130 |
-
nn.ReLU(inplace=True),
|
131 |
-
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
132 |
-
nn.ReLU(inplace=True),
|
133 |
-
nn.Conv2d(32, 16, kernel_size=3, padding=1),
|
134 |
-
nn.ReLU(inplace=True),
|
135 |
-
nn.Conv2d(16, 16, kernel_size=1),
|
136 |
-
nn.ReLU(inplace=True),
|
137 |
-
nn.Conv2d(16, num_class, kernel_size=1),
|
138 |
-
)
|
139 |
-
|
140 |
-
init_weights(self.upconv1.modules())
|
141 |
-
init_weights(self.upconv2.modules())
|
142 |
-
init_weights(self.upconv3.modules())
|
143 |
-
init_weights(self.upconv4.modules())
|
144 |
-
init_weights(self.conv_cls.modules())
|
145 |
-
|
146 |
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
147 |
-
"""Base network"""
|
148 |
-
sources = self.basenet(x)
|
149 |
-
|
150 |
-
""" U network """
|
151 |
-
y = torch.cat([sources[0], sources[1]], dim=1)
|
152 |
-
y = self.upconv1(y)
|
153 |
-
|
154 |
-
y = F.interpolate(
|
155 |
-
y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
|
156 |
-
)
|
157 |
-
y = torch.cat([y, sources[2]], dim=1)
|
158 |
-
y = self.upconv2(y)
|
159 |
-
|
160 |
-
y = F.interpolate(
|
161 |
-
y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
|
162 |
-
)
|
163 |
-
y = torch.cat([y, sources[3]], dim=1)
|
164 |
-
y = self.upconv3(y)
|
165 |
-
|
166 |
-
y = F.interpolate(
|
167 |
-
y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
|
168 |
-
)
|
169 |
-
y = torch.cat([y, sources[4]], dim=1)
|
170 |
-
feature = self.upconv4(y)
|
171 |
-
|
172 |
-
y = self.conv_cls(feature)
|
173 |
-
|
174 |
-
return y.permute(0, 2, 3, 1), feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite/reader.py
DELETED
@@ -1,271 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import Tuple
|
7 |
-
|
8 |
-
import cv2
|
9 |
-
import numpy as np
|
10 |
-
import torch
|
11 |
-
from PIL import Image, ImageEnhance
|
12 |
-
|
13 |
-
from easyocrlite.model import CRAFT
|
14 |
-
|
15 |
-
from easyocrlite.utils.download_utils import prepare_model
|
16 |
-
from easyocrlite.utils.image_utils import (
|
17 |
-
adjust_result_coordinates,
|
18 |
-
boxed_transform,
|
19 |
-
normalize_mean_variance,
|
20 |
-
resize_aspect_ratio,
|
21 |
-
)
|
22 |
-
from easyocrlite.utils.detect_utils import (
|
23 |
-
extract_boxes,
|
24 |
-
extract_regions_from_boxes,
|
25 |
-
box_expand,
|
26 |
-
greedy_merge,
|
27 |
-
)
|
28 |
-
from easyocrlite.types import BoxTuple, RegionTuple
|
29 |
-
import easyocrlite.utils.utils as utils
|
30 |
-
|
31 |
-
logger = logging.getLogger(__name__)
|
32 |
-
|
33 |
-
MODULE_PATH = (
|
34 |
-
os.environ.get("EASYOCR_MODULE_PATH")
|
35 |
-
or os.environ.get("MODULE_PATH")
|
36 |
-
or os.path.expanduser("~/.EasyOCR/")
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
class ReaderLite(object):
|
41 |
-
def __init__(
|
42 |
-
self,
|
43 |
-
gpu=True,
|
44 |
-
model_storage_directory=None,
|
45 |
-
download_enabled=True,
|
46 |
-
verbose=True,
|
47 |
-
quantize=True,
|
48 |
-
cudnn_benchmark=False,
|
49 |
-
):
|
50 |
-
|
51 |
-
self.verbose = verbose
|
52 |
-
|
53 |
-
model_storage_directory = Path(
|
54 |
-
model_storage_directory
|
55 |
-
if model_storage_directory
|
56 |
-
else MODULE_PATH + "/model"
|
57 |
-
)
|
58 |
-
self.detector_path = prepare_model(
|
59 |
-
model_storage_directory, download_enabled, verbose
|
60 |
-
)
|
61 |
-
|
62 |
-
self.quantize = quantize
|
63 |
-
self.cudnn_benchmark = cudnn_benchmark
|
64 |
-
if gpu is False:
|
65 |
-
self.device = "cpu"
|
66 |
-
if verbose:
|
67 |
-
logger.warning(
|
68 |
-
"Using CPU. Note: This module is much faster with a GPU."
|
69 |
-
)
|
70 |
-
elif not torch.cuda.is_available():
|
71 |
-
self.device = "cpu"
|
72 |
-
if verbose:
|
73 |
-
logger.warning(
|
74 |
-
"CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU."
|
75 |
-
)
|
76 |
-
elif gpu is True:
|
77 |
-
self.device = "cuda"
|
78 |
-
else:
|
79 |
-
self.device = gpu
|
80 |
-
|
81 |
-
self.detector = CRAFT()
|
82 |
-
|
83 |
-
state_dict = torch.load(self.detector_path, map_location=self.device)
|
84 |
-
if list(state_dict.keys())[0].startswith("module"):
|
85 |
-
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
86 |
-
|
87 |
-
self.detector.load_state_dict(state_dict)
|
88 |
-
|
89 |
-
if self.device == "cpu":
|
90 |
-
if self.quantize:
|
91 |
-
try:
|
92 |
-
torch.quantization.quantize_dynamic(
|
93 |
-
self.detector, dtype=torch.qint8, inplace=True
|
94 |
-
)
|
95 |
-
except:
|
96 |
-
pass
|
97 |
-
else:
|
98 |
-
self.detector = torch.nn.DataParallel(self.detector).to(self.device)
|
99 |
-
import torch.backends.cudnn as cudnn
|
100 |
-
|
101 |
-
cudnn.benchmark = self.cudnn_benchmark
|
102 |
-
|
103 |
-
self.detector.eval()
|
104 |
-
|
105 |
-
def process(
|
106 |
-
self,
|
107 |
-
image_path: str,
|
108 |
-
max_size: int = 960,
|
109 |
-
expand_ratio: float = 1.0,
|
110 |
-
sharp: float = 1.0,
|
111 |
-
contrast: float = 1.0,
|
112 |
-
text_confidence: float = 0.7,
|
113 |
-
text_threshold: float = 0.4,
|
114 |
-
link_threshold: float = 0.4,
|
115 |
-
slope_ths: float = 0.1,
|
116 |
-
ratio_ths: float = 0.5,
|
117 |
-
center_ths: float = 0.5,
|
118 |
-
dim_ths: float = 0.5,
|
119 |
-
space_ths: float = 1.0,
|
120 |
-
add_margin: float = 0.1,
|
121 |
-
min_size: float = 0.01,
|
122 |
-
) -> Tuple[BoxTuple, list[np.ndarray]]:
|
123 |
-
|
124 |
-
image = Image.open(image_path).convert('RGB')
|
125 |
-
tensor, inverse_ratio = self.preprocess(
|
126 |
-
image, max_size, expand_ratio, sharp, contrast
|
127 |
-
)
|
128 |
-
|
129 |
-
scores = self.forward_net(tensor)
|
130 |
-
|
131 |
-
boxes = self.detect(scores, text_confidence, text_threshold, link_threshold)
|
132 |
-
|
133 |
-
image = np.array(image)
|
134 |
-
region_list, box_list = self.postprocess(
|
135 |
-
image,
|
136 |
-
boxes,
|
137 |
-
inverse_ratio,
|
138 |
-
slope_ths,
|
139 |
-
ratio_ths,
|
140 |
-
center_ths,
|
141 |
-
dim_ths,
|
142 |
-
space_ths,
|
143 |
-
add_margin,
|
144 |
-
min_size,
|
145 |
-
)
|
146 |
-
|
147 |
-
# get cropped image
|
148 |
-
image_list = []
|
149 |
-
for region in region_list:
|
150 |
-
x_min, x_max, y_min, y_max = region
|
151 |
-
crop_img = image[y_min:y_max, x_min:x_max, :]
|
152 |
-
image_list.append(
|
153 |
-
(
|
154 |
-
((x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)),
|
155 |
-
crop_img,
|
156 |
-
)
|
157 |
-
)
|
158 |
-
|
159 |
-
for box in box_list:
|
160 |
-
transformed_img = boxed_transform(image, np.array(box, dtype="float32"))
|
161 |
-
image_list.append((box, transformed_img))
|
162 |
-
|
163 |
-
# sort by top left point
|
164 |
-
image_list = sorted(image_list, key=lambda x: (x[0][0][1], x[0][0][0]))
|
165 |
-
|
166 |
-
return image_list
|
167 |
-
|
168 |
-
def preprocess(
|
169 |
-
self,
|
170 |
-
image: Image.Image,
|
171 |
-
max_size: int,
|
172 |
-
expand_ratio: float = 1.0,
|
173 |
-
sharp: float = 1.0,
|
174 |
-
contrast: float = 1.0,
|
175 |
-
) -> torch.Tensor:
|
176 |
-
if sharp != 1:
|
177 |
-
enhancer = ImageEnhance.Sharpness(image)
|
178 |
-
image = enhancer.enhance(sharp)
|
179 |
-
if contrast != 1:
|
180 |
-
enhancer = ImageEnhance.Contrast(image)
|
181 |
-
image = enhancer.enhance(contrast)
|
182 |
-
|
183 |
-
image = np.array(image)
|
184 |
-
|
185 |
-
image, target_ratio = resize_aspect_ratio(
|
186 |
-
image, max_size, interpolation=cv2.INTER_LINEAR, expand_ratio=expand_ratio
|
187 |
-
)
|
188 |
-
inverse_ratio = 1 / target_ratio
|
189 |
-
|
190 |
-
x = np.transpose(normalize_mean_variance(image), (2, 0, 1))
|
191 |
-
|
192 |
-
x = torch.tensor(np.array([x]), device=self.device)
|
193 |
-
|
194 |
-
return x, inverse_ratio
|
195 |
-
|
196 |
-
@torch.no_grad()
|
197 |
-
def forward_net(self, tensor: torch.Tensor) -> torch.Tensor:
|
198 |
-
scores, feature = self.detector(tensor)
|
199 |
-
return scores[0]
|
200 |
-
|
201 |
-
def detect(
|
202 |
-
self,
|
203 |
-
scores: torch.Tensor,
|
204 |
-
text_confidence: float = 0.7,
|
205 |
-
text_threshold: float = 0.4,
|
206 |
-
link_threshold: float = 0.4,
|
207 |
-
) -> list[BoxTuple]:
|
208 |
-
# make score and link map
|
209 |
-
score_text = scores[:, :, 0].cpu().data.numpy()
|
210 |
-
score_link = scores[:, :, 1].cpu().data.numpy()
|
211 |
-
# extract box
|
212 |
-
boxes, _ = extract_boxes(
|
213 |
-
score_text, score_link, text_confidence, text_threshold, link_threshold
|
214 |
-
)
|
215 |
-
return boxes
|
216 |
-
|
217 |
-
def postprocess(
|
218 |
-
self,
|
219 |
-
image: np.ndarray,
|
220 |
-
boxes: list[BoxTuple],
|
221 |
-
inverse_ratio: float,
|
222 |
-
slope_ths: float = 0.1,
|
223 |
-
ratio_ths: float = 0.5,
|
224 |
-
center_ths: float = 0.5,
|
225 |
-
dim_ths: float = 0.5,
|
226 |
-
space_ths: float = 1.0,
|
227 |
-
add_margin: float = 0.1,
|
228 |
-
min_size: int = 0,
|
229 |
-
) -> Tuple[list[RegionTuple], list[BoxTuple]]:
|
230 |
-
|
231 |
-
# coordinate adjustment
|
232 |
-
boxes = adjust_result_coordinates(boxes, inverse_ratio)
|
233 |
-
|
234 |
-
max_y, max_x, _ = image.shape
|
235 |
-
|
236 |
-
# extract region and merge
|
237 |
-
region_list, box_list = extract_regions_from_boxes(boxes, slope_ths)
|
238 |
-
|
239 |
-
region_list = greedy_merge(
|
240 |
-
region_list,
|
241 |
-
ratio_ths=ratio_ths,
|
242 |
-
center_ths=center_ths,
|
243 |
-
dim_ths=dim_ths,
|
244 |
-
space_ths=space_ths,
|
245 |
-
verbose=0
|
246 |
-
)
|
247 |
-
|
248 |
-
# add margin
|
249 |
-
region_list = [
|
250 |
-
region.expand(add_margin, (max_x, max_y)).as_tuple()
|
251 |
-
for region in region_list
|
252 |
-
]
|
253 |
-
|
254 |
-
box_list = [box_expand(box, add_margin, (max_x, max_y)) for box in box_list]
|
255 |
-
|
256 |
-
# filter by size
|
257 |
-
if min_size:
|
258 |
-
if min_size < 1:
|
259 |
-
min_size = int(min(max_y, max_x) * min_size)
|
260 |
-
|
261 |
-
region_list = [
|
262 |
-
i for i in region_list if max(i[1] - i[0], i[3] - i[2]) > min_size
|
263 |
-
]
|
264 |
-
box_list = [
|
265 |
-
i
|
266 |
-
for i in box_list
|
267 |
-
if max(utils.diff([c[0] for c in i]), utils.diff([c[1] for c in i]))
|
268 |
-
> min_size
|
269 |
-
]
|
270 |
-
|
271 |
-
return region_list, box_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ezocr/easyocrlite/types.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from typing import Tuple
|
2 |
-
|
3 |
-
Point = Tuple[int, int]
|
4 |
-
BoxTuple = Tuple[Point, Point, Point, Point]
|
5 |
-
RegionTuple = Tuple[int, int, int, int]
|
|
|
|
|
|
|
|
|
|
|
|