JustinLin610 commited on
Commit
085ecd3
1 Parent(s): 38764eb

change the service

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +0 -201
  2. README.md +0 -27
  3. README_EncouragingLoss.md +0 -34
  4. app.py +1 -2
  5. benpao.jpeg +0 -0
  6. checkpoints.md +0 -37
  7. checkpoints_cn.md +0 -82
  8. colab.md +0 -9
  9. criterions/__init__.py +0 -4
  10. criterions/clip_scst_loss.py +0 -277
  11. criterions/label_smoothed_cross_entropy.py +0 -343
  12. criterions/label_smoothed_encouraging_loss.py +0 -395
  13. criterions/scst_loss.py +0 -281
  14. data/__init__.py +0 -0
  15. data/cv_data/image_classify_dataset.py +0 -196
  16. data/data_utils.py +0 -601
  17. data/file_dataset.py +0 -107
  18. data/mm_data/__init__.py +0 -0
  19. data/mm_data/caption_dataset.py +0 -160
  20. data/mm_data/ocr_dataset.py +0 -210
  21. data/mm_data/refcoco_dataset.py +0 -174
  22. data/mm_data/snli_ve_dataset.py +0 -203
  23. data/mm_data/vqa_gen_dataset.py +0 -218
  24. data/nlg_data/summary_dataset.py +0 -131
  25. data/nlu_data/cola_dataset.py +0 -138
  26. data/nlu_data/mnli_dataset.py +0 -143
  27. data/nlu_data/mrpc_dataset.py +0 -141
  28. data/nlu_data/qnli_dataset.py +0 -141
  29. data/nlu_data/qqp_dataset.py +0 -141
  30. data/nlu_data/rte_dataset.py +0 -141
  31. data/nlu_data/sst2_dataset.py +0 -138
  32. data/ofa_dataset.py +0 -79
  33. data/pretrain_data/unify_dataset.py +0 -636
  34. datasets.md +0 -44
  35. evaluate.py +0 -160
  36. ezocr/LICENSE +0 -201
  37. ezocr/README.md +0 -49
  38. ezocr/build/lib/easyocrlite/__init__.py +0 -1
  39. ezocr/build/lib/easyocrlite/reader.py +0 -272
  40. ezocr/build/lib/easyocrlite/types.py +0 -5
  41. ezocr/easyocrlite.egg-info/PKG-INFO +0 -8
  42. ezocr/easyocrlite.egg-info/SOURCES.txt +0 -11
  43. ezocr/easyocrlite.egg-info/dependency_links.txt +0 -1
  44. ezocr/easyocrlite.egg-info/requires.txt +0 -5
  45. ezocr/easyocrlite.egg-info/top_level.txt +0 -1
  46. ezocr/easyocrlite/__init__.py +0 -1
  47. ezocr/easyocrlite/model/__init__.py +0 -1
  48. ezocr/easyocrlite/model/craft.py +0 -174
  49. ezocr/easyocrlite/reader.py +0 -271
  50. ezocr/easyocrlite/types.py +0 -5
LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright 1999-2022 Alibaba Group Holding Ltd.
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,27 +0,0 @@
1
- ---
2
- title: Chinese OCR
3
- emoji: 📖
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
- # Configuration
12
- `title`: _string_
13
- OFA Image Caption
14
- `emoji`: _string_
15
- 🖼
16
- `colorFrom`: _string_
17
- red
18
- `colorTo`: _string_
19
- indigo
20
- `sdk`: _string_
21
- gradio
22
- `app_file`: _string_
23
- app.py
24
-
25
-
26
- `pinned`: _boolean_
27
- false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README_EncouragingLoss.md DELETED
@@ -1,34 +0,0 @@
1
- # Finetuning with Encouraging Loss (EL)
2
- Below we provide methods for finetuning with label smoothed encouraging loss proposed in [_Well-classified Examples are Underestimated in Classification with Deep Neural Networks_](https://arxiv.org/pdf/2110.06537.pdf) on different downstream tasks.
3
- The implementation is in [label_smoothed_encouraging_loss.py](criterions/label_smoothed_encouraging_loss.py).
4
- You can set the `--criterion` to `adjust_label_smoothed_encouraging_loss` to use it. This criterion has a hyper-parameter `--log-end`.
5
- `--log-end < 1` results in a approximated and conservative version of the full encouraging loss.
6
- A high log_end will more strongly weaken the gradient vanishing, enhance the modeling of the data, and increase the growth rate of the margin, but it will also bring a larger gradient norm, which will bring challenges to the existing optimization system.
7
- We recommend higher log_end for cases with higher performance, and 0.75 or 0.5 as your first try.
8
- ## Image Captioning
9
- We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
10
-
11
- <details>
12
- <summary><b>Finetuning</b></summary>
13
- <p>
14
- We propose two scripts for stage1. </b>
15
- </p>
16
- <pre>
17
- cd run_scripts/caption
18
- nohup sh train_caption_stage1_el.sh > train_stage1_el.out & # stage 1, train with encouraging loss, expected cider 1.403
19
- nohup sh train_caption_stage1_el_db.sh > train_stage1_el.out & # stage 1, train with encouraging loss, and drop best examples, expected cider 1.404
20
- </pre>
21
- </details>
22
-
23
- ## Referring Expression Comprehension
24
- We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
25
- <details>
26
- <summary><b>Finetuning</b></summary>
27
- <pre>
28
- cd run_scripts/refcoco
29
- nohup sh train_refcoco_el.sh > train_refcoco_el.out & # finetune for refcoco
30
- nohup sh train_refcocoplus_el.sh > train_refcocoplus_el.out & # finetune for refcoco+
31
- nohup sh train_refcocog_el.sh > train_refcocog_el.out & # finetune for refcocog
32
- </pre>
33
- </details>
34
- Evaluation is also the same as the default setting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -193,8 +193,7 @@ description = "Gradio Demo for Chinese OCR based on OFA-Base. "\
193
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
194
  "Repo</a></p> "
195
  examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
196
- ['qiaodaima.png'], ['benpao.jpeg'],
197
- ['xsd.jpg', 'General']]
198
  io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
199
  outputs=[gr.outputs.Image(type='pil', label='Image'),
200
  gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
 
193
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
194
  "Repo</a></p> "
195
  examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
196
+ ['qiaodaima.png'], ['xsd.jpg']]
 
197
  io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
198
  outputs=[gr.outputs.Image(type='pil', label='Image'),
199
  gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
benpao.jpeg DELETED
Binary file (6.38 kB)
 
checkpoints.md DELETED
@@ -1,37 +0,0 @@
1
- # Checkpoints
2
-
3
- We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.
4
-
5
- ## Pretraining
6
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
7
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a> (~470M parameters)
8
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt"> Pre-trained checkpoint (OFA-Base) </a> (~180M parameters)
9
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt"> Pre-trained checkpoint (OFA-Medium) </a> (~93M parameters)
10
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt"> Pre-trained checkpoint (OFA-Tiny) </a> (~33M parameters)
11
-
12
- ## Finetuning (OFA-Huge)
13
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_huge_best.pt"> Finetuned checkpoint for Caption on COCO </a>
14
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_huge_best.pt"> Finetuned checkpoint for VQAv2 </a>
15
-
16
- ## Finetuning (OFA-Large)
17
-
18
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
19
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_stage1_best.pt"> Finetuned checkpoint for Caption on COCO During Stage1 Finetuning </a>
20
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
21
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
22
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
23
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_large_best.pt"> Finetuned checkpoint for VQAv2 </a>
24
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_large_best.pt"> Finetuned checkpoint for SNLI-VE </a>
25
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_large_best.zip"> Finetuned checkpoint for Text-to-Image Generation on COCO && CLIP checkpoint && VQGAN checkpoint </a>
26
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/imagenet_1k_large_best.pt"> Finetuned checkpoint for ImageNet-1K </a>
27
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/gigaword_large_best.pt"> Finetuned checkpoint for Gigaword </a>
28
-
29
-
30
- ## Finetuning (OFA-Base)
31
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_base_best.pt"> Finetuned base checkpoint for Caption on COCO </a>
32
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_base_best.pt"> Finetuned base checkpoint for RefCOCO </a>
33
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_base_best.pt"> Finetuned base checkpoint for RefCOCO+ </a>
34
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_base_best.pt"> Finetuned base checkpoint for RefCOCOg </a>
35
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_base_best.pt"> Finetuned base checkpoint for VQAv2 </a>
36
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_base_best.pt"> Finetuned base checkpoint for SNLI-VE </a>
37
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_base_best.pt"> Finetuned base checkpoint for Text-to-Image Generation on COCO </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints_cn.md DELETED
@@ -1,82 +0,0 @@
1
- # Checkpoints (OFA-CN)
2
-
3
- We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
4
- <br>
5
-
6
- ## Checkpoints
7
- Below we provide the links for downloading the Chinese OFA checkpoints.
8
-
9
- ### Pretraining
10
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
11
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)
12
-
13
- ### Finetuning (OFA-Large)
14
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
15
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
16
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
17
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
18
-
19
- ### Finetuning (OFA-Base)
20
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
21
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
22
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
23
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
24
- <br>
25
-
26
- ## Model Card
27
- Below we provide the basic information of the base-size and large-size OFA-CN.
28
-
29
- <table border="1" width="100%">
30
- <tr align="center">
31
- <th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
32
- </tr>
33
- <tr align="center">
34
- <td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
35
- </tr>
36
- <tr align="center">
37
- <td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
38
- </tr>
39
- </tr>
40
- </table>
41
- <br>
42
-
43
- ## Results
44
- Below we provide the results of OFA-CN and the baselines for comparison.
45
-
46
- ### [MUGE Caption]("https://tianchi.aliyun.com/muge")
47
- <table border="1" width="100%">
48
- <tr align="center">
49
- <td>Model</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
50
- </tr>
51
- <tr align="center">
52
- <td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
53
- </tr>
54
- <tr align="center">
55
- <td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
56
- </tr>
57
- <tr align="center">
58
- <td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
59
- </tr>
60
- <tr align="center">
61
- <td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
62
- </tr>
63
- </table>
64
-
65
- ### RefCOCO-CN Series
66
- <table border="1" width="100%">
67
- <tr align="center">
68
- <td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
69
- </tr>
70
- <tr align="center">
71
- <td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
72
- </tr>
73
- <tr align="center">
74
- <td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
75
- </tr>
76
- <tr align="center">
77
- <td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
78
- </tr>
79
- </table>
80
- <br>
81
-
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab.md DELETED
@@ -1,9 +0,0 @@
1
- # Colab Notebooks
2
-
3
- We provide Colab notebooks of different downstream tasks for you guys to enjoy OFA. See below.
4
-
5
- * [Image Captioning in Huggingface Transformers](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
6
- * [Generic Interface](https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m?usp=sharing#scrollTo=s9Vni6YUZOpC) (using different instructions to perform various tasks with just one model.)
7
- * [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
8
- * [Referring Expression Comprehension](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
9
- * [Open-Domain Visual Question Answering](https://colab.research.google.com/drive/1lsMsF-Vum3MVyXwSVF5E-Y23rHFvj_3y?usp=sharing)
 
 
 
 
 
 
 
 
 
 
criterions/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .scst_loss import ScstRewardCriterion
2
- from .label_smoothed_cross_entropy import AdjustLabelSmoothedCrossEntropyCriterion
3
- from .clip_scst_loss import ClipScstRewardCriterion
4
- from .label_smoothed_encouraging_loss import AdjustLabelSmoothedEncouragingLossCriterion
 
 
 
 
 
criterions/clip_scst_loss.py DELETED
@@ -1,277 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import math
7
- from dataclasses import dataclass, field
8
- from typing import Optional
9
- from PIL import Image
10
- from torchvision import transforms
11
-
12
- import torch
13
- import numpy as np
14
- from fairseq import metrics
15
- from fairseq.data import data_utils
16
- from fairseq.criterions import FairseqCriterion, register_criterion
17
- from fairseq.dataclass import FairseqDataclass
18
- from fairseq import utils
19
- from omegaconf import II
20
-
21
- from models import clip
22
-
23
-
24
- def custom_to_pil(x):
25
- x = x.detach().cpu()
26
- x = torch.clamp(x, -1., 1.)
27
- x = (x + 1.) / 2.
28
- x = x.permute(1, 2, 0).numpy()
29
- x = (255 * x).astype(np.uint8)
30
- x = Image.fromarray(x)
31
- if not x.mode == "RGB":
32
- x = x.convert("RGB")
33
- return x
34
-
35
-
36
- def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
37
- loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
38
- if ignore_index is not None:
39
- pad_mask = target.eq(ignore_index)
40
- loss.masked_fill_(pad_mask, 0.0)
41
- ntokens = (~pad_mask).sum()
42
- else:
43
- loss = loss.squeeze(-1)
44
- ntokens = target.numel()
45
- if reduce:
46
- loss = loss.sum()
47
- return loss, ntokens
48
-
49
-
50
- @dataclass
51
- class ClipScstRewardCriterionConfig(FairseqDataclass):
52
- ignore_prefix_size: int = field(
53
- default=0,
54
- metadata={"help": "Ignore first N tokens"},
55
- )
56
- sentence_avg: bool = II("optimization.sentence_avg")
57
- constraint_range: Optional[str] = field(
58
- default=None,
59
- metadata={"help": "constraint range"}
60
- )
61
-
62
-
63
- @register_criterion(
64
- "clip_scst_reward_criterion", dataclass=ClipScstRewardCriterionConfig
65
- )
66
- class ClipScstRewardCriterion(FairseqCriterion):
67
- CLIP_REWARD_WEIGHT = 2.5
68
-
69
- def __init__(
70
- self,
71
- task,
72
- sentence_avg,
73
- ignore_prefix_size=0,
74
- constraint_range=None
75
- ):
76
- super().__init__(task)
77
- self.sentence_avg = sentence_avg
78
- self.ignore_prefix_size = ignore_prefix_size
79
-
80
- self.constraint_start = None
81
- self.constraint_end = None
82
- if constraint_range is not None:
83
- constraint_start, constraint_end = constraint_range.split(',')
84
- self.constraint_start = int(constraint_start)
85
- self.constraint_end = int(constraint_end)
86
-
87
- def forward(self, model, sample, update_num=0, reduce=True):
88
- """Compute the loss for the given sample.
89
-
90
- Returns a tuple with three elements:
91
- 1) the loss
92
- 2) the sample size, which is used as the denominator for the gradient
93
- 3) logging outputs to display while training
94
- """
95
- loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
96
-
97
- sample_size = (
98
- nsentences if self.sentence_avg else ntokens
99
- )
100
- logging_output = {
101
- "loss": loss.data,
102
- "score": score,
103
- "ntokens": ntokens,
104
- "nsentences": nsentences,
105
- "sample_size": sample_size,
106
- }
107
- return loss, sample_size, logging_output
108
-
109
- def _calculate_clip_scores(self, gen_res, gt_text, device):
110
- '''
111
- gen_res: generated images, list of Image
112
- gt_text: input captions.
113
- device: device for clip model
114
- '''
115
- batch_size = len(gt_text)
116
- gen_res_size = len(gen_res)
117
- img_per_seq = gen_res_size // batch_size
118
-
119
- hyp_images = torch.stack(
120
- [self.task.clip_preprocess(gen_image) for gen_image in gen_res], dim=0
121
- ).to(device)
122
-
123
- clip_input = clip.tokenize([text for text in gt_text]).to(device)
124
- with torch.no_grad():
125
- image_features = self.task.clip_model.encode_image(hyp_images)
126
- text_features = self.task.clip_model.encode_text(clip_input)
127
- image_features /= image_features.norm(dim=-1, keepdim=True)
128
- text_features /= text_features.norm(dim=-1, keepdim=True)
129
- image_features = image_features.view(batch_size, img_per_seq, -1)
130
- text_features = text_features.view(batch_size, 1, -1)
131
- ti_similarity = image_features @ text_features.transpose(1, 2)
132
- ti_similarity = ti_similarity.view(-1)
133
-
134
- scores = self.CLIP_REWARD_WEIGHT * ti_similarity
135
- return scores
136
-
137
- def get_generator_out(self, model, sample):
138
- model.eval()
139
- with torch.no_grad():
140
- self.task.scst_generator.model.eval()
141
- gen_out = self.task.scst_generator.generate([model], sample)
142
-
143
- gen_target = []
144
- gen_res = []
145
- gt_text = []
146
- for i in range(len(gen_out)):
147
- with torch.no_grad():
148
- tokens = torch.stack([item['tokens'][:-1] for item in gen_out[i]], dim=0)
149
- tokens += -len(self.task.src_dict) + self.task.cfg.code_dict_size + self.task.cfg.num_bins
150
- images = self.task.image_tokenizer.decode_code(
151
- tokens.view(-1, self.task.cfg.code_image_size // 8, self.task.cfg.code_image_size // 8)
152
- )
153
- images = [custom_to_pil(image) for image in images]
154
-
155
- gen_target += [item['tokens'] for item in gen_out[i]]
156
- gen_res += images
157
- gt_text.append(
158
- self.task.bpe.decode(
159
- self.task.tgt_dict.string(
160
- utils.strip_pad(sample['net_input']['src_tokens'][i], self.padding_idx).cpu().int()
161
- )
162
- )[38:] # remove task instruction.
163
- )
164
-
165
- return gen_target, gen_res, gt_text
166
-
167
- def get_reward_and_scores(self, gen_res, gt_text, device):
168
- batch_size = len(gt_text)
169
- gen_res_size = len(gen_res)
170
- img_per_sample = gen_res_size // batch_size
171
-
172
- scores = self._calculate_clip_scores(gen_res, gt_text, device)
173
- sc_ = scores.reshape(batch_size, img_per_sample)
174
- baseline = (sc_.sum(1, keepdim=True) - sc_) / (sc_.shape[1] - 1)
175
- # sample - baseline
176
- reward = scores.reshape(batch_size, img_per_sample)
177
- reward = reward - baseline
178
- reward = reward.view(-1)
179
-
180
- return reward, scores
181
-
182
- def get_net_output(self, model, sample, gen_target):
183
- def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
184
- return data_utils.collate_tokens(
185
- sample_list,
186
- pad_idx=self.padding_idx,
187
- eos_idx=eos,
188
- left_pad=False,
189
- move_eos_to_beginning=move_eos_to_beginning,
190
- )
191
-
192
- batch_size = len(sample["target"])
193
- gen_target_size = len(gen_target)
194
- img_per_sample = gen_target_size // batch_size
195
-
196
- model.train()
197
- sample_src_tokens = torch.repeat_interleave(
198
- sample['net_input']['src_tokens'], img_per_sample, dim=0
199
- )
200
- sample_src_lengths = torch.repeat_interleave(
201
- sample['net_input']['src_lengths'], img_per_sample, dim=0
202
- )
203
- sample_code_masks = torch.repeat_interleave(
204
- sample['net_input']['code_masks'], img_per_sample, dim=0
205
- )
206
- gen_prev_output_tokens = torch.as_tensor(
207
- merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
208
- device=sample["target"].device, dtype=torch.int64
209
- )
210
- gen_target_tokens = torch.as_tensor(
211
- merge(gen_target), device=sample["target"].device, dtype=torch.int64
212
- )
213
- net_output = model(
214
- src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
215
- code_masks=sample_code_masks, prev_output_tokens=gen_prev_output_tokens
216
- )
217
-
218
- return net_output, gen_target_tokens
219
-
220
- def get_lprobs_and_target(self, model, net_output, gen_target):
221
- if self.constraint_start is not None and self.constraint_end is not None:
222
- net_output[0][:, :, 4:self.constraint_start] = -math.inf
223
- net_output[0][:, :, self.constraint_end:] = -math.inf
224
- lprobs = model.get_normalized_probs(net_output, log_probs=True)
225
- if self.ignore_prefix_size > 0:
226
- if getattr(lprobs, "batch_first", False):
227
- lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
228
- gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
229
- else:
230
- lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
231
- gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
232
- return lprobs, gen_target
233
-
234
- def compute_loss(self, model, sample, reduce=True):
235
- gen_target, gen_res, gt_text = self.get_generator_out(model, sample)
236
- reward, scores = self.get_reward_and_scores(gen_res, gt_text, device=sample["target"].device)
237
- net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
238
- gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
239
- loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
240
- nsentences = gen_target_tokens.size(0)
241
-
242
- return loss, scores.sum(), ntokens, nsentences
243
-
244
- @classmethod
245
- def reduce_metrics(cls, logging_outputs) -> None:
246
- """Aggregate logging outputs from data parallel training."""
247
- loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
248
- score_sum = sum(log.get("score", 0) for log in logging_outputs)
249
- ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
250
- nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
251
- sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
252
-
253
- metrics.log_scalar(
254
- "loss", loss_sum / sample_size, sample_size, round=3
255
- )
256
- metrics.log_scalar(
257
- "score", score_sum / nsentences, nsentences, round=3
258
- )
259
-
260
- metrics.log_scalar(
261
- "ntokens", ntokens, 1, round=3
262
- )
263
- metrics.log_scalar(
264
- "nsentences", nsentences, 1, round=3
265
- )
266
- metrics.log_scalar(
267
- "sample_size", sample_size, 1, round=3
268
- )
269
-
270
- @staticmethod
271
- def logging_outputs_can_be_summed() -> bool:
272
- """
273
- Whether the logging outputs returned by `forward` can be summed
274
- across workers prior to calling `reduce_metrics`. Setting this
275
- to True will improves distributed training speed.
276
- """
277
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
criterions/label_smoothed_cross_entropy.py DELETED
@@ -1,343 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import math
7
- from dataclasses import dataclass, field
8
- from typing import Optional
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import numpy as np
13
- from fairseq import metrics, utils
14
- from fairseq.criterions import FairseqCriterion, register_criterion
15
- from fairseq.dataclass import FairseqDataclass
16
- from omegaconf import II
17
-
18
-
19
- @dataclass
20
- class AdjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
21
- label_smoothing: float = field(
22
- default=0.0,
23
- metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
24
- )
25
- report_accuracy: bool = field(
26
- default=False,
27
- metadata={"help": "report accuracy metric"},
28
- )
29
- ignore_prefix_size: int = field(
30
- default=0,
31
- metadata={"help": "Ignore first N tokens"},
32
- )
33
- ignore_eos: bool = field(
34
- default=False,
35
- metadata={"help": "Ignore eos token"},
36
- )
37
- sentence_avg: bool = II("optimization.sentence_avg")
38
- drop_worst_ratio: float = field(
39
- default=0.0,
40
- metadata={"help": "ratio for discarding bad samples"},
41
- )
42
- drop_worst_after: int = field(
43
- default=0,
44
- metadata={"help": "steps for discarding bad samples"},
45
- )
46
- use_rdrop: bool = field(
47
- default=False, metadata={"help": "use R-Drop"}
48
- )
49
- reg_alpha: float = field(
50
- default=1.0, metadata={"help": "weight for R-Drop"}
51
- )
52
- sample_patch_num: int = field(
53
- default=196, metadata={"help": "sample patches for v1"}
54
- )
55
- constraint_range: Optional[str] = field(
56
- default=None,
57
- metadata={"help": "constraint range"}
58
- )
59
-
60
-
61
- def construct_rdrop_sample(x):
62
- if isinstance(x, dict):
63
- for key in x:
64
- x[key] = construct_rdrop_sample(x[key])
65
- return x
66
- elif isinstance(x, torch.Tensor):
67
- return x.repeat(2, *([1] * (x.dim()-1)))
68
- elif isinstance(x, int):
69
- return x * 2
70
- elif isinstance(x, np.ndarray):
71
- return x.repeat(2)
72
- else:
73
- raise NotImplementedError
74
-
75
-
76
- def kl_loss(p, q):
77
- p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
78
- q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
79
- loss = (p_loss + q_loss) / 2
80
- return loss
81
-
82
-
83
- def label_smoothed_nll_loss(
84
- lprobs, target, epsilon, update_num, reduce=True,
85
- drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
86
- constraint_masks=None, constraint_start=None, constraint_end=None
87
- ):
88
- if target.dim() == lprobs.dim() - 1:
89
- target = target.unsqueeze(-1)
90
- nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
91
- if constraint_masks is not None:
92
- smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
93
- eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
94
- elif constraint_start is not None and constraint_end is not None:
95
- constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
96
- smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
97
- eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
98
- else:
99
- smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
100
- eps_i = epsilon / (lprobs.size(-1) - 1)
101
- loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
102
- if drop_worst_ratio > 0 and update_num > drop_worst_after:
103
- if use_rdrop:
104
- true_batch_size = loss.size(0) // 2
105
- _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
106
- loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
107
- nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
108
- lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
109
- else:
110
- loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
111
- nll_loss = nll_loss[indices]
112
- lprobs = lprobs[indices]
113
-
114
- ntokens = loss.numel()
115
- nll_loss = nll_loss.sum()
116
- loss = loss.sum()
117
- if use_rdrop:
118
- true_batch_size = lprobs.size(0) // 2
119
- p = lprobs[:true_batch_size]
120
- q = lprobs[true_batch_size:]
121
- if constraint_start is not None and constraint_end is not None:
122
- constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
123
- p = p[:, constraint_range]
124
- q = q[:, constraint_range]
125
- loss += kl_loss(p, q) * reg_alpha
126
-
127
- return loss, nll_loss, ntokens
128
-
129
-
130
- @register_criterion(
131
- "adjust_label_smoothed_cross_entropy", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
132
- )
133
- class AdjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
134
- def __init__(
135
- self,
136
- task,
137
- sentence_avg,
138
- label_smoothing,
139
- ignore_prefix_size=0,
140
- ignore_eos=False,
141
- report_accuracy=False,
142
- drop_worst_ratio=0,
143
- drop_worst_after=0,
144
- use_rdrop=False,
145
- reg_alpha=1.0,
146
- sample_patch_num=196,
147
- constraint_range=None
148
- ):
149
- super().__init__(task)
150
- self.sentence_avg = sentence_avg
151
- self.eps = label_smoothing
152
- self.ignore_prefix_size = ignore_prefix_size
153
- self.ignore_eos = ignore_eos
154
- self.report_accuracy = report_accuracy
155
- self.drop_worst_ratio = drop_worst_ratio
156
- self.drop_worst_after = drop_worst_after
157
- self.use_rdrop = use_rdrop
158
- self.reg_alpha = reg_alpha
159
- self.sample_patch_num = sample_patch_num
160
-
161
- self.constraint_start = None
162
- self.constraint_end = None
163
- if constraint_range is not None:
164
- constraint_start, constraint_end = constraint_range.split(',')
165
- self.constraint_start = int(constraint_start)
166
- self.constraint_end = int(constraint_end)
167
-
168
- def forward(self, model, sample, update_num=0, reduce=True):
169
- """Compute the loss for the given sample.
170
-
171
- Returns a tuple with three elements:
172
- 1) the loss
173
- 2) the sample size, which is used as the denominator for the gradient
174
- 3) logging outputs to display while training
175
- """
176
- if isinstance(sample, list):
177
- if self.sample_patch_num > 0:
178
- sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
179
- loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
180
- loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
181
- loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
182
- sample_size = 1
183
- logging_output = {
184
- "loss": loss.data,
185
- "loss_v1": loss_v1.data,
186
- "loss_v2": loss_v2.data,
187
- "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
188
- "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
189
- "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
190
- "sample_size": 1,
191
- "sample_size_v1": sample_size_v1,
192
- "sample_size_v2": sample_size_v2,
193
- }
194
- return loss, sample_size, logging_output
195
-
196
- if self.use_rdrop:
197
- construct_rdrop_sample(sample)
198
-
199
- net_output = model(**sample["net_input"])
200
- loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
201
- sample_size = (
202
- sample["target"].size(0) if self.sentence_avg else ntokens
203
- )
204
- logging_output = {
205
- "loss": loss.data,
206
- "nll_loss": nll_loss.data,
207
- "ntokens": sample["ntokens"],
208
- "nsentences": sample["nsentences"],
209
- "sample_size": sample_size,
210
- }
211
- if self.report_accuracy:
212
- n_correct, total = self.compute_accuracy(model, net_output, sample)
213
- logging_output["n_correct"] = utils.item(n_correct.data)
214
- logging_output["total"] = utils.item(total.data)
215
- return loss, sample_size, logging_output
216
-
217
- def get_lprobs_and_target(self, model, net_output, sample):
218
- conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
219
- constraint_masks = None
220
- if "constraint_masks" in sample and sample["constraint_masks"] is not None:
221
- constraint_masks = sample["constraint_masks"]
222
- net_output[0].masked_fill_(~constraint_masks, -math.inf)
223
- if self.constraint_start is not None and self.constraint_end is not None:
224
- net_output[0][:, :, 4:self.constraint_start] = -math.inf
225
- net_output[0][:, :, self.constraint_end:] = -math.inf
226
- lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
227
- target = model.get_targets(sample, net_output)
228
- if self.ignore_prefix_size > 0:
229
- lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
230
- target = target[:, self.ignore_prefix_size :].contiguous()
231
- if constraint_masks is not None:
232
- constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
233
- if self.ignore_eos:
234
- bsz, seq_len, embed_dim = lprobs.size()
235
- eos_indices = target.eq(self.task.tgt_dict.eos())
236
- lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
237
- target = target[~eos_indices].reshape(bsz, seq_len-1)
238
- if constraint_masks is not None:
239
- constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
240
- if constraint_masks is not None:
241
- constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
242
- return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
243
-
244
- def compute_loss(self, model, net_output, sample, update_num, reduce=True):
245
- lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
246
- if constraint_masks is not None:
247
- constraint_masks = constraint_masks[target != self.padding_idx]
248
- lprobs = lprobs[target != self.padding_idx]
249
- target = target[target != self.padding_idx]
250
- loss, nll_loss, ntokens = label_smoothed_nll_loss(
251
- lprobs,
252
- target,
253
- self.eps,
254
- update_num,
255
- reduce=reduce,
256
- drop_worst_ratio=self.drop_worst_ratio,
257
- drop_worst_after=self.drop_worst_after,
258
- use_rdrop=self.use_rdrop,
259
- reg_alpha=self.reg_alpha,
260
- constraint_masks=constraint_masks,
261
- constraint_start=self.constraint_start,
262
- constraint_end=self.constraint_end
263
- )
264
- return loss, nll_loss, ntokens
265
-
266
- def compute_accuracy(self, model, net_output, sample):
267
- lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
268
- mask = target.ne(self.padding_idx)
269
- n_correct = torch.sum(
270
- lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
271
- )
272
- total = torch.sum(mask)
273
- return n_correct, total
274
-
275
- @classmethod
276
- def reduce_metrics(cls, logging_outputs) -> None:
277
- """Aggregate logging outputs from data parallel training."""
278
- loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
279
- loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
280
- loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
281
- nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
282
- ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
283
- nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
284
- sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
285
- sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
286
- sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
287
-
288
- metrics.log_scalar(
289
- "loss", loss_sum / sample_size, sample_size, round=3
290
- )
291
- metrics.log_scalar(
292
- "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
293
- )
294
- metrics.log_scalar(
295
- "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
296
- )
297
- metrics.log_scalar(
298
- "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
299
- )
300
- metrics.log_derived(
301
- "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
302
- )
303
-
304
- metrics.log_scalar(
305
- "ntokens", ntokens, 1, round=3
306
- )
307
- metrics.log_scalar(
308
- "nsentences", nsentences, 1, round=3
309
- )
310
- metrics.log_scalar(
311
- "sample_size", sample_size, 1, round=3
312
- )
313
- metrics.log_scalar(
314
- "sample_size_v1", sample_size_v1, 1, round=3
315
- )
316
- metrics.log_scalar(
317
- "sample_size_v2", sample_size_v2, 1, round=3
318
- )
319
-
320
- total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
321
- if total > 0:
322
- metrics.log_scalar("total", total)
323
- n_correct = utils.item(
324
- sum(log.get("n_correct", 0) for log in logging_outputs)
325
- )
326
- metrics.log_scalar("n_correct", n_correct)
327
- metrics.log_derived(
328
- "accuracy",
329
- lambda meters: round(
330
- meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
331
- )
332
- if meters["total"].sum > 0
333
- else float("nan"),
334
- )
335
-
336
- @staticmethod
337
- def logging_outputs_can_be_summed() -> bool:
338
- """
339
- Whether the logging outputs returned by `forward` can be summed
340
- across workers prior to calling `reduce_metrics`. Setting this
341
- to True will improves distributed training speed.
342
- """
343
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
criterions/label_smoothed_encouraging_loss.py DELETED
@@ -1,395 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import math
7
- from dataclasses import dataclass, field
8
- from typing import Optional
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import numpy as np
13
- from fairseq import metrics, utils
14
- from fairseq.criterions import FairseqCriterion, register_criterion
15
- from fairseq.dataclass import FairseqDataclass
16
- from omegaconf import II
17
-
18
-
19
- @dataclass
20
- class AdjustLabelSmoothedEncouragingLossConfig(FairseqDataclass):
21
- label_smoothing: float = field(
22
- default=0.0,
23
- metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
24
- )
25
- report_accuracy: bool = field(
26
- default=False,
27
- metadata={"help": "report accuracy metric"},
28
- )
29
- ignore_prefix_size: int = field(
30
- default=0,
31
- metadata={"help": "Ignore first N tokens"},
32
- )
33
- ignore_eos: bool = field(
34
- default=False,
35
- metadata={"help": "Ignore eos token"},
36
- )
37
- sentence_avg: bool = II("optimization.sentence_avg")
38
- drop_worst_ratio: float = field(
39
- default=0.0,
40
- metadata={"help": "ratio for discarding bad samples"},
41
- )
42
- drop_worst_after: int = field(
43
- default=0,
44
- metadata={"help": "steps for discarding bad samples"},
45
- )
46
- use_rdrop: bool = field(
47
- default=False, metadata={"help": "use R-Drop"}
48
- )
49
- reg_alpha: float = field(
50
- default=1.0, metadata={"help": "weight for R-Drop"}
51
- )
52
- sample_patch_num: int = field(
53
- default=196, metadata={"help": "sample patchs for v1"}
54
- )
55
- constraint_range: Optional[str] = field(
56
- default=None,
57
- metadata={"help": "constraint range"}
58
- )
59
- log_end: float = field(
60
- default=0.75,
61
- metadata={"help": "higher log_end is for cases with higher performance,"
62
- " we recommend 0.75 or 0.5 as your first try."}
63
- )
64
- drop_best_ratio: float = field(
65
- default=0.0,
66
- metadata={"help": "ratio for discarding best samples"},
67
- )
68
- drop_best_after: int = field(
69
- default=0,
70
- metadata={"help": "steps for discarding best samples"},
71
- )
72
-
73
-
74
-
75
- def construct_rdrop_sample(x):
76
- if isinstance(x, dict):
77
- for key in x:
78
- x[key] = construct_rdrop_sample(x[key])
79
- return x
80
- elif isinstance(x, torch.Tensor):
81
- return x.repeat(2, *([1] * (x.dim()-1)))
82
- elif isinstance(x, int):
83
- return x * 2
84
- elif isinstance(x, np.ndarray):
85
- return x.repeat(2)
86
- else:
87
- raise NotImplementedError
88
-
89
-
90
- def kl_loss(p, q):
91
- p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
92
- q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
93
- loss = (p_loss + q_loss) / 2
94
- return loss
95
-
96
-
97
- def label_smoothed_nll_loss(
98
- lprobs, target, epsilon, update_num, reduce=True,
99
- drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
100
- constraint_masks=None, constraint_start=None, constraint_end=None, drop_best_ratio=0.0,
101
- drop_best_after=0,
102
- ):
103
- if target.dim() == lprobs.dim() - 1:
104
- target = target.unsqueeze(-1)
105
- nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
106
- if constraint_masks is not None:
107
- smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
108
- eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
109
- elif constraint_start is not None and constraint_end is not None:
110
- constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
111
- smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
112
- eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
113
- else:
114
- smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
115
- eps_i = epsilon / (lprobs.size(-1) - 1)
116
- loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
117
- if drop_worst_ratio > 0 and update_num > drop_worst_after:
118
- if use_rdrop:
119
- true_batch_size = loss.size(0) // 2
120
- _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
121
- loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
122
- nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
123
- lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
124
- else:
125
- loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
126
- nll_loss = nll_loss[indices]
127
- lprobs = lprobs[indices]
128
- target = target[indices]
129
- if update_num > drop_best_after:
130
- loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_best_ratio)), largest=True)
131
- nll_loss = nll_loss[indices]
132
- lprobs = lprobs[indices]
133
- target = target[indices]
134
-
135
- ntokens = loss.numel()
136
- nll_loss = nll_loss.sum()
137
- loss = loss.sum()
138
- if use_rdrop:
139
- true_batch_size = lprobs.size(0) // 2
140
- p = lprobs[:true_batch_size]
141
- q = lprobs[true_batch_size:]
142
- if constraint_start is not None and constraint_end is not None:
143
- constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
144
- p = p[:, constraint_range]
145
- q = q[:, constraint_range]
146
- loss += kl_loss(p, q) * reg_alpha
147
-
148
- return loss, nll_loss, ntokens,lprobs,target
149
-
150
-
151
- @register_criterion(
152
- "adjust_label_smoothed_encouraging_loss", dataclass=AdjustLabelSmoothedEncouragingLossConfig
153
- )
154
- class AdjustLabelSmoothedEncouragingLossCriterion(FairseqCriterion):
155
- def __init__(
156
- self,
157
- task,
158
- sentence_avg,
159
- label_smoothing,
160
- ignore_prefix_size=0,
161
- ignore_eos=False,
162
- report_accuracy=False,
163
- drop_worst_ratio=0,
164
- drop_worst_after=0,
165
- use_rdrop=False,
166
- reg_alpha=1.0,
167
- sample_patch_num=196,
168
- constraint_range=None,
169
- log_end=0.75,
170
- drop_best_ratio=0.0,
171
- drop_best_after=0,
172
- ):
173
- super().__init__(task)
174
- self.sentence_avg = sentence_avg
175
- self.eps = label_smoothing
176
- self.ignore_prefix_size = ignore_prefix_size
177
- self.ignore_eos = ignore_eos
178
- self.report_accuracy = report_accuracy
179
- self.drop_worst_ratio = drop_worst_ratio
180
- self.drop_worst_after = drop_worst_after
181
- self.use_rdrop = use_rdrop
182
- self.reg_alpha = reg_alpha
183
- self.sample_patch_num = sample_patch_num
184
-
185
- self.constraint_start = None
186
- self.constraint_end = None
187
- if constraint_range is not None:
188
- constraint_start, constraint_end = constraint_range.split(',')
189
- self.constraint_start = int(constraint_start)
190
- self.constraint_end = int(constraint_end)
191
- self.log_end = log_end
192
- self.drop_best_ratio = drop_best_ratio
193
- self.drop_best_after = drop_best_after
194
- print('el, self.log_end=', self.log_end)
195
- # @staticmethod
196
- # def add_args(parser):
197
- # """Add criterion-specific arguments to the parser."""
198
- # # fmt: off
199
- # parser.add_argument('--log_end', type=float, default=1.0)
200
-
201
- def forward(self, model, sample, update_num=0, reduce=True):
202
- """Compute the loss for the given sample.
203
-
204
- Returns a tuple with three elements:
205
- 1) the loss
206
- 2) the sample size, which is used as the denominator for the gradient
207
- 3) logging outputs to display while training
208
- """
209
- if isinstance(sample, list):
210
- if self.sample_patch_num > 0:
211
- sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
212
- loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
213
- loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
214
- loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
215
- sample_size = 1
216
- logging_output = {
217
- "loss": loss.data,
218
- "loss_v1": loss_v1.data,
219
- "loss_v2": loss_v2.data,
220
- "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
221
- "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
222
- "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
223
- "sample_size": 1,
224
- "sample_size_v1": sample_size_v1,
225
- "sample_size_v2": sample_size_v2,
226
- }
227
- return loss, sample_size, logging_output
228
-
229
- if self.use_rdrop:
230
- construct_rdrop_sample(sample)
231
-
232
- net_output = model(**sample["net_input"])
233
- loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
234
- sample_size = (
235
- sample["target"].size(0) if self.sentence_avg else ntokens
236
- )
237
- logging_output = {
238
- "loss": loss.data,
239
- "nll_loss": nll_loss.data,
240
- "ntokens": sample["ntokens"],
241
- "nsentences": sample["nsentences"],
242
- "sample_size": sample_size,
243
- }
244
- if self.report_accuracy:
245
- n_correct, total = self.compute_accuracy(model, net_output, sample)
246
- logging_output["n_correct"] = utils.item(n_correct.data)
247
- logging_output["total"] = utils.item(total.data)
248
- return loss, sample_size, logging_output
249
-
250
- def get_lprobs_and_target(self, model, net_output, sample):
251
- conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
252
- constraint_masks = None
253
- if "constraint_masks" in sample and sample["constraint_masks"] is not None:
254
- constraint_masks = sample["constraint_masks"]
255
- net_output[0].masked_fill_(~constraint_masks, -math.inf)
256
- if self.constraint_start is not None and self.constraint_end is not None:
257
- net_output[0][:, :, 4:self.constraint_start] = -math.inf
258
- net_output[0][:, :, self.constraint_end:] = -math.inf
259
- lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
260
- target = model.get_targets(sample, net_output)
261
- if self.ignore_prefix_size > 0:
262
- lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
263
- target = target[:, self.ignore_prefix_size :].contiguous()
264
- if constraint_masks is not None:
265
- constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
266
- if self.ignore_eos:
267
- bsz, seq_len, embed_dim = lprobs.size()
268
- eos_indices = target.eq(self.task.tgt_dict.eos())
269
- lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
270
- target = target[~eos_indices].reshape(bsz, seq_len-1)
271
- if constraint_masks is not None:
272
- constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
273
- if constraint_masks is not None:
274
- constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
275
- return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
276
-
277
- def compute_loss(self, model, net_output, sample, update_num, reduce=True):
278
- lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
279
- if constraint_masks is not None:
280
- constraint_masks = constraint_masks[target != self.padding_idx]
281
- lprobs = lprobs[target != self.padding_idx]
282
- target = target[target != self.padding_idx]
283
- loss, nll_loss, ntokens, lprobs, target = label_smoothed_nll_loss(
284
- lprobs,
285
- target,
286
- self.eps,
287
- update_num,
288
- reduce=reduce,
289
- drop_worst_ratio=self.drop_worst_ratio,
290
- drop_worst_after=self.drop_worst_after,
291
- use_rdrop=self.use_rdrop,
292
- reg_alpha=self.reg_alpha,
293
- constraint_masks=constraint_masks,
294
- constraint_start=self.constraint_start,
295
- constraint_end=self.constraint_end
296
- )
297
- # for encouraging loss
298
- probs = torch.exp(lprobs)
299
- bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=1e-5)) # likelihood bonus
300
- log_end = self.log_end
301
- if log_end != 1.0: # e.g. 0.9
302
- y_log_end = torch.log(torch.ones_like(probs) - log_end)
303
- bonus_after_log_end = 1 / (log_end - torch.ones_like(probs)) * (probs - log_end) + y_log_end
304
- # x:log_end, y torch.log(torch.clamp((torch.ones_like(probs) - probs), min=self.cl_eps))
305
- bonus = torch.where(probs > log_end, bonus_after_log_end, bonus)
306
- c_loss = F.nll_loss(
307
- -bonus,
308
- target.view(-1),
309
- reduction='sum',
310
- )
311
- smoothing_c_loss = bonus.sum(dim=-1)
312
- smoothing_c_loss = smoothing_c_loss.sum()
313
- c_loss = c_loss * (1 - self.eps) + (self.eps / lprobs.size(-1)) * smoothing_c_loss
314
- loss = loss + c_loss
315
- # end for encouraging loss
316
- return loss, nll_loss, ntokens
317
-
318
- def compute_accuracy(self, model, net_output, sample):
319
- lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
320
- mask = target.ne(self.padding_idx)
321
- n_correct = torch.sum(
322
- lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
323
- )
324
- total = torch.sum(mask)
325
- return n_correct, total
326
-
327
- @classmethod
328
- def reduce_metrics(cls, logging_outputs) -> None:
329
- """Aggregate logging outputs from data parallel training."""
330
- loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
331
- loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
332
- loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
333
- nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
334
- ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
335
- nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
336
- sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
337
- sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
338
- sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
339
-
340
- metrics.log_scalar(
341
- "loss", loss_sum / sample_size, sample_size, round=3
342
- )
343
- metrics.log_scalar(
344
- "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
345
- )
346
- metrics.log_scalar(
347
- "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
348
- )
349
- metrics.log_scalar(
350
- "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
351
- )
352
- metrics.log_derived(
353
- "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
354
- )
355
-
356
- metrics.log_scalar(
357
- "ntokens", ntokens, 1, round=3
358
- )
359
- metrics.log_scalar(
360
- "nsentences", nsentences, 1, round=3
361
- )
362
- metrics.log_scalar(
363
- "sample_size", sample_size, 1, round=3
364
- )
365
- metrics.log_scalar(
366
- "sample_size_v1", sample_size_v1, 1, round=3
367
- )
368
- metrics.log_scalar(
369
- "sample_size_v2", sample_size_v2, 1, round=3
370
- )
371
-
372
- total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
373
- if total > 0:
374
- metrics.log_scalar("total", total)
375
- n_correct = utils.item(
376
- sum(log.get("n_correct", 0) for log in logging_outputs)
377
- )
378
- metrics.log_scalar("n_correct", n_correct)
379
- metrics.log_derived(
380
- "accuracy",
381
- lambda meters: round(
382
- meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
383
- )
384
- if meters["total"].sum > 0
385
- else float("nan"),
386
- )
387
-
388
- @staticmethod
389
- def logging_outputs_can_be_summed() -> bool:
390
- """
391
- Whether the logging outputs returned by `forward` can be summed
392
- across workers prior to calling `reduce_metrics`. Setting this
393
- to True will improves distributed training speed.
394
- """
395
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
criterions/scst_loss.py DELETED
@@ -1,281 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import math
7
- import string
8
- from dataclasses import dataclass, field
9
- from collections import OrderedDict
10
- from typing import Optional
11
-
12
- import torch
13
- from fairseq import metrics, utils
14
- from fairseq.criterions import FairseqCriterion, register_criterion
15
- from fairseq.dataclass import FairseqDataclass
16
- from omegaconf import II
17
-
18
- from data import data_utils
19
- from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
20
-
21
-
22
- def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
23
- loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
24
- if ignore_index is not None:
25
- pad_mask = target.eq(ignore_index)
26
- loss.masked_fill_(pad_mask, 0.0)
27
- ntokens = (~pad_mask).sum()
28
- else:
29
- loss = loss.squeeze(-1)
30
- ntokens = target.numel()
31
- if reduce:
32
- loss = loss.sum()
33
- return loss, ntokens
34
-
35
-
36
- @dataclass
37
- class ScstRewardCriterionConfig(FairseqDataclass):
38
- scst_cider_cached_tokens: str = field(
39
- default="coco-train-words.p",
40
- metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
41
- )
42
- ignore_prefix_size: int = field(
43
- default=0,
44
- metadata={"help": "Ignore first N tokens"},
45
- )
46
- sentence_avg: bool = II("optimization.sentence_avg")
47
- constraint_range: Optional[str] = field(
48
- default=None,
49
- metadata={"help": "constraint range"}
50
- )
51
-
52
-
53
- @register_criterion(
54
- "scst_reward_criterion", dataclass=ScstRewardCriterionConfig
55
- )
56
- class ScstRewardCriterion(FairseqCriterion):
57
- CIDER_REWARD_WEIGHT = 1
58
-
59
- def __init__(
60
- self,
61
- task,
62
- scst_cider_cached_tokens,
63
- sentence_avg,
64
- ignore_prefix_size=0,
65
- constraint_range=None
66
- ):
67
- super().__init__(task)
68
- self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
69
- self.sentence_avg = sentence_avg
70
- self.ignore_prefix_size = ignore_prefix_size
71
- self.transtab = str.maketrans({key: None for key in string.punctuation})
72
-
73
- self.constraint_start = None
74
- self.constraint_end = None
75
- if constraint_range is not None:
76
- constraint_start, constraint_end = constraint_range.split(',')
77
- self.constraint_start = int(constraint_start)
78
- self.constraint_end = int(constraint_end)
79
-
80
- def forward(self, model, sample, update_num=0, reduce=True):
81
- """Compute the loss for the given sample.
82
-
83
- Returns a tuple with three elements:
84
- 1) the loss
85
- 2) the sample size, which is used as the denominator for the gradient
86
- 3) logging outputs to display while training
87
- """
88
- loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
89
-
90
- sample_size = (
91
- nsentences if self.sentence_avg else ntokens
92
- )
93
- logging_output = {
94
- "loss": loss.data,
95
- "score": score,
96
- "ntokens": ntokens,
97
- "nsentences": nsentences,
98
- "sample_size": sample_size,
99
- }
100
- return loss, sample_size, logging_output
101
-
102
- def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
103
- '''
104
- gen_res: generated captions, list of str
105
- gt_idx: list of int, of the same length as gen_res
106
- gt_res: ground truth captions, list of list of str.
107
- gen_res[i] corresponds to gt_res[gt_idx[i]]
108
- Each image can have multiple ground truth captions
109
- '''
110
- gen_res_size = len(gen_res)
111
-
112
- res = OrderedDict()
113
- for i in range(gen_res_size):
114
- res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
115
-
116
- gts = OrderedDict()
117
- gt_res_ = [
118
- [self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
119
- for i in range(len(gt_res))
120
- ]
121
- for i in range(gen_res_size):
122
- gts[i] = gt_res_[gt_idx[i]]
123
-
124
- res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
125
- _, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
126
- scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
127
- return scores
128
-
129
- @classmethod
130
- def _wrap_sentence(self, s):
131
- # ensure the sentence ends with <eos> token
132
- # in order to keep consisitent with cider_cached_tokens
133
- r = s.strip()
134
- if r.endswith('.'):
135
- r = r[:-1]
136
- r += ' <eos>'
137
- return r
138
-
139
- def get_generator_out(self, model, sample):
140
- def decode(toks):
141
- hypo = toks.int().cpu()
142
- hypo_str = self.task.tgt_dict.string(hypo)
143
- hypo_str = self.task.bpe.decode(hypo_str).strip()
144
- return hypo, hypo_str
145
-
146
- model.eval()
147
- with torch.no_grad():
148
- self.task.scst_generator.model.eval()
149
- gen_out = self.task.scst_generator.generate([model], sample)
150
-
151
- gen_target = []
152
- gen_res = []
153
- gt_res = []
154
- for i in range(len(gen_out)):
155
- for j in range(len(gen_out[i])):
156
- hypo, hypo_str = decode(gen_out[i][j]["tokens"])
157
- gen_target.append(hypo)
158
- gen_res.append(hypo_str)
159
- gt_res.append(
160
- decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
161
- )
162
-
163
- return gen_target, gen_res, gt_res
164
-
165
- def get_reward_and_scores(self, gen_res, gt_res, device):
166
- batch_size = len(gt_res)
167
- gen_res_size = len(gen_res)
168
- seq_per_img = gen_res_size // batch_size
169
-
170
- gt_idx = [i // seq_per_img for i in range(gen_res_size)]
171
- scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
172
- sc_ = scores.reshape(batch_size, seq_per_img)
173
- baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
174
- # sample - baseline
175
- reward = scores.reshape(batch_size, seq_per_img)
176
- reward = reward - baseline
177
- reward = reward.reshape(gen_res_size)
178
- reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
179
-
180
- return reward, scores
181
-
182
- def get_net_output(self, model, sample, gen_target):
183
- def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
184
- return data_utils.collate_tokens(
185
- sample_list,
186
- pad_idx=self.padding_idx,
187
- eos_idx=eos,
188
- left_pad=False,
189
- move_eos_to_beginning=move_eos_to_beginning,
190
- )
191
-
192
- batch_size = len(sample["target"])
193
- gen_target_size = len(gen_target)
194
- seq_per_img = gen_target_size // batch_size
195
-
196
- model.train()
197
- sample_src_tokens = torch.repeat_interleave(
198
- sample['net_input']['src_tokens'], seq_per_img, dim=0
199
- )
200
- sample_src_lengths = torch.repeat_interleave(
201
- sample['net_input']['src_lengths'], seq_per_img, dim=0
202
- )
203
- sample_patch_images = torch.repeat_interleave(
204
- sample['net_input']['patch_images'], seq_per_img, dim=0
205
- )
206
- sample_patch_masks = torch.repeat_interleave(
207
- sample['net_input']['patch_masks'], seq_per_img, dim=0
208
- )
209
- gen_prev_output_tokens = torch.as_tensor(
210
- merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
211
- device=sample["target"].device, dtype=torch.int64
212
- )
213
- gen_target_tokens = torch.as_tensor(
214
- merge(gen_target), device=sample["target"].device, dtype=torch.int64
215
- )
216
- net_output = model(
217
- src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
218
- patch_images=sample_patch_images, patch_masks=sample_patch_masks,
219
- prev_output_tokens=gen_prev_output_tokens
220
- )
221
-
222
- return net_output, gen_target_tokens
223
-
224
- def get_lprobs_and_target(self, model, net_output, gen_target):
225
- if self.constraint_start is not None and self.constraint_end is not None:
226
- net_output[0][:, :, 4:self.constraint_start] = -math.inf
227
- net_output[0][:, :, self.constraint_end:] = -math.inf
228
- lprobs = model.get_normalized_probs(net_output, log_probs=True)
229
- if self.ignore_prefix_size > 0:
230
- if getattr(lprobs, "batch_first", False):
231
- lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
232
- gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
233
- else:
234
- lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
235
- gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
236
- return lprobs, gen_target
237
-
238
- def compute_loss(self, model, sample, reduce=True):
239
- gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
240
- reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
241
- net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
242
- gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
243
- loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
244
- nsentences = gen_target_tokens.size(0)
245
-
246
- return loss, scores.sum(), ntokens, nsentences
247
-
248
- @classmethod
249
- def reduce_metrics(cls, logging_outputs) -> None:
250
- """Aggregate logging outputs from data parallel training."""
251
- loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
252
- score_sum = sum(log.get("score", 0) for log in logging_outputs)
253
- ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
254
- nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
255
- sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
256
-
257
- metrics.log_scalar(
258
- "loss", loss_sum / sample_size, sample_size, round=3
259
- )
260
- metrics.log_scalar(
261
- "score", score_sum / nsentences, nsentences, round=3
262
- )
263
-
264
- metrics.log_scalar(
265
- "ntokens", ntokens, 1, round=3
266
- )
267
- metrics.log_scalar(
268
- "nsentences", nsentences, 1, round=3
269
- )
270
- metrics.log_scalar(
271
- "sample_size", sample_size, 1, round=3
272
- )
273
-
274
- @staticmethod
275
- def logging_outputs_can_be_summed() -> bool:
276
- """
277
- Whether the logging outputs returned by `forward` can be summed
278
- across workers prior to calling `reduce_metrics`. Setting this
279
- to True will improves distributed training speed.
280
- """
281
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/__init__.py DELETED
File without changes
data/cv_data/image_classify_dataset.py DELETED
@@ -1,196 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
- import functools
11
-
12
- import numpy as np
13
- import torch
14
- import base64
15
- from torchvision import transforms
16
- from timm.data import create_transform
17
- from utils.vision_helper import RandomAugment
18
-
19
- from PIL import Image, ImageFile
20
-
21
- from data import data_utils
22
- from data.ofa_dataset import OFADataset
23
-
24
- ImageFile.LOAD_TRUNCATED_IMAGES = True
25
- ImageFile.MAX_IMAGE_PIXELS = None
26
- Image.MAX_IMAGE_PIXELS = None
27
-
28
- logger = logging.getLogger(__name__)
29
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
30
-
31
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
32
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
33
-
34
- def collate(samples, pad_idx, eos_idx):
35
- if len(samples) == 0:
36
- return {}
37
-
38
- def merge(key):
39
- return data_utils.collate_tokens(
40
- [s[key] for s in samples],
41
- pad_idx,
42
- eos_idx=eos_idx,
43
- )
44
-
45
- id = np.array([s["id"] for s in samples])
46
- src_tokens = merge("source")
47
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
48
-
49
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
50
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
51
-
52
- conf = None
53
- if samples[0].get("conf", None) is not None:
54
- conf = torch.cat([s['conf'] for s in samples], dim=0)
55
-
56
- ref_dict = None
57
- if samples[0].get("ref_dict", None) is not None:
58
- ref_dict = np.array([s['ref_dict'] for s in samples])
59
-
60
- constraint_masks = None
61
- if samples[0].get("constraint_mask", None) is not None:
62
- constraint_masks = merge("constraint_mask")
63
-
64
- prev_output_tokens = None
65
- target = None
66
- if samples[0].get("target", None) is not None:
67
- target = merge("target")
68
- tgt_lengths = torch.LongTensor(
69
- [s["target"].ne(pad_idx).long().sum() for s in samples]
70
- )
71
- ntokens = tgt_lengths.sum().item()
72
-
73
- if samples[0].get("prev_output_tokens", None) is not None:
74
- prev_output_tokens = merge("prev_output_tokens")
75
- else:
76
- ntokens = src_lengths.sum().item()
77
-
78
- batch = {
79
- "id": id,
80
- "nsentences": len(samples),
81
- "ntokens": ntokens,
82
- "net_input": {
83
- "src_tokens": src_tokens,
84
- "src_lengths": src_lengths,
85
- "patch_images": patch_images,
86
- "patch_masks": patch_masks,
87
- "prev_output_tokens": prev_output_tokens
88
- },
89
- "conf": conf,
90
- "ref_dict": ref_dict,
91
- "constraint_masks": constraint_masks,
92
- "target": target,
93
- }
94
-
95
- return batch
96
-
97
-
98
- class ImageClassifyDataset(OFADataset):
99
- def __init__(
100
- self,
101
- split,
102
- dataset,
103
- bpe,
104
- src_dict,
105
- tgt_dict=None,
106
- max_src_length=128,
107
- max_tgt_length=30,
108
- patch_image_size=224,
109
- constraint_trie=None,
110
- imagenet_default_mean_and_std=False
111
- ):
112
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
113
- self.max_src_length = max_src_length
114
- self.max_tgt_length = max_tgt_length
115
- self.patch_image_size = patch_image_size
116
-
117
- self.constraint_trie = constraint_trie
118
-
119
- if imagenet_default_mean_and_std:
120
- mean = IMAGENET_DEFAULT_MEAN
121
- std = IMAGENET_DEFAULT_STD
122
- else:
123
- mean = [0.5, 0.5, 0.5]
124
- std = [0.5, 0.5, 0.5]
125
-
126
- if self.split != 'train':
127
- self.patch_resize_transform = transforms.Compose([
128
- lambda image: image.convert("RGB"),
129
- transforms.Resize([patch_image_size, patch_image_size], interpolation=Image.BICUBIC),
130
- transforms.ToTensor(),
131
- transforms.Normalize(mean=mean, std=std),
132
- ])
133
- logger.info("val split, do not use random augmentation.")
134
- else:
135
- self.patch_resize_transform = create_transform(
136
- input_size=patch_image_size,
137
- is_training=True,
138
- color_jitter=0.4,
139
- auto_augment='rand-m9-mstd0.5-inc1',
140
- interpolation='bicubic',
141
- re_prob=0.25,
142
- re_mode='pixel',
143
- re_count=1,
144
- mean=mean,
145
- std=std,
146
- )
147
- self.patch_resize_transform = transforms.Compose(functools.reduce(lambda x, y:x + y, [
148
- [lambda image: image.convert("RGB"),],
149
- self.patch_resize_transform.transforms[:2],
150
- [self.patch_resize_transform.transforms[2]],
151
- [RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), ],
152
- self.patch_resize_transform.transforms[3:],
153
- ]))
154
- logger.info("train split, use random augmentation.")
155
-
156
- def __getitem__(self, index):
157
- image, label_name = self.dataset[index]
158
-
159
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
160
- patch_image = self.patch_resize_transform(image)
161
- patch_mask = torch.tensor([True])
162
-
163
- src_item = self.encode_text(' what does the image describe?')
164
- tgt_item = self.encode_text(" {}".format(label_name))
165
- ref_dict = {label_name: 1.0}
166
-
167
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
168
- target_item = torch.cat([tgt_item, self.eos_item])
169
- prev_output_item = torch.cat([self.bos_item, tgt_item])
170
-
171
- example = {
172
- "id": index,
173
- "source": src_item,
174
- "patch_image": patch_image,
175
- "patch_mask": patch_mask,
176
- "target": target_item,
177
- "prev_output_tokens": prev_output_item,
178
- "ref_dict": ref_dict,
179
- }
180
- if self.constraint_trie is not None:
181
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
182
- for i in range(len(prev_output_item)):
183
- constraint_prefix_token = prev_output_item[:i+1].tolist()
184
- constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
185
- constraint_mask[i][constraint_nodes] = True
186
- example["constraint_mask"] = constraint_mask
187
- return example
188
-
189
- def collater(self, samples, pad_to_length=None):
190
- """Merge a list of samples to form a mini-batch.
191
- Args:
192
- samples (List[dict]): samples to collate
193
- Returns:
194
- dict: a mini-batch containing the data of the task
195
- """
196
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/data_utils.py DELETED
@@ -1,601 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- try:
7
- from collections.abc import Iterable
8
- except ImportError:
9
- from collections import Iterable
10
- import contextlib
11
- import itertools
12
- import logging
13
- import re
14
- import warnings
15
- from typing import Optional, Tuple
16
-
17
- import numpy as np
18
- import torch
19
-
20
- from fairseq.file_io import PathManager
21
- from fairseq import utils
22
- import os
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- def infer_language_pair(path):
28
- """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
29
- src, dst = None, None
30
- for filename in PathManager.ls(path):
31
- parts = filename.split(".")
32
- if len(parts) >= 3 and len(parts[1].split("-")) == 2:
33
- return parts[1].split("-")
34
- return src, dst
35
-
36
-
37
- def collate_tokens(
38
- values,
39
- pad_idx,
40
- eos_idx=None,
41
- left_pad=False,
42
- move_eos_to_beginning=False,
43
- pad_to_length=None,
44
- pad_to_multiple=1,
45
- pad_to_bsz=None,
46
- ):
47
- """Convert a list of 1d tensors into a padded 2d tensor."""
48
- size = max(v.size(0) for v in values)
49
- size = size if pad_to_length is None else max(size, pad_to_length)
50
- if pad_to_multiple != 1 and size % pad_to_multiple != 0:
51
- size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
52
-
53
- def copy_tensor(src, dst):
54
- assert dst.numel() == src.numel()
55
- if move_eos_to_beginning:
56
- if eos_idx is None:
57
- # if no eos_idx is specified, then use the last token in src
58
- dst[0] = src[-1]
59
- else:
60
- dst[0] = eos_idx
61
- dst[1:] = src[:-1]
62
- else:
63
- dst.copy_(src)
64
-
65
- if values[0].dim() == 1:
66
- res = values[0].new(len(values), size).fill_(pad_idx)
67
- elif values[0].dim() == 2:
68
- assert move_eos_to_beginning is False
69
- res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
70
- else:
71
- raise NotImplementedError
72
-
73
- for i, v in enumerate(values):
74
- copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
75
- return res
76
-
77
-
78
- def load_indexed_dataset(
79
- path, dictionary=None, dataset_impl=None, combine=False, default="cached"
80
- ):
81
- """A helper function for loading indexed datasets.
82
-
83
- Args:
84
- path (str): path to indexed dataset (e.g., 'data-bin/train')
85
- dictionary (~fairseq.data.Dictionary): data dictionary
86
- dataset_impl (str, optional): which dataset implementation to use. If
87
- not provided, it will be inferred automatically. For legacy indexed
88
- data we use the 'cached' implementation by default.
89
- combine (bool, optional): automatically load and combine multiple
90
- datasets. For example, if *path* is 'data-bin/train', then we will
91
- combine 'data-bin/train', 'data-bin/train1', ... and return a
92
- single ConcatDataset instance.
93
- """
94
- import fairseq.data.indexed_dataset as indexed_dataset
95
- from fairseq.data.concat_dataset import ConcatDataset
96
-
97
- datasets = []
98
- for k in itertools.count():
99
- path_k = path + (str(k) if k > 0 else "")
100
- try:
101
- path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
102
- except Exception as e:
103
- if "StorageException: [404] Path not found" in str(e):
104
- logger.warning(f"path_k: {e} not found")
105
- else:
106
- raise e
107
-
108
- dataset_impl_k = dataset_impl
109
- if dataset_impl_k is None:
110
- dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
111
- dataset = indexed_dataset.make_dataset(
112
- path_k,
113
- impl=dataset_impl_k or default,
114
- fix_lua_indexing=True,
115
- dictionary=dictionary,
116
- )
117
- if dataset is None:
118
- break
119
- logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
120
- datasets.append(dataset)
121
- if not combine:
122
- break
123
- if len(datasets) == 0:
124
- return None
125
- elif len(datasets) == 1:
126
- return datasets[0]
127
- else:
128
- return ConcatDataset(datasets)
129
-
130
-
131
- @contextlib.contextmanager
132
- def numpy_seed(seed, *addl_seeds):
133
- """Context manager which seeds the NumPy PRNG with the specified seed and
134
- restores the state afterward"""
135
- if seed is None:
136
- yield
137
- return
138
- if len(addl_seeds) > 0:
139
- seed = int(hash((seed, *addl_seeds)) % 1e6)
140
- state = np.random.get_state()
141
- np.random.seed(seed)
142
- try:
143
- yield
144
- finally:
145
- np.random.set_state(state)
146
-
147
-
148
- def collect_filtered(function, iterable, filtered):
149
- """
150
- Similar to :func:`filter` but collects filtered elements in ``filtered``.
151
-
152
- Args:
153
- function (callable): function that returns ``False`` for elements that
154
- should be filtered
155
- iterable (iterable): iterable to filter
156
- filtered (list): list to store filtered elements
157
- """
158
- for el in iterable:
159
- if function(el):
160
- yield el
161
- else:
162
- filtered.append(el)
163
-
164
-
165
- def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
166
- def compare_leq(a, b):
167
- return a <= b if not isinstance(a, tuple) else max(a) <= b
168
-
169
- def check_size(idx):
170
- if isinstance(max_positions, float) or isinstance(max_positions, int):
171
- return size_fn(idx) <= max_positions
172
- elif isinstance(max_positions, dict):
173
- idx_size = size_fn(idx)
174
- assert isinstance(idx_size, dict)
175
- intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
176
- return all(
177
- all(
178
- a is None or b is None or a <= b
179
- for a, b in zip(idx_size[key], max_positions[key])
180
- )
181
- for key in intersect_keys
182
- )
183
- else:
184
- # For MultiCorpusSampledDataset, will generalize it later
185
- if not isinstance(size_fn(idx), Iterable):
186
- return all(size_fn(idx) <= b for b in max_positions)
187
- return all(
188
- a is None or b is None or a <= b
189
- for a, b in zip(size_fn(idx), max_positions)
190
- )
191
-
192
- ignored = []
193
- itr = collect_filtered(check_size, indices, ignored)
194
- indices = np.fromiter(itr, dtype=np.int64, count=-1)
195
- return indices, ignored
196
-
197
-
198
- def filter_by_size(indices, dataset, max_positions, raise_exception=False):
199
- """
200
- [deprecated] Filter indices based on their size.
201
- Use `FairseqDataset::filter_indices_by_size` instead.
202
-
203
- Args:
204
- indices (List[int]): ordered list of dataset indices
205
- dataset (FairseqDataset): fairseq dataset instance
206
- max_positions (tuple): filter elements larger than this size.
207
- Comparisons are done component-wise.
208
- raise_exception (bool, optional): if ``True``, raise an exception if
209
- any elements are filtered (default: False).
210
- """
211
- warnings.warn(
212
- "data_utils.filter_by_size is deprecated. "
213
- "Use `FairseqDataset::filter_indices_by_size` instead.",
214
- stacklevel=2,
215
- )
216
- if isinstance(max_positions, float) or isinstance(max_positions, int):
217
- if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
218
- ignored = indices[dataset.sizes[indices] > max_positions].tolist()
219
- indices = indices[dataset.sizes[indices] <= max_positions]
220
- elif (
221
- hasattr(dataset, "sizes")
222
- and isinstance(dataset.sizes, list)
223
- and len(dataset.sizes) == 1
224
- ):
225
- ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
226
- indices = indices[dataset.sizes[0][indices] <= max_positions]
227
- else:
228
- indices, ignored = _filter_by_size_dynamic(
229
- indices, dataset.size, max_positions
230
- )
231
- else:
232
- indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
233
-
234
- if len(ignored) > 0 and raise_exception:
235
- raise Exception(
236
- (
237
- "Size of sample #{} is invalid (={}) since max_positions={}, "
238
- "skip this example with --skip-invalid-size-inputs-valid-test"
239
- ).format(ignored[0], dataset.size(ignored[0]), max_positions)
240
- )
241
- if len(ignored) > 0:
242
- logger.warning(
243
- (
244
- "{} samples have invalid sizes and will be skipped, "
245
- "max_positions={}, first few sample ids={}"
246
- ).format(len(ignored), max_positions, ignored[:10])
247
- )
248
- return indices
249
-
250
-
251
- def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
252
- """Filter a list of sample indices. Remove those that are longer
253
- than specified in max_sizes.
254
-
255
- Args:
256
- indices (np.array): original array of sample indices
257
- max_sizes (int or list[int] or tuple[int]): max sample size,
258
- can be defined separately for src and tgt (then list or tuple)
259
-
260
- Returns:
261
- np.array: filtered sample array
262
- list: list of removed indices
263
- """
264
- if max_sizes is None:
265
- return indices, []
266
- if type(max_sizes) in (int, float):
267
- max_src_size, max_tgt_size = max_sizes, max_sizes
268
- else:
269
- max_src_size, max_tgt_size = max_sizes
270
- if tgt_sizes is None:
271
- ignored = indices[src_sizes[indices] > max_src_size]
272
- else:
273
- ignored = indices[
274
- (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
275
- ]
276
- if len(ignored) > 0:
277
- if tgt_sizes is None:
278
- indices = indices[src_sizes[indices] <= max_src_size]
279
- else:
280
- indices = indices[
281
- (src_sizes[indices] <= max_src_size)
282
- & (tgt_sizes[indices] <= max_tgt_size)
283
- ]
284
- return indices, ignored.tolist()
285
-
286
-
287
- def batch_by_size(
288
- indices,
289
- num_tokens_fn,
290
- num_tokens_vec=None,
291
- max_tokens=None,
292
- max_sentences=None,
293
- required_batch_size_multiple=1,
294
- fixed_shapes=None,
295
- ):
296
- """
297
- Yield mini-batches of indices bucketed by size. Batches may contain
298
- sequences of different lengths.
299
-
300
- Args:
301
- indices (List[int]): ordered list of dataset indices
302
- num_tokens_fn (callable): function that returns the number of tokens at
303
- a given index
304
- num_tokens_vec (List[int], optional): precomputed vector of the number
305
- of tokens for each index in indices (to enable faster batch generation)
306
- max_tokens (int, optional): max number of tokens in each batch
307
- (default: None).
308
- max_sentences (int, optional): max number of sentences in each
309
- batch (default: None).
310
- required_batch_size_multiple (int, optional): require batch size to
311
- be less than N or a multiple of N (default: 1).
312
- fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
313
- only be created with the given shapes. *max_sentences* and
314
- *required_batch_size_multiple* will be ignored (default: None).
315
- """
316
- try:
317
- from fairseq.data.data_utils_fast import (
318
- batch_by_size_fn,
319
- batch_by_size_vec,
320
- batch_fixed_shapes_fast,
321
- )
322
- except ImportError:
323
- raise ImportError(
324
- "Please build Cython components with: "
325
- "`python setup.py build_ext --inplace`"
326
- )
327
- except ValueError:
328
- raise ValueError(
329
- "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
330
- )
331
-
332
- # added int() to avoid TypeError: an integer is required
333
- max_tokens = (
334
- int(max_tokens) if max_tokens is not None else -1
335
- )
336
- max_sentences = max_sentences if max_sentences is not None else -1
337
- bsz_mult = required_batch_size_multiple
338
-
339
- if not isinstance(indices, np.ndarray):
340
- indices = np.fromiter(indices, dtype=np.int64, count=-1)
341
-
342
- if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
343
- num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
344
-
345
- if fixed_shapes is None:
346
- if num_tokens_vec is None:
347
- return batch_by_size_fn(
348
- indices,
349
- num_tokens_fn,
350
- max_tokens,
351
- max_sentences,
352
- bsz_mult,
353
- )
354
- else:
355
- return batch_by_size_vec(
356
- indices,
357
- num_tokens_vec,
358
- max_tokens,
359
- max_sentences,
360
- bsz_mult,
361
- )
362
-
363
- else:
364
- fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
365
- sort_order = np.lexsort(
366
- [
367
- fixed_shapes[:, 1].argsort(), # length
368
- fixed_shapes[:, 0].argsort(), # bsz
369
- ]
370
- )
371
- fixed_shapes_sorted = fixed_shapes[sort_order]
372
- return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
373
-
374
-
375
- def post_process(sentence: str, symbol: str):
376
- if symbol == "sentencepiece":
377
- sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
378
- elif symbol == "wordpiece":
379
- sentence = sentence.replace(" ", "").replace("_", " ").strip()
380
- elif symbol == "letter":
381
- sentence = sentence.replace(" ", "").replace("|", " ").strip()
382
- elif symbol == "silence":
383
- import re
384
- sentence = sentence.replace("<SIL>", "")
385
- sentence = re.sub(' +', ' ', sentence).strip()
386
- elif symbol == "_EOW":
387
- sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
388
- elif symbol in {"subword_nmt", "@@ ", "@@"}:
389
- if symbol == "subword_nmt":
390
- symbol = "@@ "
391
- sentence = (sentence + " ").replace(symbol, "").rstrip()
392
- elif symbol == "none":
393
- pass
394
- elif symbol is not None:
395
- raise NotImplementedError(f"Unknown post_process option: {symbol}")
396
- return sentence
397
-
398
-
399
- def compute_mask_indices(
400
- shape: Tuple[int, int],
401
- padding_mask: Optional[torch.Tensor],
402
- mask_prob: float,
403
- mask_length: int,
404
- mask_type: str = "static",
405
- mask_other: float = 0.0,
406
- min_masks: int = 0,
407
- no_overlap: bool = False,
408
- min_space: int = 0,
409
- ) -> np.ndarray:
410
- """
411
- Computes random mask spans for a given shape
412
-
413
- Args:
414
- shape: the the shape for which to compute masks.
415
- should be of size 2 where first element is batch size and 2nd is timesteps
416
- padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
417
- mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
418
- number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
419
- however due to overlaps, the actual number will be smaller (unless no_overlap is True)
420
- mask_type: how to compute mask lengths
421
- static = fixed size
422
- uniform = sample from uniform distribution [mask_other, mask_length*2]
423
- normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
424
- poisson = sample from possion distribution with lambda = mask length
425
- min_masks: minimum number of masked spans
426
- no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
427
- min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
428
- """
429
-
430
- bsz, all_sz = shape
431
- mask = np.full((bsz, all_sz), False)
432
-
433
- all_num_mask = int(
434
- # add a random number for probabilistic rounding
435
- mask_prob * all_sz / float(mask_length)
436
- + np.random.rand()
437
- )
438
-
439
- all_num_mask = max(min_masks, all_num_mask)
440
-
441
- mask_idcs = []
442
- for i in range(bsz):
443
- if padding_mask is not None:
444
- sz = all_sz - padding_mask[i].long().sum().item()
445
- num_mask = int(
446
- # add a random number for probabilistic rounding
447
- mask_prob * sz / float(mask_length)
448
- + np.random.rand()
449
- )
450
- num_mask = max(min_masks, num_mask)
451
- else:
452
- sz = all_sz
453
- num_mask = all_num_mask
454
-
455
- if mask_type == "static":
456
- lengths = np.full(num_mask, mask_length)
457
- elif mask_type == "uniform":
458
- lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
459
- elif mask_type == "normal":
460
- lengths = np.random.normal(mask_length, mask_other, size=num_mask)
461
- lengths = [max(1, int(round(x))) for x in lengths]
462
- elif mask_type == "poisson":
463
- lengths = np.random.poisson(mask_length, size=num_mask)
464
- lengths = [int(round(x)) for x in lengths]
465
- else:
466
- raise Exception("unknown mask selection " + mask_type)
467
-
468
- if sum(lengths) == 0:
469
- lengths[0] = min(mask_length, sz - 1)
470
-
471
- if no_overlap:
472
- mask_idc = []
473
-
474
- def arrange(s, e, length, keep_length):
475
- span_start = np.random.randint(s, e - length)
476
- mask_idc.extend(span_start + i for i in range(length))
477
-
478
- new_parts = []
479
- if span_start - s - min_space >= keep_length:
480
- new_parts.append((s, span_start - min_space + 1))
481
- if e - span_start - keep_length - min_space > keep_length:
482
- new_parts.append((span_start + length + min_space, e))
483
- return new_parts
484
-
485
- parts = [(0, sz)]
486
- min_length = min(lengths)
487
- for length in sorted(lengths, reverse=True):
488
- lens = np.fromiter(
489
- (e - s if e - s >= length + min_space else 0 for s, e in parts),
490
- np.int,
491
- )
492
- l_sum = np.sum(lens)
493
- if l_sum == 0:
494
- break
495
- probs = lens / np.sum(lens)
496
- c = np.random.choice(len(parts), p=probs)
497
- s, e = parts.pop(c)
498
- parts.extend(arrange(s, e, length, min_length))
499
- mask_idc = np.asarray(mask_idc)
500
- else:
501
- min_len = min(lengths)
502
- if sz - min_len <= num_mask:
503
- min_len = sz - num_mask - 1
504
-
505
- mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
506
-
507
- mask_idc = np.asarray(
508
- [
509
- mask_idc[j] + offset
510
- for j in range(len(mask_idc))
511
- for offset in range(lengths[j])
512
- ]
513
- )
514
-
515
- mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
516
-
517
- min_len = min([len(m) for m in mask_idcs])
518
- for i, mask_idc in enumerate(mask_idcs):
519
- if len(mask_idc) > min_len:
520
- mask_idc = np.random.choice(mask_idc, min_len, replace=False)
521
- mask[i, mask_idc] = True
522
-
523
- return mask
524
-
525
-
526
- def get_mem_usage():
527
- try:
528
- import psutil
529
-
530
- mb = 1024 * 1024
531
- return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
532
- except ImportError:
533
- return "N/A"
534
-
535
-
536
- # lens: torch.LongTensor
537
- # returns: torch.BoolTensor
538
- def lengths_to_padding_mask(lens):
539
- bsz, max_lens = lens.size(0), torch.max(lens).item()
540
- mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
541
- mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
542
- return mask
543
-
544
-
545
- # lens: torch.LongTensor
546
- # returns: torch.BoolTensor
547
- def lengths_to_mask(lens):
548
- return ~lengths_to_padding_mask(lens)
549
-
550
-
551
- def get_buckets(sizes, num_buckets):
552
- buckets = np.unique(
553
- np.percentile(
554
- sizes,
555
- np.linspace(0, 100, num_buckets + 1),
556
- interpolation='lower',
557
- )[1:]
558
- )
559
- return buckets
560
-
561
-
562
- def get_bucketed_sizes(orig_sizes, buckets):
563
- sizes = np.copy(orig_sizes)
564
- assert np.min(sizes) >= 0
565
- start_val = -1
566
- for end_val in buckets:
567
- mask = (sizes > start_val) & (sizes <= end_val)
568
- sizes[mask] = end_val
569
- start_val = end_val
570
- return sizes
571
-
572
-
573
-
574
- def _find_extra_valid_paths(dataset_path: str) -> set:
575
- paths = utils.split_paths(dataset_path)
576
- all_valid_paths = set()
577
- for sub_dir in paths:
578
- contents = PathManager.ls(sub_dir)
579
- valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
580
- all_valid_paths |= {os.path.basename(p) for p in valid_paths}
581
- # Remove .bin, .idx etc
582
- roots = {os.path.splitext(p)[0] for p in all_valid_paths}
583
- return roots
584
-
585
-
586
- def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
587
- """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
588
- if (
589
- train_cfg.dataset.ignore_unused_valid_subsets
590
- or train_cfg.dataset.combine_valid_subsets
591
- or train_cfg.dataset.disable_validation
592
- or not hasattr(train_cfg.task, "data")
593
- ):
594
- return
595
- other_paths = _find_extra_valid_paths(train_cfg.task.data)
596
- specified_subsets = train_cfg.dataset.valid_subset.split(",")
597
- ignored_paths = [p for p in other_paths if p not in specified_subsets]
598
- if ignored_paths:
599
- advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
600
- msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
601
- raise ValueError(msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/file_dataset.py DELETED
@@ -1,107 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import os
7
- import torch
8
- import pickle
9
-
10
-
11
- class FileDataset:
12
- def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
13
- self.file_path = file_path
14
- assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
15
-
16
- self.separator = separator
17
- if selected_col_ids is None:
18
- # default to all fields
19
- self.selected_col_ids = list(
20
- range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
21
- else:
22
- self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
23
- if dtypes is None:
24
- # default to str
25
- self.dtypes = [str for col_id in self.selected_col_ids]
26
- else:
27
- self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
28
- assert len(self.dtypes) == len(self.selected_col_ids)
29
-
30
- self.data_cnt = 0
31
- try:
32
- self.slice_id = torch.distributed.get_rank()
33
- self.slice_count = torch.distributed.get_world_size()
34
- except Exception:
35
- self.slice_id = 0
36
- self.slice_count = 1
37
- self.cached_index = cached_index
38
- self._init_seek_index()
39
- self._reader = self._get_reader()
40
- print("file {} slice_id {} row count {} total row count {}".format(
41
- self.file_path, self.slice_id, self.row_count, self.total_row_count)
42
- )
43
-
44
- def _init_seek_index(self):
45
- if self.cached_index:
46
- cache_path = "{}.index".format(self.file_path)
47
- assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
48
- self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
49
- print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
50
- self.file_path, self.slice_id))
51
- else:
52
- # make an iteration over the file to get row_count and line_idx-to-offset mapping
53
- fp = open(self.file_path, "r")
54
- print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
55
- self.file_path, self.slice_id))
56
- self.total_row_count = 0
57
- offset = 0
58
- self.lineid_to_offset = []
59
- for line in fp:
60
- self.lineid_to_offset.append(offset)
61
- self.total_row_count += 1
62
- offset += len(line.encode('utf-8'))
63
- self._compute_start_pos_and_row_count()
64
- print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
65
- self.file_path, self.slice_id))
66
-
67
- def _compute_start_pos_and_row_count(self):
68
- self.row_count = self.total_row_count // self.slice_count
69
- if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
70
- self.row_count += 1
71
- self.start_pos = self.row_count * self.slice_id
72
- else:
73
- self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
74
-
75
- def _get_reader(self):
76
- fp = open(self.file_path, "r")
77
- fp.seek(self.lineid_to_offset[self.start_pos])
78
- return fp
79
-
80
- def _seek(self, offset=0):
81
- try:
82
- print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
83
- self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
84
- self.data_cnt = offset
85
- except Exception:
86
- print("slice_id {} seek offset {}".format(self.slice_id, offset))
87
- self._reader.seek(self.lineid_to_offset[offset])
88
- self.data_cnt = offset
89
-
90
- def __del__(self):
91
- self._reader.close()
92
-
93
- def __len__(self):
94
- return self.row_count
95
-
96
- def get_total_row_count(self):
97
- return self.total_row_count
98
-
99
- def __getitem__(self, index):
100
- if self.data_cnt == self.row_count:
101
- print("reach the end of datafile, start a new reader")
102
- self.data_cnt = 0
103
- self._reader = self._get_reader()
104
- column_l = self._reader.readline().rstrip("\n").split(self.separator)
105
- self.data_cnt += 1
106
- column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
107
- return column_l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/mm_data/__init__.py DELETED
File without changes
data/mm_data/caption_dataset.py DELETED
@@ -1,160 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
- import string
11
-
12
- import numpy as np
13
- import torch
14
- import base64
15
- from torchvision import transforms
16
-
17
- from PIL import Image, ImageFile
18
-
19
- from data import data_utils
20
- from data.ofa_dataset import OFADataset
21
-
22
- ImageFile.LOAD_TRUNCATED_IMAGES = True
23
- ImageFile.MAX_IMAGE_PIXELS = None
24
- Image.MAX_IMAGE_PIXELS = None
25
-
26
- logger = logging.getLogger(__name__)
27
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
28
-
29
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
30
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
31
-
32
-
33
- def collate(samples, pad_idx, eos_idx):
34
- if len(samples) == 0:
35
- return {}
36
-
37
- def merge(key):
38
- return data_utils.collate_tokens(
39
- [s[key] for s in samples],
40
- pad_idx,
41
- eos_idx=eos_idx,
42
- )
43
-
44
- id = np.array([s["id"] for s in samples])
45
- src_tokens = merge("source")
46
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
47
-
48
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
49
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
50
-
51
- prev_output_tokens = None
52
- target = None
53
- if samples[0].get("target", None) is not None:
54
- target = merge("target")
55
- tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
56
- ntokens = tgt_lengths.sum().item()
57
-
58
- if samples[0].get("prev_output_tokens", None) is not None:
59
- prev_output_tokens = merge("prev_output_tokens")
60
- else:
61
- ntokens = src_lengths.sum().item()
62
-
63
- batch = {
64
- "id": id,
65
- "nsentences": len(samples),
66
- "ntokens": ntokens,
67
- "net_input": {
68
- "src_tokens": src_tokens,
69
- "src_lengths": src_lengths,
70
- "patch_images": patch_images,
71
- "patch_masks": patch_masks,
72
- "prev_output_tokens": prev_output_tokens
73
- },
74
- "target": target,
75
- }
76
-
77
- return batch
78
-
79
-
80
- class CaptionDataset(OFADataset):
81
- def __init__(
82
- self,
83
- split,
84
- dataset,
85
- bpe,
86
- src_dict,
87
- tgt_dict=None,
88
- max_src_length=128,
89
- max_tgt_length=30,
90
- patch_image_size=224,
91
- imagenet_default_mean_and_std=False,
92
- scst=False
93
- ):
94
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
95
- self.max_src_length = max_src_length
96
- self.max_tgt_length = max_tgt_length
97
- self.patch_image_size = patch_image_size
98
- self.scst = scst
99
-
100
- self.transtab = str.maketrans({key: None for key in string.punctuation})
101
-
102
- if imagenet_default_mean_and_std:
103
- mean = IMAGENET_DEFAULT_MEAN
104
- std = IMAGENET_DEFAULT_STD
105
- else:
106
- mean = [0.5, 0.5, 0.5]
107
- std = [0.5, 0.5, 0.5]
108
-
109
- self.patch_resize_transform = transforms.Compose([
110
- lambda image: image.convert("RGB"),
111
- transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
112
- transforms.ToTensor(),
113
- transforms.Normalize(mean=mean, std=std),
114
- ])
115
-
116
- if type(bpe).__name__ == 'GPT2BPE':
117
- self.prompt = " what does the image describe?"
118
- elif type(bpe).__name__ == 'BertBPE':
119
- self.prompt = "图片描述了什么内容?"
120
-
121
- def __getitem__(self, index):
122
- uniq_id, image, caption = self.dataset[index]
123
-
124
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
125
- patch_image = self.patch_resize_transform(image)
126
- patch_mask = torch.tensor([True])
127
-
128
- if self.split == 'train' and not self.scst:
129
- caption = caption.translate(self.transtab).strip()
130
- caption_token_list = caption.strip().split()
131
- tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
132
- else:
133
- caption = ' '.join(caption.strip().split())
134
- caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
135
- tgt_caption = '&&'.join(caption_list)
136
- src_item = self.encode_text(self.prompt)
137
- tgt_item = self.encode_text(" {}".format(tgt_caption))
138
-
139
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
140
- target_item = torch.cat([tgt_item, self.eos_item])
141
- prev_output_item = torch.cat([self.bos_item, tgt_item])
142
-
143
- example = {
144
- "id": uniq_id,
145
- "source": src_item,
146
- "patch_image": patch_image,
147
- "patch_mask": patch_mask,
148
- "target": target_item,
149
- "prev_output_tokens": prev_output_item
150
- }
151
- return example
152
-
153
- def collater(self, samples, pad_to_length=None):
154
- """Merge a list of samples to form a mini-batch.
155
- Args:
156
- samples (List[dict]): samples to collate
157
- Returns:
158
- dict: a mini-batch containing the data of the task
159
- """
160
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/mm_data/ocr_dataset.py DELETED
@@ -1,210 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
- import random
11
- import functools
12
-
13
- import torch
14
- import base64
15
- from torchvision import transforms
16
- from torchvision.transforms import InterpolationMode
17
- from torchvision.transforms import functional as F
18
-
19
- from PIL import Image, ImageFile
20
-
21
- from zhconv import convert
22
- import unicodedata
23
-
24
- from data import data_utils
25
- from data.ofa_dataset import OFADataset
26
-
27
- ImageFile.LOAD_TRUNCATED_IMAGES = True
28
- ImageFile.MAX_IMAGE_PIXELS = None
29
- Image.MAX_IMAGE_PIXELS = None
30
-
31
- logger = logging.getLogger(__name__)
32
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
33
-
34
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
35
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
36
-
37
-
38
- def collate(samples, pad_idx, eos_idx):
39
- if len(samples) == 0:
40
- return {}
41
-
42
- def merge(key):
43
- return data_utils.collate_tokens(
44
- [s[key] for s in samples],
45
- pad_idx,
46
- eos_idx=eos_idx,
47
- )
48
-
49
- id = np.array([s["id"] for s in samples])
50
- src_tokens = merge("source")
51
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
52
-
53
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
54
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
55
-
56
- prev_output_tokens = None
57
- target = None
58
- if samples[0].get("target", None) is not None:
59
- target = merge("target")
60
- tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
61
- ntokens = tgt_lengths.sum().item()
62
-
63
- if samples[0].get("prev_output_tokens", None) is not None:
64
- prev_output_tokens = merge("prev_output_tokens")
65
- else:
66
- ntokens = src_lengths.sum().item()
67
-
68
- batch = {
69
- "id": id,
70
- "nsentences": len(samples),
71
- "ntokens": ntokens,
72
- "net_input": {
73
- "src_tokens": src_tokens,
74
- "src_lengths": src_lengths,
75
- "patch_images": patch_images,
76
- "patch_masks": patch_masks,
77
- "prev_output_tokens": prev_output_tokens
78
- },
79
- "target": target,
80
- }
81
-
82
- return batch
83
-
84
-
85
- def ocr_resize(img, patch_image_size, is_document=False, split='train'):
86
- img = img.convert("RGB")
87
- width, height = img.size
88
-
89
- if is_document:
90
- new_height, new_width = 64, 1920
91
- else:
92
- if width >= height:
93
- new_width = max(64, patch_image_size)
94
- new_height = max(64, int(patch_image_size * (height / width)))
95
- if split != 'train':
96
- top = int((patch_image_size - new_height) // 2)
97
- else:
98
- top = random.randint(0, patch_image_size - new_height)
99
- bottom = patch_image_size - new_height - top
100
- left, right = 0, 0
101
- else:
102
- new_height = max(64, patch_image_size)
103
- new_width = max(64, int(patch_image_size * (width / height)))
104
- if split != 'train':
105
- left = int((patch_image_size - new_width) // 2)
106
- else:
107
- left = random.randint(0, patch_image_size - new_width)
108
- right = patch_image_size - new_width - left
109
- top, bottom = 0, 0
110
-
111
- img_new = F.resize(
112
- img,
113
- [new_height, new_width],
114
- interpolation=InterpolationMode.BICUBIC,
115
- )
116
-
117
- if is_document:
118
- img_split = transforms.ToTensor()(img_new).chunk(4, dim=-1)
119
- img_new = transforms.ToPILImage()(torch.cat(img_split, dim=-2))
120
- new_width, new_height = img_new.size
121
- top = random.randint(0, patch_image_size - new_height)
122
- bottom = patch_image_size - new_height - top
123
- left, right = 0, 0
124
-
125
- img_new = F.pad(img_new, padding=[left, top, right, bottom], padding_mode="edge")
126
- assert img_new.size == (patch_image_size, patch_image_size)
127
-
128
- return img_new
129
-
130
-
131
- class OcrDataset(OFADataset):
132
- def __init__(
133
- self,
134
- split,
135
- dataset,
136
- bpe,
137
- src_dict,
138
- tgt_dict=None,
139
- max_src_length=80,
140
- max_tgt_length=30,
141
- patch_image_size=224,
142
- imagenet_default_mean_and_std=False,
143
- is_document=False,
144
- ):
145
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
146
- self.max_src_length = max_src_length
147
- self.max_tgt_length = max_tgt_length
148
- self.patch_image_size = patch_image_size
149
-
150
- if imagenet_default_mean_and_std:
151
- mean = IMAGENET_DEFAULT_MEAN
152
- std = IMAGENET_DEFAULT_STD
153
- else:
154
- mean = [0.5, 0.5, 0.5]
155
- std = [0.5, 0.5, 0.5]
156
-
157
- self.patch_resize_transform = transforms.Compose(
158
- [
159
- lambda image: ocr_resize(
160
- image, patch_image_size, is_document=is_document, split=split,
161
- ),
162
- transforms.ToTensor(),
163
- transforms.Normalize(mean=mean, std=std),
164
- ]
165
- )
166
-
167
- self.bpe = bpe
168
- if type(bpe).__name__ == 'GPT2BPE':
169
- self.prompt = " what are the texts on the image?"
170
- elif type(bpe).__name__ == 'BertBPE':
171
- self.prompt = "图片上的文字是什么?"
172
-
173
- def __getitem__(self, index):
174
- uniq_id, image, caption = self.dataset[index]
175
-
176
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
177
- patch_image = self.patch_resize_transform(image)
178
- patch_mask = torch.tensor([True])
179
-
180
- caption = unicodedata.normalize("NFKC", convert(caption, "zh-hans"))
181
- if type(self.bpe).__name__ == 'GPT2BPE':
182
- caption_token_list = caption.lower().strip().split()
183
- tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
184
- elif type(self.bpe).__name__ == 'BertBPE':
185
- tgt_caption = caption[: self.max_tgt_length].lower()
186
- src_item = self.encode_text(self.prompt)
187
- tgt_item = self.encode_text(" {}".format(tgt_caption))
188
-
189
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
190
- target_item = torch.cat([tgt_item, self.eos_item])
191
- prev_output_item = torch.cat([self.bos_item, tgt_item])
192
-
193
- example = {
194
- "id": uniq_id,
195
- "source": src_item,
196
- "patch_image": patch_image,
197
- "patch_mask": patch_mask,
198
- "target": target_item,
199
- "prev_output_tokens": prev_output_item,
200
- }
201
- return example
202
-
203
- def collater(self, samples, pad_to_length=None):
204
- """Merge a list of samples to form a mini-batch.
205
- Args:
206
- samples (List[dict]): samples to collate
207
- Returns:
208
- dict: a mini-batch containing the data required for the task
209
- """
210
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/mm_data/refcoco_dataset.py DELETED
@@ -1,174 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
-
11
- import numpy as np
12
- import torch
13
- import base64
14
- import utils.transforms as T
15
-
16
- from PIL import Image, ImageFile
17
-
18
- from data import data_utils
19
- from data.ofa_dataset import OFADataset
20
-
21
- ImageFile.LOAD_TRUNCATED_IMAGES = True
22
- ImageFile.MAX_IMAGE_PIXELS = None
23
- Image.MAX_IMAGE_PIXELS = None
24
-
25
- logger = logging.getLogger(__name__)
26
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
-
28
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
-
31
-
32
- def collate(samples, pad_idx, eos_idx):
33
- if len(samples) == 0:
34
- return {}
35
-
36
- def merge(key):
37
- return data_utils.collate_tokens(
38
- [s[key] for s in samples],
39
- pad_idx,
40
- eos_idx=eos_idx,
41
- )
42
-
43
- id = np.array([s["id"] for s in samples])
44
- src_tokens = merge("source")
45
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
-
47
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
-
50
- w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
51
- h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
52
- region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
53
-
54
- prev_output_tokens = None
55
- target = None
56
- if samples[0].get("target", None) is not None:
57
- target = merge("target")
58
- tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
59
- ntokens = tgt_lengths.sum().item()
60
-
61
- if samples[0].get("prev_output_tokens", None) is not None:
62
- prev_output_tokens = merge("prev_output_tokens")
63
- else:
64
- ntokens = src_lengths.sum().item()
65
-
66
- batch = {
67
- "id": id,
68
- "nsentences": len(samples),
69
- "ntokens": ntokens,
70
- "net_input": {
71
- "src_tokens": src_tokens,
72
- "src_lengths": src_lengths,
73
- "patch_images": patch_images,
74
- "patch_masks": patch_masks,
75
- "prev_output_tokens": prev_output_tokens
76
- },
77
- "target": target,
78
- "w_resize_ratios": w_resize_ratios,
79
- "h_resize_ratios": h_resize_ratios,
80
- "region_coords": region_coords
81
- }
82
-
83
- return batch
84
-
85
-
86
- class RefcocoDataset(OFADataset):
87
- def __init__(
88
- self,
89
- split,
90
- dataset,
91
- bpe,
92
- src_dict,
93
- tgt_dict=None,
94
- max_src_length=80,
95
- max_tgt_length=30,
96
- patch_image_size=512,
97
- imagenet_default_mean_and_std=False,
98
- num_bins=1000,
99
- max_image_size=512
100
- ):
101
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
102
- self.max_src_length = max_src_length
103
- self.max_tgt_length = max_tgt_length
104
- self.patch_image_size = patch_image_size
105
- self.num_bins = num_bins
106
-
107
- if imagenet_default_mean_and_std:
108
- mean = IMAGENET_DEFAULT_MEAN
109
- std = IMAGENET_DEFAULT_STD
110
- else:
111
- mean = [0.5, 0.5, 0.5]
112
- std = [0.5, 0.5, 0.5]
113
-
114
- # for positioning
115
- self.positioning_transform = T.Compose([
116
- T.RandomResize([patch_image_size], max_size=patch_image_size),
117
- T.ToTensor(),
118
- T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
119
- ])
120
-
121
- if type(bpe).__name__ == 'GPT2BPE':
122
- self.prompt = ' which region does the text " {} " describe?'
123
- elif type(bpe).__name__ == 'BertBPE':
124
- self.prompt = '这段文字" {} "描述的是哪个区域?'
125
-
126
- def __getitem__(self, index):
127
- uniq_id, base64_str, text, region_coord = self.dataset[index]
128
-
129
- image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
130
- w, h = image.size
131
- boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
132
- x0, y0, x1, y1 = region_coord.strip().split(',')
133
- region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
134
- boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
135
- boxes_target["labels"] = np.array([0])
136
- boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
137
-
138
- patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
139
- resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
140
- patch_mask = torch.tensor([True])
141
- quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
142
- quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
143
- quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
144
- quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
145
- region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
146
- src_caption = self.pre_caption(text, self.max_src_length)
147
- src_item = self.encode_text(self.prompt.format(src_caption))
148
- tgt_item = self.encode_text(region_coord, use_bpe=False)
149
-
150
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
151
- target_item = torch.cat([tgt_item, self.eos_item])
152
- prev_output_item = torch.cat([self.bos_item, tgt_item])
153
-
154
- example = {
155
- "id": uniq_id,
156
- "source": src_item,
157
- "patch_image": patch_image,
158
- "patch_mask": patch_mask,
159
- "target": target_item,
160
- "prev_output_tokens": prev_output_item,
161
- "w_resize_ratio": resize_w / w,
162
- "h_resize_ratio": resize_h / h,
163
- "region_coord": region
164
- }
165
- return example
166
-
167
- def collater(self, samples, pad_to_length=None):
168
- """Merge a list of samples to form a mini-batch.
169
- Args:
170
- samples (List[dict]): samples to collate
171
- Returns:
172
- dict: a mini-batch containing the data of the task
173
- """
174
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/mm_data/snli_ve_dataset.py DELETED
@@ -1,203 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
-
11
- import numpy as np
12
- import torch
13
- import base64
14
- from torchvision import transforms
15
-
16
- from PIL import Image, ImageFile
17
-
18
- from data import data_utils
19
- from data.ofa_dataset import OFADataset
20
-
21
- ImageFile.LOAD_TRUNCATED_IMAGES = True
22
- ImageFile.MAX_IMAGE_PIXELS = None
23
- Image.MAX_IMAGE_PIXELS = None
24
-
25
- logger = logging.getLogger(__name__)
26
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
-
28
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
-
31
-
32
- def collate(samples, pad_idx, eos_idx):
33
- if len(samples) == 0:
34
- return {}
35
-
36
- def merge(key):
37
- return data_utils.collate_tokens(
38
- [s[key] for s in samples],
39
- pad_idx,
40
- eos_idx=eos_idx,
41
- )
42
-
43
- id = np.array([s["id"] for s in samples])
44
- src_tokens = merge("source")
45
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
-
47
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
-
50
- ref_dict = None
51
- if samples[0].get("ref_dict", None) is not None:
52
- ref_dict = np.array([s['ref_dict'] for s in samples])
53
-
54
- constraint_masks = None
55
- if samples[0].get("constraint_mask", None) is not None:
56
- constraint_masks = merge("constraint_mask")
57
-
58
- decoder_prompts = None
59
- if samples[0].get("decoder_prompt", None) is not None:
60
- decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
61
-
62
- prev_output_tokens = None
63
- target = None
64
- if samples[0].get("target", None) is not None:
65
- target = merge("target")
66
- tgt_lengths = torch.LongTensor(
67
- [s["target"].ne(pad_idx).long().sum() for s in samples]
68
- )
69
- ntokens = tgt_lengths.sum().item()
70
-
71
- if samples[0].get("prev_output_tokens", None) is not None:
72
- prev_output_tokens = merge("prev_output_tokens")
73
- else:
74
- ntokens = src_lengths.sum().item()
75
-
76
- batch = {
77
- "id": id,
78
- "nsentences": len(samples),
79
- "ntokens": ntokens,
80
- "net_input": {
81
- "src_tokens": src_tokens,
82
- "src_lengths": src_lengths,
83
- "patch_images": patch_images,
84
- "patch_masks": patch_masks,
85
- "prev_output_tokens": prev_output_tokens
86
- },
87
- "ref_dict": ref_dict,
88
- "constraint_masks": constraint_masks,
89
- "decoder_prompts": decoder_prompts,
90
- "target": target
91
- }
92
-
93
- return batch
94
-
95
-
96
- class SnliVeDataset(OFADataset):
97
- def __init__(
98
- self,
99
- split,
100
- dataset,
101
- bpe,
102
- src_dict,
103
- tgt_dict=None,
104
- max_src_length=80,
105
- max_tgt_length=30,
106
- patch_image_size=224,
107
- add_caption=False,
108
- constraint_trie=None,
109
- imagenet_default_mean_and_std=False,
110
- prompt_type="none"
111
- ):
112
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
113
- self.max_src_length = max_src_length
114
- self.max_tgt_length = max_tgt_length
115
- self.patch_image_size = patch_image_size
116
-
117
- self.add_caption = add_caption
118
- self.constraint_trie = constraint_trie
119
- self.prompt_type = prompt_type
120
-
121
- if imagenet_default_mean_and_std:
122
- mean = IMAGENET_DEFAULT_MEAN
123
- std = IMAGENET_DEFAULT_STD
124
- else:
125
- mean = [0.5, 0.5, 0.5]
126
- std = [0.5, 0.5, 0.5]
127
-
128
- self.patch_resize_transform = transforms.Compose([
129
- lambda image: image.convert("RGB"),
130
- transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
131
- transforms.ToTensor(),
132
- transforms.Normalize(mean=mean, std=std),
133
- ])
134
-
135
- def __getitem__(self, index):
136
- uniq_id, image, hypothesis, caption, label = self.dataset[index]
137
- if label == 'contradiction':
138
- label = 'no'
139
- elif label == 'entailment':
140
- label = 'yes'
141
- elif label == 'neutral':
142
- label = 'maybe'
143
- else:
144
- raise NotImplementedError
145
-
146
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
147
- patch_image = self.patch_resize_transform(image)
148
- patch_mask = torch.tensor([True])
149
-
150
- hypothesis = self.pre_caption(hypothesis, self.max_src_length)
151
- src_item = self.encode_text(' does the image describe " {} "?'.format(hypothesis))
152
- tgt_item = self.encode_text(" {}".format(label))
153
- ref_dict = {label: 1.0}
154
-
155
- if self.add_caption:
156
- caption = self.pre_caption(caption, self.max_src_length)
157
- src_item = self.encode_text(' can image and text1 " {} " imply text2 " {} "?'.format(caption, hypothesis))
158
-
159
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
160
- if self.prompt_type == 'none':
161
- prev_output_item = torch.cat([self.bos_item, tgt_item])
162
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
163
- decoder_prompt = self.bos_item
164
- elif self.prompt_type == 'src':
165
- prev_output_item = torch.cat([src_item, tgt_item])
166
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
167
- decoder_prompt = src_item
168
- elif self.prompt_type == 'prev_output':
169
- prev_output_item = torch.cat([src_item[:-1], tgt_item])
170
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
171
- decoder_prompt = src_item[:-1]
172
- else:
173
- raise NotImplementedError
174
- target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
175
-
176
- example = {
177
- "id": uniq_id,
178
- "source": src_item,
179
- "patch_image": patch_image,
180
- "patch_mask": patch_mask,
181
- "target": target_item,
182
- "prev_output_tokens": prev_output_item,
183
- "decoder_prompt": decoder_prompt,
184
- "ref_dict": ref_dict,
185
- }
186
- if self.constraint_trie is not None:
187
- constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
188
- start_idx = len(target_item) - len(tgt_item) - 1
189
- for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
190
- constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
191
- constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
192
- constraint_mask[i][constraint_nodes] = True
193
- example["constraint_mask"] = constraint_mask
194
- return example
195
-
196
- def collater(self, samples, pad_to_length=None):
197
- """Merge a list of samples to form a mini-batch.
198
- Args:
199
- samples (List[dict]): samples to collate
200
- Returns:
201
- dict: a mini-batch containing the data of the task
202
- """
203
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/mm_data/vqa_gen_dataset.py DELETED
@@ -1,218 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import logging
9
- import warnings
10
-
11
- import numpy as np
12
- import torch
13
- import base64
14
- from torchvision import transforms
15
-
16
- from PIL import Image, ImageFile
17
-
18
- from data import data_utils
19
- from data.ofa_dataset import OFADataset
20
-
21
- ImageFile.LOAD_TRUNCATED_IMAGES = True
22
- ImageFile.MAX_IMAGE_PIXELS = None
23
- Image.MAX_IMAGE_PIXELS = None
24
-
25
- logger = logging.getLogger(__name__)
26
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
27
-
28
- IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
29
- IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
30
-
31
-
32
- def collate(samples, pad_idx, eos_idx):
33
- if len(samples) == 0:
34
- return {}
35
-
36
- def merge(key):
37
- return data_utils.collate_tokens(
38
- [s[key] for s in samples],
39
- pad_idx,
40
- eos_idx=eos_idx,
41
- )
42
-
43
- id = np.array([s["id"] for s in samples])
44
- src_tokens = merge("source")
45
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
46
-
47
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
48
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
49
-
50
- conf = None
51
- if samples[0].get("conf", None) is not None:
52
- conf = torch.cat([s['conf'] for s in samples], dim=0)
53
-
54
- ref_dict = None
55
- if samples[0].get("ref_dict", None) is not None:
56
- ref_dict = np.array([s['ref_dict'] for s in samples])
57
-
58
- constraint_masks = None
59
- if samples[0].get("constraint_mask", None) is not None:
60
- constraint_masks = merge("constraint_mask")
61
-
62
- decoder_prompts = None
63
- if samples[0].get("decoder_prompt", None) is not None:
64
- decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
65
-
66
- prefix_tokens = None
67
- if samples[0].get("decoder_prompt", None) is not None:
68
- prefix_tokens = merge("decoder_prompt")
69
- prefix_tokens = prefix_tokens[:, 1:]
70
-
71
- prev_output_tokens = None
72
- target = None
73
- if samples[0].get("target", None) is not None:
74
- target = merge("target")
75
- tgt_lengths = torch.LongTensor(
76
- [s["target"].ne(pad_idx).long().sum() for s in samples]
77
- )
78
- ntokens = tgt_lengths.sum().item()
79
-
80
- if samples[0].get("prev_output_tokens", None) is not None:
81
- prev_output_tokens = merge("prev_output_tokens")
82
- else:
83
- ntokens = src_lengths.sum().item()
84
-
85
- batch = {
86
- "id": id,
87
- "nsentences": len(samples),
88
- "ntokens": ntokens,
89
- "net_input": {
90
- "src_tokens": src_tokens,
91
- "src_lengths": src_lengths,
92
- "patch_images": patch_images,
93
- "patch_masks": patch_masks,
94
- "prev_output_tokens": prev_output_tokens
95
- },
96
- "conf": conf,
97
- "ref_dict": ref_dict,
98
- "constraint_masks": constraint_masks,
99
- "decoder_prompts": decoder_prompts,
100
- "target": target,
101
- "prefix_tokens": prefix_tokens
102
- }
103
-
104
- return batch
105
-
106
-
107
- class VqaGenDataset(OFADataset):
108
- def __init__(
109
- self,
110
- split,
111
- dataset,
112
- bpe,
113
- src_dict,
114
- tgt_dict=None,
115
- max_src_length=128,
116
- max_object_length=30,
117
- max_tgt_length=30,
118
- patch_image_size=224,
119
- add_object=False,
120
- constraint_trie=None,
121
- imagenet_default_mean_and_std=False,
122
- prompt_type="none"
123
- ):
124
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
125
- self.max_src_length = max_src_length
126
- self.max_object_length = max_object_length
127
- self.max_tgt_length = max_tgt_length
128
- self.patch_image_size = patch_image_size
129
-
130
- self.add_object = add_object
131
- self.constraint_trie = constraint_trie
132
- self.prompt_type = prompt_type
133
-
134
- if imagenet_default_mean_and_std:
135
- mean = IMAGENET_DEFAULT_MEAN
136
- std = IMAGENET_DEFAULT_STD
137
- else:
138
- mean = [0.5, 0.5, 0.5]
139
- std = [0.5, 0.5, 0.5]
140
-
141
- self.patch_resize_transform = transforms.Compose([
142
- lambda image: image.convert("RGB"),
143
- transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
144
- transforms.ToTensor(),
145
- transforms.Normalize(mean=mean, std=std),
146
- ])
147
-
148
- def __getitem__(self, index):
149
- item = self.dataset[index]
150
- if len(item) == 5:
151
- uniq_id, image, question, ref, predict_objects = item
152
- else:
153
- uniq_id, image, question, ref, predict_objects, caption = item
154
-
155
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
156
- patch_image = self.patch_resize_transform(image)
157
- patch_mask = torch.tensor([True])
158
-
159
- question = self.pre_question(question, self.max_src_length)
160
- question = question + '?' if not question.endswith('?') else question
161
- src_item = self.encode_text(' {}'.format(question))
162
-
163
- ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
164
- answer = max(ref_dict, key=ref_dict.get)
165
- conf = torch.tensor([ref_dict[answer]])
166
- tgt_item = self.encode_text(" {}".format(answer))
167
-
168
- if self.add_object and predict_objects is not None:
169
- predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
170
- predict_object_item = self.encode_text(" object: {}".format(predict_object_seq))
171
- src_item = torch.cat([src_item, predict_object_item])
172
-
173
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
174
- if self.prompt_type == 'none':
175
- prev_output_item = torch.cat([self.bos_item, tgt_item])
176
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
177
- decoder_prompt = self.bos_item
178
- elif self.prompt_type == 'src':
179
- prev_output_item = torch.cat([src_item, tgt_item])
180
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
181
- decoder_prompt = src_item
182
- elif self.prompt_type == 'prev_output':
183
- prev_output_item = torch.cat([src_item[:-1], tgt_item])
184
- target_item = torch.cat([prev_output_item[1:], self.eos_item])
185
- decoder_prompt = src_item[:-1]
186
- else:
187
- raise NotImplementedError
188
- target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
189
-
190
- example = {
191
- "id": uniq_id,
192
- "source": src_item,
193
- "patch_image": patch_image,
194
- "patch_mask": patch_mask,
195
- "target": target_item,
196
- "prev_output_tokens": prev_output_item,
197
- "decoder_prompt": decoder_prompt,
198
- "ref_dict": ref_dict,
199
- "conf": conf,
200
- }
201
- if self.constraint_trie is not None:
202
- constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
203
- start_idx = len(target_item) - len(tgt_item) - 1
204
- for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
205
- constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
206
- constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
207
- constraint_mask[i][constraint_nodes] = True
208
- example["constraint_mask"] = constraint_mask
209
- return example
210
-
211
- def collater(self, samples, pad_to_length=None):
212
- """Merge a list of samples to form a mini-batch.
213
- Args:
214
- samples (List[dict]): samples to collate
215
- Returns:
216
- dict: a mini-batch containing the data of the task
217
- """
218
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlg_data/summary_dataset.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- prev_output_tokens = None
33
- target = None
34
- if samples[0].get("target", None) is not None:
35
- target = merge("target")
36
- tgt_lengths = torch.LongTensor(
37
- [s["target"].ne(pad_idx).long().sum() for s in samples]
38
- )
39
- ntokens = tgt_lengths.sum().item()
40
-
41
- if samples[0].get("prev_output_tokens", None) is not None:
42
- prev_output_tokens = merge("prev_output_tokens")
43
- else:
44
- ntokens = src_lengths.sum().item()
45
-
46
- target_strs = np.array([s["target_str"] for s in samples])
47
-
48
- batch = {
49
- "nsentences": len(samples),
50
- "ntokens": ntokens,
51
- "net_input": {
52
- "src_tokens": src_tokens,
53
- "src_lengths": src_lengths,
54
- "prev_output_tokens": prev_output_tokens
55
- },
56
- "target": target,
57
- "target_strs": target_strs
58
- }
59
-
60
- return batch
61
-
62
-
63
- class SummaryDataset(OFADataset):
64
- def __init__(
65
- self,
66
- split,
67
- dataset,
68
- bpe,
69
- src_dict,
70
- tgt_dict=None,
71
- code_dict_size=8192,
72
- num_bins=1000,
73
- max_src_length=512,
74
- max_tgt_length=128,
75
- noise_ratio=0.0
76
- ):
77
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
78
- self.max_src_length = max_src_length
79
- self.max_tgt_length = max_tgt_length
80
- self.code_dict_size = code_dict_size
81
- self.num_bins = num_bins
82
- self.noise_ratio = noise_ratio
83
-
84
- if type(bpe).__name__ == 'GPT2BPE':
85
- self.prompt = ' what is the summary of article " {} "?'
86
- elif type(bpe).__name__ == 'BertBPE':
87
- self.prompt = "{} 请用一个句子简单总结上文:"
88
-
89
- def __getitem__(self, index):
90
- source, target = self.dataset[index]
91
- target_str = target.lower()
92
-
93
- source = self.pre_caption(source, max_words=self.max_src_length)
94
- target = self.pre_caption(target, max_words=self.max_tgt_length)
95
- source = source.replace('<unk>', 'unk')
96
- target = target.replace('<unk>', 'unk')
97
-
98
- src_item = self.encode_text(
99
- self.prompt.format(source),
100
- length=self.max_src_length
101
- )
102
- tgt_item = self.encode_text('{}'.format(target))
103
- noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio)
104
-
105
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
106
- target_item = torch.cat([tgt_item, self.eos_item])
107
- prev_output_item = torch.cat([self.bos_item, noise_tgt_item])
108
-
109
- example = {
110
- "source": src_item,
111
- "target": target_item,
112
- "prev_output_tokens": prev_output_item,
113
- "target_str": target_str
114
- }
115
- return example
116
-
117
- def add_noise_to_tgt(self, target, p):
118
- noise_indices = torch.FloatTensor(target.size(0)).uniform_() < p
119
- target[noise_indices] = torch.randint(
120
- 4, len(self.src_dict) - self.code_dict_size - self.num_bins, size=(noise_indices.sum(),)
121
- )
122
- return target
123
-
124
- def collater(self, samples, pad_to_length=None):
125
- """Merge a list of samples to form a mini-batch.
126
- Args:
127
- samples (List[dict]): samples to collate
128
- Returns:
129
- dict: a mini-batch containing the data of the task
130
- """
131
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/cola_dataset.py DELETED
@@ -1,138 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class COLADataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- sentence, label = self.dataset[index]
91
- if label == '0':
92
- label = 'no'
93
- elif label == '1':
94
- label = 'yes'
95
- else:
96
- raise NotImplementedError
97
-
98
- sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
99
- src_item = self.encode_text(' is the text " {} " grammatically correct?'.format(sentence))
100
- tgt_item = self.encode_text(" {}".format(label))
101
- assert tgt_item.size(0) == 1
102
- ref_dict = {label: 1.0}
103
-
104
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
105
- if self.prompt_type == 'none':
106
- prev_output_item = self.bos_item
107
- target_item = tgt_item
108
- elif self.prompt_type == 'src':
109
- prev_output_item = src_item.clone()
110
- target_item = torch.cat([prev_output_item[1:], tgt_item])
111
- elif self.prompt_type == 'prev_output':
112
- prev_output_item = src_item[:-1].clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- else:
115
- raise NotImplementedError
116
- target_item[:-1] = self.tgt_dict.pad()
117
-
118
- example = {
119
- "source": src_item,
120
- "target": target_item,
121
- "prev_output_tokens": prev_output_item,
122
- "ref_dict": ref_dict,
123
- }
124
- if self.constraint_trie is not None:
125
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
126
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
127
- constraint_mask[-1][constraint_nodes] = True
128
- example["constraint_mask"] = constraint_mask
129
- return example
130
-
131
- def collater(self, samples, pad_to_length=None):
132
- """Merge a list of samples to form a mini-batch.
133
- Args:
134
- samples (List[dict]): samples to collate
135
- Returns:
136
- dict: a mini-batch containing the data of the task
137
- """
138
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/mnli_dataset.py DELETED
@@ -1,143 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class MNLIDataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- sentence1, sentence2, label = self.dataset[index]
91
- if label == '0':
92
- label = 'maybe'
93
- elif label == '1':
94
- label = 'yes'
95
- elif label == '2':
96
- label = 'no'
97
- else:
98
- raise NotImplementedError
99
-
100
- sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
101
- sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
102
- src_item = self.encode_text(
103
- ' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2)
104
- )
105
- tgt_item = self.encode_text(" {}".format(label))
106
- assert tgt_item.size(0) == 1
107
- ref_dict = {label: 1.0}
108
-
109
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
110
- if self.prompt_type == 'none':
111
- prev_output_item = self.bos_item
112
- target_item = tgt_item
113
- elif self.prompt_type == 'src':
114
- prev_output_item = src_item.clone()
115
- target_item = torch.cat([prev_output_item[1:], tgt_item])
116
- elif self.prompt_type == 'prev_output':
117
- prev_output_item = src_item[:-1].clone()
118
- target_item = torch.cat([prev_output_item[1:], tgt_item])
119
- else:
120
- raise NotImplementedError
121
- target_item[:-1] = self.tgt_dict.pad()
122
-
123
- example = {
124
- "source": src_item,
125
- "target": target_item,
126
- "prev_output_tokens": prev_output_item,
127
- "ref_dict": ref_dict,
128
- }
129
- if self.constraint_trie is not None:
130
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
131
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
132
- constraint_mask[-1][constraint_nodes] = True
133
- example["constraint_mask"] = constraint_mask
134
- return example
135
-
136
- def collater(self, samples, pad_to_length=None):
137
- """Merge a list of samples to form a mini-batch.
138
- Args:
139
- samples (List[dict]): samples to collate
140
- Returns:
141
- dict: a mini-batch containing the data of the task
142
- """
143
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/mrpc_dataset.py DELETED
@@ -1,141 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class MRPCDataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- sentence1, sentence2, label = self.dataset[index]
91
- if label == '0':
92
- label = 'no'
93
- elif label == '1':
94
- label = 'yes'
95
- else:
96
- raise NotImplementedError
97
-
98
- sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
99
- sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
100
- src_item = self.encode_text(
101
- ' does text1 " {} " and text2 " {} " have the same semantics?'.format(sentence1, sentence2),
102
- )
103
- tgt_item = self.encode_text(" {}".format(label))
104
- assert tgt_item.size(0) == 1
105
- ref_dict = {label: 1.0}
106
-
107
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
108
- if self.prompt_type == 'none':
109
- prev_output_item = self.bos_item
110
- target_item = tgt_item
111
- elif self.prompt_type == 'src':
112
- prev_output_item = src_item.clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- elif self.prompt_type == 'prev_output':
115
- prev_output_item = src_item[:-1].clone()
116
- target_item = torch.cat([prev_output_item[1:], tgt_item])
117
- else:
118
- raise NotImplementedError
119
- target_item[:-1] = self.tgt_dict.pad()
120
-
121
- example = {
122
- "source": src_item,
123
- "target": target_item,
124
- "prev_output_tokens": prev_output_item,
125
- "ref_dict": ref_dict,
126
- }
127
- if self.constraint_trie is not None:
128
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
129
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
130
- constraint_mask[-1][constraint_nodes] = True
131
- example["constraint_mask"] = constraint_mask
132
- return example
133
-
134
- def collater(self, samples, pad_to_length=None):
135
- """Merge a list of samples to form a mini-batch.
136
- Args:
137
- samples (List[dict]): samples to collate
138
- Returns:
139
- dict: a mini-batch containing the data of the task
140
- """
141
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/qnli_dataset.py DELETED
@@ -1,141 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class QNLIDataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- question, sentence, label = self.dataset[index]
91
- if label == '0' or label == 'not_entailment':
92
- label = 'no'
93
- elif label == '1' or label == 'entailment':
94
- label = 'yes'
95
- else:
96
- raise NotImplementedError
97
-
98
- question = ' '.join(question.lower().strip().split()[:self.max_src_length])
99
- sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
100
- src_item = self.encode_text(
101
- ' does " {} " contain the answer to question " {} "?'.format(sentence, question)
102
- )
103
- tgt_item = self.encode_text(" {}".format(label))
104
- assert tgt_item.size(0) == 1
105
- ref_dict = {label: 1.0}
106
-
107
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
108
- if self.prompt_type == 'none':
109
- prev_output_item = self.bos_item
110
- target_item = tgt_item
111
- elif self.prompt_type == 'src':
112
- prev_output_item = src_item.clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- elif self.prompt_type == 'prev_output':
115
- prev_output_item = src_item[:-1].clone()
116
- target_item = torch.cat([prev_output_item[1:], tgt_item])
117
- else:
118
- raise NotImplementedError
119
- target_item[:-1] = self.tgt_dict.pad()
120
-
121
- example = {
122
- "source": src_item,
123
- "target": target_item,
124
- "prev_output_tokens": prev_output_item,
125
- "ref_dict": ref_dict,
126
- }
127
- if self.constraint_trie is not None:
128
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
129
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
130
- constraint_mask[-1][constraint_nodes] = True
131
- example["constraint_mask"] = constraint_mask
132
- return example
133
-
134
- def collater(self, samples, pad_to_length=None):
135
- """Merge a list of samples to form a mini-batch.
136
- Args:
137
- samples (List[dict]): samples to collate
138
- Returns:
139
- dict: a mini-batch containing the data of the task
140
- """
141
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/qqp_dataset.py DELETED
@@ -1,141 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class QQPDataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- question1, question2, label = self.dataset[index]
91
- if label == '0':
92
- label = 'no'
93
- elif label == '1':
94
- label = 'yes'
95
- else:
96
- raise NotImplementedError
97
-
98
- question1 = ' '.join(question1.lower().strip().split()[:self.max_src_length])
99
- question2 = ' '.join(question2.lower().strip().split()[:self.max_src_length])
100
- src_item = self.encode_text(
101
- ' is question " {} " and question " {} " equivalent?'.format(question1, question2)
102
- )
103
- tgt_item = self.encode_text(" {}".format(label))
104
- assert tgt_item.size(0) == 1
105
- ref_dict = {label: 1.0}
106
-
107
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
108
- if self.prompt_type == 'none':
109
- prev_output_item = self.bos_item
110
- target_item = tgt_item
111
- elif self.prompt_type == 'src':
112
- prev_output_item = src_item.clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- elif self.prompt_type == 'prev_output':
115
- prev_output_item = src_item[:-1].clone()
116
- target_item = torch.cat([prev_output_item[1:], tgt_item])
117
- else:
118
- raise NotImplementedError
119
- target_item[:-1] = self.tgt_dict.pad()
120
-
121
- example = {
122
- "source": src_item,
123
- "target": target_item,
124
- "prev_output_tokens": prev_output_item,
125
- "ref_dict": ref_dict,
126
- }
127
- if self.constraint_trie is not None:
128
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
129
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
130
- constraint_mask[-1][constraint_nodes] = True
131
- example["constraint_mask"] = constraint_mask
132
- return example
133
-
134
- def collater(self, samples, pad_to_length=None):
135
- """Merge a list of samples to form a mini-batch.
136
- Args:
137
- samples (List[dict]): samples to collate
138
- Returns:
139
- dict: a mini-batch containing the data of the task
140
- """
141
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/rte_dataset.py DELETED
@@ -1,141 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class RTEDataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- sentence1, sentence2, label = self.dataset[index]
91
- if label == 'not_entailment':
92
- label = 'no'
93
- elif label == 'entailment':
94
- label = 'yes'
95
- else:
96
- raise NotImplementedError
97
-
98
- sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
99
- sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
100
- src_item = self.encode_text(
101
- ' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2),
102
- )
103
- tgt_item = self.encode_text(" {}".format(label))
104
- assert tgt_item.size(0) == 1
105
- ref_dict = {label: 1.0}
106
-
107
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
108
- if self.prompt_type == 'none':
109
- prev_output_item = self.bos_item
110
- target_item = tgt_item
111
- elif self.prompt_type == 'src':
112
- prev_output_item = src_item.clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- elif self.prompt_type == 'prev_output':
115
- prev_output_item = src_item[:-1].clone()
116
- target_item = torch.cat([prev_output_item[1:], tgt_item])
117
- else:
118
- raise NotImplementedError
119
- target_item[:-1] = self.tgt_dict.pad()
120
-
121
- example = {
122
- "source": src_item,
123
- "target": target_item,
124
- "prev_output_tokens": prev_output_item,
125
- "ref_dict": ref_dict,
126
- }
127
- if self.constraint_trie is not None:
128
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
129
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
130
- constraint_mask[-1][constraint_nodes] = True
131
- example["constraint_mask"] = constraint_mask
132
- return example
133
-
134
- def collater(self, samples, pad_to_length=None):
135
- """Merge a list of samples to form a mini-batch.
136
- Args:
137
- samples (List[dict]): samples to collate
138
- Returns:
139
- dict: a mini-batch containing the data of the task
140
- """
141
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/nlu_data/sst2_dataset.py DELETED
@@ -1,138 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import warnings
8
- import torch
9
- import numpy as np
10
-
11
- from data import data_utils
12
- from data.ofa_dataset import OFADataset
13
-
14
- logger = logging.getLogger(__name__)
15
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
16
-
17
-
18
- def collate(samples, pad_idx, eos_idx):
19
- if len(samples) == 0:
20
- return {}
21
-
22
- def merge(key):
23
- return data_utils.collate_tokens(
24
- [s[key] for s in samples],
25
- pad_idx,
26
- eos_idx=eos_idx,
27
- )
28
-
29
- src_tokens = merge("source")
30
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
31
-
32
- ref_dict = None
33
- if samples[0].get("ref_dict", None) is not None:
34
- ref_dict = np.array([s['ref_dict'] for s in samples])
35
-
36
- constraint_masks = None
37
- if samples[0].get("constraint_mask", None) is not None:
38
- constraint_masks = merge("constraint_mask")
39
-
40
- prev_output_tokens = None
41
- target = None
42
- if samples[0].get("target", None) is not None:
43
- target = merge("target")
44
- tgt_lengths = torch.LongTensor(
45
- [s["target"].ne(pad_idx).long().sum() for s in samples]
46
- )
47
- ntokens = tgt_lengths.sum().item()
48
-
49
- if samples[0].get("prev_output_tokens", None) is not None:
50
- prev_output_tokens = merge("prev_output_tokens")
51
- else:
52
- ntokens = src_lengths.sum().item()
53
-
54
- batch = {
55
- "nsentences": len(samples),
56
- "ntokens": ntokens,
57
- "net_input": {
58
- "src_tokens": src_tokens,
59
- "src_lengths": src_lengths,
60
- "prev_output_tokens": prev_output_tokens
61
- },
62
- "ref_dict": ref_dict,
63
- "constraint_masks": constraint_masks,
64
- "target": target,
65
- }
66
-
67
- return batch
68
-
69
-
70
- class SST2Dataset(OFADataset):
71
- def __init__(
72
- self,
73
- split,
74
- dataset,
75
- bpe,
76
- src_dict,
77
- tgt_dict=None,
78
- max_src_length=512,
79
- max_tgt_length=30,
80
- constraint_trie=None,
81
- prompt_type="none"
82
- ):
83
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
84
- self.max_src_length = max_src_length
85
- self.max_tgt_length = max_tgt_length
86
- self.constraint_trie = constraint_trie
87
- self.prompt_type = prompt_type
88
-
89
- def __getitem__(self, index):
90
- sentence, label = self.dataset[index]
91
- if label == '0':
92
- label = 'negative'
93
- elif label == '1':
94
- label = 'positive'
95
- else:
96
- raise NotImplementedError
97
-
98
- sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
99
- src_item = self.encode_text(' is the sentiment of text " {} " positive or negative?'.format(sentence))
100
- tgt_item = self.encode_text(" {}".format(label))
101
- assert tgt_item.size(0) == 1
102
- ref_dict = {label: 1.0}
103
-
104
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
105
- if self.prompt_type == 'none':
106
- prev_output_item = self.bos_item
107
- target_item = tgt_item
108
- elif self.prompt_type == 'src':
109
- prev_output_item = src_item.clone()
110
- target_item = torch.cat([prev_output_item[1:], tgt_item])
111
- elif self.prompt_type == 'prev_output':
112
- prev_output_item = src_item[:-1].clone()
113
- target_item = torch.cat([prev_output_item[1:], tgt_item])
114
- else:
115
- raise NotImplementedError
116
- target_item[:-1] = self.tgt_dict.pad()
117
-
118
- example = {
119
- "source": src_item,
120
- "target": target_item,
121
- "prev_output_tokens": prev_output_item,
122
- "ref_dict": ref_dict,
123
- }
124
- if self.constraint_trie is not None:
125
- constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
126
- constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
127
- constraint_mask[-1][constraint_nodes] = True
128
- example["constraint_mask"] = constraint_mask
129
- return example
130
-
131
- def collater(self, samples, pad_to_length=None):
132
- """Merge a list of samples to form a mini-batch.
133
- Args:
134
- samples (List[dict]): samples to collate
135
- Returns:
136
- dict: a mini-batch containing the data of the task
137
- """
138
- return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/ofa_dataset.py DELETED
@@ -1,79 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- import logging
7
- import re
8
- import torch.utils.data
9
- from fairseq.data import FairseqDataset
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class OFADataset(FairseqDataset):
15
- def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
16
- self.split = split
17
- self.dataset = dataset
18
- self.bpe = bpe
19
- self.src_dict = src_dict
20
- self.tgt_dict = tgt_dict
21
-
22
- self.bos = src_dict.bos()
23
- self.eos = src_dict.eos()
24
- self.pad = src_dict.pad()
25
- self.bos_item = torch.LongTensor([self.bos])
26
- self.eos_item = torch.LongTensor([self.eos])
27
-
28
- def __len__(self):
29
- return len(self.dataset)
30
-
31
- def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
32
- s = self.tgt_dict.encode_line(
33
- line=self.bpe.encode(text) if use_bpe else text,
34
- add_if_not_exist=False,
35
- append_eos=False
36
- ).long()
37
- if length is not None:
38
- s = s[:length]
39
- if append_bos:
40
- s = torch.cat([self.bos_item, s])
41
- if append_eos:
42
- s = torch.cat([s, self.eos_item])
43
- return s
44
-
45
- def pre_question(self, question, max_ques_words=None):
46
- question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
47
-
48
- question = re.sub(
49
- r"\s{2,}",
50
- ' ',
51
- question,
52
- )
53
- question = question.rstrip('\n')
54
- question = question.strip(' ')
55
-
56
- # truncate question
57
- question_words = question.split(' ')
58
- if max_ques_words is not None and len(question_words) > max_ques_words:
59
- question = ' '.join(question_words[:max_ques_words])
60
-
61
- return question
62
-
63
- def pre_caption(self, caption, max_words=None):
64
- caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
65
-
66
- caption = re.sub(
67
- r"\s{2,}",
68
- ' ',
69
- caption,
70
- )
71
- caption = caption.rstrip('\n')
72
- caption = caption.strip(' ')
73
-
74
- # truncate caption
75
- caption_words = caption.split(' ')
76
- if max_words is not None and len(caption_words) > max_words:
77
- caption = ' '.join(caption_words[:max_words])
78
-
79
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/pretrain_data/unify_dataset.py DELETED
@@ -1,636 +0,0 @@
1
- # Copyright 2022 The OFA-Sys Team.
2
- # All rights reserved.
3
- # This source code is licensed under the Apache 2.0 license
4
- # found in the LICENSE file in the root directory.
5
-
6
- from io import BytesIO
7
-
8
- import math
9
- import logging
10
- import random
11
- import warnings
12
-
13
- import numpy as np
14
- import torch
15
- import base64
16
- from torchvision import transforms
17
-
18
- from PIL import Image, ImageFile
19
-
20
- from data import data_utils
21
- from data.ofa_dataset import OFADataset
22
- from utils.vision_helper import RandomAugment
23
- import utils.transforms as T
24
-
25
- ImageFile.LOAD_TRUNCATED_IMAGES = True
26
- ImageFile.MAX_IMAGE_PIXELS = None
27
- Image.MAX_IMAGE_PIXELS = None
28
-
29
- logger = logging.getLogger(__name__)
30
- warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
31
-
32
-
33
- def get_whole_word_mask(bpe, dictionary):
34
- if bpe is not None:
35
-
36
- def is_beginning_of_word(i):
37
- if i < dictionary.nspecial:
38
- # special elements are always considered beginnings
39
- return True
40
- tok = dictionary[i]
41
- if tok.startswith("madeupword"):
42
- return True
43
- try:
44
- return bpe.is_beginning_of_word(tok)
45
- except ValueError:
46
- return True
47
-
48
- mask_whole_words = torch.ByteTensor(
49
- list(map(is_beginning_of_word, range(len(dictionary))))
50
- )
51
- return mask_whole_words
52
- return None
53
-
54
-
55
- def collate(samples, pad_idx, eos_idx):
56
- if len(samples) == 0:
57
- return {}
58
-
59
- def merge(key):
60
- return data_utils.collate_tokens(
61
- [s[key] for s in samples],
62
- pad_idx,
63
- eos_idx=eos_idx,
64
- )
65
-
66
- id = np.array([s["id"] for s in samples])
67
- src_tokens = merge("source")
68
- src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
69
-
70
- patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
71
- patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
72
-
73
- code_masks = None
74
- if samples[0].get("code_mask", None) is not None:
75
- code_masks = torch.cat([sample['code_mask'] for sample in samples])
76
-
77
- conf = torch.cat([s['conf'] for s in samples], dim=0)
78
-
79
- prev_output_tokens = None
80
- target = None
81
- if samples[0].get("target", None) is not None:
82
- target = merge("target")
83
- tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
84
- ntokens = tgt_lengths.sum().item()
85
-
86
- if samples[0].get("prev_output_tokens", None) is not None:
87
- prev_output_tokens = merge("prev_output_tokens")
88
- else:
89
- ntokens = src_lengths.sum().item()
90
-
91
- batch = {
92
- "id": id,
93
- "nsentences": len(samples),
94
- "ntokens": ntokens,
95
- "net_input": {
96
- "src_tokens": src_tokens,
97
- "src_lengths": src_lengths,
98
- "patch_images": patch_images,
99
- "patch_masks": patch_masks,
100
- "code_masks": code_masks,
101
- "prev_output_tokens": prev_output_tokens
102
- },
103
- "target": target,
104
- "conf": conf
105
- }
106
-
107
- return batch
108
-
109
-
110
- class UnifyDataset(OFADataset):
111
- def __init__(
112
- self,
113
- split,
114
- dataset,
115
- bpe,
116
- src_dict,
117
- tgt_dict=None,
118
- max_src_length=128,
119
- max_tgt_length=30,
120
- seed=7,
121
- code_dict_size=8192,
122
- num_bins=1000,
123
- patch_image_size=384,
124
- code_image_size=128,
125
- pure_text_dataset=None,
126
- pure_image_dataset=None,
127
- detection_dataset=None,
128
- all_object_list=None,
129
- all_caption_list=None,
130
- type2ans_dict=None,
131
- ans2type_dict=None,
132
- max_image_size=512,
133
- mask_ratio=0.3,
134
- random_ratio=0.0,
135
- keep_ratio=0.0,
136
- mask_length="span-poisson",
137
- poisson_lambda=3.0,
138
- replace_length=1
139
- ):
140
- super().__init__(split, dataset, bpe, src_dict, tgt_dict)
141
- self.max_src_length = max_src_length
142
- self.max_tgt_length = max_tgt_length
143
- self.seed = seed
144
- self.code_dict_size = code_dict_size
145
- self.num_bins = num_bins
146
- self.patch_image_size = patch_image_size
147
- self.code_image_size = code_image_size
148
-
149
- self.pure_text_dataset = pure_text_dataset
150
- self.pure_image_dataset = pure_image_dataset
151
- self.detection_dataset = detection_dataset
152
- self.epoch = 0
153
-
154
- self.all_object_list = all_object_list
155
- self.all_caption_list = all_caption_list
156
- self.type2ans_dict = type2ans_dict
157
- self.ans2type_dict = ans2type_dict
158
-
159
- self.mask_ratio = mask_ratio
160
- self.random_ratio = random_ratio
161
- self.keep_ratio = keep_ratio
162
- self.mask_length = mask_length
163
- self.poisson_lambda = poisson_lambda
164
- self.replace_length = replace_length
165
- if self.replace_length not in [-1, 0, 1]:
166
- raise ValueError(f"invalid arg: replace_length={self.replace_length}")
167
- if self.mask_length not in ["subword", "word", "span-poisson"]:
168
- raise ValueError(f"invalid arg: mask-length={self.mask_length}")
169
- if self.mask_length == "subword" and self.replace_length not in [0, 1]:
170
- raise ValueError(f"if using subwords, use replace-length=1 or 0")
171
-
172
- self.mask_idx = src_dict.index("<mask>")
173
- self.mask_whole_word = (
174
- get_whole_word_mask(self.bpe, self.src_dict)
175
- if self.mask_length != "subword"
176
- else None
177
- )
178
- self.mask_span_distribution = None
179
- if self.mask_length == "span-poisson":
180
- _lambda = self.poisson_lambda
181
- lambda_to_the_k = 1
182
- e_to_the_minus_lambda = math.exp(-_lambda)
183
- k_factorial = 1
184
- ps = []
185
- for k in range(0, 128):
186
- ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
187
- lambda_to_the_k *= _lambda
188
- k_factorial *= k + 1
189
- if ps[-1] < 0.0000001:
190
- break
191
- ps = torch.FloatTensor(ps)
192
- self.mask_span_distribution = torch.distributions.Categorical(ps)
193
-
194
- self.pos_tgt_item = self.encode_text(" yes")
195
- self.neg_tgt_item = self.encode_text(" no")
196
-
197
- self.mask_left = self.mask_top = int(0.5 * self.code_image_size)
198
- self.mask_right = self.mask_bottom = int(1.5 * self.code_image_size)
199
- self.mask_ids = [
200
- i*self.code_image_size*2+j
201
- for i in range(self.code_image_size*2) for j in range(self.code_image_size*2)
202
- if not (self.mask_left <= i < self.mask_right and self.mask_top <= j < self.mask_bottom)
203
- ]
204
-
205
- scales = np.arange(patch_image_size, 481).tolist()
206
-
207
- # for image-text pair
208
- self.patch_resize_transform = transforms.Compose([
209
- T.RandomResize(scales, max_size=672),
210
- transforms.CenterCrop(patch_image_size),
211
- RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
212
- 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
213
- transforms.ToTensor(),
214
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
215
- ])
216
- # for pure image
217
- self.patch_crop_transform = transforms.Compose([
218
- transforms.ToTensor(),
219
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
220
- ])
221
- # for detection
222
- self.detection_transform = T.Compose([
223
- T.RandomHorizontalFlip(),
224
- T.LargeScaleJitter(output_size=self.code_image_size*2, aug_scale_min=1.0, aug_scale_max=1.5),
225
- T.ToTensor(),
226
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
227
- ])
228
- # for visual grounding
229
- self.visual_grounding_transform = T.Compose([
230
- T.RandomResize(scales, max_size=672),
231
- T.ObjectCenterCrop((patch_image_size, patch_image_size)),
232
- T.ToTensor(),
233
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_image_size=max_image_size)
234
- ])
235
-
236
- def set_epoch(self, epoch, **unused):
237
- self.epoch = epoch
238
-
239
- def get_negative_caption(self, caption, gt_objects):
240
- prob = random.random()
241
- if gt_objects is not None and gt_objects != '' and prob > 0.6:
242
- gt_object = random.choice(gt_objects.strip().split('&&'))
243
- negative_object = random.choice(self.all_object_list[:-1])
244
- negative_object = self.all_object_list[-1] if negative_object == gt_object else negative_object
245
- negative_caption = caption.replace(gt_object, negative_object)
246
- else:
247
- negative_caption = random.choice(self.all_caption_list)
248
- return negative_caption
249
-
250
- def get_negative_answer(self, answer, conf):
251
- prob = random.random()
252
- if conf > (prob + 0.1) and answer in self.ans2type_dict:
253
- negative_answer_type = self.ans2type_dict[answer]
254
- if negative_answer_type == 'how many' and answer.isdigit() and prob > 0.5:
255
- negative_answer = int(answer) + random.choice([-1, 1]) if answer != 0 else 1
256
- else:
257
- negative_answer_list = self.type2ans_dict[negative_answer_type]
258
- negative_answer = random.choice(negative_answer_list[:-1])
259
- negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
260
- return negative_answer
261
-
262
- negative_answer_list = self.type2ans_dict['other']
263
- negative_answer = random.choice(negative_answer_list[:-1])
264
- negative_answer = negative_answer_list[-1] if negative_answer == answer else negative_answer
265
- return negative_answer
266
-
267
- def process_image_text_pair(self, index):
268
- uniq_id, image, caption, question, refs, gt_objects, dataset_name, type = self.dataset[index]
269
-
270
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
271
- patch_image = self.patch_resize_transform(image) if type != 'visual_grounding' else None
272
- patch_mask = torch.tensor([True])
273
- conf = torch.tensor([1.0])
274
- if type == 'caption':
275
- tgt_caption = self.pre_caption(caption, self.max_tgt_length)
276
- pos_src_caption = self.pre_caption(caption, self.max_src_length)
277
- neg_src_caption = self.pre_caption(self.get_negative_caption(caption, gt_objects), self.max_src_length)
278
- src_item = self.encode_text(" what does the image describe?")
279
- tgt_item = self.encode_text(" {}".format(tgt_caption))
280
- pos_src_item = self.encode_text(' does the image describe " {} "?'.format(pos_src_caption))
281
- neg_src_item = self.encode_text(' does the image describe " {} "?'.format(neg_src_caption))
282
- elif type == 'qa':
283
- question = self.pre_question(question, self.max_src_length)
284
- ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in refs.split('&&')}
285
- answer = max(ref_dict, key=ref_dict.get)
286
- conf = ref_dict[answer]
287
- src_item = self.encode_text(" {}".format(question))
288
- tgt_item = self.encode_text(" {}".format(answer))
289
- conf = torch.tensor([conf])
290
- pos_src_item = self.encode_text(' what is the answer to question " {} ". is " {} "?'.format(question, answer))
291
- neg_src_item = self.encode_text(
292
- ' what is the answer to question " {} ". is " {} "?'.format(question, self.get_negative_answer(answer, conf))
293
- )
294
- elif type == 'visual_grounding':
295
- conf = torch.tensor([1.0])
296
- w, h = image.size
297
- boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
298
- x0, y0, x1, y1 = refs.strip().split(',')
299
- boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
300
- boxes_target["labels"] = np.array([0])
301
- boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
302
- patch_image, boxes_target = self.visual_grounding_transform(image, boxes_target)
303
- quant_x0 = "<bin_{}>".format(int((boxes_target["boxes"][0][0] * (self.num_bins - 1)).round()))
304
- quant_y0 = "<bin_{}>".format(int((boxes_target["boxes"][0][1] * (self.num_bins - 1)).round()))
305
- quant_x1 = "<bin_{}>".format(int((boxes_target["boxes"][0][2] * (self.num_bins - 1)).round()))
306
- quant_y1 = "<bin_{}>".format(int((boxes_target["boxes"][0][3] * (self.num_bins - 1)).round()))
307
- region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
308
- src_caption = self.pre_caption(caption, self.max_src_length)
309
- src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
310
- tgt_item = self.encode_text(region_coord, use_bpe=False)
311
- else:
312
- logger.info('type {} is not implemented'.format(type))
313
- raise NotImplementedError
314
-
315
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
316
- target_item = torch.cat([tgt_item, self.eos_item])
317
- prev_output_item = torch.cat([self.bos_item, tgt_item])
318
- pos_src_item = torch.cat([self.bos_item, pos_src_item, self.eos_item]) if type != 'visual_grounding' else None
319
- neg_src_item = torch.cat([self.bos_item, neg_src_item, self.eos_item]) if type != 'visual_grounding' else None
320
-
321
- if type == 'caption' and dataset_name == 'cc12m':
322
- target_item[:2] = self.src_dict.pad()
323
- target_item[-1] = self.eos_item
324
-
325
- example = {
326
- "id": uniq_id,
327
- "source": src_item,
328
- "patch_image": patch_image,
329
- "patch_mask": patch_mask,
330
- "target": target_item,
331
- "prev_output_tokens": prev_output_item,
332
- "conf": conf,
333
- }
334
-
335
- examples = [example]
336
- prob = random.random()
337
- if type == 'visual_grounding':
338
- region_example = example.copy()
339
- region_prefix_item = self.encode_text(' what does the region describe? region:')
340
- region_coord_item = self.encode_text('{}'.format(region_coord), use_bpe=False)
341
- region_src_item = torch.cat([region_prefix_item, region_coord_item])
342
- region_tgt_item = self.encode_text(' {}'.format(self.pre_caption(caption, self.max_tgt_length)))
343
- region_example["source"] = torch.cat([self.bos_item, region_src_item, self.eos_item])
344
- region_example["target"] = torch.cat([region_tgt_item, self.eos_item])
345
- region_example["prev_output_tokens"] = torch.cat([self.bos_item, region_tgt_item])
346
- region_example["conf"] = torch.tensor([1.0])
347
- examples.append(region_example)
348
- elif prob >= 0.5 and self.split == 'train':
349
- pos_example = example.copy()
350
- pos_example["source"] = pos_src_item
351
- pos_example["target"] = torch.cat([self.pos_tgt_item, self.eos_item])
352
- pos_example["prev_output_tokens"] = torch.cat([self.bos_item, self.pos_tgt_item])
353
- examples.append(pos_example)
354
- elif self.split == 'train':
355
- neg_example = example.copy()
356
- neg_example["source"] = neg_src_item
357
- neg_example["target"] = torch.cat([self.neg_tgt_item, self.eos_item])
358
- neg_example["prev_output_tokens"] = torch.cat([self.bos_item, self.neg_tgt_item])
359
- examples.append(neg_example)
360
- return examples
361
-
362
- def process_pure_text(self, index):
363
- patch_image = torch.zeros((3, self.code_image_size*2, self.code_image_size*2))
364
- patch_mask = torch.tensor([False])
365
- code_mask = torch.tensor([False])
366
- conf = torch.tensor([2.0])
367
-
368
- examples = []
369
- for _ in range(2):
370
- uniq_id, text = self.pure_text_dataset[index]
371
- text = text.strip().lower()
372
- text_item = self.encode_text(" {}".format(text), length=512)
373
- text_item = text_item[-256:]
374
- text_item = torch.cat([self.bos_item, text_item, self.eos_item])
375
- mask_text_item = self.add_whole_word_mask(text_item.clone(), self.mask_ratio)
376
- prefix_item = self.encode_text(' what is the complete text of " "?')
377
- src_item = torch.cat([prefix_item[:-2], mask_text_item[1:-1], prefix_item[-2:]])
378
- tgt_item = text_item[1:-1]
379
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
380
- target_item = torch.cat([tgt_item, self.eos_item])
381
- prev_output_item = torch.cat([self.bos_item, tgt_item])
382
- example = {
383
- "id": uniq_id,
384
- "source": src_item,
385
- "patch_image": patch_image,
386
- "patch_mask": patch_mask,
387
- "code_mask": code_mask,
388
- "target": target_item,
389
- "prev_output_tokens": prev_output_item,
390
- "conf": conf,
391
- }
392
- examples.append(example)
393
-
394
- return examples
395
-
396
- def process_pure_image(self, index):
397
- image_id, image, code = self.pure_image_dataset[index]
398
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
399
- patch_image = self.patch_crop_transform(image)
400
- patch_image[:, self.mask_top:self.mask_bottom, self.mask_left:self.mask_right] = 0
401
- patch_mask = torch.tensor([True])
402
- src_item = self.encode_text(" what is the image in the middle part?")
403
- image_code = torch.LongTensor([int(num) for num in code.strip().split()])
404
- tgt_item = image_code + len(self.src_dict) - self.code_dict_size - self.num_bins
405
- code_mask = torch.tensor([True])
406
- conf = torch.tensor([2.0])
407
-
408
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
409
- target_item = torch.cat([tgt_item, self.eos_item])
410
- prev_output_item = torch.cat([self.bos_item, tgt_item])
411
-
412
- example = {
413
- "id": image_id,
414
- "source": src_item,
415
- "patch_image": patch_image,
416
- "patch_mask": patch_mask,
417
- "code_mask": code_mask,
418
- "target": target_item,
419
- "prev_output_tokens": prev_output_item,
420
- "conf": conf,
421
- }
422
- return [example]
423
-
424
- def process_detection(self, index):
425
- image_id, image, label = self.detection_dataset[index]
426
- image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
427
-
428
- w, h = image.size
429
- boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
430
- label_list = label.strip().split('&&')
431
- for label in label_list:
432
- x0, y0, x1, y1, cat_id, cat = label.strip().split(',', 5)
433
- boxes_target["boxes"].append([float(x0), float(y0), float(x1), float(y1)])
434
- boxes_target["labels"].append(cat)
435
- boxes_target["area"].append((float(x1) - float(x0)) * (float(y1) - float(y0)))
436
- boxes_target["boxes"] = torch.tensor(boxes_target["boxes"])
437
- boxes_target["labels"] = np.array(boxes_target["labels"])
438
- boxes_target["area"] = torch.tensor(boxes_target["area"])
439
-
440
- patch_image, boxes_target = self.detection_transform(image, boxes_target)
441
- patch_mask = torch.tensor([True])
442
- code_mask = torch.tensor([False])
443
- conf = torch.tensor([2.0])
444
-
445
- quant_boxes = []
446
- for i, box in enumerate(boxes_target["boxes"]):
447
- quant_boxes.extend(["<bin_{}>".format(int((pos * (self.num_bins - 1)).round())) for pos in box[:4]])
448
- quant_boxes.append(self.bpe.encode(' {}'.format(boxes_target["labels"][i])))
449
- src_item = self.encode_text(' what are the objects in the image?')
450
- tgt_item = self.encode_text(' '.join(quant_boxes), use_bpe=False)
451
-
452
- src_item = torch.cat([self.bos_item, src_item, self.eos_item])
453
- target_item = torch.cat([tgt_item, self.eos_item])
454
- prev_output_item = torch.cat([self.bos_item, tgt_item])
455
-
456
- example = {
457
- "id": image_id,
458
- "source": src_item,
459
- "patch_image": patch_image,
460
- "patch_mask": patch_mask,
461
- "code_mask": code_mask,
462
- "target": target_item,
463
- "prev_output_tokens": prev_output_item,
464
- "conf": conf,
465
- }
466
- return [example]
467
-
468
- def __getitem__(self, index):
469
- with data_utils.numpy_seed(self.seed, self.epoch):
470
- pair_samples = self.process_image_text_pair(index)
471
- extra_samples = []
472
- if self.split == 'train' and self.dataset.data_cnt % 8 == 0:
473
- extra_samples += self.process_pure_text(0) if self.pure_text_dataset else []
474
- extra_samples += self.process_pure_image(0) if self.pure_image_dataset else []
475
- extra_samples += self.process_detection(0) if self.detection_dataset else []
476
- return pair_samples, extra_samples
477
-
478
- def word_starts(self, source):
479
- if self.mask_whole_word is not None:
480
- is_word_start = self.mask_whole_word.gather(0, source)
481
- else:
482
- is_word_start = torch.ones(source.size())
483
- is_word_start[0] = 0
484
- is_word_start[-1] = 0
485
- return is_word_start
486
-
487
- def add_whole_word_mask(self, source, p):
488
- is_word_start = self.word_starts(source)
489
- num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
490
- num_inserts = 0
491
- if num_to_mask == 0:
492
- return source
493
-
494
- if self.mask_span_distribution is not None:
495
- lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
496
-
497
- # Make sure we have enough to mask
498
- cum_length = torch.cumsum(lengths, 0)
499
- while cum_length[-1] < num_to_mask:
500
- lengths = torch.cat(
501
- [
502
- lengths,
503
- self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
504
- ],
505
- dim=0,
506
- )
507
- cum_length = torch.cumsum(lengths, 0)
508
-
509
- # Trim to masking budget
510
- i = 0
511
- while cum_length[i] < num_to_mask:
512
- i += 1
513
- lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
514
- num_to_mask = i + 1
515
- lengths = lengths[:num_to_mask]
516
-
517
- # Handle 0-length mask (inserts) separately
518
- lengths = lengths[lengths > 0]
519
- num_inserts = num_to_mask - lengths.size(0)
520
- num_to_mask -= num_inserts
521
- if num_to_mask == 0:
522
- return self.add_insertion_noise(source, num_inserts / source.size(0))
523
-
524
- assert (lengths > 0).all()
525
- else:
526
- lengths = torch.ones((num_to_mask,)).long()
527
- assert is_word_start[-1] == 0
528
- word_starts = is_word_start.nonzero(as_tuple=False)
529
- indices = word_starts[
530
- torch.randperm(word_starts.size(0))[:num_to_mask]
531
- ].squeeze(1)
532
- mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
533
-
534
- source_length = source.size(0)
535
- assert source_length - 1 not in indices
536
- to_keep = torch.ones(source_length, dtype=torch.bool)
537
- is_word_start[
538
- -1
539
- ] = 255 # acts as a long length, so spans don't go over the end of doc
540
- if self.replace_length == 0:
541
- to_keep[indices] = 0
542
- else:
543
- # keep index, but replace it with [MASK]
544
- source[indices] = self.mask_idx
545
- source[indices[mask_random]] = torch.randint(
546
- 4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
547
- )
548
-
549
- if self.mask_span_distribution is not None:
550
- assert len(lengths.size()) == 1
551
- assert lengths.size() == indices.size()
552
- lengths -= 1
553
- while indices.size(0) > 0:
554
- assert lengths.size() == indices.size()
555
- lengths -= is_word_start[indices + 1].long()
556
- uncompleted = lengths >= 0
557
- indices = indices[uncompleted] + 1
558
- mask_random = mask_random[uncompleted]
559
- lengths = lengths[uncompleted]
560
- if self.replace_length != -1:
561
- # delete token
562
- to_keep[indices] = 0
563
- else:
564
- # keep index, but replace it with [MASK]
565
- source[indices] = self.mask_idx
566
- source[indices[mask_random]] = torch.randint(
567
- 4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
568
- )
569
- else:
570
- # A bit faster when all lengths are 1
571
- while indices.size(0) > 0:
572
- uncompleted = is_word_start[indices + 1] == 0
573
- indices = indices[uncompleted] + 1
574
- mask_random = mask_random[uncompleted]
575
- if self.replace_length != -1:
576
- # delete token
577
- to_keep[indices] = 0
578
- else:
579
- # keep index, but replace it with [MASK]
580
- source[indices] = self.mask_idx
581
- source[indices[mask_random]] = torch.randint(
582
- 4, len(self.tgt_dict) - self.code_dict_size - self.num_bins, size=(mask_random.sum(),)
583
- )
584
-
585
- assert source_length - 1 not in indices
586
-
587
- source = source[to_keep]
588
-
589
- if num_inserts > 0:
590
- source = self.add_insertion_noise(source, num_inserts / source.size(0))
591
-
592
- return source
593
-
594
- def add_insertion_noise(self, tokens, p):
595
- if p == 0.0:
596
- return tokens
597
-
598
- num_tokens = len(tokens)
599
- n = int(math.ceil(num_tokens * p))
600
-
601
- noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
602
- noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
603
- noise_mask[noise_indices] = 1
604
- result = torch.LongTensor(n + len(tokens)).fill_(-1)
605
-
606
- num_random = int(math.ceil(n * self.random_ratio))
607
- result[noise_indices[num_random:]] = self.mask_idx
608
- result[noise_indices[:num_random]] = torch.randint(
609
- low=4, high=len(self.tgt_dict)-self.code_dict_size-self.num_bins, size=(num_random,)
610
- )
611
-
612
- result[~noise_mask] = tokens
613
-
614
- assert (result >= 0).all()
615
- return result
616
-
617
- def collater(self, samples, pad_to_length=None):
618
- """Merge samples of different tasks to form two mini-batches.
619
- Args:
620
- samples (List[Tuple]): samples to collate
621
- Returns:
622
- Tuple[dict]: two mini-batch containing the data of different tasks
623
- """
624
-
625
- samples_v1 = [] # containing image-text pairs
626
- samples_v2 = [] # containing detection data, text data and image data
627
- for sample_tuple in samples:
628
- samples_v1 += sample_tuple[0]
629
- samples_v2 += sample_tuple[1]
630
- if samples_v2 != []:
631
- res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
632
- res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
633
- return res_v1, res_v2
634
- else:
635
- res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
636
- return res_v1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets.md DELETED
@@ -1,44 +0,0 @@
1
- # Datasets
2
-
3
- We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so.
4
-
5
- ## Pretraining
6
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip"> A small subset of the pretraining data </a>
7
-
8
- The pretraining datasets used in OFA are all publicly available. Here we provide the public links to these data, it is recommended that you download the data from the links first, and then process the downloaded dataset into a similar format as the examples we provided.
9
- - _CC12M_: https://github.com/google-research-datasets/conceptual-12m
10
- - _CC3M_: https://github.com/google-research-datasets/conceptual-captions
11
- - _SBU_: https://www.cs.virginia.edu/~vicente/sbucaptions
12
- - _COCO_: https://cocodataset.org/#home
13
- - _VG_: https://visualgenome.org/
14
- - _VQAv2_: https://visualqa.org/
15
- - _GQA_: https://cs.stanford.edu/people/dorarad/gqa/about.html
16
- - _RefCOCO_/_RefCOCO+_/RefCOCOg: https://github.com/lichengunc/refer
17
- - _OpenImages_: https://storage.googleapis.com/openimages/web/index.html
18
- - _Object365_: https://www.objects365.org/overview.html
19
- - _YFCC100M (subset)_: https://github.com/openai/CLIP/blob/main/data/yfcc100m.md
20
- - _ImageNet-21K_: https://image-net.org/index.php
21
- - _Pile_: https://pile.eleuther.ai
22
-
23
- ## Vision & Language Tasks
24
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
25
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
26
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
27
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
28
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/vqa_data/vqa_data.zip"> Dataset for VQAv2 </a> (we have also provided chunked parts of the dataset files for more convenient downloading, please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>)
29
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/snli_ve_data/snli_ve_data.zip"> Dataset for SNLI-VE </a>
30
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen.zip"> Dataset for Text-to-Image Genearion </a>
31
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/coco_image_gen_data/coco_image_gen_origin_id.zip"> Dataset for Text-to-Image Genearion (with original id) </a>
32
-
33
- ## Vision Tasks
34
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/imagenet_1k_data/imagenet_1k_data.zip"> Dataset for ImageNet-1K </a>
35
-
36
- ## Language Tasks
37
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/cola_data.zip"> Dataset for COLA </a>
38
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mnli_data.zip"> Dataset for MNLI </a>
39
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/mrpc_data.zip"> Dataset for MRPC </a>
40
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qnli_data.zip"> Dataset for QNLI </a>
41
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/qqp_data.zip"> Dataset for QQP </a>
42
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/rte_data.zip"> Dataset for RTE </a>
43
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/glue_data/sst2_data.zip"> Dataset for SST2 </a>
44
- * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/gigaword_data/gigaword_data.zip"> Dataset for Gigaword </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluate.py DELETED
@@ -1,160 +0,0 @@
1
- #!/usr/bin/env python3 -u
2
- # Copyright 2022 The OFA-Sys Team.
3
- # All rights reserved.
4
- # This source code is licensed under the Apache 2.0 license
5
- # found in the LICENSE file in the root directory.
6
-
7
- import logging
8
- import os
9
- import sys
10
-
11
- import numpy as np
12
- import torch
13
- from fairseq import distributed_utils, options, tasks, utils
14
- from fairseq.dataclass.utils import convert_namespace_to_omegaconf
15
- from fairseq.logging import progress_bar
16
- from fairseq.utils import reset_logging
17
- from omegaconf import DictConfig
18
-
19
- from utils import checkpoint_utils
20
- from utils.eval_utils import eval_step, merge_results
21
- from utils.zero_shot_utils import zero_shot_step
22
-
23
- logging.basicConfig(
24
- format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
25
- datefmt="%Y-%m-%d %H:%M:%S",
26
- level=os.environ.get("LOGLEVEL", "INFO").upper(),
27
- stream=sys.stdout,
28
- )
29
- logger = logging.getLogger("ofa.evaluate")
30
-
31
-
32
- def apply_half(t):
33
- if t.dtype is torch.float32:
34
- return t.to(dtype=torch.half)
35
- return t
36
-
37
-
38
- def main(cfg: DictConfig, **kwargs):
39
- utils.import_user_module(cfg.common)
40
-
41
- reset_logging()
42
- logger.info(cfg)
43
-
44
- assert (
45
- cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
46
- ), "Must specify batch size either with --max-tokens or --batch-size"
47
-
48
- # Fix seed for stochastic decoding
49
- if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
50
- np.random.seed(cfg.common.seed)
51
- utils.set_torch_seed(cfg.common.seed)
52
-
53
- use_fp16 = cfg.common.fp16
54
- use_cuda = torch.cuda.is_available() and not cfg.common.cpu
55
-
56
- if use_cuda:
57
- torch.cuda.set_device(cfg.distributed_training.device_id)
58
-
59
- # Load ensemble
60
- overrides = eval(cfg.common_eval.model_overrides)
61
- # Deal with beam-search / all-candidate VQA eval
62
- if cfg.task._name == "vqa_gen":
63
- overrides['val_inference_type'] = "beamsearch" if kwargs['beam_search_vqa_eval'] else "allcand"
64
-
65
- logger.info("loading model(s) from {}".format(cfg.common_eval.path))
66
- if kwargs["zero_shot"]:
67
- task = tasks.setup_task(cfg.task)
68
- models, saved_cfg = checkpoint_utils.load_model_ensemble(
69
- utils.split_paths(cfg.common_eval.path),
70
- arg_overrides=overrides,
71
- task=task,
72
- suffix=cfg.checkpoint.checkpoint_suffix,
73
- strict=(cfg.checkpoint.checkpoint_shard_count == 1),
74
- num_shards=cfg.checkpoint.checkpoint_shard_count,
75
- )
76
- else:
77
- models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
78
- utils.split_paths(cfg.common_eval.path),
79
- arg_overrides=overrides,
80
- suffix=cfg.checkpoint.checkpoint_suffix,
81
- strict=(cfg.checkpoint.checkpoint_shard_count == 1),
82
- num_shards=cfg.checkpoint.checkpoint_shard_count,
83
- )
84
-
85
- # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
86
- task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
87
-
88
- # Move models to GPU
89
- for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
90
- if kwargs['ema_eval']:
91
- logger.info("loading EMA weights from {}".format(ckpt_path))
92
- model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
93
- model.eval()
94
- if use_fp16:
95
- model.half()
96
- if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
97
- model.cuda()
98
- model.prepare_for_inference_(cfg)
99
-
100
- # Load dataset (possibly sharded)
101
- itr = task.get_batch_iterator(
102
- dataset=task.dataset(cfg.dataset.gen_subset),
103
- max_tokens=cfg.dataset.max_tokens,
104
- max_sentences=cfg.dataset.batch_size,
105
- max_positions=utils.resolve_max_positions(
106
- task.max_positions(), *[m.max_positions() for m in models]
107
- ),
108
- ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
109
- required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
110
- seed=cfg.common.seed,
111
- num_shards=cfg.distributed_training.distributed_world_size,
112
- shard_id=cfg.distributed_training.distributed_rank,
113
- num_workers=cfg.dataset.num_workers,
114
- data_buffer_size=cfg.dataset.data_buffer_size,
115
- ).next_epoch_itr(shuffle=False)
116
- progress = progress_bar.progress_bar(
117
- itr,
118
- log_format=cfg.common.log_format,
119
- log_interval=cfg.common.log_interval,
120
- default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
121
- )
122
-
123
- # Initialize generator
124
- generator = task.build_generator(models, cfg.generation)
125
-
126
- results = []
127
- score_sum = torch.FloatTensor([0]).cuda()
128
- score_cnt = torch.FloatTensor([0]).cuda()
129
- for sample in progress:
130
- if "net_input" not in sample:
131
- continue
132
- sample = utils.move_to_cuda(sample) if use_cuda else sample
133
- sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
134
- with torch.no_grad():
135
- if kwargs["zero_shot"]:
136
- result, scores = zero_shot_step(task, generator, models, sample)
137
- else:
138
- result, scores = eval_step(task, generator, models, sample, **kwargs)
139
- results += result
140
- score_sum += sum(scores) if scores is not None else 0
141
- score_cnt += len(scores) if scores is not None else 0
142
- progress.log({"sentences": sample["nsentences"]})
143
-
144
- merge_results(task, cfg, logger, score_cnt, score_sum, results)
145
-
146
-
147
- def cli_main():
148
- parser = options.get_generation_parser()
149
- parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
150
- parser.add_argument("--beam-search-vqa-eval", action='store_true', help="Use beam search for vqa evaluation (faster inference speed but sub-optimal result), if not specified, we compute scores for each answer in the candidate set, which is slower but can obtain best result.")
151
- parser.add_argument("--zero-shot", action='store_true')
152
- args = options.parse_args_and_arch(parser)
153
- cfg = convert_namespace_to_omegaconf(args)
154
- distributed_utils.call_main(
155
- cfg, main, ema_eval=args.ema_eval, beam_search_vqa_eval=args.beam_search_vqa_eval, zero_shot=args.zero_shot
156
- )
157
-
158
-
159
- if __name__ == "__main__":
160
- cli_main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/README.md DELETED
@@ -1,49 +0,0 @@
1
- # EasyOCR Lite
2
-
3
- 从EasyOCR提取文本定位有关代码,进一步适配中文,修正缺陷
4
-
5
- ## 安装
6
-
7
- Python版本至少为3.8。
8
-
9
-
10
- 首先按照PyTorch官方说明安装PyTorch。
11
-
12
- ```
13
- pip install -e .
14
- ```
15
-
16
- ## 使用
17
-
18
- ``` python3
19
- from easyocrlite import ReaderLite
20
-
21
- reader = ReaderLite()
22
- results = reader.process('my_awesome_handwriting.png')
23
- ```
24
-
25
- 返回的内容为边界框和对应的图像区域的列表。
26
- 其它说明见[demo](./demo.ipynb)。
27
-
28
-
29
- ## 致谢
30
-
31
- 基于[EasyOCR](https://github.com/JaidedAI/EasyOCR)修改实现。以下为EasyOCR致谢:
32
-
33
- This project is based on research and code from several papers and open-source repositories.
34
-
35
- All deep learning execution is based on [Pytorch](https://pytorch.org). :heart:
36
-
37
- Detection execution uses the CRAFT algorithm from this [official repository](https://github.com/clovaai/CRAFT-pytorch) and their [paper](https://arxiv.org/abs/1904.01941) (Thanks @YoungminBaek from [@clovaai](https://github.com/clovaai)). We also use their pretrained model. Training script is provided by [@gmuffiness](https://github.com/gmuffiness).
38
-
39
- The recognition model is a CRNN ([paper](https://arxiv.org/abs/1507.05717)). It is composed of 3 main components: feature extraction (we are currently using [Resnet](https://arxiv.org/abs/1512.03385)) and VGG, sequence labeling ([LSTM](https://www.bioinf.jku.at/publications/older/2604.pdf)) and decoding ([CTC](https://www.cs.toronto.edu/~graves/icml_2006.pdf)). The training pipeline for recognition execution is a modified version of the [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) framework. (Thanks [@ku21fan](https://github.com/ku21fan) from [@clovaai](https://github.com/clovaai)) This repository is a gem that deserves more recognition.
40
-
41
- Beam search code is based on this [repository](https://github.com/githubharald/CTCDecoder) and his [blog](https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7). (Thanks [@githubharald](https://github.com/githubharald))
42
-
43
- Data synthesis is based on [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator). (Thanks [@Belval](https://github.com/Belval))
44
-
45
- And a good read about CTC from distill.pub [here](https://distill.pub/2017/ctc/).
46
-
47
-
48
- ## 许可证 (注意!)
49
- Apache 2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/build/lib/easyocrlite/__init__.py DELETED
@@ -1 +0,0 @@
1
- from easyocrlite.reader import ReaderLite
 
 
ezocr/build/lib/easyocrlite/reader.py DELETED
@@ -1,272 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- import os
5
- from pathlib import Path
6
- from typing import Tuple
7
-
8
- import cv2
9
- import numpy as np
10
- import torch
11
- from PIL import Image, ImageEnhance
12
-
13
- from easyocrlite.model import CRAFT
14
-
15
- from easyocrlite.utils.download_utils import prepare_model
16
- from easyocrlite.utils.image_utils import (
17
- adjust_result_coordinates,
18
- boxed_transform,
19
- normalize_mean_variance,
20
- resize_aspect_ratio,
21
- )
22
- from easyocrlite.utils.detect_utils import (
23
- extract_boxes,
24
- extract_regions_from_boxes,
25
- box_expand,
26
- greedy_merge,
27
- )
28
- from easyocrlite.types import BoxTuple, RegionTuple
29
- import easyocrlite.utils.utils as utils
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- MODULE_PATH = (
34
- os.environ.get("EASYOCR_MODULE_PATH")
35
- or os.environ.get("MODULE_PATH")
36
- or os.path.expanduser("~/.EasyOCR/")
37
- )
38
-
39
-
40
- class ReaderLite(object):
41
- def __init__(
42
- self,
43
- gpu=True,
44
- model_storage_directory=None,
45
- download_enabled=True,
46
- verbose=True,
47
- quantize=True,
48
- cudnn_benchmark=False,
49
- ):
50
-
51
- self.verbose = verbose
52
-
53
- model_storage_directory = Path(
54
- model_storage_directory
55
- if model_storage_directory
56
- else MODULE_PATH + "/model"
57
- )
58
- self.detector_path = prepare_model(
59
- model_storage_directory, download_enabled, verbose
60
- )
61
-
62
- self.quantize = quantize
63
- self.cudnn_benchmark = cudnn_benchmark
64
- if gpu is False:
65
- self.device = "cpu"
66
- if verbose:
67
- logger.warning(
68
- "Using CPU. Note: This module is much faster with a GPU."
69
- )
70
- elif not torch.cuda.is_available():
71
- self.device = "cpu"
72
- if verbose:
73
- logger.warning(
74
- "CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU."
75
- )
76
- elif gpu is True:
77
- self.device = "cuda"
78
- else:
79
- self.device = gpu
80
-
81
- self.detector = CRAFT()
82
-
83
- state_dict = torch.load(self.detector_path, map_location=self.device)
84
- if list(state_dict.keys())[0].startswith("module"):
85
- state_dict = {k[7:]: v for k, v in state_dict.items()}
86
-
87
- self.detector.load_state_dict(state_dict)
88
-
89
- if self.device == "cpu":
90
- if self.quantize:
91
- try:
92
- torch.quantization.quantize_dynamic(
93
- self.detector, dtype=torch.qint8, inplace=True
94
- )
95
- except:
96
- pass
97
- else:
98
- self.detector = torch.nn.DataParallel(self.detector).to(self.device)
99
- import torch.backends.cudnn as cudnn
100
-
101
- cudnn.benchmark = self.cudnn_benchmark
102
-
103
- self.detector.eval()
104
-
105
- def process(
106
- self,
107
- image_path: str,
108
- max_size: int = 960,
109
- expand_ratio: float = 1.0,
110
- sharp: float = 1.0,
111
- contrast: float = 1.0,
112
- text_confidence: float = 0.7,
113
- text_threshold: float = 0.4,
114
- link_threshold: float = 0.4,
115
- slope_ths: float = 0.1,
116
- ratio_ths: float = 0.5,
117
- center_ths: float = 0.5,
118
- dim_ths: float = 0.5,
119
- space_ths: float = 1.0,
120
- add_margin: float = 0.1,
121
- min_size: float = 0.01,
122
- ) -> Tuple[BoxTuple, list[np.ndarray]]:
123
-
124
- image = Image.open(image_path).convert('RGB')
125
-
126
- tensor, inverse_ratio = self.preprocess(
127
- image, max_size, expand_ratio, sharp, contrast
128
- )
129
-
130
- scores = self.forward_net(tensor)
131
-
132
- boxes = self.detect(scores, text_confidence, text_threshold, link_threshold)
133
-
134
- image = np.array(image)
135
- region_list, box_list = self.postprocess(
136
- image,
137
- boxes,
138
- inverse_ratio,
139
- slope_ths,
140
- ratio_ths,
141
- center_ths,
142
- dim_ths,
143
- space_ths,
144
- add_margin,
145
- min_size,
146
- )
147
-
148
- # get cropped image
149
- image_list = []
150
- for region in region_list:
151
- x_min, x_max, y_min, y_max = region
152
- crop_img = image[y_min:y_max, x_min:x_max, :]
153
- image_list.append(
154
- (
155
- ((x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)),
156
- crop_img,
157
- )
158
- )
159
-
160
- for box in box_list:
161
- transformed_img = boxed_transform(image, np.array(box, dtype="float32"))
162
- image_list.append((box, transformed_img))
163
-
164
- # sort by top left point
165
- image_list = sorted(image_list, key=lambda x: (x[0][0][1], x[0][0][0]))
166
-
167
- return image_list
168
-
169
- def preprocess(
170
- self,
171
- image: Image.Image,
172
- max_size: int,
173
- expand_ratio: float = 1.0,
174
- sharp: float = 1.0,
175
- contrast: float = 1.0,
176
- ) -> torch.Tensor:
177
- if sharp != 1:
178
- enhancer = ImageEnhance.Sharpness(image)
179
- image = enhancer.enhance(sharp)
180
- if contrast != 1:
181
- enhancer = ImageEnhance.Contrast(image)
182
- image = enhancer.enhance(contrast)
183
-
184
- image = np.array(image)
185
-
186
- image, target_ratio = resize_aspect_ratio(
187
- image, max_size, interpolation=cv2.INTER_LINEAR, expand_ratio=expand_ratio
188
- )
189
- inverse_ratio = 1 / target_ratio
190
-
191
- x = np.transpose(normalize_mean_variance(image), (2, 0, 1))
192
-
193
- x = torch.tensor(np.array([x]), device=self.device)
194
-
195
- return x, inverse_ratio
196
-
197
- @torch.no_grad()
198
- def forward_net(self, tensor: torch.Tensor) -> torch.Tensor:
199
- scores, feature = self.detector(tensor)
200
- return scores[0]
201
-
202
- def detect(
203
- self,
204
- scores: torch.Tensor,
205
- text_confidence: float = 0.7,
206
- text_threshold: float = 0.4,
207
- link_threshold: float = 0.4,
208
- ) -> list[BoxTuple]:
209
- # make score and link map
210
- score_text = scores[:, :, 0].cpu().data.numpy()
211
- score_link = scores[:, :, 1].cpu().data.numpy()
212
- # extract box
213
- boxes, _ = extract_boxes(
214
- score_text, score_link, text_confidence, text_threshold, link_threshold
215
- )
216
- return boxes
217
-
218
- def postprocess(
219
- self,
220
- image: np.ndarray,
221
- boxes: list[BoxTuple],
222
- inverse_ratio: float,
223
- slope_ths: float = 0.1,
224
- ratio_ths: float = 0.5,
225
- center_ths: float = 0.5,
226
- dim_ths: float = 0.5,
227
- space_ths: float = 1.0,
228
- add_margin: float = 0.1,
229
- min_size: int = 0,
230
- ) -> Tuple[list[RegionTuple], list[BoxTuple]]:
231
-
232
- # coordinate adjustment
233
- boxes = adjust_result_coordinates(boxes, inverse_ratio)
234
-
235
- max_y, max_x, _ = image.shape
236
-
237
- # extract region and merge
238
- region_list, box_list = extract_regions_from_boxes(boxes, slope_ths)
239
-
240
- region_list = greedy_merge(
241
- region_list,
242
- ratio_ths=ratio_ths,
243
- center_ths=center_ths,
244
- dim_ths=dim_ths,
245
- space_ths=space_ths,
246
- verbose=0
247
- )
248
-
249
- # add margin
250
- region_list = [
251
- region.expand(add_margin, (max_x, max_y)).as_tuple()
252
- for region in region_list
253
- ]
254
-
255
- box_list = [box_expand(box, add_margin, (max_x, max_y)) for box in box_list]
256
-
257
- # filter by size
258
- if min_size:
259
- if min_size < 1:
260
- min_size = int(min(max_y, max_x) * min_size)
261
-
262
- region_list = [
263
- i for i in region_list if max(i[1] - i[0], i[3] - i[2]) > min_size
264
- ]
265
- box_list = [
266
- i
267
- for i in box_list
268
- if max(utils.diff([c[0] for c in i]), utils.diff([c[1] for c in i]))
269
- > min_size
270
- ]
271
-
272
- return region_list, box_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/build/lib/easyocrlite/types.py DELETED
@@ -1,5 +0,0 @@
1
- from typing import Tuple
2
-
3
- Point = Tuple[int, int]
4
- BoxTuple = Tuple[Point, Point, Point, Point]
5
- RegionTuple = Tuple[int, int, int, int]
 
 
 
 
 
 
ezocr/easyocrlite.egg-info/PKG-INFO DELETED
@@ -1,8 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: easyocrlite
3
- Version: 0.0.1
4
- License: Apache License 2.0
5
- Keywords: ocr optical character recognition deep learning neural network
6
- Classifier: Development Status :: 5 - Production/Stable
7
- Requires-Python: >=3.7
8
- License-File: LICENSE
 
 
 
 
 
 
 
 
 
ezocr/easyocrlite.egg-info/SOURCES.txt DELETED
@@ -1,11 +0,0 @@
1
- LICENSE
2
- README.md
3
- setup.py
4
- easyocrlite/__init__.py
5
- easyocrlite/reader.py
6
- easyocrlite/types.py
7
- easyocrlite.egg-info/PKG-INFO
8
- easyocrlite.egg-info/SOURCES.txt
9
- easyocrlite.egg-info/dependency_links.txt
10
- easyocrlite.egg-info/requires.txt
11
- easyocrlite.egg-info/top_level.txt
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/easyocrlite.egg-info/dependency_links.txt DELETED
@@ -1 +0,0 @@
1
-
 
 
ezocr/easyocrlite.egg-info/requires.txt DELETED
@@ -1,5 +0,0 @@
1
- torch
2
- torchvision>=0.5
3
- opencv-python-headless<=4.5.4.60
4
- numpy
5
- Pillow
 
 
 
 
 
 
ezocr/easyocrlite.egg-info/top_level.txt DELETED
@@ -1 +0,0 @@
1
- easyocrlite
 
 
ezocr/easyocrlite/__init__.py DELETED
@@ -1 +0,0 @@
1
- from easyocrlite.reader import ReaderLite
 
 
ezocr/easyocrlite/model/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .craft import CRAFT
 
 
ezocr/easyocrlite/model/craft.py DELETED
@@ -1,174 +0,0 @@
1
- """
2
- Copyright (c) 2019-present NAVER Corp.
3
- MIT License
4
- """
5
- from __future__ import annotations
6
-
7
- from collections import namedtuple
8
- from typing import Iterable, Tuple
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- import torchvision
14
- from packaging import version
15
- from torchvision import models
16
-
17
- VGGOutputs = namedtuple(
18
- "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
19
- )
20
-
21
- def init_weights(modules: Iterable[nn.Module]):
22
- for m in modules:
23
- if isinstance(m, nn.Conv2d):
24
- nn.init.xavier_uniform_(m.weight)
25
- if m.bias is not None:
26
- nn.init.zeros_(m.bias)
27
- elif isinstance(m, nn.BatchNorm2d):
28
- nn.init.constant_(m.weight, 1.0)
29
- nn.init.zeros_(m.bias)
30
- elif isinstance(m, nn.Linear):
31
- nn.init.normal_(m.weight, 0, 0.01)
32
- nn.init.zeros_(m.bias)
33
-
34
-
35
- class VGG16_BN(nn.Module):
36
- def __init__(self, pretrained: bool=True, freeze: bool=True):
37
- super().__init__()
38
- if version.parse(torchvision.__version__) >= version.parse("0.13"):
39
- vgg_pretrained_features = models.vgg16_bn(
40
- weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
41
- ).features
42
- else: # torchvision.__version__ < 0.13
43
- models.vgg.model_urls["vgg16_bn"] = models.vgg.model_urls[
44
- "vgg16_bn"
45
- ].replace("https://", "http://")
46
- vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
47
-
48
- self.slice1 = torch.nn.Sequential()
49
- self.slice2 = torch.nn.Sequential()
50
- self.slice3 = torch.nn.Sequential()
51
- self.slice4 = torch.nn.Sequential()
52
- self.slice5 = torch.nn.Sequential()
53
- for x in range(12): # conv2_2
54
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
55
- for x in range(12, 19): # conv3_3
56
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
57
- for x in range(19, 29): # conv4_3
58
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
59
- for x in range(29, 39): # conv5_3
60
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
61
-
62
- # fc6, fc7 without atrous conv
63
- self.slice5 = torch.nn.Sequential(
64
- nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
65
- nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
66
- nn.Conv2d(1024, 1024, kernel_size=1),
67
- )
68
-
69
- if not pretrained:
70
- init_weights(self.slice1.modules())
71
- init_weights(self.slice2.modules())
72
- init_weights(self.slice3.modules())
73
- init_weights(self.slice4.modules())
74
-
75
- init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
76
-
77
- if freeze:
78
- for param in self.slice1.parameters(): # only first conv
79
- param.requires_grad = False
80
-
81
- def forward(self, x: torch.Tensor) -> VGGOutputs:
82
- h = self.slice1(x)
83
- h_relu2_2 = h
84
- h = self.slice2(h)
85
- h_relu3_2 = h
86
- h = self.slice3(h)
87
- h_relu4_3 = h
88
- h = self.slice4(h)
89
- h_relu5_3 = h
90
- h = self.slice5(h)
91
- h_fc7 = h
92
-
93
- out = VGGOutputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
94
- return out
95
-
96
-
97
- class DoubleConv(nn.Module):
98
- def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
99
- super().__init__()
100
- self.conv = nn.Sequential(
101
- nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
102
- nn.BatchNorm2d(mid_ch),
103
- nn.ReLU(inplace=True),
104
- nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
105
- nn.BatchNorm2d(out_ch),
106
- nn.ReLU(inplace=True),
107
- )
108
-
109
- def forward(self, x: torch.Tensor) -> torch.Tensor:
110
- x = self.conv(x)
111
- return x
112
-
113
-
114
- class CRAFT(nn.Module):
115
- def __init__(self, pretrained: bool=False, freeze: bool=False):
116
- super(CRAFT, self).__init__()
117
-
118
- """ Base network """
119
- self.basenet = VGG16_BN(pretrained, freeze)
120
-
121
- """ U network """
122
- self.upconv1 = DoubleConv(1024, 512, 256)
123
- self.upconv2 = DoubleConv(512, 256, 128)
124
- self.upconv3 = DoubleConv(256, 128, 64)
125
- self.upconv4 = DoubleConv(128, 64, 32)
126
-
127
- num_class = 2
128
- self.conv_cls = nn.Sequential(
129
- nn.Conv2d(32, 32, kernel_size=3, padding=1),
130
- nn.ReLU(inplace=True),
131
- nn.Conv2d(32, 32, kernel_size=3, padding=1),
132
- nn.ReLU(inplace=True),
133
- nn.Conv2d(32, 16, kernel_size=3, padding=1),
134
- nn.ReLU(inplace=True),
135
- nn.Conv2d(16, 16, kernel_size=1),
136
- nn.ReLU(inplace=True),
137
- nn.Conv2d(16, num_class, kernel_size=1),
138
- )
139
-
140
- init_weights(self.upconv1.modules())
141
- init_weights(self.upconv2.modules())
142
- init_weights(self.upconv3.modules())
143
- init_weights(self.upconv4.modules())
144
- init_weights(self.conv_cls.modules())
145
-
146
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
147
- """Base network"""
148
- sources = self.basenet(x)
149
-
150
- """ U network """
151
- y = torch.cat([sources[0], sources[1]], dim=1)
152
- y = self.upconv1(y)
153
-
154
- y = F.interpolate(
155
- y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
156
- )
157
- y = torch.cat([y, sources[2]], dim=1)
158
- y = self.upconv2(y)
159
-
160
- y = F.interpolate(
161
- y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
162
- )
163
- y = torch.cat([y, sources[3]], dim=1)
164
- y = self.upconv3(y)
165
-
166
- y = F.interpolate(
167
- y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
168
- )
169
- y = torch.cat([y, sources[4]], dim=1)
170
- feature = self.upconv4(y)
171
-
172
- y = self.conv_cls(feature)
173
-
174
- return y.permute(0, 2, 3, 1), feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/easyocrlite/reader.py DELETED
@@ -1,271 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- import os
5
- from pathlib import Path
6
- from typing import Tuple
7
-
8
- import cv2
9
- import numpy as np
10
- import torch
11
- from PIL import Image, ImageEnhance
12
-
13
- from easyocrlite.model import CRAFT
14
-
15
- from easyocrlite.utils.download_utils import prepare_model
16
- from easyocrlite.utils.image_utils import (
17
- adjust_result_coordinates,
18
- boxed_transform,
19
- normalize_mean_variance,
20
- resize_aspect_ratio,
21
- )
22
- from easyocrlite.utils.detect_utils import (
23
- extract_boxes,
24
- extract_regions_from_boxes,
25
- box_expand,
26
- greedy_merge,
27
- )
28
- from easyocrlite.types import BoxTuple, RegionTuple
29
- import easyocrlite.utils.utils as utils
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
- MODULE_PATH = (
34
- os.environ.get("EASYOCR_MODULE_PATH")
35
- or os.environ.get("MODULE_PATH")
36
- or os.path.expanduser("~/.EasyOCR/")
37
- )
38
-
39
-
40
- class ReaderLite(object):
41
- def __init__(
42
- self,
43
- gpu=True,
44
- model_storage_directory=None,
45
- download_enabled=True,
46
- verbose=True,
47
- quantize=True,
48
- cudnn_benchmark=False,
49
- ):
50
-
51
- self.verbose = verbose
52
-
53
- model_storage_directory = Path(
54
- model_storage_directory
55
- if model_storage_directory
56
- else MODULE_PATH + "/model"
57
- )
58
- self.detector_path = prepare_model(
59
- model_storage_directory, download_enabled, verbose
60
- )
61
-
62
- self.quantize = quantize
63
- self.cudnn_benchmark = cudnn_benchmark
64
- if gpu is False:
65
- self.device = "cpu"
66
- if verbose:
67
- logger.warning(
68
- "Using CPU. Note: This module is much faster with a GPU."
69
- )
70
- elif not torch.cuda.is_available():
71
- self.device = "cpu"
72
- if verbose:
73
- logger.warning(
74
- "CUDA not available - defaulting to CPU. Note: This module is much faster with a GPU."
75
- )
76
- elif gpu is True:
77
- self.device = "cuda"
78
- else:
79
- self.device = gpu
80
-
81
- self.detector = CRAFT()
82
-
83
- state_dict = torch.load(self.detector_path, map_location=self.device)
84
- if list(state_dict.keys())[0].startswith("module"):
85
- state_dict = {k[7:]: v for k, v in state_dict.items()}
86
-
87
- self.detector.load_state_dict(state_dict)
88
-
89
- if self.device == "cpu":
90
- if self.quantize:
91
- try:
92
- torch.quantization.quantize_dynamic(
93
- self.detector, dtype=torch.qint8, inplace=True
94
- )
95
- except:
96
- pass
97
- else:
98
- self.detector = torch.nn.DataParallel(self.detector).to(self.device)
99
- import torch.backends.cudnn as cudnn
100
-
101
- cudnn.benchmark = self.cudnn_benchmark
102
-
103
- self.detector.eval()
104
-
105
- def process(
106
- self,
107
- image_path: str,
108
- max_size: int = 960,
109
- expand_ratio: float = 1.0,
110
- sharp: float = 1.0,
111
- contrast: float = 1.0,
112
- text_confidence: float = 0.7,
113
- text_threshold: float = 0.4,
114
- link_threshold: float = 0.4,
115
- slope_ths: float = 0.1,
116
- ratio_ths: float = 0.5,
117
- center_ths: float = 0.5,
118
- dim_ths: float = 0.5,
119
- space_ths: float = 1.0,
120
- add_margin: float = 0.1,
121
- min_size: float = 0.01,
122
- ) -> Tuple[BoxTuple, list[np.ndarray]]:
123
-
124
- image = Image.open(image_path).convert('RGB')
125
- tensor, inverse_ratio = self.preprocess(
126
- image, max_size, expand_ratio, sharp, contrast
127
- )
128
-
129
- scores = self.forward_net(tensor)
130
-
131
- boxes = self.detect(scores, text_confidence, text_threshold, link_threshold)
132
-
133
- image = np.array(image)
134
- region_list, box_list = self.postprocess(
135
- image,
136
- boxes,
137
- inverse_ratio,
138
- slope_ths,
139
- ratio_ths,
140
- center_ths,
141
- dim_ths,
142
- space_ths,
143
- add_margin,
144
- min_size,
145
- )
146
-
147
- # get cropped image
148
- image_list = []
149
- for region in region_list:
150
- x_min, x_max, y_min, y_max = region
151
- crop_img = image[y_min:y_max, x_min:x_max, :]
152
- image_list.append(
153
- (
154
- ((x_min, y_min), (x_max, y_min), (x_max, y_max), (x_min, y_max)),
155
- crop_img,
156
- )
157
- )
158
-
159
- for box in box_list:
160
- transformed_img = boxed_transform(image, np.array(box, dtype="float32"))
161
- image_list.append((box, transformed_img))
162
-
163
- # sort by top left point
164
- image_list = sorted(image_list, key=lambda x: (x[0][0][1], x[0][0][0]))
165
-
166
- return image_list
167
-
168
- def preprocess(
169
- self,
170
- image: Image.Image,
171
- max_size: int,
172
- expand_ratio: float = 1.0,
173
- sharp: float = 1.0,
174
- contrast: float = 1.0,
175
- ) -> torch.Tensor:
176
- if sharp != 1:
177
- enhancer = ImageEnhance.Sharpness(image)
178
- image = enhancer.enhance(sharp)
179
- if contrast != 1:
180
- enhancer = ImageEnhance.Contrast(image)
181
- image = enhancer.enhance(contrast)
182
-
183
- image = np.array(image)
184
-
185
- image, target_ratio = resize_aspect_ratio(
186
- image, max_size, interpolation=cv2.INTER_LINEAR, expand_ratio=expand_ratio
187
- )
188
- inverse_ratio = 1 / target_ratio
189
-
190
- x = np.transpose(normalize_mean_variance(image), (2, 0, 1))
191
-
192
- x = torch.tensor(np.array([x]), device=self.device)
193
-
194
- return x, inverse_ratio
195
-
196
- @torch.no_grad()
197
- def forward_net(self, tensor: torch.Tensor) -> torch.Tensor:
198
- scores, feature = self.detector(tensor)
199
- return scores[0]
200
-
201
- def detect(
202
- self,
203
- scores: torch.Tensor,
204
- text_confidence: float = 0.7,
205
- text_threshold: float = 0.4,
206
- link_threshold: float = 0.4,
207
- ) -> list[BoxTuple]:
208
- # make score and link map
209
- score_text = scores[:, :, 0].cpu().data.numpy()
210
- score_link = scores[:, :, 1].cpu().data.numpy()
211
- # extract box
212
- boxes, _ = extract_boxes(
213
- score_text, score_link, text_confidence, text_threshold, link_threshold
214
- )
215
- return boxes
216
-
217
- def postprocess(
218
- self,
219
- image: np.ndarray,
220
- boxes: list[BoxTuple],
221
- inverse_ratio: float,
222
- slope_ths: float = 0.1,
223
- ratio_ths: float = 0.5,
224
- center_ths: float = 0.5,
225
- dim_ths: float = 0.5,
226
- space_ths: float = 1.0,
227
- add_margin: float = 0.1,
228
- min_size: int = 0,
229
- ) -> Tuple[list[RegionTuple], list[BoxTuple]]:
230
-
231
- # coordinate adjustment
232
- boxes = adjust_result_coordinates(boxes, inverse_ratio)
233
-
234
- max_y, max_x, _ = image.shape
235
-
236
- # extract region and merge
237
- region_list, box_list = extract_regions_from_boxes(boxes, slope_ths)
238
-
239
- region_list = greedy_merge(
240
- region_list,
241
- ratio_ths=ratio_ths,
242
- center_ths=center_ths,
243
- dim_ths=dim_ths,
244
- space_ths=space_ths,
245
- verbose=0
246
- )
247
-
248
- # add margin
249
- region_list = [
250
- region.expand(add_margin, (max_x, max_y)).as_tuple()
251
- for region in region_list
252
- ]
253
-
254
- box_list = [box_expand(box, add_margin, (max_x, max_y)) for box in box_list]
255
-
256
- # filter by size
257
- if min_size:
258
- if min_size < 1:
259
- min_size = int(min(max_y, max_x) * min_size)
260
-
261
- region_list = [
262
- i for i in region_list if max(i[1] - i[0], i[3] - i[2]) > min_size
263
- ]
264
- box_list = [
265
- i
266
- for i in box_list
267
- if max(utils.diff([c[0] for c in i]), utils.diff([c[1] for c in i]))
268
- > min_size
269
- ]
270
-
271
- return region_list, box_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ezocr/easyocrlite/types.py DELETED
@@ -1,5 +0,0 @@
1
- from typing import Tuple
2
-
3
- Point = Tuple[int, int]
4
- BoxTuple = Tuple[Point, Point, Point, Point]
5
- RegionTuple = Tuple[int, int, int, int]