Spaces:
No application file
No application file
Upload 63 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Indic-BERT-v1-master/.gitignore +116 -0
- Indic-BERT-v1-master/LICENSE +9 -0
- Indic-BERT-v1-master/albert/CONTRIBUTING.md +28 -0
- Indic-BERT-v1-master/albert/LICENSE +202 -0
- Indic-BERT-v1-master/albert/README.md +324 -0
- Indic-BERT-v1-master/albert/__init__.py +14 -0
- Indic-BERT-v1-master/albert/albert_glue_fine_tuning_tutorial.ipynb +303 -0
- Indic-BERT-v1-master/albert/classifier_utils.py +1037 -0
- Indic-BERT-v1-master/albert/create_pretraining_data.py +654 -0
- Indic-BERT-v1-master/albert/evaluate.py +0 -0
- Indic-BERT-v1-master/albert/export_checkpoints.py +162 -0
- Indic-BERT-v1-master/albert/export_to_tfhub.py +177 -0
- Indic-BERT-v1-master/albert/fine_tuning_utils.py +85 -0
- Indic-BERT-v1-master/albert/lamb_optimizer.py +148 -0
- Indic-BERT-v1-master/albert/modeling.py +1209 -0
- Indic-BERT-v1-master/albert/modeling_test.py +309 -0
- Indic-BERT-v1-master/albert/optimization.py +204 -0
- Indic-BERT-v1-master/albert/optimization_test.py +50 -0
- Indic-BERT-v1-master/albert/race_utils.py +432 -0
- Indic-BERT-v1-master/albert/requirements.txt +5 -0
- Indic-BERT-v1-master/albert/run_classifier.py +488 -0
- Indic-BERT-v1-master/albert/run_glue.sh +52 -0
- Indic-BERT-v1-master/albert/run_pretraining.py +577 -0
- Indic-BERT-v1-master/albert/run_pretraining_test.py +133 -0
- Indic-BERT-v1-master/albert/run_race.py +458 -0
- Indic-BERT-v1-master/albert/run_squad_v1.py +547 -0
- Indic-BERT-v1-master/albert/run_squad_v2.py +516 -0
- Indic-BERT-v1-master/albert/run_trivial_model_test.sh +27 -0
- Indic-BERT-v1-master/albert/squad_utils.py +1735 -0
- Indic-BERT-v1-master/albert/tokenization.py +465 -0
- Indic-BERT-v1-master/albert/tokenization_test.py +137 -0
- Indic-BERT-v1-master/albert/train.py +0 -0
- Indic-BERT-v1-master/configs/albert_base_config.json +21 -0
- Indic-BERT-v1-master/configs/albert_large_config.json +21 -0
- Indic-BERT-v1-master/docs/advanced-usage.md +45 -0
- Indic-BERT-v1-master/docs/arxiv2020_indicnlp_corpus.pdf +0 -0
- Indic-BERT-v1-master/fine_tune/__init__.py +0 -0
- Indic-BERT-v1-master/fine_tune/cli.py +196 -0
- Indic-BERT-v1-master/fine_tune/data/__init__.py +27 -0
- Indic-BERT-v1-master/fine_tune/data/examples.py +327 -0
- Indic-BERT-v1-master/fine_tune/data/processors.py +521 -0
- Indic-BERT-v1-master/fine_tune/modules/__init__.py +22 -0
- Indic-BERT-v1-master/fine_tune/modules/base.py +397 -0
- Indic-BERT-v1-master/fine_tune/modules/masked_lm.py +155 -0
- Indic-BERT-v1-master/fine_tune/modules/multiple_choice.py +51 -0
- Indic-BERT-v1-master/fine_tune/modules/question_answering.py +0 -0
- Indic-BERT-v1-master/fine_tune/modules/text_classification.py +70 -0
- Indic-BERT-v1-master/fine_tune/modules/token_classification.py +87 -0
- Indic-BERT-v1-master/fine_tune/modules/utils.py +4 -0
- Indic-BERT-v1-master/fine_tune/modules/xsent_retrieval.py +111 -0
Indic-BERT-v1-master/.gitignore
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Initially taken from Github's Python gitignore file
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
|
62 |
+
# Flask stuff:
|
63 |
+
instance/
|
64 |
+
.webassets-cache
|
65 |
+
|
66 |
+
# Scrapy stuff:
|
67 |
+
.scrapy
|
68 |
+
|
69 |
+
# Sphinx documentation
|
70 |
+
docs/_build/
|
71 |
+
|
72 |
+
# PyBuilder
|
73 |
+
target/
|
74 |
+
|
75 |
+
# Jupyter Notebook
|
76 |
+
.ipynb_checkpoints
|
77 |
+
|
78 |
+
# IPython
|
79 |
+
profile_default/
|
80 |
+
ipython_config.py
|
81 |
+
|
82 |
+
# pyenv
|
83 |
+
.python-version
|
84 |
+
|
85 |
+
# celery beat schedule file
|
86 |
+
celerybeat-schedule
|
87 |
+
|
88 |
+
# SageMath parsed files
|
89 |
+
*.sage.py
|
90 |
+
|
91 |
+
# Environments
|
92 |
+
.env
|
93 |
+
.venv
|
94 |
+
env/
|
95 |
+
venv/
|
96 |
+
ENV/
|
97 |
+
env.bak/
|
98 |
+
venv.bak/
|
99 |
+
|
100 |
+
# Spyder project settings
|
101 |
+
.spyderproject
|
102 |
+
.spyproject
|
103 |
+
|
104 |
+
# Rope project settings
|
105 |
+
.ropeproject
|
106 |
+
|
107 |
+
# mkdocs documentation
|
108 |
+
/site
|
109 |
+
|
110 |
+
# mypy
|
111 |
+
.mypy_cache/
|
112 |
+
.dmypy.json
|
113 |
+
dmypy.json
|
114 |
+
|
115 |
+
# Pyre type checker
|
116 |
+
.pyre/
|
Indic-BERT-v1-master/LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020-present AI4Bharat
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
8 |
+
|
9 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
Indic-BERT-v1-master/albert/CONTRIBUTING.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Contribute
|
2 |
+
|
3 |
+
We'd love to accept your patches and contributions to this project. There are
|
4 |
+
just a few small guidelines you need to follow.
|
5 |
+
|
6 |
+
## Contributor License Agreement
|
7 |
+
|
8 |
+
Contributions to this project must be accompanied by a Contributor License
|
9 |
+
Agreement. You (or your employer) retain the copyright to your contribution;
|
10 |
+
this simply gives us permission to use and redistribute your contributions as
|
11 |
+
part of the project. Head over to <https://cla.developers.google.com/> to see
|
12 |
+
your current agreements on file or to sign a new one.
|
13 |
+
|
14 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
15 |
+
(even if it was for a different project), you probably don't need to do it
|
16 |
+
again.
|
17 |
+
|
18 |
+
## Code reviews
|
19 |
+
|
20 |
+
All submissions, including submissions by project members, require review. We
|
21 |
+
use GitHub pull requests for this purpose. Consult
|
22 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
23 |
+
information on using pull requests.
|
24 |
+
|
25 |
+
## Community Guidelines
|
26 |
+
|
27 |
+
This project follows
|
28 |
+
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
Indic-BERT-v1-master/albert/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
Indic-BERT-v1-master/albert/README.md
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALBERT
|
2 |
+
======
|
3 |
+
|
4 |
+
*************** Changes from Original Implementation ***************
|
5 |
+
|
6 |
+
1. Remove sentence order in `run_pretraining.py`
|
7 |
+
2. Modify `_is_start_piece_sp` function in `create_pretraining_data.py` to account for non-English languages.
|
8 |
+
|
9 |
+
***************New March 28, 2020 ***************
|
10 |
+
|
11 |
+
Add a colab [tutorial](https://github.com/google-research/albert/blob/master/albert_glue_fine_tuning_tutorial.ipynb) to run fine-tuning for GLUE datasets.
|
12 |
+
|
13 |
+
***************New January 7, 2020 ***************
|
14 |
+
|
15 |
+
v2 TF-Hub models should be working now with TF 1.15, as we removed the
|
16 |
+
native Einsum op from the graph. See updated TF-Hub links below.
|
17 |
+
|
18 |
+
***************New December 30, 2019 ***************
|
19 |
+
|
20 |
+
Chinese models are released. We would like to thank [CLUE team ](https://github.com/CLUEbenchmark/CLUE) for providing the training data.
|
21 |
+
|
22 |
+
- [Base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz)
|
23 |
+
- [Large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz)
|
24 |
+
- [Xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz)
|
25 |
+
- [Xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz)
|
26 |
+
|
27 |
+
Version 2 of ALBERT models is released.
|
28 |
+
|
29 |
+
- Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/3)]
|
30 |
+
- Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/3)]
|
31 |
+
- Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/3)]
|
32 |
+
- Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/3)]
|
33 |
+
|
34 |
+
In this version, we apply 'no dropout', 'additional training data' and 'long training time' strategies to all models. We train ALBERT-base for 10M steps and other models for 3M steps.
|
35 |
+
|
36 |
+
The result comparison to the v1 models is as followings:
|
37 |
+
|
38 |
+
| | Average | SQuAD1.1 | SQuAD2.0 | MNLI | SST-2 | RACE |
|
39 |
+
|----------------|----------|----------|----------|----------|----------|----------|
|
40 |
+
|V2 |
|
41 |
+
|ALBERT-base |82.3 |90.2/83.2 |82.1/79.3 |84.6 |92.9 |66.8 |
|
42 |
+
|ALBERT-large |85.7 |91.8/85.2 |84.9/81.8 |86.5 |94.9 |75.2 |
|
43 |
+
|ALBERT-xlarge |87.9 |92.9/86.4 |87.9/84.1 |87.9 |95.4 |80.7 |
|
44 |
+
|ALBERT-xxlarge |90.9 |94.6/89.1 |89.8/86.9 |90.6 |96.8 |86.8 |
|
45 |
+
|V1 |
|
46 |
+
|ALBERT-base |80.1 |89.3/82.3 | 80.0/77.1|81.6 |90.3 | 64.0 |
|
47 |
+
|ALBERT-large |82.4 |90.6/83.9 | 82.3/79.4|83.5 |91.7 | 68.5 |
|
48 |
+
|ALBERT-xlarge |85.5 |92.5/86.1 | 86.1/83.1|86.4 |92.4 | 74.8 |
|
49 |
+
|ALBERT-xxlarge |91.0 |94.8/89.3 | 90.2/87.4|90.8 |96.9 | 86.5 |
|
50 |
+
|
51 |
+
The comparison shows that for ALBERT-base, ALBERT-large, and ALBERT-xlarge, v2 is much better than v1, indicating the importance of applying the above three strategies. On average, ALBERT-xxlarge is slightly worse than the v1, because of the following two reasons: 1) Training additional 1.5 M steps (the only difference between these two models is training for 1.5M steps and 3M steps) did not lead to significant performance improvement. 2) For v1, we did a little bit hyperparameter search among the parameters sets given by BERT, Roberta, and XLnet. For v2, we simply adopt the parameters from v1 except for RACE, where we use a learning rate of 1e-5 and 0 [ALBERT DR](https://arxiv.org/pdf/1909.11942.pdf) (dropout rate for ALBERT in finetuning). The original (v1) RACE hyperparameter will cause model divergence for v2 models. Given that the downstream tasks are sensitive to the fine-tuning hyperparameters, we should be careful about so called slight improvements.
|
52 |
+
|
53 |
+
ALBERT is "A Lite" version of BERT, a popular unsupervised language
|
54 |
+
representation learning algorithm. ALBERT uses parameter-reduction techniques
|
55 |
+
that allow for large-scale configurations, overcome previous memory limitations,
|
56 |
+
and achieve better behavior with respect to model degradation.
|
57 |
+
|
58 |
+
For a technical description of the algorithm, see our paper:
|
59 |
+
|
60 |
+
[ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942)
|
61 |
+
|
62 |
+
Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut
|
63 |
+
|
64 |
+
Release Notes
|
65 |
+
=============
|
66 |
+
|
67 |
+
- Initial release: 10/9/2019
|
68 |
+
|
69 |
+
Results
|
70 |
+
=======
|
71 |
+
|
72 |
+
Performance of ALBERT on GLUE benchmark results using a single-model setup on
|
73 |
+
dev:
|
74 |
+
|
75 |
+
| Models | MNLI | QNLI | QQP | RTE | SST | MRPC | CoLA | STS |
|
76 |
+
|-------------------|----------|----------|----------|----------|----------|----------|----------|----------|
|
77 |
+
| BERT-large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 |
|
78 |
+
| XLNet-large | 89.8 | 93.9 | 91.8 | 83.8 | 95.6 | 89.2 | 63.6 | 91.8 |
|
79 |
+
| RoBERTa-large | 90.2 | 94.7 | **92.2** | 86.6 | 96.4 | **90.9** | 68.0 | 92.4 |
|
80 |
+
| ALBERT (1M) | 90.4 | 95.2 | 92.0 | 88.1 | 96.8 | 90.2 | 68.7 | 92.7 |
|
81 |
+
| ALBERT (1.5M) | **90.8** | **95.3** | **92.2** | **89.2** | **96.9** | **90.9** | **71.4** | **93.0** |
|
82 |
+
|
83 |
+
Performance of ALBERT-xxl on SQuaD and RACE benchmarks using a single-model
|
84 |
+
setup:
|
85 |
+
|
86 |
+
|Models | SQuAD1.1 dev | SQuAD2.0 dev | SQuAD2.0 test | RACE test (Middle/High) |
|
87 |
+
|--------------------------|---------------|---------------|---------------|-------------------------|
|
88 |
+
|BERT-large | 90.9/84.1 | 81.8/79.0 | 89.1/86.3 | 72.0 (76.6/70.1) |
|
89 |
+
|XLNet | 94.5/89.0 | 88.8/86.1 | 89.1/86.3 | 81.8 (85.5/80.2) |
|
90 |
+
|RoBERTa | 94.6/88.9 | 89.4/86.5 | 89.8/86.8 | 83.2 (86.5/81.3) |
|
91 |
+
|UPM | - | - | 89.9/87.2 | - |
|
92 |
+
|XLNet + SG-Net Verifier++ | - | - | 90.1/87.2 | - |
|
93 |
+
|ALBERT (1M) | 94.8/89.2 | 89.9/87.2 | - | 86.0 (88.2/85.1) |
|
94 |
+
|ALBERT (1.5M) | **94.8/89.3** | **90.2/87.4** | **90.9/88.1** | **86.5 (89.0/85.5)** |
|
95 |
+
|
96 |
+
|
97 |
+
Pre-trained Models
|
98 |
+
==================
|
99 |
+
TF-Hub modules are available:
|
100 |
+
|
101 |
+
- Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/1)]
|
102 |
+
- Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/1)]
|
103 |
+
- Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/1)]
|
104 |
+
- Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/1)]
|
105 |
+
|
106 |
+
Example usage of the TF-Hub module in code:
|
107 |
+
|
108 |
+
```
|
109 |
+
tags = set()
|
110 |
+
if is_training:
|
111 |
+
tags.add("train")
|
112 |
+
albert_module = hub.Module("https://tfhub.dev/google/albert_base/1", tags=tags,
|
113 |
+
trainable=True)
|
114 |
+
albert_inputs = dict(
|
115 |
+
input_ids=input_ids,
|
116 |
+
input_mask=input_mask,
|
117 |
+
segment_ids=segment_ids)
|
118 |
+
albert_outputs = albert_module(
|
119 |
+
inputs=albert_inputs,
|
120 |
+
signature="tokens",
|
121 |
+
as_dict=True)
|
122 |
+
|
123 |
+
# If you want to use the token-level output, use
|
124 |
+
# albert_outputs["sequence_output"] instead.
|
125 |
+
output_layer = albert_outputs["pooled_output"]
|
126 |
+
```
|
127 |
+
|
128 |
+
Most of the fine-tuning scripts in this repository support TF-hub modules
|
129 |
+
via the `--albert_hub_module_handle` flag.
|
130 |
+
|
131 |
+
Pre-training Instructions
|
132 |
+
=========================
|
133 |
+
To pretrain ALBERT, use `run_pretraining.py`:
|
134 |
+
|
135 |
+
```
|
136 |
+
pip install -r albert/requirements.txt
|
137 |
+
python -m albert.run_pretraining \
|
138 |
+
--input_file=... \
|
139 |
+
--output_dir=... \
|
140 |
+
--init_checkpoint=... \
|
141 |
+
--albert_config_file=... \
|
142 |
+
--do_train \
|
143 |
+
--do_eval \
|
144 |
+
--train_batch_size=4096 \
|
145 |
+
--eval_batch_size=64 \
|
146 |
+
--max_seq_length=512 \
|
147 |
+
--max_predictions_per_seq=20 \
|
148 |
+
--optimizer='lamb' \
|
149 |
+
--learning_rate=.00176 \
|
150 |
+
--num_train_steps=125000 \
|
151 |
+
--num_warmup_steps=3125 \
|
152 |
+
--save_checkpoints_steps=5000
|
153 |
+
```
|
154 |
+
|
155 |
+
Fine-tuning on GLUE
|
156 |
+
===================
|
157 |
+
To fine-tune and evaluate a pretrained ALBERT on GLUE, please see the
|
158 |
+
convenience script `run_glue.sh`.
|
159 |
+
|
160 |
+
Lower-level use cases may want to use the `run_classifier.py` script directly.
|
161 |
+
The `run_classifier.py` script is used both for fine-tuning and evaluation of
|
162 |
+
ALBERT on individual GLUE benchmark tasks, such as MNLI:
|
163 |
+
|
164 |
+
```
|
165 |
+
pip install -r albert/requirements.txt
|
166 |
+
python -m albert.run_classifier \
|
167 |
+
--data_dir=... \
|
168 |
+
--output_dir=... \
|
169 |
+
--init_checkpoint=... \
|
170 |
+
--albert_config_file=... \
|
171 |
+
--spm_model_file=... \
|
172 |
+
--do_train \
|
173 |
+
--do_eval \
|
174 |
+
--do_predict \
|
175 |
+
--do_lower_case \
|
176 |
+
--max_seq_length=128 \
|
177 |
+
--optimizer=adamw \
|
178 |
+
--task_name=MNLI \
|
179 |
+
--warmup_step=1000 \
|
180 |
+
--learning_rate=3e-5 \
|
181 |
+
--train_step=10000 \
|
182 |
+
--save_checkpoints_steps=100 \
|
183 |
+
--train_batch_size=128
|
184 |
+
```
|
185 |
+
|
186 |
+
Good default flag values for each GLUE task can be found in `run_glue.sh`.
|
187 |
+
|
188 |
+
You can fine-tune the model starting from TF-Hub modules instead of raw
|
189 |
+
checkpoints by setting e.g.
|
190 |
+
`--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
|
191 |
+
of `--init_checkpoint`.
|
192 |
+
|
193 |
+
You can find the spm_model_file in the tar files or under the assets folder of
|
194 |
+
the tf-hub module. The name of the model file is "30k-clean.model".
|
195 |
+
|
196 |
+
After evaluation, the script should report some output like this:
|
197 |
+
|
198 |
+
```
|
199 |
+
***** Eval results *****
|
200 |
+
global_step = ...
|
201 |
+
loss = ...
|
202 |
+
masked_lm_accuracy = ...
|
203 |
+
masked_lm_loss = ...
|
204 |
+
sentence_order_accuracy = ...
|
205 |
+
sentence_order_loss = ...
|
206 |
+
```
|
207 |
+
|
208 |
+
Fine-tuning on SQuAD
|
209 |
+
====================
|
210 |
+
To fine-tune and evaluate a pretrained model on SQuAD v1, use the
|
211 |
+
`run_squad_v1.py` script:
|
212 |
+
|
213 |
+
```
|
214 |
+
pip install -r albert/requirements.txt
|
215 |
+
python -m albert.run_squad_v1 \
|
216 |
+
--albert_config_file=... \
|
217 |
+
--output_dir=... \
|
218 |
+
--train_file=... \
|
219 |
+
--predict_file=... \
|
220 |
+
--train_feature_file=... \
|
221 |
+
--predict_feature_file=... \
|
222 |
+
--predict_feature_left_file=... \
|
223 |
+
--init_checkpoint=... \
|
224 |
+
--spm_model_file=... \
|
225 |
+
--do_lower_case \
|
226 |
+
--max_seq_length=384 \
|
227 |
+
--doc_stride=128 \
|
228 |
+
--max_query_length=64 \
|
229 |
+
--do_train=true \
|
230 |
+
--do_predict=true \
|
231 |
+
--train_batch_size=48 \
|
232 |
+
--predict_batch_size=8 \
|
233 |
+
--learning_rate=5e-5 \
|
234 |
+
--num_train_epochs=2.0 \
|
235 |
+
--warmup_proportion=.1 \
|
236 |
+
--save_checkpoints_steps=5000 \
|
237 |
+
--n_best_size=20 \
|
238 |
+
--max_answer_length=30
|
239 |
+
```
|
240 |
+
|
241 |
+
You can fine-tune the model starting from TF-Hub modules instead of raw
|
242 |
+
checkpoints by setting e.g.
|
243 |
+
`--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
|
244 |
+
of `--init_checkpoint`.
|
245 |
+
|
246 |
+
For SQuAD v2, use the `run_squad_v2.py` script:
|
247 |
+
|
248 |
+
```
|
249 |
+
pip install -r albert/requirements.txt
|
250 |
+
python -m albert.run_squad_v2 \
|
251 |
+
--albert_config_file=... \
|
252 |
+
--output_dir=... \
|
253 |
+
--train_file=... \
|
254 |
+
--predict_file=... \
|
255 |
+
--train_feature_file=... \
|
256 |
+
--predict_feature_file=... \
|
257 |
+
--predict_feature_left_file=... \
|
258 |
+
--init_checkpoint=... \
|
259 |
+
--spm_model_file=... \
|
260 |
+
--do_lower_case \
|
261 |
+
--max_seq_length=384 \
|
262 |
+
--doc_stride=128 \
|
263 |
+
--max_query_length=64 \
|
264 |
+
--do_train \
|
265 |
+
--do_predict \
|
266 |
+
--train_batch_size=48 \
|
267 |
+
--predict_batch_size=8 \
|
268 |
+
--learning_rate=5e-5 \
|
269 |
+
--num_train_epochs=2.0 \
|
270 |
+
--warmup_proportion=.1 \
|
271 |
+
--save_checkpoints_steps=5000 \
|
272 |
+
--n_best_size=20 \
|
273 |
+
--max_answer_length=30
|
274 |
+
```
|
275 |
+
|
276 |
+
You can fine-tune the model starting from TF-Hub modules instead of raw
|
277 |
+
checkpoints by setting e.g.
|
278 |
+
`--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
|
279 |
+
of `--init_checkpoint`.
|
280 |
+
|
281 |
+
Fine-tuning on RACE
|
282 |
+
===================
|
283 |
+
For RACE, use the `run_race.py` script:
|
284 |
+
|
285 |
+
```
|
286 |
+
pip install -r albert/requirements.txt
|
287 |
+
python -m albert.run_race \
|
288 |
+
--albert_config_file=... \
|
289 |
+
--output_dir=... \
|
290 |
+
--train_file=... \
|
291 |
+
--eval_file=... \
|
292 |
+
--data_dir=...\
|
293 |
+
--init_checkpoint=... \
|
294 |
+
--spm_model_file=... \
|
295 |
+
--max_seq_length=512 \
|
296 |
+
--max_qa_length=128 \
|
297 |
+
--do_train \
|
298 |
+
--do_eval \
|
299 |
+
--train_batch_size=32 \
|
300 |
+
--eval_batch_size=8 \
|
301 |
+
--learning_rate=1e-5 \
|
302 |
+
--train_step=12000 \
|
303 |
+
--warmup_step=1000 \
|
304 |
+
--save_checkpoints_steps=100
|
305 |
+
```
|
306 |
+
|
307 |
+
You can fine-tune the model starting from TF-Hub modules instead of raw
|
308 |
+
checkpoints by setting e.g.
|
309 |
+
`--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
|
310 |
+
of `--init_checkpoint`.
|
311 |
+
|
312 |
+
SentencePiece
|
313 |
+
=============
|
314 |
+
Command for generating the sentence piece vocabulary:
|
315 |
+
|
316 |
+
```
|
317 |
+
spm_train \
|
318 |
+
--input all.txt --model_prefix=30k-clean --vocab_size=30000 --logtostderr
|
319 |
+
--pad_id=0 --unk_id=1 --eos_id=-1 --bos_id=-1
|
320 |
+
--control_symbols=[CLS],[SEP],[MASK]
|
321 |
+
--user_defined_symbols="(,),\",-,.,–,£,€"
|
322 |
+
--shuffle_input_sentence=true --input_sentence_size=10000000
|
323 |
+
--character_coverage=0.99995 --model_type=unigram
|
324 |
+
```
|
Indic-BERT-v1-master/albert/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
Indic-BERT-v1-master/albert/albert_glue_fine_tuning_tutorial.ipynb
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"name": "albert_glue_fine_tuning_tutorial",
|
7 |
+
"provenance": [],
|
8 |
+
"collapsed_sections": [],
|
9 |
+
"toc_visible": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"accelerator": "TPU"
|
16 |
+
},
|
17 |
+
"cells": [
|
18 |
+
{
|
19 |
+
"cell_type": "markdown",
|
20 |
+
"metadata": {
|
21 |
+
"id": "y8SJfpgTccDB",
|
22 |
+
"colab_type": "text"
|
23 |
+
},
|
24 |
+
"source": [
|
25 |
+
"\n",
|
26 |
+
"<a href=\"https://colab.research.google.com/github/google-research/albert/blob/master/albert_glue_fine_tuning_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"metadata": {
|
32 |
+
"id": "wHQH4OCHZ9bq",
|
33 |
+
"colab_type": "code",
|
34 |
+
"cellView": "form",
|
35 |
+
"colab": {}
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"# @title Copyright 2020 The ALBERT Authors. All Rights Reserved.\n",
|
39 |
+
"#\n",
|
40 |
+
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
41 |
+
"# you may not use this file except in compliance with the License.\n",
|
42 |
+
"# You may obtain a copy of the License at\n",
|
43 |
+
"#\n",
|
44 |
+
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
45 |
+
"#\n",
|
46 |
+
"# Unless required by applicable law or agreed to in writing, software\n",
|
47 |
+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
48 |
+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
49 |
+
"# See the License for the specific language governing permissions and\n",
|
50 |
+
"# limitations under the License.\n",
|
51 |
+
"# =============================================================================="
|
52 |
+
],
|
53 |
+
"execution_count": 0,
|
54 |
+
"outputs": []
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "markdown",
|
58 |
+
"metadata": {
|
59 |
+
"id": "rkTLZ3I4_7c_",
|
60 |
+
"colab_type": "text"
|
61 |
+
},
|
62 |
+
"source": [
|
63 |
+
"# ALBERT End to End (Fine-tuning + Predicting) with Cloud TPU"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"metadata": {
|
69 |
+
"id": "1wtjs1QDb3DX",
|
70 |
+
"colab_type": "text"
|
71 |
+
},
|
72 |
+
"source": [
|
73 |
+
"## Overview\n",
|
74 |
+
"\n",
|
75 |
+
"ALBERT is \"A Lite\" version of BERT, a popular unsupervised language representation learning algorithm. ALBERT uses parameter-reduction techniques that allow for large-scale configurations, overcome previous memory limitations, and achieve better behavior with respect to model degradation.\n",
|
76 |
+
"\n",
|
77 |
+
"For a technical description of the algorithm, see our paper:\n",
|
78 |
+
"\n",
|
79 |
+
"https://arxiv.org/abs/1909.11942\n",
|
80 |
+
"\n",
|
81 |
+
"Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut\n",
|
82 |
+
"\n",
|
83 |
+
"This Colab demonstates using a free Colab Cloud TPU to fine-tune GLUE tasks built on top of pretrained ALBERT models and \n",
|
84 |
+
"run predictions on tuned model. The colab demonsrates loading pretrained ALBERT models from both [TF Hub](https://www.tensorflow.org/hub) and checkpoints.\n",
|
85 |
+
"\n",
|
86 |
+
"**Note:** You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud \n",
|
87 |
+
"Storage) bucket for this Colab to run.\n",
|
88 |
+
"\n",
|
89 |
+
"Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.\n",
|
90 |
+
"\n",
|
91 |
+
"This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File > View on GitHub**."
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "markdown",
|
96 |
+
"metadata": {
|
97 |
+
"id": "Ld-JXlueIuPH",
|
98 |
+
"colab_type": "text"
|
99 |
+
},
|
100 |
+
"source": [
|
101 |
+
"## Instructions"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "markdown",
|
106 |
+
"metadata": {
|
107 |
+
"id": "POkof5uHaQ_c",
|
108 |
+
"colab_type": "text"
|
109 |
+
},
|
110 |
+
"source": [
|
111 |
+
"<h3><a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a> Train on TPU</h3>\n",
|
112 |
+
"\n",
|
113 |
+
" 1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the BUCKET parameter in the \"Parameters\" section below.\n",
|
114 |
+
" \n",
|
115 |
+
" 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
|
116 |
+
" 1. Click Runtime again and select **Runtime > Run All** (Watch out: the \"Colab-only auth for this notebook and the TPU\" cell requires user input). You can also run the cells manually with Shift-ENTER."
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "markdown",
|
121 |
+
"metadata": {
|
122 |
+
"id": "UdMmwCJFaT8F",
|
123 |
+
"colab_type": "text"
|
124 |
+
},
|
125 |
+
"source": [
|
126 |
+
"### Set up your TPU environment\n",
|
127 |
+
"\n",
|
128 |
+
"In this section, you perform the following tasks:\n",
|
129 |
+
"\n",
|
130 |
+
"* Set up a Colab TPU running environment\n",
|
131 |
+
"* Verify that you are connected to a TPU device\n",
|
132 |
+
"* Upload your credentials to TPU to access your GCS bucket."
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"metadata": {
|
138 |
+
"id": "191zq3ZErihP",
|
139 |
+
"colab_type": "code",
|
140 |
+
"colab": {}
|
141 |
+
},
|
142 |
+
"source": [
|
143 |
+
"# TODO(lanzhzh): Add support for 2.x.\n",
|
144 |
+
"%tensorflow_version 1.x\n",
|
145 |
+
"import os\n",
|
146 |
+
"import pprint\n",
|
147 |
+
"import json\n",
|
148 |
+
"import tensorflow as tf\n",
|
149 |
+
"\n",
|
150 |
+
"assert \"COLAB_TPU_ADDR\" in os.environ, \"ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!\"\n",
|
151 |
+
"TPU_ADDRESS = \"grpc://\" + os.environ[\"COLAB_TPU_ADDR\"] \n",
|
152 |
+
"TPU_TOPOLOGY = \"2x2\"\n",
|
153 |
+
"print(\"TPU address is\", TPU_ADDRESS)\n",
|
154 |
+
"\n",
|
155 |
+
"from google.colab import auth\n",
|
156 |
+
"auth.authenticate_user()\n",
|
157 |
+
"with tf.Session(TPU_ADDRESS) as session:\n",
|
158 |
+
" print('TPU devices:')\n",
|
159 |
+
" pprint.pprint(session.list_devices())\n",
|
160 |
+
"\n",
|
161 |
+
" # Upload credentials to TPU.\n",
|
162 |
+
" with open('/content/adc.json', 'r') as f:\n",
|
163 |
+
" auth_info = json.load(f)\n",
|
164 |
+
" tf.contrib.cloud.configure_gcs(session, credentials=auth_info)\n",
|
165 |
+
" # Now credentials are set for all future sessions on this TPU."
|
166 |
+
],
|
167 |
+
"execution_count": 0,
|
168 |
+
"outputs": []
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "markdown",
|
172 |
+
"metadata": {
|
173 |
+
"id": "HUBP35oCDmbF",
|
174 |
+
"colab_type": "text"
|
175 |
+
},
|
176 |
+
"source": [
|
177 |
+
"### Prepare and import ALBERT modules\n",
|
178 |
+
"\n",
|
179 |
+
"With your environment configured, you can now prepare and import the ALBERT modules. The following step clones the source code from GitHub."
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"metadata": {
|
185 |
+
"id": "7wzwke0sxS6W",
|
186 |
+
"colab_type": "code",
|
187 |
+
"colab": {},
|
188 |
+
"cellView": "code"
|
189 |
+
},
|
190 |
+
"source": [
|
191 |
+
"#TODO(lanzhzh): Add pip support\n",
|
192 |
+
"import sys\n",
|
193 |
+
"\n",
|
194 |
+
"!test -d albert || git clone https://github.com/google-research/albert albert\n",
|
195 |
+
"if not 'albert' in sys.path:\n",
|
196 |
+
" sys.path += ['albert']\n",
|
197 |
+
" \n",
|
198 |
+
"!pip install sentencepiece\n"
|
199 |
+
],
|
200 |
+
"execution_count": 0,
|
201 |
+
"outputs": []
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "markdown",
|
205 |
+
"metadata": {
|
206 |
+
"id": "RRu1aKO1D7-Z",
|
207 |
+
"colab_type": "text"
|
208 |
+
},
|
209 |
+
"source": [
|
210 |
+
"### Prepare for training\n",
|
211 |
+
"\n",
|
212 |
+
"This next section of code performs the following tasks:\n",
|
213 |
+
"\n",
|
214 |
+
"* Specify GS bucket, create output directory for model checkpoints and eval results.\n",
|
215 |
+
"* Specify task and download training data.\n",
|
216 |
+
"* Specify ALBERT pretrained model\n",
|
217 |
+
"\n",
|
218 |
+
"\n",
|
219 |
+
"\n"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "code",
|
224 |
+
"metadata": {
|
225 |
+
"id": "tYkaAlJNfhul",
|
226 |
+
"colab_type": "code",
|
227 |
+
"colab": {},
|
228 |
+
"cellView": "form"
|
229 |
+
},
|
230 |
+
"source": [
|
231 |
+
"# Please find the full list of tasks and their fintuning hyperparameters\n",
|
232 |
+
"# here https://github.com/google-research/albert/blob/master/run_glue.sh\n",
|
233 |
+
"\n",
|
234 |
+
"BUCKET = \"albert_tutorial_glue\" #@param { type: \"string\" }\n",
|
235 |
+
"TASK = 'MRPC' #@param {type:\"string\"}\n",
|
236 |
+
"# Available pretrained model checkpoints:\n",
|
237 |
+
"# base, large, xlarge, xxlarge\n",
|
238 |
+
"ALBERT_MODEL = 'base' #@param {type:\"string\"}\n",
|
239 |
+
"\n",
|
240 |
+
"TASK_DATA_DIR = 'glue_data'\n",
|
241 |
+
"\n",
|
242 |
+
"BASE_DIR = \"gs://\" + BUCKET\n",
|
243 |
+
"if not BASE_DIR or BASE_DIR == \"gs://\":\n",
|
244 |
+
" raise ValueError(\"You must enter a BUCKET.\")\n",
|
245 |
+
"DATA_DIR = os.path.join(BASE_DIR, \"data\")\n",
|
246 |
+
"MODELS_DIR = os.path.join(BASE_DIR, \"models\")\n",
|
247 |
+
"OUTPUT_DIR = 'gs://{}/albert-tfhub/models/{}'.format(BUCKET, TASK)\n",
|
248 |
+
"tf.gfile.MakeDirs(OUTPUT_DIR)\n",
|
249 |
+
"print('***** Model output directory: {} *****'.format(OUTPUT_DIR))\n",
|
250 |
+
"\n",
|
251 |
+
"# Download glue data.\n",
|
252 |
+
"! test -d download_glue_repo || git clone https://gist.github.com/60c2bdb54d156a41194446737ce03e2e.git download_glue_repo\n",
|
253 |
+
"!python download_glue_repo/download_glue_data.py --data_dir=$TASK_DATA_DIR --tasks=$TASK\n",
|
254 |
+
"print('***** Task data directory: {} *****'.format(TASK_DATA_DIR))\n",
|
255 |
+
"\n",
|
256 |
+
"ALBERT_MODEL_HUB = 'https://tfhub.dev/google/albert_' + ALBERT_MODEL + '/3'"
|
257 |
+
],
|
258 |
+
"execution_count": 0,
|
259 |
+
"outputs": []
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "markdown",
|
263 |
+
"metadata": {
|
264 |
+
"id": "Hcpfl4N2EdOk",
|
265 |
+
"colab_type": "text"
|
266 |
+
},
|
267 |
+
"source": [
|
268 |
+
"Now let's run the fine-tuning scripts. If you use the default MRPC task, this should be finished in around 10 mintues and you will get an accuracy of around 86.5."
|
269 |
+
]
|
270 |
+
},
|
271 |
+
{
|
272 |
+
"cell_type": "code",
|
273 |
+
"metadata": {
|
274 |
+
"id": "o8qXPxv8-kBO",
|
275 |
+
"colab_type": "code",
|
276 |
+
"colab": {}
|
277 |
+
},
|
278 |
+
"source": [
|
279 |
+
"os.environ['TFHUB_CACHE_DIR'] = OUTPUT_DIR\n",
|
280 |
+
"!python -m albert.run_classifier \\\n",
|
281 |
+
" --data_dir=\"glue_data/\" \\\n",
|
282 |
+
" --output_dir=$OUTPUT_DIR \\\n",
|
283 |
+
" --albert_hub_module_handle=$ALBERT_MODEL_HUB \\\n",
|
284 |
+
" --spm_model_file=\"from_tf_hub\" \\\n",
|
285 |
+
" --do_train=True \\\n",
|
286 |
+
" --do_eval=True \\\n",
|
287 |
+
" --do_predict=False \\\n",
|
288 |
+
" --max_seq_length=512 \\\n",
|
289 |
+
" --optimizer=adamw \\\n",
|
290 |
+
" --task_name=$TASK \\\n",
|
291 |
+
" --warmup_step=200 \\\n",
|
292 |
+
" --learning_rate=2e-5 \\\n",
|
293 |
+
" --train_step=800 \\\n",
|
294 |
+
" --save_checkpoints_steps=100 \\\n",
|
295 |
+
" --train_batch_size=32 \\\n",
|
296 |
+
" --tpu_name=$TPU_ADDRESS \\\n",
|
297 |
+
" --use_tpu=True"
|
298 |
+
],
|
299 |
+
"execution_count": 0,
|
300 |
+
"outputs": []
|
301 |
+
}
|
302 |
+
]
|
303 |
+
}
|
Indic-BERT-v1-master/albert/classifier_utils.py
ADDED
@@ -0,0 +1,1037 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Utility functions for GLUE classification tasks."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
# from __future__ import google_type_annotations
|
20 |
+
from __future__ import print_function
|
21 |
+
import collections
|
22 |
+
import csv
|
23 |
+
import os
|
24 |
+
from albert import fine_tuning_utils
|
25 |
+
from albert import modeling
|
26 |
+
from albert import optimization
|
27 |
+
from albert import tokenization
|
28 |
+
import tensorflow.compat.v1 as tf
|
29 |
+
from tensorflow.contrib import data as contrib_data
|
30 |
+
from tensorflow.contrib import metrics as contrib_metrics
|
31 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
32 |
+
|
33 |
+
|
34 |
+
class InputExample(object):
|
35 |
+
"""A single training/test example for simple sequence classification."""
|
36 |
+
|
37 |
+
def __init__(self, guid, text_a, text_b=None, label=None):
|
38 |
+
"""Constructs a InputExample.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
guid: Unique id for the example.
|
42 |
+
text_a: string. The untokenized text of the first sequence. For single
|
43 |
+
sequence tasks, only this sequence must be specified.
|
44 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
45 |
+
Only must be specified for sequence pair tasks.
|
46 |
+
label: (Optional) string. The label of the example. This should be
|
47 |
+
specified for train and dev examples, but not for test examples.
|
48 |
+
"""
|
49 |
+
self.guid = guid
|
50 |
+
self.text_a = text_a
|
51 |
+
self.text_b = text_b
|
52 |
+
self.label = label
|
53 |
+
|
54 |
+
|
55 |
+
class PaddingInputExample(object):
|
56 |
+
"""Fake example so the num input examples is a multiple of the batch size.
|
57 |
+
|
58 |
+
When running eval/predict on the TPU, we need to pad the number of examples
|
59 |
+
to be a multiple of the batch size, because the TPU requires a fixed batch
|
60 |
+
size. The alternative is to drop the last batch, which is bad because it means
|
61 |
+
the entire output data won't be generated.
|
62 |
+
|
63 |
+
We use this class instead of `None` because treating `None` as padding
|
64 |
+
battches could cause silent errors.
|
65 |
+
"""
|
66 |
+
|
67 |
+
|
68 |
+
class InputFeatures(object):
|
69 |
+
"""A single set of features of data."""
|
70 |
+
|
71 |
+
def __init__(self,
|
72 |
+
input_ids,
|
73 |
+
input_mask,
|
74 |
+
segment_ids,
|
75 |
+
label_id,
|
76 |
+
guid=None,
|
77 |
+
example_id=None,
|
78 |
+
is_real_example=True):
|
79 |
+
self.input_ids = input_ids
|
80 |
+
self.input_mask = input_mask
|
81 |
+
self.segment_ids = segment_ids
|
82 |
+
self.label_id = label_id
|
83 |
+
self.example_id = example_id
|
84 |
+
self.guid = guid
|
85 |
+
self.is_real_example = is_real_example
|
86 |
+
|
87 |
+
|
88 |
+
class DataProcessor(object):
|
89 |
+
"""Base class for data converters for sequence classification data sets."""
|
90 |
+
|
91 |
+
def __init__(self, use_spm, do_lower_case):
|
92 |
+
super(DataProcessor, self).__init__()
|
93 |
+
self.use_spm = use_spm
|
94 |
+
self.do_lower_case = do_lower_case
|
95 |
+
|
96 |
+
def get_train_examples(self, data_dir):
|
97 |
+
"""Gets a collection of `InputExample`s for the train set."""
|
98 |
+
raise NotImplementedError()
|
99 |
+
|
100 |
+
def get_dev_examples(self, data_dir):
|
101 |
+
"""Gets a collection of `InputExample`s for the dev set."""
|
102 |
+
raise NotImplementedError()
|
103 |
+
|
104 |
+
def get_test_examples(self, data_dir):
|
105 |
+
"""Gets a collection of `InputExample`s for prediction."""
|
106 |
+
raise NotImplementedError()
|
107 |
+
|
108 |
+
def get_labels(self):
|
109 |
+
"""Gets the list of labels for this data set."""
|
110 |
+
raise NotImplementedError()
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
114 |
+
"""Reads a tab separated value file."""
|
115 |
+
with tf.gfile.Open(input_file, "r") as f:
|
116 |
+
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
117 |
+
lines = []
|
118 |
+
for line in reader:
|
119 |
+
lines.append(line)
|
120 |
+
return lines
|
121 |
+
|
122 |
+
def process_text(self, text):
|
123 |
+
if self.use_spm:
|
124 |
+
return tokenization.preprocess_text(text, lower=self.do_lower_case)
|
125 |
+
else:
|
126 |
+
return tokenization.convert_to_unicode(text)
|
127 |
+
|
128 |
+
|
129 |
+
class MnliProcessor(DataProcessor):
|
130 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
131 |
+
|
132 |
+
def get_train_examples(self, data_dir):
|
133 |
+
"""See base class."""
|
134 |
+
return self._create_examples(
|
135 |
+
self._read_tsv(os.path.join(data_dir, "MNLI", "train.tsv")), "train")
|
136 |
+
|
137 |
+
def get_dev_examples(self, data_dir):
|
138 |
+
"""See base class."""
|
139 |
+
return self._create_examples(
|
140 |
+
self._read_tsv(os.path.join(data_dir, "MNLI", "dev_matched.tsv")),
|
141 |
+
"dev_matched")
|
142 |
+
|
143 |
+
def get_test_examples(self, data_dir):
|
144 |
+
"""See base class."""
|
145 |
+
return self._create_examples(
|
146 |
+
self._read_tsv(os.path.join(data_dir, "MNLI", "test_matched.tsv")),
|
147 |
+
"test")
|
148 |
+
|
149 |
+
def get_labels(self):
|
150 |
+
"""See base class."""
|
151 |
+
return ["contradiction", "entailment", "neutral"]
|
152 |
+
|
153 |
+
def _create_examples(self, lines, set_type):
|
154 |
+
"""Creates examples for the training and dev sets."""
|
155 |
+
examples = []
|
156 |
+
for (i, line) in enumerate(lines):
|
157 |
+
if i == 0:
|
158 |
+
continue
|
159 |
+
# Note(mingdachen): We will rely on this guid for GLUE submission.
|
160 |
+
guid = self.process_text(line[0])
|
161 |
+
text_a = self.process_text(line[8])
|
162 |
+
text_b = self.process_text(line[9])
|
163 |
+
if set_type == "test":
|
164 |
+
label = "contradiction"
|
165 |
+
else:
|
166 |
+
label = self.process_text(line[-1])
|
167 |
+
examples.append(
|
168 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
169 |
+
return examples
|
170 |
+
|
171 |
+
|
172 |
+
class MisMnliProcessor(MnliProcessor):
|
173 |
+
"""Processor for the Mismatched MultiNLI data set (GLUE version)."""
|
174 |
+
|
175 |
+
def get_dev_examples(self, data_dir):
|
176 |
+
"""See base class."""
|
177 |
+
return self._create_examples(
|
178 |
+
self._read_tsv(os.path.join(data_dir, "MNLI", "dev_mismatched.tsv")),
|
179 |
+
"dev")
|
180 |
+
|
181 |
+
def get_test_examples(self, data_dir):
|
182 |
+
"""See base class."""
|
183 |
+
return self._create_examples(
|
184 |
+
self._read_tsv(os.path.join(data_dir, "MNLI", "test_mismatched.tsv")),
|
185 |
+
"test")
|
186 |
+
|
187 |
+
|
188 |
+
class MrpcProcessor(DataProcessor):
|
189 |
+
"""Processor for the MRPC data set (GLUE version)."""
|
190 |
+
|
191 |
+
def get_train_examples(self, data_dir):
|
192 |
+
"""See base class."""
|
193 |
+
return self._create_examples(
|
194 |
+
self._read_tsv(os.path.join(data_dir, "MRPC", "train.tsv")), "train")
|
195 |
+
|
196 |
+
def get_dev_examples(self, data_dir):
|
197 |
+
"""See base class."""
|
198 |
+
return self._create_examples(
|
199 |
+
self._read_tsv(os.path.join(data_dir, "MRPC", "dev.tsv")), "dev")
|
200 |
+
|
201 |
+
def get_test_examples(self, data_dir):
|
202 |
+
"""See base class."""
|
203 |
+
return self._create_examples(
|
204 |
+
self._read_tsv(os.path.join(data_dir, "MRPC", "test.tsv")), "test")
|
205 |
+
|
206 |
+
def get_labels(self):
|
207 |
+
"""See base class."""
|
208 |
+
return ["0", "1"]
|
209 |
+
|
210 |
+
def _create_examples(self, lines, set_type):
|
211 |
+
"""Creates examples for the training and dev sets."""
|
212 |
+
examples = []
|
213 |
+
for (i, line) in enumerate(lines):
|
214 |
+
if i == 0:
|
215 |
+
continue
|
216 |
+
guid = "%s-%s" % (set_type, i)
|
217 |
+
text_a = self.process_text(line[3])
|
218 |
+
text_b = self.process_text(line[4])
|
219 |
+
if set_type == "test":
|
220 |
+
guid = line[0]
|
221 |
+
label = "0"
|
222 |
+
else:
|
223 |
+
label = self.process_text(line[0])
|
224 |
+
examples.append(
|
225 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
226 |
+
return examples
|
227 |
+
|
228 |
+
|
229 |
+
class ColaProcessor(DataProcessor):
|
230 |
+
"""Processor for the CoLA data set (GLUE version)."""
|
231 |
+
|
232 |
+
def get_train_examples(self, data_dir):
|
233 |
+
"""See base class."""
|
234 |
+
return self._create_examples(
|
235 |
+
self._read_tsv(os.path.join(data_dir, "CoLA", "train.tsv")), "train")
|
236 |
+
|
237 |
+
def get_dev_examples(self, data_dir):
|
238 |
+
"""See base class."""
|
239 |
+
return self._create_examples(
|
240 |
+
self._read_tsv(os.path.join(data_dir, "CoLA", "dev.tsv")), "dev")
|
241 |
+
|
242 |
+
def get_test_examples(self, data_dir):
|
243 |
+
"""See base class."""
|
244 |
+
return self._create_examples(
|
245 |
+
self._read_tsv(os.path.join(data_dir, "CoLA", "test.tsv")), "test")
|
246 |
+
|
247 |
+
def get_labels(self):
|
248 |
+
"""See base class."""
|
249 |
+
return ["0", "1"]
|
250 |
+
|
251 |
+
def _create_examples(self, lines, set_type):
|
252 |
+
"""Creates examples for the training and dev sets."""
|
253 |
+
examples = []
|
254 |
+
for (i, line) in enumerate(lines):
|
255 |
+
# Only the test set has a header
|
256 |
+
if set_type == "test" and i == 0:
|
257 |
+
continue
|
258 |
+
guid = "%s-%s" % (set_type, i)
|
259 |
+
if set_type == "test":
|
260 |
+
guid = line[0]
|
261 |
+
text_a = self.process_text(line[1])
|
262 |
+
label = "0"
|
263 |
+
else:
|
264 |
+
text_a = self.process_text(line[3])
|
265 |
+
label = self.process_text(line[1])
|
266 |
+
examples.append(
|
267 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
268 |
+
return examples
|
269 |
+
|
270 |
+
|
271 |
+
class Sst2Processor(DataProcessor):
|
272 |
+
"""Processor for the SST-2 data set (GLUE version)."""
|
273 |
+
|
274 |
+
def get_train_examples(self, data_dir):
|
275 |
+
"""See base class."""
|
276 |
+
return self._create_examples(
|
277 |
+
self._read_tsv(os.path.join(data_dir, "SST-2", "train.tsv")), "train")
|
278 |
+
|
279 |
+
def get_dev_examples(self, data_dir):
|
280 |
+
"""See base class."""
|
281 |
+
return self._create_examples(
|
282 |
+
self._read_tsv(os.path.join(data_dir, "SST-2", "dev.tsv")), "dev")
|
283 |
+
|
284 |
+
def get_test_examples(self, data_dir):
|
285 |
+
"""See base class."""
|
286 |
+
return self._create_examples(
|
287 |
+
self._read_tsv(os.path.join(data_dir, "SST-2", "test.tsv")), "test")
|
288 |
+
|
289 |
+
def get_labels(self):
|
290 |
+
"""See base class."""
|
291 |
+
return ["0", "1"]
|
292 |
+
|
293 |
+
def _create_examples(self, lines, set_type):
|
294 |
+
"""Creates examples for the training and dev sets."""
|
295 |
+
examples = []
|
296 |
+
for (i, line) in enumerate(lines):
|
297 |
+
if i == 0:
|
298 |
+
continue
|
299 |
+
if set_type != "test":
|
300 |
+
guid = "%s-%s" % (set_type, i)
|
301 |
+
text_a = self.process_text(line[0])
|
302 |
+
label = self.process_text(line[1])
|
303 |
+
else:
|
304 |
+
guid = self.process_text(line[0])
|
305 |
+
# guid = "%s-%s" % (set_type, line[0])
|
306 |
+
text_a = self.process_text(line[1])
|
307 |
+
label = "0"
|
308 |
+
examples.append(
|
309 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
310 |
+
return examples
|
311 |
+
|
312 |
+
|
313 |
+
class StsbProcessor(DataProcessor):
|
314 |
+
"""Processor for the STS-B data set (GLUE version)."""
|
315 |
+
|
316 |
+
def get_train_examples(self, data_dir):
|
317 |
+
"""See base class."""
|
318 |
+
return self._create_examples(
|
319 |
+
self._read_tsv(os.path.join(data_dir, "STS-B", "train.tsv")), "train")
|
320 |
+
|
321 |
+
def get_dev_examples(self, data_dir):
|
322 |
+
"""See base class."""
|
323 |
+
return self._create_examples(
|
324 |
+
self._read_tsv(os.path.join(data_dir, "STS-B", "dev.tsv")), "dev")
|
325 |
+
|
326 |
+
def get_test_examples(self, data_dir):
|
327 |
+
"""See base class."""
|
328 |
+
return self._create_examples(
|
329 |
+
self._read_tsv(os.path.join(data_dir, "STS-B", "test.tsv")), "test")
|
330 |
+
|
331 |
+
def get_labels(self):
|
332 |
+
"""See base class."""
|
333 |
+
return [None]
|
334 |
+
|
335 |
+
def _create_examples(self, lines, set_type):
|
336 |
+
"""Creates examples for the training and dev sets."""
|
337 |
+
examples = []
|
338 |
+
for (i, line) in enumerate(lines):
|
339 |
+
if i == 0:
|
340 |
+
continue
|
341 |
+
guid = self.process_text(line[0])
|
342 |
+
# guid = "%s-%s" % (set_type, line[0])
|
343 |
+
text_a = self.process_text(line[7])
|
344 |
+
text_b = self.process_text(line[8])
|
345 |
+
if set_type != "test":
|
346 |
+
label = float(line[-1])
|
347 |
+
else:
|
348 |
+
label = 0
|
349 |
+
examples.append(
|
350 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
351 |
+
return examples
|
352 |
+
|
353 |
+
|
354 |
+
class QqpProcessor(DataProcessor):
|
355 |
+
"""Processor for the QQP data set (GLUE version)."""
|
356 |
+
|
357 |
+
def get_train_examples(self, data_dir):
|
358 |
+
"""See base class."""
|
359 |
+
return self._create_examples(
|
360 |
+
self._read_tsv(os.path.join(data_dir, "QQP", "train.tsv")), "train")
|
361 |
+
|
362 |
+
def get_dev_examples(self, data_dir):
|
363 |
+
"""See base class."""
|
364 |
+
return self._create_examples(
|
365 |
+
self._read_tsv(os.path.join(data_dir, "QQP", "dev.tsv")), "dev")
|
366 |
+
|
367 |
+
def get_test_examples(self, data_dir):
|
368 |
+
"""See base class."""
|
369 |
+
return self._create_examples(
|
370 |
+
self._read_tsv(os.path.join(data_dir, "QQP", "test.tsv")), "test")
|
371 |
+
|
372 |
+
def get_labels(self):
|
373 |
+
"""See base class."""
|
374 |
+
return ["0", "1"]
|
375 |
+
|
376 |
+
def _create_examples(self, lines, set_type):
|
377 |
+
"""Creates examples for the training and dev sets."""
|
378 |
+
examples = []
|
379 |
+
for (i, line) in enumerate(lines):
|
380 |
+
if i == 0:
|
381 |
+
continue
|
382 |
+
guid = line[0]
|
383 |
+
# guid = "%s-%s" % (set_type, line[0])
|
384 |
+
if set_type != "test":
|
385 |
+
try:
|
386 |
+
text_a = self.process_text(line[3])
|
387 |
+
text_b = self.process_text(line[4])
|
388 |
+
label = self.process_text(line[5])
|
389 |
+
except IndexError:
|
390 |
+
continue
|
391 |
+
else:
|
392 |
+
text_a = self.process_text(line[1])
|
393 |
+
text_b = self.process_text(line[2])
|
394 |
+
label = "0"
|
395 |
+
examples.append(
|
396 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
397 |
+
return examples
|
398 |
+
|
399 |
+
|
400 |
+
class QnliProcessor(DataProcessor):
|
401 |
+
"""Processor for the QNLI data set (GLUE version)."""
|
402 |
+
|
403 |
+
def get_train_examples(self, data_dir):
|
404 |
+
"""See base class."""
|
405 |
+
return self._create_examples(
|
406 |
+
self._read_tsv(os.path.join(data_dir, "QNLI", "train.tsv")), "train")
|
407 |
+
|
408 |
+
def get_dev_examples(self, data_dir):
|
409 |
+
"""See base class."""
|
410 |
+
return self._create_examples(
|
411 |
+
self._read_tsv(os.path.join(data_dir, "QNLI", "dev.tsv")),
|
412 |
+
"dev_matched")
|
413 |
+
|
414 |
+
def get_test_examples(self, data_dir):
|
415 |
+
"""See base class."""
|
416 |
+
return self._create_examples(
|
417 |
+
self._read_tsv(os.path.join(data_dir, "QNLI", "test.tsv")),
|
418 |
+
"test_matched")
|
419 |
+
|
420 |
+
def get_labels(self):
|
421 |
+
"""See base class."""
|
422 |
+
return ["entailment", "not_entailment"]
|
423 |
+
|
424 |
+
def _create_examples(self, lines, set_type):
|
425 |
+
"""Creates examples for the training and dev sets."""
|
426 |
+
examples = []
|
427 |
+
for (i, line) in enumerate(lines):
|
428 |
+
if i == 0:
|
429 |
+
continue
|
430 |
+
guid = self.process_text(line[0])
|
431 |
+
# guid = "%s-%s" % (set_type, line[0])
|
432 |
+
text_a = self.process_text(line[1])
|
433 |
+
text_b = self.process_text(line[2])
|
434 |
+
if set_type == "test_matched":
|
435 |
+
label = "entailment"
|
436 |
+
else:
|
437 |
+
label = self.process_text(line[-1])
|
438 |
+
examples.append(
|
439 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
440 |
+
return examples
|
441 |
+
|
442 |
+
|
443 |
+
class RteProcessor(DataProcessor):
|
444 |
+
"""Processor for the RTE data set (GLUE version)."""
|
445 |
+
|
446 |
+
def get_train_examples(self, data_dir):
|
447 |
+
"""See base class."""
|
448 |
+
return self._create_examples(
|
449 |
+
self._read_tsv(os.path.join(data_dir, "RTE", "train.tsv")), "train")
|
450 |
+
|
451 |
+
def get_dev_examples(self, data_dir):
|
452 |
+
"""See base class."""
|
453 |
+
return self._create_examples(
|
454 |
+
self._read_tsv(os.path.join(data_dir, "RTE", "dev.tsv")), "dev")
|
455 |
+
|
456 |
+
def get_test_examples(self, data_dir):
|
457 |
+
"""See base class."""
|
458 |
+
return self._create_examples(
|
459 |
+
self._read_tsv(os.path.join(data_dir, "RTE", "test.tsv")), "test")
|
460 |
+
|
461 |
+
def get_labels(self):
|
462 |
+
"""See base class."""
|
463 |
+
return ["entailment", "not_entailment"]
|
464 |
+
|
465 |
+
def _create_examples(self, lines, set_type):
|
466 |
+
"""Creates examples for the training and dev sets."""
|
467 |
+
examples = []
|
468 |
+
for (i, line) in enumerate(lines):
|
469 |
+
if i == 0:
|
470 |
+
continue
|
471 |
+
guid = self.process_text(line[0])
|
472 |
+
# guid = "%s-%s" % (set_type, line[0])
|
473 |
+
text_a = self.process_text(line[1])
|
474 |
+
text_b = self.process_text(line[2])
|
475 |
+
if set_type == "test":
|
476 |
+
label = "entailment"
|
477 |
+
else:
|
478 |
+
label = self.process_text(line[-1])
|
479 |
+
examples.append(
|
480 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
481 |
+
return examples
|
482 |
+
|
483 |
+
|
484 |
+
class WnliProcessor(DataProcessor):
|
485 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
486 |
+
|
487 |
+
def get_train_examples(self, data_dir):
|
488 |
+
"""See base class."""
|
489 |
+
return self._create_examples(
|
490 |
+
self._read_tsv(os.path.join(data_dir, "WNLI", "train.tsv")), "train")
|
491 |
+
|
492 |
+
def get_dev_examples(self, data_dir):
|
493 |
+
"""See base class."""
|
494 |
+
return self._create_examples(
|
495 |
+
self._read_tsv(os.path.join(data_dir, "WNLI", "dev.tsv")), "dev")
|
496 |
+
|
497 |
+
def get_test_examples(self, data_dir):
|
498 |
+
"""See base class."""
|
499 |
+
return self._create_examples(
|
500 |
+
self._read_tsv(os.path.join(data_dir, "WNLI", "test.tsv")), "test")
|
501 |
+
|
502 |
+
def get_labels(self):
|
503 |
+
"""See base class."""
|
504 |
+
return ["0", "1"]
|
505 |
+
|
506 |
+
def _create_examples(self, lines, set_type):
|
507 |
+
"""Creates examples for the training and dev sets."""
|
508 |
+
examples = []
|
509 |
+
for (i, line) in enumerate(lines):
|
510 |
+
if i == 0:
|
511 |
+
continue
|
512 |
+
guid = self.process_text(line[0])
|
513 |
+
# guid = "%s-%s" % (set_type, line[0])
|
514 |
+
text_a = self.process_text(line[1])
|
515 |
+
text_b = self.process_text(line[2])
|
516 |
+
if set_type != "test":
|
517 |
+
label = self.process_text(line[-1])
|
518 |
+
else:
|
519 |
+
label = "0"
|
520 |
+
examples.append(
|
521 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
522 |
+
return examples
|
523 |
+
|
524 |
+
|
525 |
+
class AXProcessor(DataProcessor):
|
526 |
+
"""Processor for the AX data set (GLUE version)."""
|
527 |
+
|
528 |
+
def get_test_examples(self, data_dir):
|
529 |
+
"""See base class."""
|
530 |
+
return self._create_examples(
|
531 |
+
self._read_tsv(os.path.join(data_dir, "diagnostic", "diagnostic.tsv")),
|
532 |
+
"test")
|
533 |
+
|
534 |
+
def get_labels(self):
|
535 |
+
"""See base class."""
|
536 |
+
return ["contradiction", "entailment", "neutral"]
|
537 |
+
|
538 |
+
def _create_examples(self, lines, set_type):
|
539 |
+
"""Creates examples for the training and dev sets."""
|
540 |
+
examples = []
|
541 |
+
for (i, line) in enumerate(lines):
|
542 |
+
if i == 0:
|
543 |
+
continue
|
544 |
+
# Note(mingdachen): We will rely on this guid for GLUE submission.
|
545 |
+
guid = self.process_text(line[0])
|
546 |
+
text_a = self.process_text(line[1])
|
547 |
+
text_b = self.process_text(line[2])
|
548 |
+
if set_type == "test":
|
549 |
+
label = "contradiction"
|
550 |
+
else:
|
551 |
+
label = self.process_text(line[-1])
|
552 |
+
examples.append(
|
553 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
554 |
+
return examples
|
555 |
+
|
556 |
+
|
557 |
+
def convert_single_example(ex_index, example, label_list, max_seq_length,
|
558 |
+
tokenizer, task_name):
|
559 |
+
"""Converts a single `InputExample` into a single `InputFeatures`."""
|
560 |
+
|
561 |
+
if isinstance(example, PaddingInputExample):
|
562 |
+
return InputFeatures(
|
563 |
+
input_ids=[0] * max_seq_length,
|
564 |
+
input_mask=[0] * max_seq_length,
|
565 |
+
segment_ids=[0] * max_seq_length,
|
566 |
+
label_id=0,
|
567 |
+
is_real_example=False)
|
568 |
+
|
569 |
+
if task_name != "sts-b":
|
570 |
+
label_map = {}
|
571 |
+
for (i, label) in enumerate(label_list):
|
572 |
+
label_map[label] = i
|
573 |
+
|
574 |
+
tokens_a = tokenizer.tokenize(example.text_a)
|
575 |
+
tokens_b = None
|
576 |
+
if example.text_b:
|
577 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
578 |
+
|
579 |
+
if tokens_b:
|
580 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
581 |
+
# length is less than the specified length.
|
582 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
583 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
584 |
+
else:
|
585 |
+
# Account for [CLS] and [SEP] with "- 2"
|
586 |
+
if len(tokens_a) > max_seq_length - 2:
|
587 |
+
tokens_a = tokens_a[0:(max_seq_length - 2)]
|
588 |
+
|
589 |
+
# The convention in ALBERT is:
|
590 |
+
# (a) For sequence pairs:
|
591 |
+
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
592 |
+
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
593 |
+
# (b) For single sequences:
|
594 |
+
# tokens: [CLS] the dog is hairy . [SEP]
|
595 |
+
# type_ids: 0 0 0 0 0 0 0
|
596 |
+
#
|
597 |
+
# Where "type_ids" are used to indicate whether this is the first
|
598 |
+
# sequence or the second sequence. The embedding vectors for `type=0` and
|
599 |
+
# `type=1` were learned during pre-training and are added to the
|
600 |
+
# embedding vector (and position vector). This is not *strictly* necessary
|
601 |
+
# since the [SEP] token unambiguously separates the sequences, but it makes
|
602 |
+
# it easier for the model to learn the concept of sequences.
|
603 |
+
#
|
604 |
+
# For classification tasks, the first vector (corresponding to [CLS]) is
|
605 |
+
# used as the "sentence vector". Note that this only makes sense because
|
606 |
+
# the entire model is fine-tuned.
|
607 |
+
tokens = []
|
608 |
+
segment_ids = []
|
609 |
+
tokens.append("[CLS]")
|
610 |
+
segment_ids.append(0)
|
611 |
+
for token in tokens_a:
|
612 |
+
tokens.append(token)
|
613 |
+
segment_ids.append(0)
|
614 |
+
tokens.append("[SEP]")
|
615 |
+
segment_ids.append(0)
|
616 |
+
|
617 |
+
if tokens_b:
|
618 |
+
for token in tokens_b:
|
619 |
+
tokens.append(token)
|
620 |
+
segment_ids.append(1)
|
621 |
+
tokens.append("[SEP]")
|
622 |
+
segment_ids.append(1)
|
623 |
+
|
624 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
625 |
+
|
626 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
627 |
+
# tokens are attended to.
|
628 |
+
input_mask = [1] * len(input_ids)
|
629 |
+
|
630 |
+
# Zero-pad up to the sequence length.
|
631 |
+
while len(input_ids) < max_seq_length:
|
632 |
+
input_ids.append(0)
|
633 |
+
input_mask.append(0)
|
634 |
+
segment_ids.append(0)
|
635 |
+
|
636 |
+
assert len(input_ids) == max_seq_length
|
637 |
+
assert len(input_mask) == max_seq_length
|
638 |
+
assert len(segment_ids) == max_seq_length
|
639 |
+
|
640 |
+
if task_name != "sts-b":
|
641 |
+
label_id = label_map[example.label]
|
642 |
+
else:
|
643 |
+
label_id = example.label
|
644 |
+
|
645 |
+
if ex_index < 5:
|
646 |
+
tf.logging.info("*** Example ***")
|
647 |
+
tf.logging.info("guid: %s" % (example.guid))
|
648 |
+
tf.logging.info("tokens: %s" % " ".join(
|
649 |
+
[tokenization.printable_text(x) for x in tokens]))
|
650 |
+
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
651 |
+
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
652 |
+
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
653 |
+
tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
|
654 |
+
|
655 |
+
feature = InputFeatures(
|
656 |
+
input_ids=input_ids,
|
657 |
+
input_mask=input_mask,
|
658 |
+
segment_ids=segment_ids,
|
659 |
+
label_id=label_id,
|
660 |
+
is_real_example=True)
|
661 |
+
return feature
|
662 |
+
|
663 |
+
|
664 |
+
def file_based_convert_examples_to_features(
|
665 |
+
examples, label_list, max_seq_length, tokenizer, output_file, task_name):
|
666 |
+
"""Convert a set of `InputExample`s to a TFRecord file."""
|
667 |
+
|
668 |
+
writer = tf.python_io.TFRecordWriter(output_file)
|
669 |
+
|
670 |
+
for (ex_index, example) in enumerate(examples):
|
671 |
+
if ex_index % 10000 == 0:
|
672 |
+
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
|
673 |
+
|
674 |
+
feature = convert_single_example(ex_index, example, label_list,
|
675 |
+
max_seq_length, tokenizer, task_name)
|
676 |
+
|
677 |
+
def create_int_feature(values):
|
678 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
679 |
+
return f
|
680 |
+
|
681 |
+
def create_float_feature(values):
|
682 |
+
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
683 |
+
return f
|
684 |
+
|
685 |
+
features = collections.OrderedDict()
|
686 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
687 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
688 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
689 |
+
features["label_ids"] = create_float_feature([feature.label_id])\
|
690 |
+
if task_name == "sts-b" else create_int_feature([feature.label_id])
|
691 |
+
features["is_real_example"] = create_int_feature(
|
692 |
+
[int(feature.is_real_example)])
|
693 |
+
|
694 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
695 |
+
writer.write(tf_example.SerializeToString())
|
696 |
+
writer.close()
|
697 |
+
|
698 |
+
|
699 |
+
def file_based_input_fn_builder(input_file, seq_length, is_training,
|
700 |
+
drop_remainder, task_name, use_tpu, bsz,
|
701 |
+
multiple=1):
|
702 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
703 |
+
labeltype = tf.float32 if task_name == "sts-b" else tf.int64
|
704 |
+
|
705 |
+
name_to_features = {
|
706 |
+
"input_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
|
707 |
+
"input_mask": tf.FixedLenFeature([seq_length * multiple], tf.int64),
|
708 |
+
"segment_ids": tf.FixedLenFeature([seq_length * multiple], tf.int64),
|
709 |
+
"label_ids": tf.FixedLenFeature([], labeltype),
|
710 |
+
"is_real_example": tf.FixedLenFeature([], tf.int64),
|
711 |
+
}
|
712 |
+
|
713 |
+
def _decode_record(record, name_to_features):
|
714 |
+
"""Decodes a record to a TensorFlow example."""
|
715 |
+
example = tf.parse_single_example(record, name_to_features)
|
716 |
+
|
717 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
718 |
+
# So cast all int64 to int32.
|
719 |
+
for name in list(example.keys()):
|
720 |
+
t = example[name]
|
721 |
+
if t.dtype == tf.int64:
|
722 |
+
t = tf.to_int32(t)
|
723 |
+
example[name] = t
|
724 |
+
|
725 |
+
return example
|
726 |
+
|
727 |
+
def input_fn(params):
|
728 |
+
"""The actual input function."""
|
729 |
+
if use_tpu:
|
730 |
+
batch_size = params["batch_size"]
|
731 |
+
else:
|
732 |
+
batch_size = bsz
|
733 |
+
|
734 |
+
# For training, we want a lot of parallel reading and shuffling.
|
735 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
736 |
+
d = tf.data.TFRecordDataset(input_file)
|
737 |
+
if is_training:
|
738 |
+
d = d.repeat()
|
739 |
+
d = d.shuffle(buffer_size=100)
|
740 |
+
|
741 |
+
d = d.apply(
|
742 |
+
contrib_data.map_and_batch(
|
743 |
+
lambda record: _decode_record(record, name_to_features),
|
744 |
+
batch_size=batch_size,
|
745 |
+
drop_remainder=drop_remainder))
|
746 |
+
|
747 |
+
return d
|
748 |
+
|
749 |
+
return input_fn
|
750 |
+
|
751 |
+
|
752 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
753 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
754 |
+
|
755 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
756 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
757 |
+
# of tokens from each, since if one sequence is very short then each token
|
758 |
+
# that's truncated likely contains more information than a longer sequence.
|
759 |
+
while True:
|
760 |
+
total_length = len(tokens_a) + len(tokens_b)
|
761 |
+
if total_length <= max_length:
|
762 |
+
break
|
763 |
+
if len(tokens_a) > len(tokens_b):
|
764 |
+
tokens_a.pop()
|
765 |
+
else:
|
766 |
+
tokens_b.pop()
|
767 |
+
|
768 |
+
|
769 |
+
def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,
|
770 |
+
labels, num_labels, use_one_hot_embeddings, task_name,
|
771 |
+
hub_module):
|
772 |
+
"""Creates a classification model."""
|
773 |
+
(output_layer, _) = fine_tuning_utils.create_albert(
|
774 |
+
albert_config=albert_config,
|
775 |
+
is_training=is_training,
|
776 |
+
input_ids=input_ids,
|
777 |
+
input_mask=input_mask,
|
778 |
+
segment_ids=segment_ids,
|
779 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
780 |
+
use_einsum=True,
|
781 |
+
hub_module=hub_module)
|
782 |
+
|
783 |
+
hidden_size = output_layer.shape[-1].value
|
784 |
+
|
785 |
+
output_weights = tf.get_variable(
|
786 |
+
"output_weights", [num_labels, hidden_size],
|
787 |
+
initializer=tf.truncated_normal_initializer(stddev=0.02))
|
788 |
+
|
789 |
+
output_bias = tf.get_variable(
|
790 |
+
"output_bias", [num_labels], initializer=tf.zeros_initializer())
|
791 |
+
|
792 |
+
with tf.variable_scope("loss"):
|
793 |
+
if is_training:
|
794 |
+
# I.e., 0.1 dropout
|
795 |
+
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
|
796 |
+
|
797 |
+
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
|
798 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
799 |
+
if task_name != "sts-b":
|
800 |
+
probabilities = tf.nn.softmax(logits, axis=-1)
|
801 |
+
predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32)
|
802 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
803 |
+
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
|
804 |
+
|
805 |
+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
806 |
+
else:
|
807 |
+
probabilities = logits
|
808 |
+
logits = tf.squeeze(logits, [-1])
|
809 |
+
predictions = logits
|
810 |
+
per_example_loss = tf.square(logits - labels)
|
811 |
+
loss = tf.reduce_mean(per_example_loss)
|
812 |
+
|
813 |
+
return (loss, per_example_loss, probabilities, logits, predictions)
|
814 |
+
|
815 |
+
|
816 |
+
def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate,
|
817 |
+
num_train_steps, num_warmup_steps, use_tpu,
|
818 |
+
use_one_hot_embeddings, task_name, hub_module=None,
|
819 |
+
optimizer="adamw"):
|
820 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
821 |
+
|
822 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
823 |
+
"""The `model_fn` for TPUEstimator."""
|
824 |
+
|
825 |
+
tf.logging.info("*** Features ***")
|
826 |
+
for name in sorted(features.keys()):
|
827 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
828 |
+
|
829 |
+
input_ids = features["input_ids"]
|
830 |
+
input_mask = features["input_mask"]
|
831 |
+
segment_ids = features["segment_ids"]
|
832 |
+
label_ids = features["label_ids"]
|
833 |
+
is_real_example = None
|
834 |
+
if "is_real_example" in features:
|
835 |
+
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
|
836 |
+
else:
|
837 |
+
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
|
838 |
+
|
839 |
+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
840 |
+
|
841 |
+
(total_loss, per_example_loss, probabilities, logits, predictions) = \
|
842 |
+
create_model(albert_config, is_training, input_ids, input_mask,
|
843 |
+
segment_ids, label_ids, num_labels, use_one_hot_embeddings,
|
844 |
+
task_name, hub_module)
|
845 |
+
|
846 |
+
tvars = tf.trainable_variables()
|
847 |
+
initialized_variable_names = {}
|
848 |
+
scaffold_fn = None
|
849 |
+
if init_checkpoint:
|
850 |
+
(assignment_map, initialized_variable_names
|
851 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
852 |
+
if use_tpu:
|
853 |
+
|
854 |
+
def tpu_scaffold():
|
855 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
856 |
+
return tf.train.Scaffold()
|
857 |
+
|
858 |
+
scaffold_fn = tpu_scaffold
|
859 |
+
else:
|
860 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
861 |
+
|
862 |
+
tf.logging.info("**** Trainable Variables ****")
|
863 |
+
for var in tvars:
|
864 |
+
init_string = ""
|
865 |
+
if var.name in initialized_variable_names:
|
866 |
+
init_string = ", *INIT_FROM_CKPT*"
|
867 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
868 |
+
init_string)
|
869 |
+
|
870 |
+
output_spec = None
|
871 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
872 |
+
|
873 |
+
train_op = optimization.create_optimizer(
|
874 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps,
|
875 |
+
use_tpu, optimizer)
|
876 |
+
|
877 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
878 |
+
mode=mode,
|
879 |
+
loss=total_loss,
|
880 |
+
train_op=train_op,
|
881 |
+
scaffold_fn=scaffold_fn)
|
882 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
883 |
+
if task_name not in ["sts-b", "cola"]:
|
884 |
+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
885 |
+
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
886 |
+
accuracy = tf.metrics.accuracy(
|
887 |
+
labels=label_ids, predictions=predictions,
|
888 |
+
weights=is_real_example)
|
889 |
+
loss = tf.metrics.mean(
|
890 |
+
values=per_example_loss, weights=is_real_example)
|
891 |
+
return {
|
892 |
+
"eval_accuracy": accuracy,
|
893 |
+
"eval_loss": loss,
|
894 |
+
}
|
895 |
+
elif task_name == "sts-b":
|
896 |
+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
897 |
+
"""Compute Pearson correlations for STS-B."""
|
898 |
+
# Display labels and predictions
|
899 |
+
concat1 = contrib_metrics.streaming_concat(logits)
|
900 |
+
concat2 = contrib_metrics.streaming_concat(label_ids)
|
901 |
+
|
902 |
+
# Compute Pearson correlation
|
903 |
+
pearson = contrib_metrics.streaming_pearson_correlation(
|
904 |
+
logits, label_ids, weights=is_real_example)
|
905 |
+
|
906 |
+
# Compute MSE
|
907 |
+
# mse = tf.metrics.mean(per_example_loss)
|
908 |
+
mse = tf.metrics.mean_squared_error(
|
909 |
+
label_ids, logits, weights=is_real_example)
|
910 |
+
|
911 |
+
loss = tf.metrics.mean(
|
912 |
+
values=per_example_loss,
|
913 |
+
weights=is_real_example)
|
914 |
+
|
915 |
+
return {"pred": concat1, "label_ids": concat2, "pearson": pearson,
|
916 |
+
"MSE": mse, "eval_loss": loss,}
|
917 |
+
elif task_name == "cola":
|
918 |
+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
919 |
+
"""Compute Matthew's correlations for STS-B."""
|
920 |
+
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
921 |
+
# https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
|
922 |
+
tp, tp_op = tf.metrics.true_positives(
|
923 |
+
predictions, label_ids, weights=is_real_example)
|
924 |
+
tn, tn_op = tf.metrics.true_negatives(
|
925 |
+
predictions, label_ids, weights=is_real_example)
|
926 |
+
fp, fp_op = tf.metrics.false_positives(
|
927 |
+
predictions, label_ids, weights=is_real_example)
|
928 |
+
fn, fn_op = tf.metrics.false_negatives(
|
929 |
+
predictions, label_ids, weights=is_real_example)
|
930 |
+
|
931 |
+
# Compute Matthew's correlation
|
932 |
+
mcc = tf.div_no_nan(
|
933 |
+
tp * tn - fp * fn,
|
934 |
+
tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5))
|
935 |
+
|
936 |
+
# Compute accuracy
|
937 |
+
accuracy = tf.metrics.accuracy(
|
938 |
+
labels=label_ids, predictions=predictions,
|
939 |
+
weights=is_real_example)
|
940 |
+
|
941 |
+
loss = tf.metrics.mean(
|
942 |
+
values=per_example_loss,
|
943 |
+
weights=is_real_example)
|
944 |
+
|
945 |
+
return {"matthew_corr": (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
|
946 |
+
"eval_accuracy": accuracy, "eval_loss": loss,}
|
947 |
+
|
948 |
+
eval_metrics = (metric_fn,
|
949 |
+
[per_example_loss, label_ids, logits, is_real_example])
|
950 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
951 |
+
mode=mode,
|
952 |
+
loss=total_loss,
|
953 |
+
eval_metrics=eval_metrics,
|
954 |
+
scaffold_fn=scaffold_fn)
|
955 |
+
else:
|
956 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
957 |
+
mode=mode,
|
958 |
+
predictions={
|
959 |
+
"probabilities": probabilities,
|
960 |
+
"predictions": predictions
|
961 |
+
},
|
962 |
+
scaffold_fn=scaffold_fn)
|
963 |
+
return output_spec
|
964 |
+
|
965 |
+
return model_fn
|
966 |
+
|
967 |
+
|
968 |
+
# This function is not used by this file but is still used by the Colab and
|
969 |
+
# people who depend on it.
|
970 |
+
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
971 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
972 |
+
|
973 |
+
all_input_ids = []
|
974 |
+
all_input_mask = []
|
975 |
+
all_segment_ids = []
|
976 |
+
all_label_ids = []
|
977 |
+
|
978 |
+
for feature in features:
|
979 |
+
all_input_ids.append(feature.input_ids)
|
980 |
+
all_input_mask.append(feature.input_mask)
|
981 |
+
all_segment_ids.append(feature.segment_ids)
|
982 |
+
all_label_ids.append(feature.label_id)
|
983 |
+
|
984 |
+
def input_fn(params):
|
985 |
+
"""The actual input function."""
|
986 |
+
batch_size = params["batch_size"]
|
987 |
+
|
988 |
+
num_examples = len(features)
|
989 |
+
|
990 |
+
# This is for demo purposes and does NOT scale to large data sets. We do
|
991 |
+
# not use Dataset.from_generator() because that uses tf.py_func which is
|
992 |
+
# not TPU compatible. The right way to load data is with TFRecordReader.
|
993 |
+
d = tf.data.Dataset.from_tensor_slices({
|
994 |
+
"input_ids":
|
995 |
+
tf.constant(
|
996 |
+
all_input_ids, shape=[num_examples, seq_length],
|
997 |
+
dtype=tf.int32),
|
998 |
+
"input_mask":
|
999 |
+
tf.constant(
|
1000 |
+
all_input_mask,
|
1001 |
+
shape=[num_examples, seq_length],
|
1002 |
+
dtype=tf.int32),
|
1003 |
+
"segment_ids":
|
1004 |
+
tf.constant(
|
1005 |
+
all_segment_ids,
|
1006 |
+
shape=[num_examples, seq_length],
|
1007 |
+
dtype=tf.int32),
|
1008 |
+
"label_ids":
|
1009 |
+
tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
|
1010 |
+
})
|
1011 |
+
|
1012 |
+
if is_training:
|
1013 |
+
d = d.repeat()
|
1014 |
+
d = d.shuffle(buffer_size=100)
|
1015 |
+
|
1016 |
+
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
1017 |
+
return d
|
1018 |
+
|
1019 |
+
return input_fn
|
1020 |
+
|
1021 |
+
|
1022 |
+
# This function is not used by this file but is still used by the Colab and
|
1023 |
+
# people who depend on it.
|
1024 |
+
def convert_examples_to_features(examples, label_list, max_seq_length,
|
1025 |
+
tokenizer, task_name):
|
1026 |
+
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
|
1027 |
+
|
1028 |
+
features = []
|
1029 |
+
for (ex_index, example) in enumerate(examples):
|
1030 |
+
if ex_index % 10000 == 0:
|
1031 |
+
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
|
1032 |
+
|
1033 |
+
feature = convert_single_example(ex_index, example, label_list,
|
1034 |
+
max_seq_length, tokenizer, task_name)
|
1035 |
+
|
1036 |
+
features.append(feature)
|
1037 |
+
return features
|
Indic-BERT-v1-master/albert/create_pretraining_data.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
# coding=utf-8
|
17 |
+
"""Create masked LM/next sentence masked_lm TF examples for ALBERT."""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
import collections
|
23 |
+
import random
|
24 |
+
from albert import tokenization
|
25 |
+
import numpy as np
|
26 |
+
import six
|
27 |
+
from six.moves import range
|
28 |
+
from six.moves import zip
|
29 |
+
import tensorflow.compat.v1 as tf
|
30 |
+
|
31 |
+
flags = tf.flags
|
32 |
+
|
33 |
+
FLAGS = flags.FLAGS
|
34 |
+
|
35 |
+
flags.DEFINE_string("input_file", None,
|
36 |
+
"Input raw text file (or comma-separated list of files).")
|
37 |
+
|
38 |
+
flags.DEFINE_string(
|
39 |
+
"output_file", None,
|
40 |
+
"Output TF example file (or comma-separated list of files).")
|
41 |
+
|
42 |
+
flags.DEFINE_string(
|
43 |
+
"vocab_file", None,
|
44 |
+
"The vocabulary file that the ALBERT model was trained on.")
|
45 |
+
|
46 |
+
flags.DEFINE_string("spm_model_file", None,
|
47 |
+
"The model file for sentence piece tokenization.")
|
48 |
+
|
49 |
+
flags.DEFINE_string("input_file_mode", "r",
|
50 |
+
"The data format of the input file.")
|
51 |
+
|
52 |
+
flags.DEFINE_bool(
|
53 |
+
"do_lower_case", True,
|
54 |
+
"Whether to lower case the input text. Should be True for uncased "
|
55 |
+
"models and False for cased models.")
|
56 |
+
|
57 |
+
flags.DEFINE_bool(
|
58 |
+
"do_whole_word_mask", True,
|
59 |
+
"Whether to use whole word masking rather than per-WordPiece masking.")
|
60 |
+
|
61 |
+
flags.DEFINE_bool(
|
62 |
+
"do_permutation", False,
|
63 |
+
"Whether to do the permutation training.")
|
64 |
+
|
65 |
+
flags.DEFINE_bool(
|
66 |
+
"favor_shorter_ngram", True,
|
67 |
+
"Whether to set higher probabilities for sampling shorter ngrams.")
|
68 |
+
|
69 |
+
flags.DEFINE_bool(
|
70 |
+
"random_next_sentence", False,
|
71 |
+
"Whether to use the sentence that's right before the current sentence "
|
72 |
+
"as the negative sample for next sentence prection, rather than using "
|
73 |
+
"sentences from other random documents.")
|
74 |
+
|
75 |
+
flags.DEFINE_integer("max_seq_length", 512, "Maximum sequence length.")
|
76 |
+
|
77 |
+
flags.DEFINE_integer("ngram", 3, "Maximum number of ngrams to mask.")
|
78 |
+
|
79 |
+
flags.DEFINE_integer("max_predictions_per_seq", 20,
|
80 |
+
"Maximum number of masked LM predictions per sequence.")
|
81 |
+
|
82 |
+
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
83 |
+
|
84 |
+
flags.DEFINE_integer(
|
85 |
+
"dupe_factor", 40,
|
86 |
+
"Number of times to duplicate the input data (with different masks).")
|
87 |
+
|
88 |
+
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
89 |
+
|
90 |
+
flags.DEFINE_float(
|
91 |
+
"short_seq_prob", 0.1,
|
92 |
+
"Probability of creating sequences which are shorter than the "
|
93 |
+
"maximum length.")
|
94 |
+
|
95 |
+
|
96 |
+
class TrainingInstance(object):
|
97 |
+
"""A single training instance (sentence pair)."""
|
98 |
+
|
99 |
+
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
|
100 |
+
is_random_next, token_boundary):
|
101 |
+
self.tokens = tokens
|
102 |
+
self.segment_ids = segment_ids
|
103 |
+
self.is_random_next = is_random_next
|
104 |
+
self.token_boundary = token_boundary
|
105 |
+
self.masked_lm_positions = masked_lm_positions
|
106 |
+
self.masked_lm_labels = masked_lm_labels
|
107 |
+
|
108 |
+
def __str__(self):
|
109 |
+
s = ""
|
110 |
+
s += "tokens: %s\n" % (" ".join(
|
111 |
+
[tokenization.printable_text(x) for x in self.tokens]))
|
112 |
+
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
113 |
+
s += "token_boundary: %s\n" % (" ".join(
|
114 |
+
[str(x) for x in self.token_boundary]))
|
115 |
+
s += "is_random_next: %s\n" % self.is_random_next
|
116 |
+
s += "masked_lm_positions: %s\n" % (" ".join(
|
117 |
+
[str(x) for x in self.masked_lm_positions]))
|
118 |
+
s += "masked_lm_labels: %s\n" % (" ".join(
|
119 |
+
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
|
120 |
+
s += "\n"
|
121 |
+
return s
|
122 |
+
|
123 |
+
def __repr__(self):
|
124 |
+
return self.__str__()
|
125 |
+
|
126 |
+
|
127 |
+
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
|
128 |
+
max_predictions_per_seq, output_files):
|
129 |
+
"""Create TF example files from `TrainingInstance`s."""
|
130 |
+
writers = []
|
131 |
+
for output_file in output_files:
|
132 |
+
writers.append(tf.python_io.TFRecordWriter(output_file))
|
133 |
+
|
134 |
+
writer_index = 0
|
135 |
+
|
136 |
+
total_written = 0
|
137 |
+
for (inst_index, instance) in enumerate(instances):
|
138 |
+
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
139 |
+
input_mask = [1] * len(input_ids)
|
140 |
+
segment_ids = list(instance.segment_ids)
|
141 |
+
token_boundary = list(instance.token_boundary)
|
142 |
+
assert len(input_ids) <= max_seq_length
|
143 |
+
|
144 |
+
while len(input_ids) < max_seq_length:
|
145 |
+
input_ids.append(0)
|
146 |
+
input_mask.append(0)
|
147 |
+
segment_ids.append(0)
|
148 |
+
token_boundary.append(0)
|
149 |
+
|
150 |
+
assert len(input_ids) == max_seq_length
|
151 |
+
assert len(input_mask) == max_seq_length
|
152 |
+
assert len(segment_ids) == max_seq_length
|
153 |
+
|
154 |
+
masked_lm_positions = list(instance.masked_lm_positions)
|
155 |
+
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
156 |
+
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
157 |
+
|
158 |
+
multiplier = 1 + int(FLAGS.do_permutation)
|
159 |
+
while len(masked_lm_positions) < max_predictions_per_seq * multiplier:
|
160 |
+
masked_lm_positions.append(0)
|
161 |
+
masked_lm_ids.append(0)
|
162 |
+
masked_lm_weights.append(0.0)
|
163 |
+
|
164 |
+
sentence_order_label = 1 if instance.is_random_next else 0
|
165 |
+
|
166 |
+
features = collections.OrderedDict()
|
167 |
+
features["input_ids"] = create_int_feature(input_ids)
|
168 |
+
features["input_mask"] = create_int_feature(input_mask)
|
169 |
+
features["segment_ids"] = create_int_feature(segment_ids)
|
170 |
+
features["token_boundary"] = create_int_feature(token_boundary)
|
171 |
+
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
172 |
+
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
173 |
+
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
174 |
+
# Note: We keep this feature name `next_sentence_labels` to be compatible
|
175 |
+
# with the original data created by lanzhzh@. However, in the ALBERT case
|
176 |
+
# it does contain sentence_order_label.
|
177 |
+
features["next_sentence_labels"] = create_int_feature(
|
178 |
+
[sentence_order_label])
|
179 |
+
|
180 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
181 |
+
|
182 |
+
writers[writer_index].write(tf_example.SerializeToString())
|
183 |
+
writer_index = (writer_index + 1) % len(writers)
|
184 |
+
|
185 |
+
total_written += 1
|
186 |
+
|
187 |
+
if inst_index < 20:
|
188 |
+
tf.logging.info("*** Example ***")
|
189 |
+
tf.logging.info("tokens: %s" % " ".join(
|
190 |
+
[tokenization.printable_text(x) for x in instance.tokens]))
|
191 |
+
|
192 |
+
for feature_name in features.keys():
|
193 |
+
feature = features[feature_name]
|
194 |
+
values = []
|
195 |
+
if feature.int64_list.value:
|
196 |
+
values = feature.int64_list.value
|
197 |
+
elif feature.float_list.value:
|
198 |
+
values = feature.float_list.value
|
199 |
+
tf.logging.info(
|
200 |
+
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
|
201 |
+
|
202 |
+
for writer in writers:
|
203 |
+
writer.close()
|
204 |
+
|
205 |
+
tf.logging.info("Wrote %d total instances", total_written)
|
206 |
+
|
207 |
+
|
208 |
+
def create_int_feature(values):
|
209 |
+
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
210 |
+
return feature
|
211 |
+
|
212 |
+
|
213 |
+
def create_float_feature(values):
|
214 |
+
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
215 |
+
return feature
|
216 |
+
|
217 |
+
|
218 |
+
def create_training_instances(input_files, tokenizer, max_seq_length,
|
219 |
+
dupe_factor, short_seq_prob, masked_lm_prob,
|
220 |
+
max_predictions_per_seq, rng):
|
221 |
+
"""Create `TrainingInstance`s from raw text."""
|
222 |
+
all_documents = [[]]
|
223 |
+
|
224 |
+
# Input file format:
|
225 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
226 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
227 |
+
# sentence boundaries for the "next sentence prediction" task).
|
228 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
229 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
230 |
+
for input_file in input_files:
|
231 |
+
with tf.gfile.GFile(input_file, FLAGS.input_file_mode) as reader:
|
232 |
+
while True:
|
233 |
+
line = reader.readline()
|
234 |
+
if not FLAGS.spm_model_file:
|
235 |
+
line = tokenization.convert_to_unicode(line)
|
236 |
+
if not line:
|
237 |
+
break
|
238 |
+
if FLAGS.spm_model_file:
|
239 |
+
line = tokenization.preprocess_text(line, lower=FLAGS.do_lower_case)
|
240 |
+
else:
|
241 |
+
line = line.strip()
|
242 |
+
|
243 |
+
# Empty lines are used as document delimiters
|
244 |
+
if not line:
|
245 |
+
all_documents.append([])
|
246 |
+
tokens = tokenizer.tokenize(line)
|
247 |
+
if tokens:
|
248 |
+
all_documents[-1].append(tokens)
|
249 |
+
|
250 |
+
# Remove empty documents
|
251 |
+
all_documents = [x for x in all_documents if x]
|
252 |
+
rng.shuffle(all_documents)
|
253 |
+
|
254 |
+
vocab_words = list(tokenizer.vocab.keys())
|
255 |
+
instances = []
|
256 |
+
for _ in range(dupe_factor):
|
257 |
+
for document_index in range(len(all_documents)):
|
258 |
+
instances.extend(
|
259 |
+
create_instances_from_document(
|
260 |
+
all_documents, document_index, max_seq_length, short_seq_prob,
|
261 |
+
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
|
262 |
+
|
263 |
+
rng.shuffle(instances)
|
264 |
+
return instances
|
265 |
+
|
266 |
+
|
267 |
+
def create_instances_from_document(
|
268 |
+
all_documents, document_index, max_seq_length, short_seq_prob,
|
269 |
+
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
|
270 |
+
"""Creates `TrainingInstance`s for a single document."""
|
271 |
+
document = all_documents[document_index]
|
272 |
+
|
273 |
+
# Account for [CLS], [SEP], [SEP]
|
274 |
+
max_num_tokens = max_seq_length - 3
|
275 |
+
|
276 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
277 |
+
# to `max_seq_length` anyways, so short sequences are generally wasted
|
278 |
+
# computation. However, we *sometimes*
|
279 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
280 |
+
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
281 |
+
# The `target_seq_length` is just a rough target however, whereas
|
282 |
+
# `max_seq_length` is a hard limit.
|
283 |
+
target_seq_length = max_num_tokens
|
284 |
+
if rng.random() < short_seq_prob:
|
285 |
+
target_seq_length = rng.randint(2, max_num_tokens)
|
286 |
+
|
287 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
288 |
+
# sequence and choose an arbitrary split point because this would make the
|
289 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
290 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
291 |
+
# input.
|
292 |
+
instances = []
|
293 |
+
current_chunk = []
|
294 |
+
current_length = 0
|
295 |
+
i = 0
|
296 |
+
while i < len(document):
|
297 |
+
segment = document[i]
|
298 |
+
current_chunk.append(segment)
|
299 |
+
current_length += len(segment)
|
300 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
301 |
+
if current_chunk:
|
302 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
303 |
+
# (first) sentence.
|
304 |
+
a_end = 1
|
305 |
+
if len(current_chunk) >= 2:
|
306 |
+
a_end = rng.randint(1, len(current_chunk) - 1)
|
307 |
+
|
308 |
+
tokens_a = []
|
309 |
+
for j in range(a_end):
|
310 |
+
tokens_a.extend(current_chunk[j])
|
311 |
+
|
312 |
+
tokens_b = []
|
313 |
+
# Random next
|
314 |
+
is_random_next = False
|
315 |
+
if len(current_chunk) == 1 or \
|
316 |
+
(FLAGS.random_next_sentence and rng.random() < 0.5):
|
317 |
+
is_random_next = True
|
318 |
+
target_b_length = target_seq_length - len(tokens_a)
|
319 |
+
|
320 |
+
# This should rarely go for more than one iteration for large
|
321 |
+
# corpora. However, just to be careful, we try to make sure that
|
322 |
+
# the random document is not the same as the document
|
323 |
+
# we're processing.
|
324 |
+
for _ in range(10):
|
325 |
+
random_document_index = rng.randint(0, len(all_documents) - 1)
|
326 |
+
if random_document_index != document_index:
|
327 |
+
break
|
328 |
+
|
329 |
+
random_document = all_documents[random_document_index]
|
330 |
+
random_start = rng.randint(0, len(random_document) - 1)
|
331 |
+
for j in range(random_start, len(random_document)):
|
332 |
+
tokens_b.extend(random_document[j])
|
333 |
+
if len(tokens_b) >= target_b_length:
|
334 |
+
break
|
335 |
+
# We didn't actually use these segments so we "put them back" so
|
336 |
+
# they don't go to waste.
|
337 |
+
num_unused_segments = len(current_chunk) - a_end
|
338 |
+
i -= num_unused_segments
|
339 |
+
elif not FLAGS.random_next_sentence and rng.random() < 0.5:
|
340 |
+
is_random_next = True
|
341 |
+
for j in range(a_end, len(current_chunk)):
|
342 |
+
tokens_b.extend(current_chunk[j])
|
343 |
+
# Note(mingdachen): in this case, we just swap tokens_a and tokens_b
|
344 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
345 |
+
# Actual next
|
346 |
+
else:
|
347 |
+
is_random_next = False
|
348 |
+
for j in range(a_end, len(current_chunk)):
|
349 |
+
tokens_b.extend(current_chunk[j])
|
350 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
351 |
+
|
352 |
+
assert len(tokens_a) >= 1
|
353 |
+
assert len(tokens_b) >= 1
|
354 |
+
|
355 |
+
tokens = []
|
356 |
+
segment_ids = []
|
357 |
+
tokens.append("[CLS]")
|
358 |
+
segment_ids.append(0)
|
359 |
+
for token in tokens_a:
|
360 |
+
tokens.append(token)
|
361 |
+
segment_ids.append(0)
|
362 |
+
|
363 |
+
tokens.append("[SEP]")
|
364 |
+
segment_ids.append(0)
|
365 |
+
|
366 |
+
for token in tokens_b:
|
367 |
+
tokens.append(token)
|
368 |
+
segment_ids.append(1)
|
369 |
+
tokens.append("[SEP]")
|
370 |
+
segment_ids.append(1)
|
371 |
+
|
372 |
+
(tokens, masked_lm_positions,
|
373 |
+
masked_lm_labels, token_boundary) = create_masked_lm_predictions(
|
374 |
+
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
|
375 |
+
instance = TrainingInstance(
|
376 |
+
tokens=tokens,
|
377 |
+
segment_ids=segment_ids,
|
378 |
+
is_random_next=is_random_next,
|
379 |
+
token_boundary=token_boundary,
|
380 |
+
masked_lm_positions=masked_lm_positions,
|
381 |
+
masked_lm_labels=masked_lm_labels)
|
382 |
+
instances.append(instance)
|
383 |
+
current_chunk = []
|
384 |
+
current_length = 0
|
385 |
+
i += 1
|
386 |
+
|
387 |
+
return instances
|
388 |
+
|
389 |
+
|
390 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
391 |
+
["index", "label"])
|
392 |
+
|
393 |
+
|
394 |
+
def _is_start_piece_sp(piece):
|
395 |
+
"""Check if the current word piece is the starting piece (sentence piece)."""
|
396 |
+
special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~₹'))
|
397 |
+
special_pieces.add(u"€".encode("utf-8"))
|
398 |
+
special_pieces.add(u"£".encode("utf-8"))
|
399 |
+
# Note(mingdachen):
|
400 |
+
# For foreign characters, we always treat them as a whole piece.
|
401 |
+
english_chars = set(list("abcdefghijklmnopqrstuvwxyz"))
|
402 |
+
if (six.ensure_str(piece).startswith("▁") or
|
403 |
+
six.ensure_str(piece).startswith("<") or piece in special_pieces):
|
404 |
+
return True
|
405 |
+
else:
|
406 |
+
return False
|
407 |
+
|
408 |
+
|
409 |
+
def _is_start_piece_bert(piece):
|
410 |
+
"""Check if the current word piece is the starting piece (BERT)."""
|
411 |
+
# When a word has been split into
|
412 |
+
# WordPieces, the first token does not have any marker and any subsequence
|
413 |
+
# tokens are prefixed with ##. So whenever we see the ## token, we
|
414 |
+
# append it to the previous set of word indexes.
|
415 |
+
return not six.ensure_str(piece).startswith("##")
|
416 |
+
|
417 |
+
|
418 |
+
def is_start_piece(piece):
|
419 |
+
if FLAGS.spm_model_file:
|
420 |
+
return _is_start_piece_sp(piece)
|
421 |
+
else:
|
422 |
+
return _is_start_piece_bert(piece)
|
423 |
+
|
424 |
+
|
425 |
+
def create_masked_lm_predictions(tokens, masked_lm_prob,
|
426 |
+
max_predictions_per_seq, vocab_words, rng):
|
427 |
+
"""Creates the predictions for the masked LM objective."""
|
428 |
+
|
429 |
+
cand_indexes = []
|
430 |
+
# Note(mingdachen): We create a list for recording if the piece is
|
431 |
+
# the starting piece of current token, where 1 means true, so that
|
432 |
+
# on-the-fly whole word masking is possible.
|
433 |
+
token_boundary = [0] * len(tokens)
|
434 |
+
|
435 |
+
for (i, token) in enumerate(tokens):
|
436 |
+
if token == "[CLS]" or token == "[SEP]":
|
437 |
+
token_boundary[i] = 1
|
438 |
+
continue
|
439 |
+
# Whole Word Masking means that if we mask all of the wordpieces
|
440 |
+
# corresponding to an original word.
|
441 |
+
#
|
442 |
+
# Note that Whole Word Masking does *not* change the training code
|
443 |
+
# at all -- we still predict each WordPiece independently, softmaxed
|
444 |
+
# over the entire vocabulary.
|
445 |
+
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
|
446 |
+
not is_start_piece(token)):
|
447 |
+
cand_indexes[-1].append(i)
|
448 |
+
else:
|
449 |
+
cand_indexes.append([i])
|
450 |
+
if is_start_piece(token):
|
451 |
+
token_boundary[i] = 1
|
452 |
+
|
453 |
+
output_tokens = list(tokens)
|
454 |
+
|
455 |
+
masked_lm_positions = []
|
456 |
+
masked_lm_labels = []
|
457 |
+
|
458 |
+
if masked_lm_prob == 0:
|
459 |
+
return (output_tokens, masked_lm_positions,
|
460 |
+
masked_lm_labels, token_boundary)
|
461 |
+
|
462 |
+
num_to_predict = min(max_predictions_per_seq,
|
463 |
+
max(1, int(round(len(tokens) * masked_lm_prob))))
|
464 |
+
|
465 |
+
# Note(mingdachen):
|
466 |
+
# By default, we set the probilities to favor shorter ngram sequences.
|
467 |
+
ngrams = np.arange(1, FLAGS.ngram + 1, dtype=np.int64)
|
468 |
+
pvals = 1. / np.arange(1, FLAGS.ngram + 1)
|
469 |
+
pvals /= pvals.sum(keepdims=True)
|
470 |
+
|
471 |
+
if not FLAGS.favor_shorter_ngram:
|
472 |
+
pvals = pvals[::-1]
|
473 |
+
|
474 |
+
ngram_indexes = []
|
475 |
+
for idx in range(len(cand_indexes)):
|
476 |
+
ngram_index = []
|
477 |
+
for n in ngrams:
|
478 |
+
ngram_index.append(cand_indexes[idx:idx+n])
|
479 |
+
ngram_indexes.append(ngram_index)
|
480 |
+
|
481 |
+
rng.shuffle(ngram_indexes)
|
482 |
+
|
483 |
+
masked_lms = []
|
484 |
+
covered_indexes = set()
|
485 |
+
for cand_index_set in ngram_indexes:
|
486 |
+
if len(masked_lms) >= num_to_predict:
|
487 |
+
break
|
488 |
+
if not cand_index_set:
|
489 |
+
continue
|
490 |
+
# Note(mingdachen):
|
491 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
492 |
+
for index_set in cand_index_set[0]:
|
493 |
+
for index in index_set:
|
494 |
+
if index in covered_indexes:
|
495 |
+
continue
|
496 |
+
|
497 |
+
n = np.random.choice(ngrams[:len(cand_index_set)],
|
498 |
+
p=pvals[:len(cand_index_set)] /
|
499 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
500 |
+
index_set = sum(cand_index_set[n - 1], [])
|
501 |
+
n -= 1
|
502 |
+
# Note(mingdachen):
|
503 |
+
# Repeatedly looking for a candidate that does not exceed the
|
504 |
+
# maximum number of predictions by trying shorter ngrams.
|
505 |
+
while len(masked_lms) + len(index_set) > num_to_predict:
|
506 |
+
if n == 0:
|
507 |
+
break
|
508 |
+
index_set = sum(cand_index_set[n - 1], [])
|
509 |
+
n -= 1
|
510 |
+
# If adding a whole-word mask would exceed the maximum number of
|
511 |
+
# predictions, then just skip this candidate.
|
512 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
513 |
+
continue
|
514 |
+
is_any_index_covered = False
|
515 |
+
for index in index_set:
|
516 |
+
if index in covered_indexes:
|
517 |
+
is_any_index_covered = True
|
518 |
+
break
|
519 |
+
if is_any_index_covered:
|
520 |
+
continue
|
521 |
+
for index in index_set:
|
522 |
+
covered_indexes.add(index)
|
523 |
+
|
524 |
+
masked_token = None
|
525 |
+
# 80% of the time, replace with [MASK]
|
526 |
+
if rng.random() < 0.8:
|
527 |
+
masked_token = "[MASK]"
|
528 |
+
else:
|
529 |
+
# 10% of the time, keep original
|
530 |
+
if rng.random() < 0.5:
|
531 |
+
masked_token = tokens[index]
|
532 |
+
# 10% of the time, replace with random word
|
533 |
+
else:
|
534 |
+
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
535 |
+
|
536 |
+
output_tokens[index] = masked_token
|
537 |
+
|
538 |
+
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
539 |
+
assert len(masked_lms) <= num_to_predict
|
540 |
+
|
541 |
+
rng.shuffle(ngram_indexes)
|
542 |
+
|
543 |
+
select_indexes = set()
|
544 |
+
if FLAGS.do_permutation:
|
545 |
+
for cand_index_set in ngram_indexes:
|
546 |
+
if len(select_indexes) >= num_to_predict:
|
547 |
+
break
|
548 |
+
if not cand_index_set:
|
549 |
+
continue
|
550 |
+
# Note(mingdachen):
|
551 |
+
# Skip current piece if they are covered in lm masking or previous ngrams.
|
552 |
+
for index_set in cand_index_set[0]:
|
553 |
+
for index in index_set:
|
554 |
+
if index in covered_indexes or index in select_indexes:
|
555 |
+
continue
|
556 |
+
|
557 |
+
n = np.random.choice(ngrams[:len(cand_index_set)],
|
558 |
+
p=pvals[:len(cand_index_set)] /
|
559 |
+
pvals[:len(cand_index_set)].sum(keepdims=True))
|
560 |
+
index_set = sum(cand_index_set[n - 1], [])
|
561 |
+
n -= 1
|
562 |
+
|
563 |
+
while len(select_indexes) + len(index_set) > num_to_predict:
|
564 |
+
if n == 0:
|
565 |
+
break
|
566 |
+
index_set = sum(cand_index_set[n - 1], [])
|
567 |
+
n -= 1
|
568 |
+
# If adding a whole-word mask would exceed the maximum number of
|
569 |
+
# predictions, then just skip this candidate.
|
570 |
+
if len(select_indexes) + len(index_set) > num_to_predict:
|
571 |
+
continue
|
572 |
+
is_any_index_covered = False
|
573 |
+
for index in index_set:
|
574 |
+
if index in covered_indexes or index in select_indexes:
|
575 |
+
is_any_index_covered = True
|
576 |
+
break
|
577 |
+
if is_any_index_covered:
|
578 |
+
continue
|
579 |
+
for index in index_set:
|
580 |
+
select_indexes.add(index)
|
581 |
+
assert len(select_indexes) <= num_to_predict
|
582 |
+
|
583 |
+
select_indexes = sorted(select_indexes)
|
584 |
+
permute_indexes = list(select_indexes)
|
585 |
+
rng.shuffle(permute_indexes)
|
586 |
+
orig_token = list(output_tokens)
|
587 |
+
|
588 |
+
for src_i, tgt_i in zip(select_indexes, permute_indexes):
|
589 |
+
output_tokens[src_i] = orig_token[tgt_i]
|
590 |
+
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
|
591 |
+
|
592 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
593 |
+
|
594 |
+
for p in masked_lms:
|
595 |
+
masked_lm_positions.append(p.index)
|
596 |
+
masked_lm_labels.append(p.label)
|
597 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
598 |
+
|
599 |
+
|
600 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
601 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
602 |
+
while True:
|
603 |
+
total_length = len(tokens_a) + len(tokens_b)
|
604 |
+
if total_length <= max_num_tokens:
|
605 |
+
break
|
606 |
+
|
607 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
608 |
+
assert len(trunc_tokens) >= 1
|
609 |
+
|
610 |
+
# We want to sometimes truncate from the front and sometimes from the
|
611 |
+
# back to add more randomness and avoid biases.
|
612 |
+
if rng.random() < 0.5:
|
613 |
+
del trunc_tokens[0]
|
614 |
+
else:
|
615 |
+
trunc_tokens.pop()
|
616 |
+
|
617 |
+
|
618 |
+
def main(_):
|
619 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
620 |
+
|
621 |
+
tokenizer = tokenization.FullTokenizer(
|
622 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case,
|
623 |
+
spm_model_file=FLAGS.spm_model_file)
|
624 |
+
|
625 |
+
input_files = []
|
626 |
+
for input_pattern in FLAGS.input_file.split(","):
|
627 |
+
input_files.extend(tf.gfile.Glob(input_pattern))
|
628 |
+
|
629 |
+
tf.logging.info("*** Reading from input files ***")
|
630 |
+
for input_file in input_files:
|
631 |
+
tf.logging.info(" %s", input_file)
|
632 |
+
|
633 |
+
rng = random.Random(FLAGS.random_seed)
|
634 |
+
instances = create_training_instances(
|
635 |
+
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
|
636 |
+
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
|
637 |
+
rng)
|
638 |
+
|
639 |
+
tf.logging.info("number of instances: %i", len(instances))
|
640 |
+
|
641 |
+
output_files = FLAGS.output_file.split(",")
|
642 |
+
tf.logging.info("*** Writing to output files ***")
|
643 |
+
for output_file in output_files:
|
644 |
+
tf.logging.info(" %s", output_file)
|
645 |
+
|
646 |
+
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
|
647 |
+
FLAGS.max_predictions_per_seq, output_files)
|
648 |
+
|
649 |
+
|
650 |
+
if __name__ == "__main__":
|
651 |
+
flags.mark_flag_as_required("input_file")
|
652 |
+
flags.mark_flag_as_required("output_file")
|
653 |
+
flags.mark_flag_as_required("vocab_file")
|
654 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/evaluate.py
ADDED
File without changes
|
Indic-BERT-v1-master/albert/export_checkpoints.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
r"""Exports a minimal module for ALBERT models."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
import os
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
from albert import modeling
|
24 |
+
import tensorflow.compat.v1 as tf
|
25 |
+
|
26 |
+
flags.DEFINE_string(
|
27 |
+
"albert_directory", None,
|
28 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
29 |
+
"This specifies the model architecture.")
|
30 |
+
|
31 |
+
flags.DEFINE_string(
|
32 |
+
"checkpoint_name", "model.ckpt-best",
|
33 |
+
"Name of the checkpoint under albert_directory to be exported.")
|
34 |
+
|
35 |
+
flags.DEFINE_bool(
|
36 |
+
"do_lower_case", True,
|
37 |
+
"Whether to lower case the input text. Should be True for uncased "
|
38 |
+
"models and False for cased models.")
|
39 |
+
|
40 |
+
flags.DEFINE_string("export_path", None, "Path to the output module.")
|
41 |
+
|
42 |
+
FLAGS = flags.FLAGS
|
43 |
+
|
44 |
+
|
45 |
+
def gather_indexes(sequence_tensor, positions):
|
46 |
+
"""Gathers the vectors at the specific positions over a minibatch."""
|
47 |
+
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
48 |
+
batch_size = sequence_shape[0]
|
49 |
+
seq_length = sequence_shape[1]
|
50 |
+
width = sequence_shape[2]
|
51 |
+
|
52 |
+
flat_offsets = tf.reshape(
|
53 |
+
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
54 |
+
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
55 |
+
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
56 |
+
[batch_size * seq_length, width])
|
57 |
+
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
58 |
+
return output_tensor
|
59 |
+
|
60 |
+
|
61 |
+
def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights):
|
62 |
+
"""From run_pretraining.py."""
|
63 |
+
input_tensor = gather_indexes(input_tensor, mlm_positions)
|
64 |
+
with tf.variable_scope("cls/predictions"):
|
65 |
+
# We apply one more non-linear transformation before the output layer.
|
66 |
+
# This matrix is not used after pre-training.
|
67 |
+
with tf.variable_scope("transform"):
|
68 |
+
input_tensor = tf.layers.dense(
|
69 |
+
input_tensor,
|
70 |
+
units=albert_config.embedding_size,
|
71 |
+
activation=modeling.get_activation(albert_config.hidden_act),
|
72 |
+
kernel_initializer=modeling.create_initializer(
|
73 |
+
albert_config.initializer_range))
|
74 |
+
input_tensor = modeling.layer_norm(input_tensor)
|
75 |
+
|
76 |
+
# The output weights are the same as the input embeddings, but there is
|
77 |
+
# an output-only bias for each token.
|
78 |
+
output_bias = tf.get_variable(
|
79 |
+
"output_bias",
|
80 |
+
shape=[albert_config.vocab_size],
|
81 |
+
initializer=tf.zeros_initializer())
|
82 |
+
logits = tf.matmul(
|
83 |
+
input_tensor, output_weights, transpose_b=True)
|
84 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
85 |
+
return logits
|
86 |
+
|
87 |
+
|
88 |
+
def get_sentence_order_logits(input_tensor, albert_config):
|
89 |
+
"""Get loss and log probs for the next sentence prediction."""
|
90 |
+
|
91 |
+
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
92 |
+
# "random sentence". This weight matrix is not used after pre-training.
|
93 |
+
with tf.variable_scope("cls/seq_relationship"):
|
94 |
+
output_weights = tf.get_variable(
|
95 |
+
"output_weights",
|
96 |
+
shape=[2, albert_config.hidden_size],
|
97 |
+
initializer=modeling.create_initializer(
|
98 |
+
albert_config.initializer_range))
|
99 |
+
output_bias = tf.get_variable(
|
100 |
+
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
101 |
+
|
102 |
+
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
103 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
104 |
+
return logits
|
105 |
+
|
106 |
+
|
107 |
+
def build_model(sess):
|
108 |
+
"""Module function."""
|
109 |
+
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
|
110 |
+
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
|
111 |
+
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
|
112 |
+
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
|
113 |
+
|
114 |
+
albert_config_path = os.path.join(
|
115 |
+
FLAGS.albert_directory, "albert_config.json")
|
116 |
+
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
|
117 |
+
model = modeling.AlbertModel(
|
118 |
+
config=albert_config,
|
119 |
+
is_training=False,
|
120 |
+
input_ids=input_ids,
|
121 |
+
input_mask=input_mask,
|
122 |
+
token_type_ids=segment_ids,
|
123 |
+
use_one_hot_embeddings=False)
|
124 |
+
|
125 |
+
get_mlm_logits(model.get_sequence_output(), albert_config,
|
126 |
+
mlm_positions, model.get_embedding_table())
|
127 |
+
get_sentence_order_logits(model.get_pooled_output(), albert_config)
|
128 |
+
|
129 |
+
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
|
130 |
+
tvars = tf.trainable_variables()
|
131 |
+
(assignment_map, initialized_variable_names
|
132 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path)
|
133 |
+
|
134 |
+
tf.logging.info("**** Trainable Variables ****")
|
135 |
+
for var in tvars:
|
136 |
+
init_string = ""
|
137 |
+
if var.name in initialized_variable_names:
|
138 |
+
init_string = ", *INIT_FROM_CKPT*"
|
139 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
140 |
+
init_string)
|
141 |
+
tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
|
142 |
+
init = tf.global_variables_initializer()
|
143 |
+
sess.run(init)
|
144 |
+
return sess
|
145 |
+
|
146 |
+
|
147 |
+
def main(_):
|
148 |
+
sess = tf.Session()
|
149 |
+
tf.train.get_or_create_global_step()
|
150 |
+
sess = build_model(sess)
|
151 |
+
my_vars = []
|
152 |
+
for var in tf.global_variables():
|
153 |
+
if "lamb_v" not in var.name and "lamb_m" not in var.name:
|
154 |
+
my_vars.append(var)
|
155 |
+
saver = tf.train.Saver(my_vars)
|
156 |
+
saver.save(sess, FLAGS.export_path)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
flags.mark_flag_as_required("albert_directory")
|
161 |
+
flags.mark_flag_as_required("export_path")
|
162 |
+
app.run(main)
|
Indic-BERT-v1-master/albert/export_to_tfhub.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
r"""Exports a minimal TF-Hub module for ALBERT models."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
import os
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
from albert import modeling
|
24 |
+
import tensorflow.compat.v1 as tf
|
25 |
+
import tensorflow_hub as hub
|
26 |
+
|
27 |
+
flags.DEFINE_string(
|
28 |
+
"albert_directory", None,
|
29 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
30 |
+
"This specifies the model architecture.")
|
31 |
+
|
32 |
+
flags.DEFINE_string(
|
33 |
+
"checkpoint_name", "model.ckpt-best",
|
34 |
+
"Name of the checkpoint under albert_directory to be exported.")
|
35 |
+
|
36 |
+
flags.DEFINE_bool(
|
37 |
+
"do_lower_case", True,
|
38 |
+
"Whether to lower case the input text. Should be True for uncased "
|
39 |
+
"models and False for cased models.")
|
40 |
+
|
41 |
+
flags.DEFINE_bool(
|
42 |
+
"use_einsum", True,
|
43 |
+
"Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
|
44 |
+
"be set to False for TFLite compatibility.")
|
45 |
+
|
46 |
+
flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.")
|
47 |
+
|
48 |
+
FLAGS = flags.FLAGS
|
49 |
+
|
50 |
+
|
51 |
+
def gather_indexes(sequence_tensor, positions):
|
52 |
+
"""Gathers the vectors at the specific positions over a minibatch."""
|
53 |
+
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
54 |
+
batch_size = sequence_shape[0]
|
55 |
+
seq_length = sequence_shape[1]
|
56 |
+
width = sequence_shape[2]
|
57 |
+
|
58 |
+
flat_offsets = tf.reshape(
|
59 |
+
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
60 |
+
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
61 |
+
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
62 |
+
[batch_size * seq_length, width])
|
63 |
+
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
64 |
+
return output_tensor
|
65 |
+
|
66 |
+
|
67 |
+
def get_mlm_logits(model, albert_config, mlm_positions):
|
68 |
+
"""From run_pretraining.py."""
|
69 |
+
input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions)
|
70 |
+
with tf.variable_scope("cls/predictions"):
|
71 |
+
# We apply one more non-linear transformation before the output layer.
|
72 |
+
# This matrix is not used after pre-training.
|
73 |
+
with tf.variable_scope("transform"):
|
74 |
+
input_tensor = tf.layers.dense(
|
75 |
+
input_tensor,
|
76 |
+
units=albert_config.embedding_size,
|
77 |
+
activation=modeling.get_activation(albert_config.hidden_act),
|
78 |
+
kernel_initializer=modeling.create_initializer(
|
79 |
+
albert_config.initializer_range))
|
80 |
+
input_tensor = modeling.layer_norm(input_tensor)
|
81 |
+
|
82 |
+
# The output weights are the same as the input embeddings, but there is
|
83 |
+
# an output-only bias for each token.
|
84 |
+
output_bias = tf.get_variable(
|
85 |
+
"output_bias",
|
86 |
+
shape=[albert_config.vocab_size],
|
87 |
+
initializer=tf.zeros_initializer())
|
88 |
+
logits = tf.matmul(
|
89 |
+
input_tensor, model.get_embedding_table(), transpose_b=True)
|
90 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
91 |
+
return logits
|
92 |
+
|
93 |
+
|
94 |
+
def module_fn(is_training):
|
95 |
+
"""Module function."""
|
96 |
+
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
|
97 |
+
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
|
98 |
+
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
|
99 |
+
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
|
100 |
+
|
101 |
+
albert_config_path = os.path.join(
|
102 |
+
FLAGS.albert_directory, "albert_config.json")
|
103 |
+
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
|
104 |
+
model = modeling.AlbertModel(
|
105 |
+
config=albert_config,
|
106 |
+
is_training=is_training,
|
107 |
+
input_ids=input_ids,
|
108 |
+
input_mask=input_mask,
|
109 |
+
token_type_ids=segment_ids,
|
110 |
+
use_one_hot_embeddings=False,
|
111 |
+
use_einsum=FLAGS.use_einsum)
|
112 |
+
|
113 |
+
mlm_logits = get_mlm_logits(model, albert_config, mlm_positions)
|
114 |
+
|
115 |
+
vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model")
|
116 |
+
vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab")
|
117 |
+
|
118 |
+
config_file = tf.constant(
|
119 |
+
value=albert_config_path, dtype=tf.string, name="config_file")
|
120 |
+
vocab_model = tf.constant(
|
121 |
+
value=vocab_model_path, dtype=tf.string, name="vocab_model")
|
122 |
+
# This is only for visualization purpose.
|
123 |
+
vocab_file = tf.constant(
|
124 |
+
value=vocab_file_path, dtype=tf.string, name="vocab_file")
|
125 |
+
|
126 |
+
# By adding `config_file, vocab_model and vocab_file`
|
127 |
+
# to the ASSET_FILEPATHS collection, TF-Hub will
|
128 |
+
# rewrite this tensor so that this asset is portable.
|
129 |
+
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
|
130 |
+
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model)
|
131 |
+
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
|
132 |
+
|
133 |
+
hub.add_signature(
|
134 |
+
name="tokens",
|
135 |
+
inputs=dict(
|
136 |
+
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
|
137 |
+
outputs=dict(
|
138 |
+
sequence_output=model.get_sequence_output(),
|
139 |
+
pooled_output=model.get_pooled_output()))
|
140 |
+
|
141 |
+
hub.add_signature(
|
142 |
+
name="mlm",
|
143 |
+
inputs=dict(
|
144 |
+
input_ids=input_ids,
|
145 |
+
input_mask=input_mask,
|
146 |
+
segment_ids=segment_ids,
|
147 |
+
mlm_positions=mlm_positions),
|
148 |
+
outputs=dict(
|
149 |
+
sequence_output=model.get_sequence_output(),
|
150 |
+
pooled_output=model.get_pooled_output(),
|
151 |
+
mlm_logits=mlm_logits))
|
152 |
+
|
153 |
+
hub.add_signature(
|
154 |
+
name="tokenization_info",
|
155 |
+
inputs={},
|
156 |
+
outputs=dict(
|
157 |
+
vocab_file=vocab_model,
|
158 |
+
do_lower_case=tf.constant(FLAGS.do_lower_case)))
|
159 |
+
|
160 |
+
|
161 |
+
def main(_):
|
162 |
+
tags_and_args = []
|
163 |
+
for is_training in (True, False):
|
164 |
+
tags = set()
|
165 |
+
if is_training:
|
166 |
+
tags.add("train")
|
167 |
+
tags_and_args.append((tags, dict(is_training=is_training)))
|
168 |
+
spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
|
169 |
+
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
|
170 |
+
tf.logging.info("Using checkpoint {}".format(checkpoint_path))
|
171 |
+
spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path)
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
flags.mark_flag_as_required("albert_directory")
|
176 |
+
flags.mark_flag_as_required("export_path")
|
177 |
+
app.run(main)
|
Indic-BERT-v1-master/albert/fine_tuning_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python3
|
16 |
+
"""Helper library for ALBERT fine-tuning.
|
17 |
+
|
18 |
+
This library can be used to construct ALBERT models for fine-tuning, either from
|
19 |
+
json config files or from TF-Hub modules.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from albert import modeling
|
23 |
+
from albert import tokenization
|
24 |
+
import tensorflow.compat.v1 as tf
|
25 |
+
import tensorflow_hub as hub
|
26 |
+
|
27 |
+
|
28 |
+
def _create_model_from_hub(hub_module, is_training, input_ids, input_mask,
|
29 |
+
segment_ids):
|
30 |
+
"""Creates an ALBERT model from TF-Hub."""
|
31 |
+
tags = set()
|
32 |
+
if is_training:
|
33 |
+
tags.add("train")
|
34 |
+
albert_module = hub.Module(hub_module, tags=tags, trainable=True)
|
35 |
+
albert_inputs = dict(
|
36 |
+
input_ids=input_ids,
|
37 |
+
input_mask=input_mask,
|
38 |
+
segment_ids=segment_ids)
|
39 |
+
albert_outputs = albert_module(
|
40 |
+
inputs=albert_inputs,
|
41 |
+
signature="tokens",
|
42 |
+
as_dict=True)
|
43 |
+
return (albert_outputs["pooled_output"], albert_outputs["sequence_output"])
|
44 |
+
|
45 |
+
|
46 |
+
def _create_model_from_scratch(albert_config, is_training, input_ids,
|
47 |
+
input_mask, segment_ids, use_one_hot_embeddings,
|
48 |
+
use_einsum):
|
49 |
+
"""Creates an ALBERT model from scratch/config."""
|
50 |
+
model = modeling.AlbertModel(
|
51 |
+
config=albert_config,
|
52 |
+
is_training=is_training,
|
53 |
+
input_ids=input_ids,
|
54 |
+
input_mask=input_mask,
|
55 |
+
token_type_ids=segment_ids,
|
56 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
57 |
+
use_einsum=use_einsum)
|
58 |
+
return (model.get_pooled_output(), model.get_sequence_output())
|
59 |
+
|
60 |
+
|
61 |
+
def create_albert(albert_config, is_training, input_ids, input_mask,
|
62 |
+
segment_ids, use_one_hot_embeddings, use_einsum, hub_module):
|
63 |
+
"""Creates an ALBERT, either from TF-Hub or from scratch."""
|
64 |
+
if hub_module:
|
65 |
+
tf.logging.info("creating model from hub_module: %s", hub_module)
|
66 |
+
return _create_model_from_hub(hub_module, is_training, input_ids,
|
67 |
+
input_mask, segment_ids)
|
68 |
+
else:
|
69 |
+
tf.logging.info("creating model from albert_config")
|
70 |
+
return _create_model_from_scratch(albert_config, is_training, input_ids,
|
71 |
+
input_mask, segment_ids,
|
72 |
+
use_one_hot_embeddings, use_einsum)
|
73 |
+
|
74 |
+
|
75 |
+
def create_vocab(vocab_file, do_lower_case, spm_model_file, hub_module):
|
76 |
+
"""Creates a vocab, either from vocab file or from a TF-Hub module."""
|
77 |
+
if hub_module:
|
78 |
+
use_spm = True if spm_model_file else False
|
79 |
+
return tokenization.FullTokenizer.from_hub_module(
|
80 |
+
hub_module=hub_module, use_spm=use_spm)
|
81 |
+
else:
|
82 |
+
return tokenization.FullTokenizer.from_scratch(
|
83 |
+
vocab_file=vocab_file, do_lower_case=do_lower_case,
|
84 |
+
spm_model_file=spm_model_file)
|
85 |
+
|
Indic-BERT-v1-master/albert/lamb_optimizer.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Functions and classes related to optimization (weight updates)."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import re
|
23 |
+
import six
|
24 |
+
import tensorflow.compat.v1 as tf
|
25 |
+
|
26 |
+
# pylint: disable=g-direct-tensorflow-import
|
27 |
+
from tensorflow.python.ops import array_ops
|
28 |
+
from tensorflow.python.ops import linalg_ops
|
29 |
+
from tensorflow.python.ops import math_ops
|
30 |
+
# pylint: enable=g-direct-tensorflow-import
|
31 |
+
|
32 |
+
|
33 |
+
class LAMBOptimizer(tf.train.Optimizer):
|
34 |
+
"""LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
|
35 |
+
# A new optimizer that includes correct L2 weight decay, adaptive
|
36 |
+
# element-wise updating, and layer-wise justification. The LAMB optimizer
|
37 |
+
# was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
|
38 |
+
# James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
|
39 |
+
# Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
learning_rate,
|
43 |
+
weight_decay_rate=0.0,
|
44 |
+
beta_1=0.9,
|
45 |
+
beta_2=0.999,
|
46 |
+
epsilon=1e-6,
|
47 |
+
exclude_from_weight_decay=None,
|
48 |
+
exclude_from_layer_adaptation=None,
|
49 |
+
name="LAMBOptimizer"):
|
50 |
+
"""Constructs a LAMBOptimizer."""
|
51 |
+
super(LAMBOptimizer, self).__init__(False, name)
|
52 |
+
|
53 |
+
self.learning_rate = learning_rate
|
54 |
+
self.weight_decay_rate = weight_decay_rate
|
55 |
+
self.beta_1 = beta_1
|
56 |
+
self.beta_2 = beta_2
|
57 |
+
self.epsilon = epsilon
|
58 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
59 |
+
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
|
60 |
+
# arg is None.
|
61 |
+
# TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
|
62 |
+
if exclude_from_layer_adaptation:
|
63 |
+
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
|
64 |
+
else:
|
65 |
+
self.exclude_from_layer_adaptation = exclude_from_weight_decay
|
66 |
+
|
67 |
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
68 |
+
"""See base class."""
|
69 |
+
assignments = []
|
70 |
+
for (grad, param) in grads_and_vars:
|
71 |
+
if grad is None or param is None:
|
72 |
+
continue
|
73 |
+
|
74 |
+
param_name = self._get_variable_name(param.name)
|
75 |
+
|
76 |
+
m = tf.get_variable(
|
77 |
+
name=six.ensure_str(param_name) + "/adam_m",
|
78 |
+
shape=param.shape.as_list(),
|
79 |
+
dtype=tf.float32,
|
80 |
+
trainable=False,
|
81 |
+
initializer=tf.zeros_initializer())
|
82 |
+
v = tf.get_variable(
|
83 |
+
name=six.ensure_str(param_name) + "/adam_v",
|
84 |
+
shape=param.shape.as_list(),
|
85 |
+
dtype=tf.float32,
|
86 |
+
trainable=False,
|
87 |
+
initializer=tf.zeros_initializer())
|
88 |
+
|
89 |
+
# Standard Adam update.
|
90 |
+
next_m = (
|
91 |
+
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
92 |
+
next_v = (
|
93 |
+
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
94 |
+
tf.square(grad)))
|
95 |
+
|
96 |
+
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
97 |
+
|
98 |
+
# Just adding the square of the weights to the loss function is *not*
|
99 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
100 |
+
# since that will interact with the m and v parameters in strange ways.
|
101 |
+
#
|
102 |
+
# Instead we want ot decay the weights in a manner that doesn't interact
|
103 |
+
# with the m/v parameters. This is equivalent to adding the square
|
104 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
105 |
+
if self._do_use_weight_decay(param_name):
|
106 |
+
update += self.weight_decay_rate * param
|
107 |
+
|
108 |
+
ratio = 1.0
|
109 |
+
if self._do_layer_adaptation(param_name):
|
110 |
+
w_norm = linalg_ops.norm(param, ord=2)
|
111 |
+
g_norm = linalg_ops.norm(update, ord=2)
|
112 |
+
ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
|
113 |
+
math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
|
114 |
+
|
115 |
+
update_with_lr = ratio * self.learning_rate * update
|
116 |
+
|
117 |
+
next_param = param - update_with_lr
|
118 |
+
|
119 |
+
assignments.extend(
|
120 |
+
[param.assign(next_param),
|
121 |
+
m.assign(next_m),
|
122 |
+
v.assign(next_v)])
|
123 |
+
return tf.group(*assignments, name=name)
|
124 |
+
|
125 |
+
def _do_use_weight_decay(self, param_name):
|
126 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
127 |
+
if not self.weight_decay_rate:
|
128 |
+
return False
|
129 |
+
if self.exclude_from_weight_decay:
|
130 |
+
for r in self.exclude_from_weight_decay:
|
131 |
+
if re.search(r, param_name) is not None:
|
132 |
+
return False
|
133 |
+
return True
|
134 |
+
|
135 |
+
def _do_layer_adaptation(self, param_name):
|
136 |
+
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
|
137 |
+
if self.exclude_from_layer_adaptation:
|
138 |
+
for r in self.exclude_from_layer_adaptation:
|
139 |
+
if re.search(r, param_name) is not None:
|
140 |
+
return False
|
141 |
+
return True
|
142 |
+
|
143 |
+
def _get_variable_name(self, param_name):
|
144 |
+
"""Get the variable name from the tensor name."""
|
145 |
+
m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
|
146 |
+
if m is not None:
|
147 |
+
param_name = m.group(1)
|
148 |
+
return param_name
|
Indic-BERT-v1-master/albert/modeling.py
ADDED
@@ -0,0 +1,1209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""The main ALBERT model and related functions.
|
17 |
+
|
18 |
+
For a description of the algorithm, see https://arxiv.org/abs/1909.11942.
|
19 |
+
"""
|
20 |
+
|
21 |
+
from __future__ import absolute_import
|
22 |
+
from __future__ import division
|
23 |
+
from __future__ import print_function
|
24 |
+
|
25 |
+
import collections
|
26 |
+
import copy
|
27 |
+
import json
|
28 |
+
import math
|
29 |
+
import re
|
30 |
+
import numpy as np
|
31 |
+
import six
|
32 |
+
from six.moves import range
|
33 |
+
import tensorflow.compat.v1 as tf
|
34 |
+
from tensorflow.contrib import layers as contrib_layers
|
35 |
+
|
36 |
+
|
37 |
+
class AlbertConfig(object):
|
38 |
+
"""Configuration for `AlbertModel`.
|
39 |
+
|
40 |
+
The default settings match the configuration of model `albert_xxlarge`.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
vocab_size,
|
45 |
+
embedding_size=128,
|
46 |
+
hidden_size=4096,
|
47 |
+
num_hidden_layers=12,
|
48 |
+
num_hidden_groups=1,
|
49 |
+
num_attention_heads=64,
|
50 |
+
intermediate_size=16384,
|
51 |
+
inner_group_num=1,
|
52 |
+
down_scale_factor=1,
|
53 |
+
hidden_act="gelu",
|
54 |
+
hidden_dropout_prob=0,
|
55 |
+
attention_probs_dropout_prob=0,
|
56 |
+
max_position_embeddings=512,
|
57 |
+
type_vocab_size=2,
|
58 |
+
initializer_range=0.02):
|
59 |
+
"""Constructs AlbertConfig.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
vocab_size: Vocabulary size of `inputs_ids` in `AlbertModel`.
|
63 |
+
embedding_size: size of voc embeddings.
|
64 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
65 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
66 |
+
num_hidden_groups: Number of group for the hidden layers, parameters in
|
67 |
+
the same group are shared.
|
68 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
69 |
+
the Transformer encoder.
|
70 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
71 |
+
layer in the Transformer encoder.
|
72 |
+
inner_group_num: int, number of inner repetition of attention and ffn.
|
73 |
+
down_scale_factor: float, the scale to apply
|
74 |
+
hidden_act: The non-linear activation function (function or string) in the
|
75 |
+
encoder and pooler.
|
76 |
+
hidden_dropout_prob: The dropout probability for all fully connected
|
77 |
+
layers in the embeddings, encoder, and pooler.
|
78 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
79 |
+
probabilities.
|
80 |
+
max_position_embeddings: The maximum sequence length that this model might
|
81 |
+
ever be used with. Typically set this to something large just in case
|
82 |
+
(e.g., 512 or 1024 or 2048).
|
83 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
84 |
+
`AlbertModel`.
|
85 |
+
initializer_range: The stdev of the truncated_normal_initializer for
|
86 |
+
initializing all weight matrices.
|
87 |
+
"""
|
88 |
+
self.vocab_size = vocab_size
|
89 |
+
self.embedding_size = embedding_size
|
90 |
+
self.hidden_size = hidden_size
|
91 |
+
self.num_hidden_layers = num_hidden_layers
|
92 |
+
self.num_hidden_groups = num_hidden_groups
|
93 |
+
self.num_attention_heads = num_attention_heads
|
94 |
+
self.inner_group_num = inner_group_num
|
95 |
+
self.down_scale_factor = down_scale_factor
|
96 |
+
self.hidden_act = hidden_act
|
97 |
+
self.intermediate_size = intermediate_size
|
98 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
99 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
100 |
+
self.max_position_embeddings = max_position_embeddings
|
101 |
+
self.type_vocab_size = type_vocab_size
|
102 |
+
self.initializer_range = initializer_range
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_dict(cls, json_object):
|
106 |
+
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
|
107 |
+
config = AlbertConfig(vocab_size=None)
|
108 |
+
for (key, value) in six.iteritems(json_object):
|
109 |
+
config.__dict__[key] = value
|
110 |
+
return config
|
111 |
+
|
112 |
+
@classmethod
|
113 |
+
def from_json_file(cls, json_file):
|
114 |
+
"""Constructs a `AlbertConfig` from a json file of parameters."""
|
115 |
+
with tf.gfile.GFile(json_file, "r") as reader:
|
116 |
+
text = reader.read()
|
117 |
+
return cls.from_dict(json.loads(text))
|
118 |
+
|
119 |
+
def to_dict(self):
|
120 |
+
"""Serializes this instance to a Python dictionary."""
|
121 |
+
output = copy.deepcopy(self.__dict__)
|
122 |
+
return output
|
123 |
+
|
124 |
+
def to_json_string(self):
|
125 |
+
"""Serializes this instance to a JSON string."""
|
126 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
127 |
+
|
128 |
+
|
129 |
+
class AlbertModel(object):
|
130 |
+
"""BERT model ("Bidirectional Encoder Representations from Transformers").
|
131 |
+
|
132 |
+
Example usage:
|
133 |
+
|
134 |
+
```python
|
135 |
+
# Already been converted from strings into ids
|
136 |
+
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
|
137 |
+
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
|
138 |
+
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
|
139 |
+
|
140 |
+
config = modeling.AlbertConfig(vocab_size=32000, hidden_size=512,
|
141 |
+
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
142 |
+
|
143 |
+
model = modeling.AlbertModel(config=config, is_training=True,
|
144 |
+
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
|
145 |
+
|
146 |
+
label_embeddings = tf.get_variable(...)
|
147 |
+
pooled_output = model.get_pooled_output()
|
148 |
+
logits = tf.matmul(pooled_output, label_embeddings)
|
149 |
+
...
|
150 |
+
```
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self,
|
154 |
+
config,
|
155 |
+
is_training,
|
156 |
+
input_ids,
|
157 |
+
input_mask=None,
|
158 |
+
token_type_ids=None,
|
159 |
+
use_one_hot_embeddings=False,
|
160 |
+
use_einsum=True,
|
161 |
+
scope=None):
|
162 |
+
"""Constructor for AlbertModel.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
config: `AlbertConfig` instance.
|
166 |
+
is_training: bool. true for training model, false for eval model. Controls
|
167 |
+
whether dropout will be applied.
|
168 |
+
input_ids: int32 Tensor of shape [batch_size, seq_length].
|
169 |
+
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
|
170 |
+
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
171 |
+
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
|
172 |
+
embeddings or tf.embedding_lookup() for the word embeddings.
|
173 |
+
use_einsum: (optional) bool. Whether to use einsum or reshape+matmul for
|
174 |
+
dense layers
|
175 |
+
scope: (optional) variable scope. Defaults to "bert".
|
176 |
+
|
177 |
+
Raises:
|
178 |
+
ValueError: The config is invalid or one of the input tensor shapes
|
179 |
+
is invalid.
|
180 |
+
"""
|
181 |
+
config = copy.deepcopy(config)
|
182 |
+
if not is_training:
|
183 |
+
config.hidden_dropout_prob = 0.0
|
184 |
+
config.attention_probs_dropout_prob = 0.0
|
185 |
+
|
186 |
+
input_shape = get_shape_list(input_ids, expected_rank=2)
|
187 |
+
batch_size = input_shape[0]
|
188 |
+
seq_length = input_shape[1]
|
189 |
+
|
190 |
+
if input_mask is None:
|
191 |
+
input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
|
192 |
+
|
193 |
+
if token_type_ids is None:
|
194 |
+
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
|
195 |
+
|
196 |
+
with tf.variable_scope(scope, default_name="bert"):
|
197 |
+
with tf.variable_scope("embeddings"):
|
198 |
+
# Perform embedding lookup on the word ids.
|
199 |
+
(self.word_embedding_output,
|
200 |
+
self.output_embedding_table) = embedding_lookup(
|
201 |
+
input_ids=input_ids,
|
202 |
+
vocab_size=config.vocab_size,
|
203 |
+
embedding_size=config.embedding_size,
|
204 |
+
initializer_range=config.initializer_range,
|
205 |
+
word_embedding_name="word_embeddings",
|
206 |
+
use_one_hot_embeddings=use_one_hot_embeddings)
|
207 |
+
|
208 |
+
# Add positional embeddings and token type embeddings, then layer
|
209 |
+
# normalize and perform dropout.
|
210 |
+
self.embedding_output = embedding_postprocessor(
|
211 |
+
input_tensor=self.word_embedding_output,
|
212 |
+
use_token_type=True,
|
213 |
+
token_type_ids=token_type_ids,
|
214 |
+
token_type_vocab_size=config.type_vocab_size,
|
215 |
+
token_type_embedding_name="token_type_embeddings",
|
216 |
+
use_position_embeddings=True,
|
217 |
+
position_embedding_name="position_embeddings",
|
218 |
+
initializer_range=config.initializer_range,
|
219 |
+
max_position_embeddings=config.max_position_embeddings,
|
220 |
+
dropout_prob=config.hidden_dropout_prob,
|
221 |
+
use_one_hot_embeddings=use_one_hot_embeddings)
|
222 |
+
|
223 |
+
with tf.variable_scope("encoder"):
|
224 |
+
# Run the stacked transformer.
|
225 |
+
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
|
226 |
+
self.all_encoder_layers = transformer_model(
|
227 |
+
input_tensor=self.embedding_output,
|
228 |
+
attention_mask=input_mask,
|
229 |
+
hidden_size=config.hidden_size,
|
230 |
+
num_hidden_layers=config.num_hidden_layers,
|
231 |
+
num_hidden_groups=config.num_hidden_groups,
|
232 |
+
num_attention_heads=config.num_attention_heads,
|
233 |
+
intermediate_size=config.intermediate_size,
|
234 |
+
inner_group_num=config.inner_group_num,
|
235 |
+
intermediate_act_fn=get_activation(config.hidden_act),
|
236 |
+
hidden_dropout_prob=config.hidden_dropout_prob,
|
237 |
+
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
238 |
+
initializer_range=config.initializer_range,
|
239 |
+
do_return_all_layers=True,
|
240 |
+
use_einsum=use_einsum)
|
241 |
+
|
242 |
+
self.sequence_output = self.all_encoder_layers[-1]
|
243 |
+
# The "pooler" converts the encoded sequence tensor of shape
|
244 |
+
# [batch_size, seq_length, hidden_size] to a tensor of shape
|
245 |
+
# [batch_size, hidden_size]. This is necessary for segment-level
|
246 |
+
# (or segment-pair-level) classification tasks where we need a fixed
|
247 |
+
# dimensional representation of the segment.
|
248 |
+
with tf.variable_scope("pooler"):
|
249 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
250 |
+
# to the first token. We assume that this has been pre-trained
|
251 |
+
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
|
252 |
+
self.pooled_output = tf.layers.dense(
|
253 |
+
first_token_tensor,
|
254 |
+
config.hidden_size,
|
255 |
+
activation=tf.tanh,
|
256 |
+
kernel_initializer=create_initializer(config.initializer_range))
|
257 |
+
|
258 |
+
def get_pooled_output(self):
|
259 |
+
return self.pooled_output
|
260 |
+
|
261 |
+
def get_sequence_output(self):
|
262 |
+
"""Gets final hidden layer of encoder.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
|
266 |
+
to the final hidden of the transformer encoder.
|
267 |
+
"""
|
268 |
+
return self.sequence_output
|
269 |
+
|
270 |
+
def get_all_encoder_layers(self):
|
271 |
+
return self.all_encoder_layers
|
272 |
+
|
273 |
+
def get_word_embedding_output(self):
|
274 |
+
"""Get output of the word(piece) embedding lookup.
|
275 |
+
|
276 |
+
This is BEFORE positional embeddings and token type embeddings have been
|
277 |
+
added.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
float Tensor of shape [batch_size, seq_length, embedding_size]
|
281 |
+
corresponding to the output of the word(piece) embedding layer.
|
282 |
+
"""
|
283 |
+
return self.word_embedding_output
|
284 |
+
|
285 |
+
def get_embedding_output(self):
|
286 |
+
"""Gets output of the embedding lookup (i.e., input to the transformer).
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
float Tensor of shape [batch_size, seq_length, embedding_size]
|
290 |
+
corresponding to the output of the embedding layer, after summing the word
|
291 |
+
embeddings with the positional embeddings and the token type embeddings,
|
292 |
+
then performing layer normalization. This is the input to the transformer.
|
293 |
+
"""
|
294 |
+
return self.embedding_output
|
295 |
+
|
296 |
+
def get_embedding_table(self):
|
297 |
+
return self.output_embedding_table
|
298 |
+
|
299 |
+
|
300 |
+
def gelu(x):
|
301 |
+
"""Gaussian Error Linear Unit.
|
302 |
+
|
303 |
+
This is a smoother version of the RELU.
|
304 |
+
Original paper: https://arxiv.org/abs/1606.08415
|
305 |
+
Args:
|
306 |
+
x: float Tensor to perform activation.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
`x` with the GELU activation applied.
|
310 |
+
"""
|
311 |
+
cdf = 0.5 * (1.0 + tf.tanh(
|
312 |
+
(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
313 |
+
return x * cdf
|
314 |
+
|
315 |
+
|
316 |
+
def get_activation(activation_string):
|
317 |
+
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
activation_string: String name of the activation function.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
A Python function corresponding to the activation function. If
|
324 |
+
`activation_string` is None, empty, or "linear", this will return None.
|
325 |
+
If `activation_string` is not a string, it will return `activation_string`.
|
326 |
+
|
327 |
+
Raises:
|
328 |
+
ValueError: The `activation_string` does not correspond to a known
|
329 |
+
activation.
|
330 |
+
"""
|
331 |
+
|
332 |
+
# We assume that anything that"s not a string is already an activation
|
333 |
+
# function, so we just return it.
|
334 |
+
if not isinstance(activation_string, six.string_types):
|
335 |
+
return activation_string
|
336 |
+
|
337 |
+
if not activation_string:
|
338 |
+
return None
|
339 |
+
|
340 |
+
act = activation_string.lower()
|
341 |
+
if act == "linear":
|
342 |
+
return None
|
343 |
+
elif act == "relu":
|
344 |
+
return tf.nn.relu
|
345 |
+
elif act == "gelu":
|
346 |
+
return gelu
|
347 |
+
elif act == "tanh":
|
348 |
+
return tf.tanh
|
349 |
+
else:
|
350 |
+
raise ValueError("Unsupported activation: %s" % act)
|
351 |
+
|
352 |
+
|
353 |
+
def get_assignment_map_from_checkpoint(tvars, init_checkpoint, num_of_group=0):
|
354 |
+
"""Compute the union of the current variables and checkpoint variables."""
|
355 |
+
assignment_map = {}
|
356 |
+
initialized_variable_names = {}
|
357 |
+
|
358 |
+
name_to_variable = collections.OrderedDict()
|
359 |
+
for var in tvars:
|
360 |
+
name = var.name
|
361 |
+
m = re.match("^(.*):\\d+$", name)
|
362 |
+
if m is not None:
|
363 |
+
name = m.group(1)
|
364 |
+
name_to_variable[name] = var
|
365 |
+
init_vars = tf.train.list_variables(init_checkpoint)
|
366 |
+
init_vars_name = [name for (name, _) in init_vars]
|
367 |
+
|
368 |
+
if num_of_group > 0:
|
369 |
+
assignment_map = []
|
370 |
+
for gid in range(num_of_group):
|
371 |
+
assignment_map.append(collections.OrderedDict())
|
372 |
+
else:
|
373 |
+
assignment_map = collections.OrderedDict()
|
374 |
+
|
375 |
+
for name in name_to_variable:
|
376 |
+
if name in init_vars_name:
|
377 |
+
tvar_name = name
|
378 |
+
elif (re.sub(r"/group_\d+/", "/group_0/",
|
379 |
+
six.ensure_str(name)) in init_vars_name and
|
380 |
+
num_of_group > 1):
|
381 |
+
tvar_name = re.sub(r"/group_\d+/", "/group_0/", six.ensure_str(name))
|
382 |
+
elif (re.sub(r"/ffn_\d+/", "/ffn_1/", six.ensure_str(name))
|
383 |
+
in init_vars_name and num_of_group > 1):
|
384 |
+
tvar_name = re.sub(r"/ffn_\d+/", "/ffn_1/", six.ensure_str(name))
|
385 |
+
elif (re.sub(r"/attention_\d+/", "/attention_1/", six.ensure_str(name))
|
386 |
+
in init_vars_name and num_of_group > 1):
|
387 |
+
tvar_name = re.sub(r"/attention_\d+/", "/attention_1/",
|
388 |
+
six.ensure_str(name))
|
389 |
+
else:
|
390 |
+
tf.logging.info("name %s does not get matched", name)
|
391 |
+
continue
|
392 |
+
tf.logging.info("name %s match to %s", name, tvar_name)
|
393 |
+
if num_of_group > 0:
|
394 |
+
group_matched = False
|
395 |
+
for gid in range(1, num_of_group):
|
396 |
+
if (("/group_" + str(gid) + "/" in name) or
|
397 |
+
("/ffn_" + str(gid) + "/" in name) or
|
398 |
+
("/attention_" + str(gid) + "/" in name)):
|
399 |
+
group_matched = True
|
400 |
+
tf.logging.info("%s belongs to %dth", name, gid)
|
401 |
+
assignment_map[gid][tvar_name] = name
|
402 |
+
if not group_matched:
|
403 |
+
assignment_map[0][tvar_name] = name
|
404 |
+
else:
|
405 |
+
assignment_map[tvar_name] = name
|
406 |
+
initialized_variable_names[name] = 1
|
407 |
+
initialized_variable_names[six.ensure_str(name) + ":0"] = 1
|
408 |
+
|
409 |
+
return (assignment_map, initialized_variable_names)
|
410 |
+
|
411 |
+
|
412 |
+
def dropout(input_tensor, dropout_prob):
|
413 |
+
"""Perform dropout.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
input_tensor: float Tensor.
|
417 |
+
dropout_prob: Python float. The probability of dropping out a value (NOT of
|
418 |
+
*keeping* a dimension as in `tf.nn.dropout`).
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
A version of `input_tensor` with dropout applied.
|
422 |
+
"""
|
423 |
+
if dropout_prob is None or dropout_prob == 0.0:
|
424 |
+
return input_tensor
|
425 |
+
|
426 |
+
output = tf.nn.dropout(input_tensor, rate=dropout_prob)
|
427 |
+
return output
|
428 |
+
|
429 |
+
|
430 |
+
def layer_norm(input_tensor, name=None):
|
431 |
+
"""Run layer normalization on the last dimension of the tensor."""
|
432 |
+
return contrib_layers.layer_norm(
|
433 |
+
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
434 |
+
|
435 |
+
|
436 |
+
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
|
437 |
+
"""Runs layer normalization followed by dropout."""
|
438 |
+
output_tensor = layer_norm(input_tensor, name)
|
439 |
+
output_tensor = dropout(output_tensor, dropout_prob)
|
440 |
+
return output_tensor
|
441 |
+
|
442 |
+
|
443 |
+
def create_initializer(initializer_range=0.02):
|
444 |
+
"""Creates a `truncated_normal_initializer` with the given range."""
|
445 |
+
return tf.truncated_normal_initializer(stddev=initializer_range)
|
446 |
+
|
447 |
+
|
448 |
+
def get_timing_signal_1d_given_position(channels,
|
449 |
+
position,
|
450 |
+
min_timescale=1.0,
|
451 |
+
max_timescale=1.0e4):
|
452 |
+
"""Get sinusoids of diff frequencies, with timing position given.
|
453 |
+
|
454 |
+
Adapted from add_timing_signal_1d_given_position in
|
455 |
+
//third_party/py/tensor2tensor/layers/common_attention.py
|
456 |
+
|
457 |
+
Args:
|
458 |
+
channels: scalar, size of timing embeddings to create. The number of
|
459 |
+
different timescales is equal to channels / 2.
|
460 |
+
position: a Tensor with shape [batch, seq_len]
|
461 |
+
min_timescale: a float
|
462 |
+
max_timescale: a float
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
a Tensor of timing signals [batch, seq_len, channels]
|
466 |
+
"""
|
467 |
+
num_timescales = channels // 2
|
468 |
+
log_timescale_increment = (
|
469 |
+
math.log(float(max_timescale) / float(min_timescale)) /
|
470 |
+
(tf.to_float(num_timescales) - 1))
|
471 |
+
inv_timescales = min_timescale * tf.exp(
|
472 |
+
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
|
473 |
+
scaled_time = (
|
474 |
+
tf.expand_dims(tf.to_float(position), 2) * tf.expand_dims(
|
475 |
+
tf.expand_dims(inv_timescales, 0), 0))
|
476 |
+
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2)
|
477 |
+
signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]])
|
478 |
+
return signal
|
479 |
+
|
480 |
+
|
481 |
+
def embedding_lookup(input_ids,
|
482 |
+
vocab_size,
|
483 |
+
embedding_size=128,
|
484 |
+
initializer_range=0.02,
|
485 |
+
word_embedding_name="word_embeddings",
|
486 |
+
use_one_hot_embeddings=False):
|
487 |
+
"""Looks up words embeddings for id tensor.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
|
491 |
+
ids.
|
492 |
+
vocab_size: int. Size of the embedding vocabulary.
|
493 |
+
embedding_size: int. Width of the word embeddings.
|
494 |
+
initializer_range: float. Embedding initialization range.
|
495 |
+
word_embedding_name: string. Name of the embedding table.
|
496 |
+
use_one_hot_embeddings: bool. If True, use one-hot method for word
|
497 |
+
embeddings. If False, use `tf.nn.embedding_lookup()`.
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
float Tensor of shape [batch_size, seq_length, embedding_size].
|
501 |
+
"""
|
502 |
+
# This function assumes that the input is of shape [batch_size, seq_length,
|
503 |
+
# num_inputs].
|
504 |
+
#
|
505 |
+
# If the input is a 2D tensor of shape [batch_size, seq_length], we
|
506 |
+
# reshape to [batch_size, seq_length, 1].
|
507 |
+
if input_ids.shape.ndims == 2:
|
508 |
+
input_ids = tf.expand_dims(input_ids, axis=[-1])
|
509 |
+
|
510 |
+
embedding_table = tf.get_variable(
|
511 |
+
name=word_embedding_name,
|
512 |
+
shape=[vocab_size, embedding_size],
|
513 |
+
initializer=create_initializer(initializer_range))
|
514 |
+
|
515 |
+
if use_one_hot_embeddings:
|
516 |
+
flat_input_ids = tf.reshape(input_ids, [-1])
|
517 |
+
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
|
518 |
+
output = tf.matmul(one_hot_input_ids, embedding_table)
|
519 |
+
else:
|
520 |
+
output = tf.nn.embedding_lookup(embedding_table, input_ids)
|
521 |
+
|
522 |
+
input_shape = get_shape_list(input_ids)
|
523 |
+
|
524 |
+
output = tf.reshape(output,
|
525 |
+
input_shape[0:-1] + [input_shape[-1] * embedding_size])
|
526 |
+
return (output, embedding_table)
|
527 |
+
|
528 |
+
|
529 |
+
def embedding_postprocessor(input_tensor,
|
530 |
+
use_token_type=False,
|
531 |
+
token_type_ids=None,
|
532 |
+
token_type_vocab_size=16,
|
533 |
+
token_type_embedding_name="token_type_embeddings",
|
534 |
+
use_position_embeddings=True,
|
535 |
+
position_embedding_name="position_embeddings",
|
536 |
+
initializer_range=0.02,
|
537 |
+
max_position_embeddings=512,
|
538 |
+
dropout_prob=0.1,
|
539 |
+
use_one_hot_embeddings=True):
|
540 |
+
"""Performs various post-processing on a word embedding tensor.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
input_tensor: float Tensor of shape [batch_size, seq_length,
|
544 |
+
embedding_size].
|
545 |
+
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
|
546 |
+
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
547 |
+
Must be specified if `use_token_type` is True.
|
548 |
+
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
|
549 |
+
token_type_embedding_name: string. The name of the embedding table variable
|
550 |
+
for token type ids.
|
551 |
+
use_position_embeddings: bool. Whether to add position embeddings for the
|
552 |
+
position of each token in the sequence.
|
553 |
+
position_embedding_name: string. The name of the embedding table variable
|
554 |
+
for positional embeddings.
|
555 |
+
initializer_range: float. Range of the weight initialization.
|
556 |
+
max_position_embeddings: int. Maximum sequence length that might ever be
|
557 |
+
used with this model. This can be longer than the sequence length of
|
558 |
+
input_tensor, but cannot be shorter.
|
559 |
+
dropout_prob: float. Dropout probability applied to the final output tensor.
|
560 |
+
use_one_hot_embeddings: bool. If True, use one-hot method for word
|
561 |
+
embeddings. If False, use `tf.nn.embedding_lookup()`.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
float tensor with same shape as `input_tensor`.
|
565 |
+
|
566 |
+
Raises:
|
567 |
+
ValueError: One of the tensor shapes or input values is invalid.
|
568 |
+
"""
|
569 |
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
570 |
+
batch_size = input_shape[0]
|
571 |
+
seq_length = input_shape[1]
|
572 |
+
width = input_shape[2]
|
573 |
+
|
574 |
+
output = input_tensor
|
575 |
+
|
576 |
+
if use_token_type:
|
577 |
+
if token_type_ids is None:
|
578 |
+
raise ValueError("`token_type_ids` must be specified if"
|
579 |
+
"`use_token_type` is True.")
|
580 |
+
token_type_table = tf.get_variable(
|
581 |
+
name=token_type_embedding_name,
|
582 |
+
shape=[token_type_vocab_size, width],
|
583 |
+
initializer=create_initializer(initializer_range))
|
584 |
+
# This vocab will be small so we always do one-hot here, since it is always
|
585 |
+
# faster for a small vocabulary, unless converting to tflite model.
|
586 |
+
if use_one_hot_embeddings:
|
587 |
+
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
|
588 |
+
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
|
589 |
+
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
|
590 |
+
token_type_embeddings = tf.reshape(token_type_embeddings,
|
591 |
+
[batch_size, seq_length, width])
|
592 |
+
else:
|
593 |
+
token_type_embeddings = tf.nn.embedding_lookup(token_type_table,
|
594 |
+
token_type_ids)
|
595 |
+
output += token_type_embeddings
|
596 |
+
|
597 |
+
if use_position_embeddings:
|
598 |
+
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
|
599 |
+
with tf.control_dependencies([assert_op]):
|
600 |
+
full_position_embeddings = tf.get_variable(
|
601 |
+
name=position_embedding_name,
|
602 |
+
shape=[max_position_embeddings, width],
|
603 |
+
initializer=create_initializer(initializer_range))
|
604 |
+
# Since the position embedding table is a learned variable, we create it
|
605 |
+
# using a (long) sequence length `max_position_embeddings`. The actual
|
606 |
+
# sequence length might be shorter than this, for faster training of
|
607 |
+
# tasks that do not have long sequences.
|
608 |
+
#
|
609 |
+
# So `full_position_embeddings` is effectively an embedding table
|
610 |
+
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
611 |
+
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
612 |
+
# perform a slice.
|
613 |
+
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
|
614 |
+
[seq_length, -1])
|
615 |
+
num_dims = len(output.shape.as_list())
|
616 |
+
|
617 |
+
# Only the last two dimensions are relevant (`seq_length` and `width`), so
|
618 |
+
# we broadcast among the first dimensions, which is typically just
|
619 |
+
# the batch size.
|
620 |
+
position_broadcast_shape = []
|
621 |
+
for _ in range(num_dims - 2):
|
622 |
+
position_broadcast_shape.append(1)
|
623 |
+
position_broadcast_shape.extend([seq_length, width])
|
624 |
+
position_embeddings = tf.reshape(position_embeddings,
|
625 |
+
position_broadcast_shape)
|
626 |
+
output += position_embeddings
|
627 |
+
|
628 |
+
output = layer_norm_and_dropout(output, dropout_prob)
|
629 |
+
return output
|
630 |
+
|
631 |
+
|
632 |
+
def einsum_via_matmul(input_tensor, w, num_inner_dims):
|
633 |
+
"""Implements einsum via matmul and reshape ops.
|
634 |
+
|
635 |
+
Args:
|
636 |
+
input_tensor: float Tensor of shape [<batch_dims>, <inner_dims>].
|
637 |
+
w: float Tensor of shape [<inner_dims>, <outer_dims>].
|
638 |
+
num_inner_dims: int. number of dimensions to use for inner products.
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
float Tensor of shape [<batch_dims>, <outer_dims>].
|
642 |
+
"""
|
643 |
+
input_shape = get_shape_list(input_tensor)
|
644 |
+
w_shape = get_shape_list(w)
|
645 |
+
batch_dims = input_shape[: -num_inner_dims]
|
646 |
+
inner_dims = input_shape[-num_inner_dims:]
|
647 |
+
outer_dims = w_shape[num_inner_dims:]
|
648 |
+
inner_dim = np.prod(inner_dims)
|
649 |
+
outer_dim = np.prod(outer_dims)
|
650 |
+
if num_inner_dims > 1:
|
651 |
+
input_tensor = tf.reshape(input_tensor, batch_dims + [inner_dim])
|
652 |
+
if len(w_shape) > 2:
|
653 |
+
w = tf.reshape(w, [inner_dim, outer_dim])
|
654 |
+
ret = tf.matmul(input_tensor, w)
|
655 |
+
if len(outer_dims) > 1:
|
656 |
+
ret = tf.reshape(ret, batch_dims + outer_dims)
|
657 |
+
return ret
|
658 |
+
|
659 |
+
|
660 |
+
def dense_layer_3d(input_tensor,
|
661 |
+
num_attention_heads,
|
662 |
+
head_size,
|
663 |
+
initializer,
|
664 |
+
activation,
|
665 |
+
use_einsum,
|
666 |
+
name=None):
|
667 |
+
"""A dense layer with 3D kernel.
|
668 |
+
|
669 |
+
Args:
|
670 |
+
input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
|
671 |
+
num_attention_heads: Number of attention heads.
|
672 |
+
head_size: The size per attention head.
|
673 |
+
initializer: Kernel initializer.
|
674 |
+
activation: Actication function.
|
675 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
|
676 |
+
name: The name scope of this layer.
|
677 |
+
|
678 |
+
Returns:
|
679 |
+
float logits Tensor.
|
680 |
+
"""
|
681 |
+
|
682 |
+
input_shape = get_shape_list(input_tensor)
|
683 |
+
hidden_size = input_shape[2]
|
684 |
+
|
685 |
+
with tf.variable_scope(name):
|
686 |
+
w = tf.get_variable(
|
687 |
+
name="kernel",
|
688 |
+
shape=[hidden_size, num_attention_heads * head_size],
|
689 |
+
initializer=initializer)
|
690 |
+
w = tf.reshape(w, [hidden_size, num_attention_heads, head_size])
|
691 |
+
b = tf.get_variable(
|
692 |
+
name="bias",
|
693 |
+
shape=[num_attention_heads * head_size],
|
694 |
+
initializer=tf.zeros_initializer)
|
695 |
+
b = tf.reshape(b, [num_attention_heads, head_size])
|
696 |
+
if use_einsum:
|
697 |
+
ret = tf.einsum("BFH,HND->BFND", input_tensor, w)
|
698 |
+
else:
|
699 |
+
ret = einsum_via_matmul(input_tensor, w, 1)
|
700 |
+
ret += b
|
701 |
+
if activation is not None:
|
702 |
+
return activation(ret)
|
703 |
+
else:
|
704 |
+
return ret
|
705 |
+
|
706 |
+
|
707 |
+
def dense_layer_3d_proj(input_tensor,
|
708 |
+
hidden_size,
|
709 |
+
head_size,
|
710 |
+
initializer,
|
711 |
+
activation,
|
712 |
+
use_einsum,
|
713 |
+
name=None):
|
714 |
+
"""A dense layer with 3D kernel for projection.
|
715 |
+
|
716 |
+
Args:
|
717 |
+
input_tensor: float Tensor of shape [batch,from_seq_length,
|
718 |
+
num_attention_heads, size_per_head].
|
719 |
+
hidden_size: The size of hidden layer.
|
720 |
+
head_size: The size of head.
|
721 |
+
initializer: Kernel initializer.
|
722 |
+
activation: Actication function.
|
723 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
|
724 |
+
name: The name scope of this layer.
|
725 |
+
|
726 |
+
Returns:
|
727 |
+
float logits Tensor.
|
728 |
+
"""
|
729 |
+
input_shape = get_shape_list(input_tensor)
|
730 |
+
num_attention_heads = input_shape[2]
|
731 |
+
with tf.variable_scope(name):
|
732 |
+
w = tf.get_variable(
|
733 |
+
name="kernel",
|
734 |
+
shape=[num_attention_heads * head_size, hidden_size],
|
735 |
+
initializer=initializer)
|
736 |
+
w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])
|
737 |
+
b = tf.get_variable(
|
738 |
+
name="bias", shape=[hidden_size], initializer=tf.zeros_initializer)
|
739 |
+
if use_einsum:
|
740 |
+
ret = tf.einsum("BFND,NDH->BFH", input_tensor, w)
|
741 |
+
else:
|
742 |
+
ret = einsum_via_matmul(input_tensor, w, 2)
|
743 |
+
ret += b
|
744 |
+
if activation is not None:
|
745 |
+
return activation(ret)
|
746 |
+
else:
|
747 |
+
return ret
|
748 |
+
|
749 |
+
|
750 |
+
def dense_layer_2d(input_tensor,
|
751 |
+
output_size,
|
752 |
+
initializer,
|
753 |
+
activation,
|
754 |
+
use_einsum,
|
755 |
+
num_attention_heads=1,
|
756 |
+
name=None):
|
757 |
+
"""A dense layer with 2D kernel.
|
758 |
+
|
759 |
+
Args:
|
760 |
+
input_tensor: Float tensor with rank 3.
|
761 |
+
output_size: The size of output dimension.
|
762 |
+
initializer: Kernel initializer.
|
763 |
+
activation: Activation function.
|
764 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers.
|
765 |
+
num_attention_heads: number of attention head in attention layer.
|
766 |
+
name: The name scope of this layer.
|
767 |
+
|
768 |
+
Returns:
|
769 |
+
float logits Tensor.
|
770 |
+
"""
|
771 |
+
del num_attention_heads # unused
|
772 |
+
input_shape = get_shape_list(input_tensor)
|
773 |
+
hidden_size = input_shape[2]
|
774 |
+
with tf.variable_scope(name):
|
775 |
+
w = tf.get_variable(
|
776 |
+
name="kernel",
|
777 |
+
shape=[hidden_size, output_size],
|
778 |
+
initializer=initializer)
|
779 |
+
b = tf.get_variable(
|
780 |
+
name="bias", shape=[output_size], initializer=tf.zeros_initializer)
|
781 |
+
if use_einsum:
|
782 |
+
ret = tf.einsum("BFH,HO->BFO", input_tensor, w)
|
783 |
+
else:
|
784 |
+
ret = tf.matmul(input_tensor, w)
|
785 |
+
ret += b
|
786 |
+
if activation is not None:
|
787 |
+
return activation(ret)
|
788 |
+
else:
|
789 |
+
return ret
|
790 |
+
|
791 |
+
|
792 |
+
def dot_product_attention(q, k, v, bias, dropout_rate=0.0):
|
793 |
+
"""Dot-product attention.
|
794 |
+
|
795 |
+
Args:
|
796 |
+
q: Tensor with shape [..., length_q, depth_k].
|
797 |
+
k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
|
798 |
+
match with q.
|
799 |
+
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
|
800 |
+
match with q.
|
801 |
+
bias: bias Tensor (see attention_bias())
|
802 |
+
dropout_rate: a float.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
Tensor with shape [..., length_q, depth_v].
|
806 |
+
"""
|
807 |
+
logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv]
|
808 |
+
logits = tf.multiply(logits, 1.0 / math.sqrt(float(get_shape_list(q)[-1])))
|
809 |
+
if bias is not None:
|
810 |
+
# `attention_mask` = [B, T]
|
811 |
+
from_shape = get_shape_list(q)
|
812 |
+
if len(from_shape) == 4:
|
813 |
+
broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], 1], tf.float32)
|
814 |
+
elif len(from_shape) == 5:
|
815 |
+
# from_shape = [B, N, Block_num, block_size, depth]#
|
816 |
+
broadcast_ones = tf.ones([from_shape[0], 1, from_shape[2], from_shape[3],
|
817 |
+
1], tf.float32)
|
818 |
+
|
819 |
+
bias = tf.matmul(broadcast_ones,
|
820 |
+
tf.cast(bias, tf.float32), transpose_b=True)
|
821 |
+
|
822 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
823 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
824 |
+
# positions we want to attend and -10000.0 for masked positions.
|
825 |
+
adder = (1.0 - bias) * -10000.0
|
826 |
+
|
827 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
828 |
+
# effectively the same as removing these entirely.
|
829 |
+
logits += adder
|
830 |
+
else:
|
831 |
+
adder = 0.0
|
832 |
+
|
833 |
+
attention_probs = tf.nn.softmax(logits, name="attention_probs")
|
834 |
+
attention_probs = dropout(attention_probs, dropout_rate)
|
835 |
+
return tf.matmul(attention_probs, v)
|
836 |
+
|
837 |
+
|
838 |
+
def attention_layer(from_tensor,
|
839 |
+
to_tensor,
|
840 |
+
attention_mask=None,
|
841 |
+
num_attention_heads=1,
|
842 |
+
query_act=None,
|
843 |
+
key_act=None,
|
844 |
+
value_act=None,
|
845 |
+
attention_probs_dropout_prob=0.0,
|
846 |
+
initializer_range=0.02,
|
847 |
+
batch_size=None,
|
848 |
+
from_seq_length=None,
|
849 |
+
to_seq_length=None,
|
850 |
+
use_einsum=True):
|
851 |
+
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
|
852 |
+
|
853 |
+
Args:
|
854 |
+
from_tensor: float Tensor of shape [batch_size, from_seq_length,
|
855 |
+
from_width].
|
856 |
+
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
|
857 |
+
attention_mask: (optional) int32 Tensor of shape [batch_size,
|
858 |
+
from_seq_length, to_seq_length]. The values should be 1 or 0. The
|
859 |
+
attention scores will effectively be set to -infinity for any positions in
|
860 |
+
the mask that are 0, and will be unchanged for positions that are 1.
|
861 |
+
num_attention_heads: int. Number of attention heads.
|
862 |
+
query_act: (optional) Activation function for the query transform.
|
863 |
+
key_act: (optional) Activation function for the key transform.
|
864 |
+
value_act: (optional) Activation function for the value transform.
|
865 |
+
attention_probs_dropout_prob: (optional) float. Dropout probability of the
|
866 |
+
attention probabilities.
|
867 |
+
initializer_range: float. Range of the weight initializer.
|
868 |
+
batch_size: (Optional) int. If the input is 2D, this might be the batch size
|
869 |
+
of the 3D version of the `from_tensor` and `to_tensor`.
|
870 |
+
from_seq_length: (Optional) If the input is 2D, this might be the seq length
|
871 |
+
of the 3D version of the `from_tensor`.
|
872 |
+
to_seq_length: (Optional) If the input is 2D, this might be the seq length
|
873 |
+
of the 3D version of the `to_tensor`.
|
874 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers
|
875 |
+
|
876 |
+
Returns:
|
877 |
+
float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
|
878 |
+
size_per_head].
|
879 |
+
|
880 |
+
Raises:
|
881 |
+
ValueError: Any of the arguments or tensor shapes are invalid.
|
882 |
+
"""
|
883 |
+
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
884 |
+
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
|
885 |
+
size_per_head = int(from_shape[2]/num_attention_heads)
|
886 |
+
|
887 |
+
if len(from_shape) != len(to_shape):
|
888 |
+
raise ValueError(
|
889 |
+
"The rank of `from_tensor` must match the rank of `to_tensor`.")
|
890 |
+
|
891 |
+
if len(from_shape) == 3:
|
892 |
+
batch_size = from_shape[0]
|
893 |
+
from_seq_length = from_shape[1]
|
894 |
+
to_seq_length = to_shape[1]
|
895 |
+
elif len(from_shape) == 2:
|
896 |
+
if (batch_size is None or from_seq_length is None or to_seq_length is None):
|
897 |
+
raise ValueError(
|
898 |
+
"When passing in rank 2 tensors to attention_layer, the values "
|
899 |
+
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
|
900 |
+
"must all be specified.")
|
901 |
+
|
902 |
+
# Scalar dimensions referenced here:
|
903 |
+
# B = batch size (number of sequences)
|
904 |
+
# F = `from_tensor` sequence length
|
905 |
+
# T = `to_tensor` sequence length
|
906 |
+
# N = `num_attention_heads`
|
907 |
+
# H = `size_per_head`
|
908 |
+
|
909 |
+
# `query_layer` = [B, F, N, H]
|
910 |
+
q = dense_layer_3d(from_tensor, num_attention_heads, size_per_head,
|
911 |
+
create_initializer(initializer_range), query_act,
|
912 |
+
use_einsum, "query")
|
913 |
+
|
914 |
+
# `key_layer` = [B, T, N, H]
|
915 |
+
k = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
|
916 |
+
create_initializer(initializer_range), key_act,
|
917 |
+
use_einsum, "key")
|
918 |
+
# `value_layer` = [B, T, N, H]
|
919 |
+
v = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
|
920 |
+
create_initializer(initializer_range), value_act,
|
921 |
+
use_einsum, "value")
|
922 |
+
q = tf.transpose(q, [0, 2, 1, 3])
|
923 |
+
k = tf.transpose(k, [0, 2, 1, 3])
|
924 |
+
v = tf.transpose(v, [0, 2, 1, 3])
|
925 |
+
if attention_mask is not None:
|
926 |
+
attention_mask = tf.reshape(
|
927 |
+
attention_mask, [batch_size, 1, to_seq_length, 1])
|
928 |
+
# 'new_embeddings = [B, N, F, H]'
|
929 |
+
new_embeddings = dot_product_attention(q, k, v, attention_mask,
|
930 |
+
attention_probs_dropout_prob)
|
931 |
+
|
932 |
+
return tf.transpose(new_embeddings, [0, 2, 1, 3])
|
933 |
+
|
934 |
+
|
935 |
+
def attention_ffn_block(layer_input,
|
936 |
+
hidden_size=768,
|
937 |
+
attention_mask=None,
|
938 |
+
num_attention_heads=1,
|
939 |
+
attention_head_size=64,
|
940 |
+
attention_probs_dropout_prob=0.0,
|
941 |
+
intermediate_size=3072,
|
942 |
+
intermediate_act_fn=None,
|
943 |
+
initializer_range=0.02,
|
944 |
+
hidden_dropout_prob=0.0,
|
945 |
+
use_einsum=True):
|
946 |
+
"""A network with attention-ffn as sub-block.
|
947 |
+
|
948 |
+
Args:
|
949 |
+
layer_input: float Tensor of shape [batch_size, from_seq_length,
|
950 |
+
from_width].
|
951 |
+
hidden_size: (optional) int, size of hidden layer.
|
952 |
+
attention_mask: (optional) int32 Tensor of shape [batch_size,
|
953 |
+
from_seq_length, to_seq_length]. The values should be 1 or 0. The
|
954 |
+
attention scores will effectively be set to -infinity for any positions in
|
955 |
+
the mask that are 0, and will be unchanged for positions that are 1.
|
956 |
+
num_attention_heads: int. Number of attention heads.
|
957 |
+
attention_head_size: int. Size of attention head.
|
958 |
+
attention_probs_dropout_prob: float. dropout probability for attention_layer
|
959 |
+
intermediate_size: int. Size of intermediate hidden layer.
|
960 |
+
intermediate_act_fn: (optional) Activation function for the intermediate
|
961 |
+
layer.
|
962 |
+
initializer_range: float. Range of the weight initializer.
|
963 |
+
hidden_dropout_prob: (optional) float. Dropout probability of the hidden
|
964 |
+
layer.
|
965 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers
|
966 |
+
|
967 |
+
Returns:
|
968 |
+
layer output
|
969 |
+
"""
|
970 |
+
|
971 |
+
with tf.variable_scope("attention_1"):
|
972 |
+
with tf.variable_scope("self"):
|
973 |
+
attention_output = attention_layer(
|
974 |
+
from_tensor=layer_input,
|
975 |
+
to_tensor=layer_input,
|
976 |
+
attention_mask=attention_mask,
|
977 |
+
num_attention_heads=num_attention_heads,
|
978 |
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
979 |
+
initializer_range=initializer_range,
|
980 |
+
use_einsum=use_einsum)
|
981 |
+
|
982 |
+
# Run a linear projection of `hidden_size` then add a residual
|
983 |
+
# with `layer_input`.
|
984 |
+
with tf.variable_scope("output"):
|
985 |
+
attention_output = dense_layer_3d_proj(
|
986 |
+
attention_output,
|
987 |
+
hidden_size,
|
988 |
+
attention_head_size,
|
989 |
+
create_initializer(initializer_range),
|
990 |
+
None,
|
991 |
+
use_einsum=use_einsum,
|
992 |
+
name="dense")
|
993 |
+
attention_output = dropout(attention_output, hidden_dropout_prob)
|
994 |
+
attention_output = layer_norm(attention_output + layer_input)
|
995 |
+
with tf.variable_scope("ffn_1"):
|
996 |
+
with tf.variable_scope("intermediate"):
|
997 |
+
intermediate_output = dense_layer_2d(
|
998 |
+
attention_output,
|
999 |
+
intermediate_size,
|
1000 |
+
create_initializer(initializer_range),
|
1001 |
+
intermediate_act_fn,
|
1002 |
+
use_einsum=use_einsum,
|
1003 |
+
num_attention_heads=num_attention_heads,
|
1004 |
+
name="dense")
|
1005 |
+
with tf.variable_scope("output"):
|
1006 |
+
ffn_output = dense_layer_2d(
|
1007 |
+
intermediate_output,
|
1008 |
+
hidden_size,
|
1009 |
+
create_initializer(initializer_range),
|
1010 |
+
None,
|
1011 |
+
use_einsum=use_einsum,
|
1012 |
+
num_attention_heads=num_attention_heads,
|
1013 |
+
name="dense")
|
1014 |
+
ffn_output = dropout(ffn_output, hidden_dropout_prob)
|
1015 |
+
ffn_output = layer_norm(ffn_output + attention_output)
|
1016 |
+
return ffn_output
|
1017 |
+
|
1018 |
+
|
1019 |
+
def transformer_model(input_tensor,
|
1020 |
+
attention_mask=None,
|
1021 |
+
hidden_size=768,
|
1022 |
+
num_hidden_layers=12,
|
1023 |
+
num_hidden_groups=12,
|
1024 |
+
num_attention_heads=12,
|
1025 |
+
intermediate_size=3072,
|
1026 |
+
inner_group_num=1,
|
1027 |
+
intermediate_act_fn="gelu",
|
1028 |
+
hidden_dropout_prob=0.1,
|
1029 |
+
attention_probs_dropout_prob=0.1,
|
1030 |
+
initializer_range=0.02,
|
1031 |
+
do_return_all_layers=False,
|
1032 |
+
use_einsum=True):
|
1033 |
+
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
|
1034 |
+
|
1035 |
+
This is almost an exact implementation of the original Transformer encoder.
|
1036 |
+
|
1037 |
+
See the original paper:
|
1038 |
+
https://arxiv.org/abs/1706.03762
|
1039 |
+
|
1040 |
+
Also see:
|
1041 |
+
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
|
1042 |
+
|
1043 |
+
Args:
|
1044 |
+
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
|
1045 |
+
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
|
1046 |
+
seq_length], with 1 for positions that can be attended to and 0 in
|
1047 |
+
positions that should not be.
|
1048 |
+
hidden_size: int. Hidden size of the Transformer.
|
1049 |
+
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
|
1050 |
+
num_hidden_groups: int. Number of group for the hidden layers, parameters
|
1051 |
+
in the same group are shared.
|
1052 |
+
num_attention_heads: int. Number of attention heads in the Transformer.
|
1053 |
+
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
1054 |
+
forward) layer.
|
1055 |
+
inner_group_num: int, number of inner repetition of attention and ffn.
|
1056 |
+
intermediate_act_fn: function. The non-linear activation function to apply
|
1057 |
+
to the output of the intermediate/feed-forward layer.
|
1058 |
+
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
1059 |
+
attention_probs_dropout_prob: float. Dropout probability of the attention
|
1060 |
+
probabilities.
|
1061 |
+
initializer_range: float. Range of the initializer (stddev of truncated
|
1062 |
+
normal).
|
1063 |
+
do_return_all_layers: Whether to also return all layers or just the final
|
1064 |
+
layer.
|
1065 |
+
use_einsum: bool. Whether to use einsum or reshape+matmul for dense layers
|
1066 |
+
|
1067 |
+
Returns:
|
1068 |
+
float Tensor of shape [batch_size, seq_length, hidden_size], the final
|
1069 |
+
hidden layer of the Transformer.
|
1070 |
+
|
1071 |
+
Raises:
|
1072 |
+
ValueError: A Tensor shape or parameter is invalid.
|
1073 |
+
"""
|
1074 |
+
if hidden_size % num_attention_heads != 0:
|
1075 |
+
raise ValueError(
|
1076 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
1077 |
+
"heads (%d)" % (hidden_size, num_attention_heads))
|
1078 |
+
|
1079 |
+
attention_head_size = hidden_size // num_attention_heads
|
1080 |
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
1081 |
+
input_width = input_shape[2]
|
1082 |
+
|
1083 |
+
all_layer_outputs = []
|
1084 |
+
if input_width != hidden_size:
|
1085 |
+
prev_output = dense_layer_2d(
|
1086 |
+
input_tensor, hidden_size, create_initializer(initializer_range),
|
1087 |
+
None, use_einsum=use_einsum, name="embedding_hidden_mapping_in")
|
1088 |
+
else:
|
1089 |
+
prev_output = input_tensor
|
1090 |
+
with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE):
|
1091 |
+
for layer_idx in range(num_hidden_layers):
|
1092 |
+
group_idx = int(layer_idx / num_hidden_layers * num_hidden_groups)
|
1093 |
+
with tf.variable_scope("group_%d" % group_idx):
|
1094 |
+
with tf.name_scope("layer_%d" % layer_idx):
|
1095 |
+
layer_output = prev_output
|
1096 |
+
for inner_group_idx in range(inner_group_num):
|
1097 |
+
with tf.variable_scope("inner_group_%d" % inner_group_idx):
|
1098 |
+
layer_output = attention_ffn_block(
|
1099 |
+
layer_input=layer_output,
|
1100 |
+
hidden_size=hidden_size,
|
1101 |
+
attention_mask=attention_mask,
|
1102 |
+
num_attention_heads=num_attention_heads,
|
1103 |
+
attention_head_size=attention_head_size,
|
1104 |
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
1105 |
+
intermediate_size=intermediate_size,
|
1106 |
+
intermediate_act_fn=intermediate_act_fn,
|
1107 |
+
initializer_range=initializer_range,
|
1108 |
+
hidden_dropout_prob=hidden_dropout_prob,
|
1109 |
+
use_einsum=use_einsum)
|
1110 |
+
prev_output = layer_output
|
1111 |
+
all_layer_outputs.append(layer_output)
|
1112 |
+
if do_return_all_layers:
|
1113 |
+
return all_layer_outputs
|
1114 |
+
else:
|
1115 |
+
return all_layer_outputs[-1]
|
1116 |
+
|
1117 |
+
|
1118 |
+
def get_shape_list(tensor, expected_rank=None, name=None):
|
1119 |
+
"""Returns a list of the shape of tensor, preferring static dimensions.
|
1120 |
+
|
1121 |
+
Args:
|
1122 |
+
tensor: A tf.Tensor object to find the shape of.
|
1123 |
+
expected_rank: (optional) int. The expected rank of `tensor`. If this is
|
1124 |
+
specified and the `tensor` has a different rank, and exception will be
|
1125 |
+
thrown.
|
1126 |
+
name: Optional name of the tensor for the error message.
|
1127 |
+
|
1128 |
+
Returns:
|
1129 |
+
A list of dimensions of the shape of tensor. All static dimensions will
|
1130 |
+
be returned as python integers, and dynamic dimensions will be returned
|
1131 |
+
as tf.Tensor scalars.
|
1132 |
+
"""
|
1133 |
+
if name is None:
|
1134 |
+
name = tensor.name
|
1135 |
+
|
1136 |
+
if expected_rank is not None:
|
1137 |
+
assert_rank(tensor, expected_rank, name)
|
1138 |
+
|
1139 |
+
shape = tensor.shape.as_list()
|
1140 |
+
|
1141 |
+
non_static_indexes = []
|
1142 |
+
for (index, dim) in enumerate(shape):
|
1143 |
+
if dim is None:
|
1144 |
+
non_static_indexes.append(index)
|
1145 |
+
|
1146 |
+
if not non_static_indexes:
|
1147 |
+
return shape
|
1148 |
+
|
1149 |
+
dyn_shape = tf.shape(tensor)
|
1150 |
+
for index in non_static_indexes:
|
1151 |
+
shape[index] = dyn_shape[index]
|
1152 |
+
return shape
|
1153 |
+
|
1154 |
+
|
1155 |
+
def reshape_to_matrix(input_tensor):
|
1156 |
+
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
|
1157 |
+
ndims = input_tensor.shape.ndims
|
1158 |
+
if ndims < 2:
|
1159 |
+
raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
|
1160 |
+
(input_tensor.shape))
|
1161 |
+
if ndims == 2:
|
1162 |
+
return input_tensor
|
1163 |
+
|
1164 |
+
width = input_tensor.shape[-1]
|
1165 |
+
output_tensor = tf.reshape(input_tensor, [-1, width])
|
1166 |
+
return output_tensor
|
1167 |
+
|
1168 |
+
|
1169 |
+
def reshape_from_matrix(output_tensor, orig_shape_list):
|
1170 |
+
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
|
1171 |
+
if len(orig_shape_list) == 2:
|
1172 |
+
return output_tensor
|
1173 |
+
|
1174 |
+
output_shape = get_shape_list(output_tensor)
|
1175 |
+
|
1176 |
+
orig_dims = orig_shape_list[0:-1]
|
1177 |
+
width = output_shape[-1]
|
1178 |
+
|
1179 |
+
return tf.reshape(output_tensor, orig_dims + [width])
|
1180 |
+
|
1181 |
+
|
1182 |
+
def assert_rank(tensor, expected_rank, name=None):
|
1183 |
+
"""Raises an exception if the tensor rank is not of the expected rank.
|
1184 |
+
|
1185 |
+
Args:
|
1186 |
+
tensor: A tf.Tensor to check the rank of.
|
1187 |
+
expected_rank: Python integer or list of integers, expected rank.
|
1188 |
+
name: Optional name of the tensor for the error message.
|
1189 |
+
|
1190 |
+
Raises:
|
1191 |
+
ValueError: If the expected shape doesn't match the actual shape.
|
1192 |
+
"""
|
1193 |
+
if name is None:
|
1194 |
+
name = tensor.name
|
1195 |
+
|
1196 |
+
expected_rank_dict = {}
|
1197 |
+
if isinstance(expected_rank, six.integer_types):
|
1198 |
+
expected_rank_dict[expected_rank] = True
|
1199 |
+
else:
|
1200 |
+
for x in expected_rank:
|
1201 |
+
expected_rank_dict[x] = True
|
1202 |
+
|
1203 |
+
actual_rank = tensor.shape.ndims
|
1204 |
+
if actual_rank not in expected_rank_dict:
|
1205 |
+
scope_name = tf.get_variable_scope().name
|
1206 |
+
raise ValueError(
|
1207 |
+
"For the tensor `%s` in scope `%s`, the actual rank "
|
1208 |
+
"`%d` (shape = %s) is not equal to the expected rank `%s`" %
|
1209 |
+
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
|
Indic-BERT-v1-master/albert/modeling_test.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
|
20 |
+
import collections
|
21 |
+
import json
|
22 |
+
import random
|
23 |
+
import re
|
24 |
+
|
25 |
+
from albert import modeling
|
26 |
+
import numpy as np
|
27 |
+
import six
|
28 |
+
from six.moves import range
|
29 |
+
import tensorflow.compat.v1 as tf
|
30 |
+
|
31 |
+
|
32 |
+
class AlbertModelTest(tf.test.TestCase):
|
33 |
+
|
34 |
+
class AlbertModelTester(object):
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
parent,
|
38 |
+
batch_size=13,
|
39 |
+
seq_length=7,
|
40 |
+
is_training=True,
|
41 |
+
use_input_mask=True,
|
42 |
+
use_token_type_ids=True,
|
43 |
+
vocab_size=99,
|
44 |
+
embedding_size=32,
|
45 |
+
hidden_size=32,
|
46 |
+
num_hidden_layers=5,
|
47 |
+
num_attention_heads=4,
|
48 |
+
intermediate_size=37,
|
49 |
+
hidden_act="gelu",
|
50 |
+
hidden_dropout_prob=0.1,
|
51 |
+
attention_probs_dropout_prob=0.1,
|
52 |
+
max_position_embeddings=512,
|
53 |
+
type_vocab_size=16,
|
54 |
+
initializer_range=0.02,
|
55 |
+
scope=None):
|
56 |
+
self.parent = parent
|
57 |
+
self.batch_size = batch_size
|
58 |
+
self.seq_length = seq_length
|
59 |
+
self.is_training = is_training
|
60 |
+
self.use_input_mask = use_input_mask
|
61 |
+
self.use_token_type_ids = use_token_type_ids
|
62 |
+
self.vocab_size = vocab_size
|
63 |
+
self.embedding_size = embedding_size
|
64 |
+
self.hidden_size = hidden_size
|
65 |
+
self.num_hidden_layers = num_hidden_layers
|
66 |
+
self.num_attention_heads = num_attention_heads
|
67 |
+
self.intermediate_size = intermediate_size
|
68 |
+
self.hidden_act = hidden_act
|
69 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
70 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
71 |
+
self.max_position_embeddings = max_position_embeddings
|
72 |
+
self.type_vocab_size = type_vocab_size
|
73 |
+
self.initializer_range = initializer_range
|
74 |
+
self.scope = scope
|
75 |
+
|
76 |
+
def create_model(self):
|
77 |
+
input_ids = AlbertModelTest.ids_tensor([self.batch_size, self.seq_length],
|
78 |
+
self.vocab_size)
|
79 |
+
|
80 |
+
input_mask = None
|
81 |
+
if self.use_input_mask:
|
82 |
+
input_mask = AlbertModelTest.ids_tensor(
|
83 |
+
[self.batch_size, self.seq_length], vocab_size=2)
|
84 |
+
|
85 |
+
token_type_ids = None
|
86 |
+
if self.use_token_type_ids:
|
87 |
+
token_type_ids = AlbertModelTest.ids_tensor(
|
88 |
+
[self.batch_size, self.seq_length], self.type_vocab_size)
|
89 |
+
|
90 |
+
config = modeling.AlbertConfig(
|
91 |
+
vocab_size=self.vocab_size,
|
92 |
+
embedding_size=self.embedding_size,
|
93 |
+
hidden_size=self.hidden_size,
|
94 |
+
num_hidden_layers=self.num_hidden_layers,
|
95 |
+
num_attention_heads=self.num_attention_heads,
|
96 |
+
intermediate_size=self.intermediate_size,
|
97 |
+
hidden_act=self.hidden_act,
|
98 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
99 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
100 |
+
max_position_embeddings=self.max_position_embeddings,
|
101 |
+
type_vocab_size=self.type_vocab_size,
|
102 |
+
initializer_range=self.initializer_range)
|
103 |
+
|
104 |
+
model = modeling.AlbertModel(
|
105 |
+
config=config,
|
106 |
+
is_training=self.is_training,
|
107 |
+
input_ids=input_ids,
|
108 |
+
input_mask=input_mask,
|
109 |
+
token_type_ids=token_type_ids,
|
110 |
+
scope=self.scope)
|
111 |
+
|
112 |
+
outputs = {
|
113 |
+
"embedding_output": model.get_embedding_output(),
|
114 |
+
"sequence_output": model.get_sequence_output(),
|
115 |
+
"pooled_output": model.get_pooled_output(),
|
116 |
+
"all_encoder_layers": model.get_all_encoder_layers(),
|
117 |
+
}
|
118 |
+
return outputs
|
119 |
+
|
120 |
+
def check_output(self, result):
|
121 |
+
self.parent.assertAllEqual(
|
122 |
+
result["embedding_output"].shape,
|
123 |
+
[self.batch_size, self.seq_length, self.embedding_size])
|
124 |
+
|
125 |
+
self.parent.assertAllEqual(
|
126 |
+
result["sequence_output"].shape,
|
127 |
+
[self.batch_size, self.seq_length, self.hidden_size])
|
128 |
+
|
129 |
+
self.parent.assertAllEqual(result["pooled_output"].shape,
|
130 |
+
[self.batch_size, self.hidden_size])
|
131 |
+
|
132 |
+
def test_default(self):
|
133 |
+
self.run_tester(AlbertModelTest.AlbertModelTester(self))
|
134 |
+
|
135 |
+
def test_config_to_json_string(self):
|
136 |
+
config = modeling.AlbertConfig(vocab_size=99, hidden_size=37)
|
137 |
+
obj = json.loads(config.to_json_string())
|
138 |
+
self.assertEqual(obj["vocab_size"], 99)
|
139 |
+
self.assertEqual(obj["hidden_size"], 37)
|
140 |
+
|
141 |
+
def test_einsum_via_matmul(self):
|
142 |
+
batch_size = 8
|
143 |
+
seq_length = 12
|
144 |
+
num_attention_heads = 3
|
145 |
+
head_size = 6
|
146 |
+
hidden_size = 10
|
147 |
+
|
148 |
+
input_tensor = np.random.uniform(0, 1,
|
149 |
+
[batch_size, seq_length, hidden_size])
|
150 |
+
input_tensor = tf.constant(input_tensor, dtype=tf.float32)
|
151 |
+
w = np.random.uniform(0, 1, [hidden_size, num_attention_heads, head_size])
|
152 |
+
w = tf.constant(w, dtype=tf.float32)
|
153 |
+
ret1 = tf.einsum("BFH,HND->BFND", input_tensor, w)
|
154 |
+
ret2 = modeling.einsum_via_matmul(input_tensor, w, 1)
|
155 |
+
self.assertAllClose(ret1, ret2)
|
156 |
+
|
157 |
+
input_tensor = np.random.uniform(0, 1,
|
158 |
+
[batch_size, seq_length,
|
159 |
+
num_attention_heads, head_size])
|
160 |
+
input_tensor = tf.constant(input_tensor, dtype=tf.float32)
|
161 |
+
w = np.random.uniform(0, 1, [num_attention_heads, head_size, hidden_size])
|
162 |
+
w = tf.constant(w, dtype=tf.float32)
|
163 |
+
ret1 = tf.einsum("BFND,NDH->BFH", input_tensor, w)
|
164 |
+
ret2 = modeling.einsum_via_matmul(input_tensor, w, 2)
|
165 |
+
self.assertAllClose(ret1, ret2)
|
166 |
+
|
167 |
+
def run_tester(self, tester):
|
168 |
+
with self.test_session() as sess:
|
169 |
+
ops = tester.create_model()
|
170 |
+
init_op = tf.group(tf.global_variables_initializer(),
|
171 |
+
tf.local_variables_initializer())
|
172 |
+
sess.run(init_op)
|
173 |
+
output_result = sess.run(ops)
|
174 |
+
tester.check_output(output_result)
|
175 |
+
|
176 |
+
self.assert_all_tensors_reachable(sess, [init_op, ops])
|
177 |
+
|
178 |
+
@classmethod
|
179 |
+
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
180 |
+
"""Creates a random int32 tensor of the shape within the vocab size."""
|
181 |
+
if rng is None:
|
182 |
+
rng = random.Random()
|
183 |
+
|
184 |
+
total_dims = 1
|
185 |
+
for dim in shape:
|
186 |
+
total_dims *= dim
|
187 |
+
|
188 |
+
values = []
|
189 |
+
for _ in range(total_dims):
|
190 |
+
values.append(rng.randint(0, vocab_size - 1))
|
191 |
+
|
192 |
+
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
|
193 |
+
|
194 |
+
def assert_all_tensors_reachable(self, sess, outputs):
|
195 |
+
"""Checks that all the tensors in the graph are reachable from outputs."""
|
196 |
+
graph = sess.graph
|
197 |
+
|
198 |
+
ignore_strings = [
|
199 |
+
"^.*/assert_less_equal/.*$",
|
200 |
+
"^.*/dilation_rate$",
|
201 |
+
"^.*/Tensordot/concat$",
|
202 |
+
"^.*/Tensordot/concat/axis$",
|
203 |
+
"^testing/.*$",
|
204 |
+
]
|
205 |
+
|
206 |
+
ignore_regexes = [re.compile(x) for x in ignore_strings]
|
207 |
+
|
208 |
+
unreachable = self.get_unreachable_ops(graph, outputs)
|
209 |
+
filtered_unreachable = []
|
210 |
+
for x in unreachable:
|
211 |
+
do_ignore = False
|
212 |
+
for r in ignore_regexes:
|
213 |
+
m = r.match(six.ensure_str(x.name))
|
214 |
+
if m is not None:
|
215 |
+
do_ignore = True
|
216 |
+
if do_ignore:
|
217 |
+
continue
|
218 |
+
filtered_unreachable.append(x)
|
219 |
+
unreachable = filtered_unreachable
|
220 |
+
|
221 |
+
self.assertEqual(
|
222 |
+
len(unreachable), 0, "The following ops are unreachable: %s" %
|
223 |
+
(" ".join([x.name for x in unreachable])))
|
224 |
+
|
225 |
+
@classmethod
|
226 |
+
def get_unreachable_ops(cls, graph, outputs):
|
227 |
+
"""Finds all of the tensors in graph that are unreachable from outputs."""
|
228 |
+
outputs = cls.flatten_recursive(outputs)
|
229 |
+
output_to_op = collections.defaultdict(list)
|
230 |
+
op_to_all = collections.defaultdict(list)
|
231 |
+
assign_out_to_in = collections.defaultdict(list)
|
232 |
+
|
233 |
+
for op in graph.get_operations():
|
234 |
+
for x in op.inputs:
|
235 |
+
op_to_all[op.name].append(x.name)
|
236 |
+
for y in op.outputs:
|
237 |
+
output_to_op[y.name].append(op.name)
|
238 |
+
op_to_all[op.name].append(y.name)
|
239 |
+
if str(op.type) == "Assign":
|
240 |
+
for y in op.outputs:
|
241 |
+
for x in op.inputs:
|
242 |
+
assign_out_to_in[y.name].append(x.name)
|
243 |
+
|
244 |
+
assign_groups = collections.defaultdict(list)
|
245 |
+
for out_name in assign_out_to_in.keys():
|
246 |
+
name_group = assign_out_to_in[out_name]
|
247 |
+
for n1 in name_group:
|
248 |
+
assign_groups[n1].append(out_name)
|
249 |
+
for n2 in name_group:
|
250 |
+
if n1 != n2:
|
251 |
+
assign_groups[n1].append(n2)
|
252 |
+
|
253 |
+
seen_tensors = {}
|
254 |
+
stack = [x.name for x in outputs]
|
255 |
+
while stack:
|
256 |
+
name = stack.pop()
|
257 |
+
if name in seen_tensors:
|
258 |
+
continue
|
259 |
+
seen_tensors[name] = True
|
260 |
+
|
261 |
+
if name in output_to_op:
|
262 |
+
for op_name in output_to_op[name]:
|
263 |
+
if op_name in op_to_all:
|
264 |
+
for input_name in op_to_all[op_name]:
|
265 |
+
if input_name not in stack:
|
266 |
+
stack.append(input_name)
|
267 |
+
|
268 |
+
expanded_names = []
|
269 |
+
if name in assign_groups:
|
270 |
+
for assign_name in assign_groups[name]:
|
271 |
+
expanded_names.append(assign_name)
|
272 |
+
|
273 |
+
for expanded_name in expanded_names:
|
274 |
+
if expanded_name not in stack:
|
275 |
+
stack.append(expanded_name)
|
276 |
+
|
277 |
+
unreachable_ops = []
|
278 |
+
for op in graph.get_operations():
|
279 |
+
is_unreachable = False
|
280 |
+
all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
|
281 |
+
for name in all_names:
|
282 |
+
if name not in seen_tensors:
|
283 |
+
is_unreachable = True
|
284 |
+
if is_unreachable:
|
285 |
+
unreachable_ops.append(op)
|
286 |
+
return unreachable_ops
|
287 |
+
|
288 |
+
@classmethod
|
289 |
+
def flatten_recursive(cls, item):
|
290 |
+
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
|
291 |
+
output = []
|
292 |
+
if isinstance(item, list):
|
293 |
+
output.extend(item)
|
294 |
+
elif isinstance(item, tuple):
|
295 |
+
output.extend(list(item))
|
296 |
+
elif isinstance(item, dict):
|
297 |
+
for (_, v) in six.iteritems(item):
|
298 |
+
output.append(v)
|
299 |
+
else:
|
300 |
+
return [item]
|
301 |
+
|
302 |
+
flat_output = []
|
303 |
+
for x in output:
|
304 |
+
flat_output.extend(cls.flatten_recursive(x))
|
305 |
+
return flat_output
|
306 |
+
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
tf.test.main()
|
Indic-BERT-v1-master/albert/optimization.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Functions and classes related to optimization (weight updates)."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
import re
|
22 |
+
from albert import lamb_optimizer
|
23 |
+
import six
|
24 |
+
from six.moves import zip
|
25 |
+
import tensorflow.compat.v1 as tf
|
26 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
27 |
+
|
28 |
+
|
29 |
+
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu,
|
30 |
+
optimizer="adamw", poly_power=1.0, start_warmup_step=0,
|
31 |
+
colocate_gradients_with_ops=False):
|
32 |
+
"""Creates an optimizer training op."""
|
33 |
+
global_step = tf.train.get_or_create_global_step()
|
34 |
+
|
35 |
+
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
36 |
+
|
37 |
+
# Implements linear decay of the learning rate.
|
38 |
+
learning_rate = tf.train.polynomial_decay(
|
39 |
+
learning_rate,
|
40 |
+
global_step,
|
41 |
+
num_train_steps,
|
42 |
+
end_learning_rate=0.0,
|
43 |
+
power=poly_power,
|
44 |
+
cycle=False)
|
45 |
+
|
46 |
+
# Implements linear warmup. I.e., if global_step - start_warmup_step <
|
47 |
+
# num_warmup_steps, the learning rate will be
|
48 |
+
# `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
|
49 |
+
if num_warmup_steps:
|
50 |
+
tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step)
|
51 |
+
+ ", for " + str(num_warmup_steps) + " steps ++++++")
|
52 |
+
global_steps_int = tf.cast(global_step, tf.int32)
|
53 |
+
start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
|
54 |
+
global_steps_int = global_steps_int - start_warm_int
|
55 |
+
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
56 |
+
|
57 |
+
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
58 |
+
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
59 |
+
|
60 |
+
warmup_percent_done = global_steps_float / warmup_steps_float
|
61 |
+
warmup_learning_rate = init_lr * warmup_percent_done
|
62 |
+
|
63 |
+
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
64 |
+
learning_rate = (
|
65 |
+
(1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
|
66 |
+
|
67 |
+
# It is OK that you use this optimizer for finetuning, since this
|
68 |
+
# is how the model was trained (note that the Adam m/v variables are NOT
|
69 |
+
# loaded from init_checkpoint.)
|
70 |
+
# It is OK to use AdamW in the finetuning even the model is trained by LAMB.
|
71 |
+
# As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
|
72 |
+
# is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
|
73 |
+
# batch size of 64 in the finetune.
|
74 |
+
if optimizer == "adamw":
|
75 |
+
tf.logging.info("using adamw")
|
76 |
+
optimizer = AdamWeightDecayOptimizer(
|
77 |
+
learning_rate=learning_rate,
|
78 |
+
weight_decay_rate=0.01,
|
79 |
+
beta_1=0.9,
|
80 |
+
beta_2=0.999,
|
81 |
+
epsilon=1e-6,
|
82 |
+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
83 |
+
elif optimizer == "lamb":
|
84 |
+
tf.logging.info("using lamb")
|
85 |
+
optimizer = lamb_optimizer.LAMBOptimizer(
|
86 |
+
learning_rate=learning_rate,
|
87 |
+
weight_decay_rate=0.01,
|
88 |
+
beta_1=0.9,
|
89 |
+
beta_2=0.999,
|
90 |
+
epsilon=1e-6,
|
91 |
+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
92 |
+
else:
|
93 |
+
raise ValueError("Not supported optimizer: ", optimizer)
|
94 |
+
|
95 |
+
if use_tpu:
|
96 |
+
optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
|
97 |
+
|
98 |
+
tvars = tf.trainable_variables()
|
99 |
+
grads = tf.gradients(
|
100 |
+
loss, tvars, colocate_gradients_with_ops=colocate_gradients_with_ops)
|
101 |
+
|
102 |
+
# This is how the model was pre-trained.
|
103 |
+
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
104 |
+
|
105 |
+
train_op = optimizer.apply_gradients(
|
106 |
+
list(zip(grads, tvars)), global_step=global_step)
|
107 |
+
|
108 |
+
# Normally the global step update is done inside of `apply_gradients`.
|
109 |
+
# However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this.
|
110 |
+
# But if you use a different optimizer, you should probably take this line
|
111 |
+
# out.
|
112 |
+
new_global_step = global_step + 1
|
113 |
+
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
114 |
+
return train_op
|
115 |
+
|
116 |
+
|
117 |
+
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
118 |
+
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
119 |
+
|
120 |
+
def __init__(self,
|
121 |
+
learning_rate,
|
122 |
+
weight_decay_rate=0.0,
|
123 |
+
beta_1=0.9,
|
124 |
+
beta_2=0.999,
|
125 |
+
epsilon=1e-6,
|
126 |
+
exclude_from_weight_decay=None,
|
127 |
+
name="AdamWeightDecayOptimizer"):
|
128 |
+
"""Constructs a AdamWeightDecayOptimizer."""
|
129 |
+
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
130 |
+
|
131 |
+
self.learning_rate = learning_rate
|
132 |
+
self.weight_decay_rate = weight_decay_rate
|
133 |
+
self.beta_1 = beta_1
|
134 |
+
self.beta_2 = beta_2
|
135 |
+
self.epsilon = epsilon
|
136 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
137 |
+
|
138 |
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
139 |
+
"""See base class."""
|
140 |
+
assignments = []
|
141 |
+
for (grad, param) in grads_and_vars:
|
142 |
+
if grad is None or param is None:
|
143 |
+
continue
|
144 |
+
|
145 |
+
param_name = self._get_variable_name(param.name)
|
146 |
+
|
147 |
+
m = tf.get_variable(
|
148 |
+
name=six.ensure_str(param_name) + "/adam_m",
|
149 |
+
shape=param.shape.as_list(),
|
150 |
+
dtype=tf.float32,
|
151 |
+
trainable=False,
|
152 |
+
initializer=tf.zeros_initializer())
|
153 |
+
v = tf.get_variable(
|
154 |
+
name=six.ensure_str(param_name) + "/adam_v",
|
155 |
+
shape=param.shape.as_list(),
|
156 |
+
dtype=tf.float32,
|
157 |
+
trainable=False,
|
158 |
+
initializer=tf.zeros_initializer())
|
159 |
+
|
160 |
+
# Standard Adam update.
|
161 |
+
next_m = (
|
162 |
+
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
163 |
+
next_v = (
|
164 |
+
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
165 |
+
tf.square(grad)))
|
166 |
+
|
167 |
+
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
168 |
+
|
169 |
+
# Just adding the square of the weights to the loss function is *not*
|
170 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
171 |
+
# since that will interact with the m and v parameters in strange ways.
|
172 |
+
#
|
173 |
+
# Instead we want ot decay the weights in a manner that doesn't interact
|
174 |
+
# with the m/v parameters. This is equivalent to adding the square
|
175 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
176 |
+
if self._do_use_weight_decay(param_name):
|
177 |
+
update += self.weight_decay_rate * param
|
178 |
+
|
179 |
+
update_with_lr = self.learning_rate * update
|
180 |
+
|
181 |
+
next_param = param - update_with_lr
|
182 |
+
|
183 |
+
assignments.extend(
|
184 |
+
[param.assign(next_param),
|
185 |
+
m.assign(next_m),
|
186 |
+
v.assign(next_v)])
|
187 |
+
return tf.group(*assignments, name=name)
|
188 |
+
|
189 |
+
def _do_use_weight_decay(self, param_name):
|
190 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
191 |
+
if not self.weight_decay_rate:
|
192 |
+
return False
|
193 |
+
if self.exclude_from_weight_decay:
|
194 |
+
for r in self.exclude_from_weight_decay:
|
195 |
+
if re.search(r, param_name) is not None:
|
196 |
+
return False
|
197 |
+
return True
|
198 |
+
|
199 |
+
def _get_variable_name(self, param_name):
|
200 |
+
"""Get the variable name from the tensor name."""
|
201 |
+
m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
|
202 |
+
if m is not None:
|
203 |
+
param_name = m.group(1)
|
204 |
+
return param_name
|
Indic-BERT-v1-master/albert/optimization_test.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
from albert import optimization
|
20 |
+
from six.moves import range
|
21 |
+
from six.moves import zip
|
22 |
+
import tensorflow.compat.v1 as tf
|
23 |
+
|
24 |
+
|
25 |
+
class OptimizationTest(tf.test.TestCase):
|
26 |
+
|
27 |
+
def test_adam(self):
|
28 |
+
with self.test_session() as sess:
|
29 |
+
w = tf.get_variable(
|
30 |
+
"w",
|
31 |
+
shape=[3],
|
32 |
+
initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
|
33 |
+
x = tf.constant([0.4, 0.2, -0.5])
|
34 |
+
loss = tf.reduce_mean(tf.square(x - w))
|
35 |
+
tvars = tf.trainable_variables()
|
36 |
+
grads = tf.gradients(loss, tvars)
|
37 |
+
global_step = tf.train.get_or_create_global_step()
|
38 |
+
optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
|
39 |
+
train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step)
|
40 |
+
init_op = tf.group(tf.global_variables_initializer(),
|
41 |
+
tf.local_variables_initializer())
|
42 |
+
sess.run(init_op)
|
43 |
+
for _ in range(100):
|
44 |
+
sess.run(train_op)
|
45 |
+
w_np = sess.run(w)
|
46 |
+
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
tf.test.main()
|
Indic-BERT-v1-master/albert/race_utils.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Utility functions for RACE dataset."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
# from __future__ import google_type_annotations
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import collections
|
23 |
+
import json
|
24 |
+
import os
|
25 |
+
from albert import classifier_utils
|
26 |
+
from albert import fine_tuning_utils
|
27 |
+
from albert import modeling
|
28 |
+
from albert import optimization
|
29 |
+
from albert import tokenization
|
30 |
+
import tensorflow.compat.v1 as tf
|
31 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
32 |
+
|
33 |
+
|
34 |
+
class InputExample(object):
|
35 |
+
"""A single training/test example for the RACE dataset."""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
example_id,
|
39 |
+
context_sentence,
|
40 |
+
start_ending,
|
41 |
+
endings,
|
42 |
+
label=None):
|
43 |
+
self.example_id = example_id
|
44 |
+
self.context_sentence = context_sentence
|
45 |
+
self.start_ending = start_ending
|
46 |
+
self.endings = endings
|
47 |
+
self.label = label
|
48 |
+
|
49 |
+
def __str__(self):
|
50 |
+
return self.__repr__()
|
51 |
+
|
52 |
+
def __repr__(self):
|
53 |
+
l = [
|
54 |
+
"id: {}".format(self.example_id),
|
55 |
+
"context_sentence: {}".format(self.context_sentence),
|
56 |
+
"start_ending: {}".format(self.start_ending),
|
57 |
+
"ending_0: {}".format(self.endings[0]),
|
58 |
+
"ending_1: {}".format(self.endings[1]),
|
59 |
+
"ending_2: {}".format(self.endings[2]),
|
60 |
+
"ending_3: {}".format(self.endings[3]),
|
61 |
+
]
|
62 |
+
|
63 |
+
if self.label is not None:
|
64 |
+
l.append("label: {}".format(self.label))
|
65 |
+
|
66 |
+
return ", ".join(l)
|
67 |
+
|
68 |
+
|
69 |
+
class RaceProcessor(object):
|
70 |
+
"""Processor for the RACE data set."""
|
71 |
+
|
72 |
+
def __init__(self, use_spm, do_lower_case, high_only, middle_only):
|
73 |
+
super(RaceProcessor, self).__init__()
|
74 |
+
self.use_spm = use_spm
|
75 |
+
self.do_lower_case = do_lower_case
|
76 |
+
self.high_only = high_only
|
77 |
+
self.middle_only = middle_only
|
78 |
+
|
79 |
+
def get_train_examples(self, data_dir):
|
80 |
+
"""Gets a collection of `InputExample`s for the train set."""
|
81 |
+
return self.read_examples(
|
82 |
+
os.path.join(data_dir, "RACE", "train"))
|
83 |
+
|
84 |
+
def get_dev_examples(self, data_dir):
|
85 |
+
"""Gets a collection of `InputExample`s for the dev set."""
|
86 |
+
return self.read_examples(
|
87 |
+
os.path.join(data_dir, "RACE", "dev"))
|
88 |
+
|
89 |
+
def get_test_examples(self, data_dir):
|
90 |
+
"""Gets a collection of `InputExample`s for prediction."""
|
91 |
+
return self.read_examples(
|
92 |
+
os.path.join(data_dir, "RACE", "test"))
|
93 |
+
|
94 |
+
def get_labels(self):
|
95 |
+
"""Gets the list of labels for this data set."""
|
96 |
+
return ["A", "B", "C", "D"]
|
97 |
+
|
98 |
+
def process_text(self, text):
|
99 |
+
if self.use_spm:
|
100 |
+
return tokenization.preprocess_text(text, lower=self.do_lower_case)
|
101 |
+
else:
|
102 |
+
return tokenization.convert_to_unicode(text)
|
103 |
+
|
104 |
+
def read_examples(self, data_dir):
|
105 |
+
"""Read examples from RACE json files."""
|
106 |
+
examples = []
|
107 |
+
for level in ["middle", "high"]:
|
108 |
+
if level == "middle" and self.high_only: continue
|
109 |
+
if level == "high" and self.middle_only: continue
|
110 |
+
cur_dir = os.path.join(data_dir, level)
|
111 |
+
|
112 |
+
cur_path = os.path.join(cur_dir, "all.txt")
|
113 |
+
with tf.gfile.Open(cur_path) as f:
|
114 |
+
for line in f:
|
115 |
+
cur_data = json.loads(line.strip())
|
116 |
+
|
117 |
+
answers = cur_data["answers"]
|
118 |
+
options = cur_data["options"]
|
119 |
+
questions = cur_data["questions"]
|
120 |
+
context = self.process_text(cur_data["article"])
|
121 |
+
|
122 |
+
for i in range(len(answers)):
|
123 |
+
label = ord(answers[i]) - ord("A")
|
124 |
+
qa_list = []
|
125 |
+
|
126 |
+
question = self.process_text(questions[i])
|
127 |
+
for j in range(4):
|
128 |
+
option = self.process_text(options[i][j])
|
129 |
+
|
130 |
+
if "_" in question:
|
131 |
+
qa_cat = question.replace("_", option)
|
132 |
+
else:
|
133 |
+
qa_cat = " ".join([question, option])
|
134 |
+
|
135 |
+
qa_list.append(qa_cat)
|
136 |
+
|
137 |
+
examples.append(
|
138 |
+
InputExample(
|
139 |
+
example_id=cur_data["id"],
|
140 |
+
context_sentence=context,
|
141 |
+
start_ending=None,
|
142 |
+
endings=[qa_list[0], qa_list[1], qa_list[2], qa_list[3]],
|
143 |
+
label=label
|
144 |
+
)
|
145 |
+
)
|
146 |
+
|
147 |
+
return examples
|
148 |
+
|
149 |
+
|
150 |
+
def convert_single_example(example_index, example, label_size, max_seq_length,
|
151 |
+
tokenizer, max_qa_length):
|
152 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
153 |
+
|
154 |
+
# RACE is a multiple choice task. To perform this task using AlBERT,
|
155 |
+
# we will use the formatting proposed in "Improving Language
|
156 |
+
# Understanding by Generative Pre-Training" and suggested by
|
157 |
+
# @jacobdevlin-google in this issue
|
158 |
+
# https://github.com/google-research/bert/issues/38.
|
159 |
+
#
|
160 |
+
# Each choice will correspond to a sample on which we run the
|
161 |
+
# inference. For a given RACE example, we will create the 4
|
162 |
+
# following inputs:
|
163 |
+
# - [CLS] context [SEP] choice_1 [SEP]
|
164 |
+
# - [CLS] context [SEP] choice_2 [SEP]
|
165 |
+
# - [CLS] context [SEP] choice_3 [SEP]
|
166 |
+
# - [CLS] context [SEP] choice_4 [SEP]
|
167 |
+
# The model will output a single value for each input. To get the
|
168 |
+
# final decision of the model, we will run a softmax over these 4
|
169 |
+
# outputs.
|
170 |
+
if isinstance(example, classifier_utils.PaddingInputExample):
|
171 |
+
return classifier_utils.InputFeatures(
|
172 |
+
example_id=0,
|
173 |
+
input_ids=[[0] * max_seq_length] * label_size,
|
174 |
+
input_mask=[[0] * max_seq_length] * label_size,
|
175 |
+
segment_ids=[[0] * max_seq_length] * label_size,
|
176 |
+
label_id=0,
|
177 |
+
is_real_example=False)
|
178 |
+
else:
|
179 |
+
context_tokens = tokenizer.tokenize(example.context_sentence)
|
180 |
+
if example.start_ending is not None:
|
181 |
+
start_ending_tokens = tokenizer.tokenize(example.start_ending)
|
182 |
+
|
183 |
+
all_input_tokens = []
|
184 |
+
all_input_ids = []
|
185 |
+
all_input_mask = []
|
186 |
+
all_segment_ids = []
|
187 |
+
for ending in example.endings:
|
188 |
+
# We create a copy of the context tokens in order to be
|
189 |
+
# able to shrink it according to ending_tokens
|
190 |
+
context_tokens_choice = context_tokens[:]
|
191 |
+
if example.start_ending is not None:
|
192 |
+
ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
|
193 |
+
else:
|
194 |
+
ending_tokens = tokenizer.tokenize(ending)
|
195 |
+
# Modifies `context_tokens_choice` and `ending_tokens` in
|
196 |
+
# place so that the total length is less than the
|
197 |
+
# specified length. Account for [CLS], [SEP], [SEP] with
|
198 |
+
# "- 3"
|
199 |
+
ending_tokens = ending_tokens[- max_qa_length:]
|
200 |
+
|
201 |
+
if len(context_tokens_choice) + len(ending_tokens) > max_seq_length - 3:
|
202 |
+
context_tokens_choice = context_tokens_choice[: (
|
203 |
+
max_seq_length - 3 - len(ending_tokens))]
|
204 |
+
tokens = ["[CLS]"] + context_tokens_choice + (
|
205 |
+
["[SEP]"] + ending_tokens + ["[SEP]"])
|
206 |
+
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (
|
207 |
+
len(ending_tokens) + 1)
|
208 |
+
|
209 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
210 |
+
input_mask = [1] * len(input_ids)
|
211 |
+
|
212 |
+
# Zero-pad up to the sequence length.
|
213 |
+
padding = [0] * (max_seq_length - len(input_ids))
|
214 |
+
input_ids += padding
|
215 |
+
input_mask += padding
|
216 |
+
segment_ids += padding
|
217 |
+
|
218 |
+
assert len(input_ids) == max_seq_length
|
219 |
+
assert len(input_mask) == max_seq_length
|
220 |
+
assert len(segment_ids) == max_seq_length
|
221 |
+
|
222 |
+
all_input_tokens.append(tokens)
|
223 |
+
all_input_ids.append(input_ids)
|
224 |
+
all_input_mask.append(input_mask)
|
225 |
+
all_segment_ids.append(segment_ids)
|
226 |
+
|
227 |
+
label = example.label
|
228 |
+
if example_index < 5:
|
229 |
+
tf.logging.info("*** Example ***")
|
230 |
+
tf.logging.info("id: {}".format(example.example_id))
|
231 |
+
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in \
|
232 |
+
enumerate(zip(all_input_tokens, all_input_ids, all_input_mask, all_segment_ids)):
|
233 |
+
tf.logging.info("choice: {}".format(choice_idx))
|
234 |
+
tf.logging.info("tokens: {}".format(" ".join(tokens)))
|
235 |
+
tf.logging.info(
|
236 |
+
"input_ids: {}".format(" ".join(map(str, input_ids))))
|
237 |
+
tf.logging.info(
|
238 |
+
"input_mask: {}".format(" ".join(map(str, input_mask))))
|
239 |
+
tf.logging.info(
|
240 |
+
"segment_ids: {}".format(" ".join(map(str, segment_ids))))
|
241 |
+
tf.logging.info("label: {}".format(label))
|
242 |
+
|
243 |
+
return classifier_utils.InputFeatures(
|
244 |
+
example_id=example.example_id,
|
245 |
+
input_ids=all_input_ids,
|
246 |
+
input_mask=all_input_mask,
|
247 |
+
segment_ids=all_segment_ids,
|
248 |
+
label_id=label
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
def file_based_convert_examples_to_features(
|
253 |
+
examples, label_list, max_seq_length, tokenizer,
|
254 |
+
output_file, max_qa_length):
|
255 |
+
"""Convert a set of `InputExample`s to a TFRecord file."""
|
256 |
+
|
257 |
+
writer = tf.python_io.TFRecordWriter(output_file)
|
258 |
+
|
259 |
+
for (ex_index, example) in enumerate(examples):
|
260 |
+
if ex_index % 10000 == 0:
|
261 |
+
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
|
262 |
+
|
263 |
+
feature = convert_single_example(ex_index, example, len(label_list),
|
264 |
+
max_seq_length, tokenizer, max_qa_length)
|
265 |
+
|
266 |
+
def create_int_feature(values):
|
267 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
268 |
+
return f
|
269 |
+
|
270 |
+
features = collections.OrderedDict()
|
271 |
+
features["input_ids"] = create_int_feature(sum(feature.input_ids, []))
|
272 |
+
features["input_mask"] = create_int_feature(sum(feature.input_mask, []))
|
273 |
+
features["segment_ids"] = create_int_feature(sum(feature.segment_ids, []))
|
274 |
+
features["label_ids"] = create_int_feature([feature.label_id])
|
275 |
+
features["is_real_example"] = create_int_feature(
|
276 |
+
[int(feature.is_real_example)])
|
277 |
+
|
278 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
279 |
+
writer.write(tf_example.SerializeToString())
|
280 |
+
writer.close()
|
281 |
+
|
282 |
+
|
283 |
+
def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,
|
284 |
+
labels, num_labels, use_one_hot_embeddings, max_seq_length,
|
285 |
+
dropout_prob, hub_module):
|
286 |
+
"""Creates a classification model."""
|
287 |
+
bsz_per_core = tf.shape(input_ids)[0]
|
288 |
+
|
289 |
+
input_ids = tf.reshape(input_ids, [bsz_per_core * num_labels, max_seq_length])
|
290 |
+
input_mask = tf.reshape(input_mask,
|
291 |
+
[bsz_per_core * num_labels, max_seq_length])
|
292 |
+
token_type_ids = tf.reshape(segment_ids,
|
293 |
+
[bsz_per_core * num_labels, max_seq_length])
|
294 |
+
|
295 |
+
(output_layer, _) = fine_tuning_utils.create_albert(
|
296 |
+
albert_config=albert_config,
|
297 |
+
is_training=is_training,
|
298 |
+
input_ids=input_ids,
|
299 |
+
input_mask=input_mask,
|
300 |
+
segment_ids=token_type_ids,
|
301 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
302 |
+
use_einsum=True,
|
303 |
+
hub_module=hub_module)
|
304 |
+
|
305 |
+
hidden_size = output_layer.shape[-1].value
|
306 |
+
|
307 |
+
output_weights = tf.get_variable(
|
308 |
+
"output_weights", [1, hidden_size],
|
309 |
+
initializer=tf.truncated_normal_initializer(stddev=0.02))
|
310 |
+
|
311 |
+
output_bias = tf.get_variable(
|
312 |
+
"output_bias", [1],
|
313 |
+
initializer=tf.zeros_initializer())
|
314 |
+
|
315 |
+
with tf.variable_scope("loss"):
|
316 |
+
if is_training:
|
317 |
+
# I.e., 0.1 dropout
|
318 |
+
output_layer = tf.nn.dropout(
|
319 |
+
output_layer, keep_prob=1 - dropout_prob)
|
320 |
+
|
321 |
+
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
|
322 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
323 |
+
logits = tf.reshape(logits, [bsz_per_core, num_labels])
|
324 |
+
probabilities = tf.nn.softmax(logits, axis=-1)
|
325 |
+
predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32)
|
326 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
327 |
+
|
328 |
+
one_hot_labels = tf.one_hot(
|
329 |
+
labels, depth=tf.cast(num_labels, dtype=tf.int32), dtype=tf.float32)
|
330 |
+
|
331 |
+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
332 |
+
loss = tf.reduce_mean(per_example_loss)
|
333 |
+
|
334 |
+
return (loss, per_example_loss, probabilities, logits, predictions)
|
335 |
+
|
336 |
+
|
337 |
+
def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate,
|
338 |
+
num_train_steps, num_warmup_steps, use_tpu,
|
339 |
+
use_one_hot_embeddings, max_seq_length, dropout_prob,
|
340 |
+
hub_module):
|
341 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
342 |
+
|
343 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
344 |
+
"""The `model_fn` for TPUEstimator."""
|
345 |
+
|
346 |
+
tf.logging.info("*** Features ***")
|
347 |
+
for name in sorted(features.keys()):
|
348 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
349 |
+
|
350 |
+
input_ids = features["input_ids"]
|
351 |
+
input_mask = features["input_mask"]
|
352 |
+
segment_ids = features["segment_ids"]
|
353 |
+
label_ids = features["label_ids"]
|
354 |
+
is_real_example = None
|
355 |
+
if "is_real_example" in features:
|
356 |
+
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
|
357 |
+
else:
|
358 |
+
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
|
359 |
+
|
360 |
+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
361 |
+
|
362 |
+
(total_loss, per_example_loss, probabilities, logits, predictions) = \
|
363 |
+
create_model(albert_config, is_training, input_ids, input_mask,
|
364 |
+
segment_ids, label_ids, num_labels,
|
365 |
+
use_one_hot_embeddings, max_seq_length, dropout_prob,
|
366 |
+
hub_module)
|
367 |
+
|
368 |
+
tvars = tf.trainable_variables()
|
369 |
+
initialized_variable_names = {}
|
370 |
+
scaffold_fn = None
|
371 |
+
if init_checkpoint:
|
372 |
+
(assignment_map, initialized_variable_names
|
373 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
374 |
+
if use_tpu:
|
375 |
+
|
376 |
+
def tpu_scaffold():
|
377 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
378 |
+
return tf.train.Scaffold()
|
379 |
+
|
380 |
+
scaffold_fn = tpu_scaffold
|
381 |
+
else:
|
382 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
383 |
+
|
384 |
+
tf.logging.info("**** Trainable Variables ****")
|
385 |
+
for var in tvars:
|
386 |
+
init_string = ""
|
387 |
+
if var.name in initialized_variable_names:
|
388 |
+
init_string = ", *INIT_FROM_CKPT*"
|
389 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
390 |
+
init_string)
|
391 |
+
|
392 |
+
output_spec = None
|
393 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
394 |
+
|
395 |
+
train_op = optimization.create_optimizer(
|
396 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
397 |
+
|
398 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
399 |
+
mode=mode,
|
400 |
+
loss=total_loss,
|
401 |
+
train_op=train_op,
|
402 |
+
scaffold_fn=scaffold_fn)
|
403 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
404 |
+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
405 |
+
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
406 |
+
accuracy = tf.metrics.accuracy(
|
407 |
+
labels=label_ids, predictions=predictions,
|
408 |
+
weights=is_real_example)
|
409 |
+
loss = tf.metrics.mean(
|
410 |
+
values=per_example_loss, weights=is_real_example)
|
411 |
+
return {
|
412 |
+
"eval_accuracy": accuracy,
|
413 |
+
"eval_loss": loss,
|
414 |
+
}
|
415 |
+
|
416 |
+
eval_metrics = (metric_fn,
|
417 |
+
[per_example_loss, label_ids, logits, is_real_example])
|
418 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
419 |
+
mode=mode,
|
420 |
+
loss=total_loss,
|
421 |
+
eval_metrics=eval_metrics,
|
422 |
+
scaffold_fn=scaffold_fn)
|
423 |
+
else:
|
424 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
425 |
+
mode=mode,
|
426 |
+
predictions={"probabilities": probabilities,
|
427 |
+
"predictions": predictions},
|
428 |
+
scaffold_fn=scaffold_fn)
|
429 |
+
return output_spec
|
430 |
+
|
431 |
+
return model_fn
|
432 |
+
|
Indic-BERT-v1-master/albert/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run pip install --upgrade pip if tensorflow 1.15 cannot be found
|
2 |
+
tensorflow==1.15.2 # CPU Version of TensorFlow
|
3 |
+
tensorflow_hub==0.7
|
4 |
+
# tensorflow-gpu==1.15 # GPU version of TensorFlow
|
5 |
+
sentencepiece
|
Indic-BERT-v1-master/albert/run_classifier.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""BERT finetuning on classification tasks."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
from albert import classifier_utils
|
24 |
+
from albert import fine_tuning_utils
|
25 |
+
from albert import modeling
|
26 |
+
import tensorflow.compat.v1 as tf
|
27 |
+
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
|
28 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
29 |
+
|
30 |
+
flags = tf.flags
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
|
34 |
+
## Required parameters
|
35 |
+
flags.DEFINE_string(
|
36 |
+
"data_dir", None,
|
37 |
+
"The input data dir. Should contain the .tsv files (or other data files) "
|
38 |
+
"for the task.")
|
39 |
+
|
40 |
+
flags.DEFINE_string(
|
41 |
+
"albert_config_file", None,
|
42 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
43 |
+
"This specifies the model architecture.")
|
44 |
+
|
45 |
+
flags.DEFINE_string("task_name", None, "The name of the task to train.")
|
46 |
+
|
47 |
+
flags.DEFINE_string(
|
48 |
+
"vocab_file", None,
|
49 |
+
"The vocabulary file that the ALBERT model was trained on.")
|
50 |
+
|
51 |
+
flags.DEFINE_string("spm_model_file", None,
|
52 |
+
"The model file for sentence piece tokenization.")
|
53 |
+
|
54 |
+
flags.DEFINE_string(
|
55 |
+
"output_dir", None,
|
56 |
+
"The output directory where the model checkpoints will be written.")
|
57 |
+
|
58 |
+
flags.DEFINE_string("cached_dir", None,
|
59 |
+
"Path to cached training and dev tfrecord file. "
|
60 |
+
"The file will be generated if not exist.")
|
61 |
+
|
62 |
+
## Other parameters
|
63 |
+
|
64 |
+
flags.DEFINE_string(
|
65 |
+
"init_checkpoint", None,
|
66 |
+
"Initial checkpoint (usually from a pre-trained BERT model).")
|
67 |
+
|
68 |
+
flags.DEFINE_string(
|
69 |
+
"albert_hub_module_handle", None,
|
70 |
+
"If set, the ALBERT hub module to use.")
|
71 |
+
|
72 |
+
flags.DEFINE_bool(
|
73 |
+
"do_lower_case", True,
|
74 |
+
"Whether to lower case the input text. Should be True for uncased "
|
75 |
+
"models and False for cased models.")
|
76 |
+
|
77 |
+
flags.DEFINE_integer(
|
78 |
+
"max_seq_length", 512,
|
79 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
80 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
81 |
+
"than this will be padded.")
|
82 |
+
|
83 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
84 |
+
|
85 |
+
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
86 |
+
|
87 |
+
flags.DEFINE_bool(
|
88 |
+
"do_predict", False,
|
89 |
+
"Whether to run the model in inference mode on the test set.")
|
90 |
+
|
91 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
92 |
+
|
93 |
+
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
94 |
+
|
95 |
+
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
|
96 |
+
|
97 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
98 |
+
|
99 |
+
flags.DEFINE_integer("train_step", 1000,
|
100 |
+
"Total number of training steps to perform.")
|
101 |
+
|
102 |
+
flags.DEFINE_integer(
|
103 |
+
"warmup_step", 0,
|
104 |
+
"number of steps to perform linear learning rate warmup for.")
|
105 |
+
|
106 |
+
flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
107 |
+
"How often to save the model checkpoint.")
|
108 |
+
|
109 |
+
flags.DEFINE_integer("keep_checkpoint_max", 5,
|
110 |
+
"How many checkpoints to keep.")
|
111 |
+
|
112 |
+
flags.DEFINE_integer("iterations_per_loop", 1000,
|
113 |
+
"How many steps to make in each estimator call.")
|
114 |
+
|
115 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
116 |
+
|
117 |
+
flags.DEFINE_string("optimizer", "adamw", "Optimizer to use")
|
118 |
+
|
119 |
+
tf.flags.DEFINE_string(
|
120 |
+
"tpu_name", None,
|
121 |
+
"The Cloud TPU to use for training. This should be either the name "
|
122 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
123 |
+
"url.")
|
124 |
+
|
125 |
+
tf.flags.DEFINE_string(
|
126 |
+
"tpu_zone", None,
|
127 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
128 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
129 |
+
"metadata.")
|
130 |
+
|
131 |
+
tf.flags.DEFINE_string(
|
132 |
+
"gcp_project", None,
|
133 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
134 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
135 |
+
"metadata.")
|
136 |
+
|
137 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
138 |
+
|
139 |
+
flags.DEFINE_integer(
|
140 |
+
"num_tpu_cores", 8,
|
141 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
142 |
+
|
143 |
+
|
144 |
+
def main(_):
|
145 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
146 |
+
|
147 |
+
processors = {
|
148 |
+
"cola": classifier_utils.ColaProcessor,
|
149 |
+
"mnli": classifier_utils.MnliProcessor,
|
150 |
+
"mismnli": classifier_utils.MisMnliProcessor,
|
151 |
+
"mrpc": classifier_utils.MrpcProcessor,
|
152 |
+
"rte": classifier_utils.RteProcessor,
|
153 |
+
"sst-2": classifier_utils.Sst2Processor,
|
154 |
+
"sts-b": classifier_utils.StsbProcessor,
|
155 |
+
"qqp": classifier_utils.QqpProcessor,
|
156 |
+
"qnli": classifier_utils.QnliProcessor,
|
157 |
+
"wnli": classifier_utils.WnliProcessor,
|
158 |
+
}
|
159 |
+
|
160 |
+
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
|
161 |
+
raise ValueError(
|
162 |
+
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
|
163 |
+
|
164 |
+
if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
|
165 |
+
raise ValueError("At least one of `--albert_config_file` and "
|
166 |
+
"`--albert_hub_module_handle` must be set")
|
167 |
+
|
168 |
+
if FLAGS.albert_config_file:
|
169 |
+
albert_config = modeling.AlbertConfig.from_json_file(
|
170 |
+
FLAGS.albert_config_file)
|
171 |
+
if FLAGS.max_seq_length > albert_config.max_position_embeddings:
|
172 |
+
raise ValueError(
|
173 |
+
"Cannot use sequence length %d because the ALBERT model "
|
174 |
+
"was only trained up to sequence length %d" %
|
175 |
+
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
|
176 |
+
else:
|
177 |
+
albert_config = None # Get the config from TF-Hub.
|
178 |
+
|
179 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
180 |
+
|
181 |
+
task_name = FLAGS.task_name.lower()
|
182 |
+
|
183 |
+
if task_name not in processors:
|
184 |
+
raise ValueError("Task not found: %s" % (task_name))
|
185 |
+
|
186 |
+
processor = processors[task_name](
|
187 |
+
use_spm=True if FLAGS.spm_model_file else False,
|
188 |
+
do_lower_case=FLAGS.do_lower_case)
|
189 |
+
|
190 |
+
label_list = processor.get_labels()
|
191 |
+
|
192 |
+
tokenizer = fine_tuning_utils.create_vocab(
|
193 |
+
vocab_file=FLAGS.vocab_file,
|
194 |
+
do_lower_case=FLAGS.do_lower_case,
|
195 |
+
spm_model_file=FLAGS.spm_model_file,
|
196 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
197 |
+
|
198 |
+
tpu_cluster_resolver = None
|
199 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
200 |
+
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
|
201 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
202 |
+
|
203 |
+
is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
|
204 |
+
if FLAGS.do_train:
|
205 |
+
iterations_per_loop = int(min(FLAGS.iterations_per_loop,
|
206 |
+
FLAGS.save_checkpoints_steps))
|
207 |
+
else:
|
208 |
+
iterations_per_loop = FLAGS.iterations_per_loop
|
209 |
+
run_config = contrib_tpu.RunConfig(
|
210 |
+
cluster=tpu_cluster_resolver,
|
211 |
+
master=FLAGS.master,
|
212 |
+
model_dir=FLAGS.output_dir,
|
213 |
+
save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
|
214 |
+
keep_checkpoint_max=0,
|
215 |
+
tpu_config=contrib_tpu.TPUConfig(
|
216 |
+
iterations_per_loop=iterations_per_loop,
|
217 |
+
num_shards=FLAGS.num_tpu_cores,
|
218 |
+
per_host_input_for_training=is_per_host))
|
219 |
+
|
220 |
+
train_examples = None
|
221 |
+
if FLAGS.do_train:
|
222 |
+
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
223 |
+
model_fn = classifier_utils.model_fn_builder(
|
224 |
+
albert_config=albert_config,
|
225 |
+
num_labels=len(label_list),
|
226 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
227 |
+
learning_rate=FLAGS.learning_rate,
|
228 |
+
num_train_steps=FLAGS.train_step,
|
229 |
+
num_warmup_steps=FLAGS.warmup_step,
|
230 |
+
use_tpu=FLAGS.use_tpu,
|
231 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
232 |
+
task_name=task_name,
|
233 |
+
hub_module=FLAGS.albert_hub_module_handle,
|
234 |
+
optimizer=FLAGS.optimizer)
|
235 |
+
|
236 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
237 |
+
# or GPU.
|
238 |
+
estimator = contrib_tpu.TPUEstimator(
|
239 |
+
use_tpu=FLAGS.use_tpu,
|
240 |
+
model_fn=model_fn,
|
241 |
+
config=run_config,
|
242 |
+
train_batch_size=FLAGS.train_batch_size,
|
243 |
+
eval_batch_size=FLAGS.eval_batch_size,
|
244 |
+
predict_batch_size=FLAGS.predict_batch_size)
|
245 |
+
|
246 |
+
if FLAGS.do_train:
|
247 |
+
cached_dir = FLAGS.cached_dir
|
248 |
+
if not cached_dir:
|
249 |
+
cached_dir = FLAGS.output_dir
|
250 |
+
train_file = os.path.join(cached_dir, task_name + "_train.tf_record")
|
251 |
+
if not tf.gfile.Exists(train_file):
|
252 |
+
classifier_utils.file_based_convert_examples_to_features(
|
253 |
+
train_examples, label_list, FLAGS.max_seq_length, tokenizer,
|
254 |
+
train_file, task_name)
|
255 |
+
tf.logging.info("***** Running training *****")
|
256 |
+
tf.logging.info(" Num examples = %d", len(train_examples))
|
257 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
258 |
+
tf.logging.info(" Num steps = %d", FLAGS.train_step)
|
259 |
+
train_input_fn = classifier_utils.file_based_input_fn_builder(
|
260 |
+
input_file=train_file,
|
261 |
+
seq_length=FLAGS.max_seq_length,
|
262 |
+
is_training=True,
|
263 |
+
drop_remainder=True,
|
264 |
+
task_name=task_name,
|
265 |
+
use_tpu=FLAGS.use_tpu,
|
266 |
+
bsz=FLAGS.train_batch_size)
|
267 |
+
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step)
|
268 |
+
|
269 |
+
if FLAGS.do_eval:
|
270 |
+
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
|
271 |
+
num_actual_eval_examples = len(eval_examples)
|
272 |
+
if FLAGS.use_tpu:
|
273 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
274 |
+
# of examples must be a multiple of the batch size, or else examples
|
275 |
+
# will get dropped. So we pad with fake examples which are ignored
|
276 |
+
# later on. These do NOT count towards the metric (all tf.metrics
|
277 |
+
# support a per-instance weight, and these get a weight of 0.0).
|
278 |
+
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
279 |
+
eval_examples.append(classifier_utils.PaddingInputExample())
|
280 |
+
|
281 |
+
cached_dir = FLAGS.cached_dir
|
282 |
+
if not cached_dir:
|
283 |
+
cached_dir = FLAGS.output_dir
|
284 |
+
eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record")
|
285 |
+
if not tf.gfile.Exists(eval_file):
|
286 |
+
classifier_utils.file_based_convert_examples_to_features(
|
287 |
+
eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
|
288 |
+
eval_file, task_name)
|
289 |
+
|
290 |
+
tf.logging.info("***** Running evaluation *****")
|
291 |
+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
292 |
+
len(eval_examples), num_actual_eval_examples,
|
293 |
+
len(eval_examples) - num_actual_eval_examples)
|
294 |
+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
295 |
+
|
296 |
+
# This tells the estimator to run through the entire set.
|
297 |
+
eval_steps = None
|
298 |
+
# However, if running eval on the TPU, you will need to specify the
|
299 |
+
# number of steps.
|
300 |
+
if FLAGS.use_tpu:
|
301 |
+
assert len(eval_examples) % FLAGS.eval_batch_size == 0
|
302 |
+
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
|
303 |
+
|
304 |
+
eval_drop_remainder = True if FLAGS.use_tpu else False
|
305 |
+
eval_input_fn = classifier_utils.file_based_input_fn_builder(
|
306 |
+
input_file=eval_file,
|
307 |
+
seq_length=FLAGS.max_seq_length,
|
308 |
+
is_training=False,
|
309 |
+
drop_remainder=eval_drop_remainder,
|
310 |
+
task_name=task_name,
|
311 |
+
use_tpu=FLAGS.use_tpu,
|
312 |
+
bsz=FLAGS.eval_batch_size)
|
313 |
+
|
314 |
+
best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt")
|
315 |
+
|
316 |
+
def _best_trial_info():
|
317 |
+
"""Returns information about which checkpoints have been evaled so far."""
|
318 |
+
if tf.gfile.Exists(best_trial_info_file):
|
319 |
+
with tf.gfile.GFile(best_trial_info_file, "r") as best_info:
|
320 |
+
global_step, best_metric_global_step, metric_value = (
|
321 |
+
best_info.read().split(":"))
|
322 |
+
global_step = int(global_step)
|
323 |
+
best_metric_global_step = int(best_metric_global_step)
|
324 |
+
metric_value = float(metric_value)
|
325 |
+
else:
|
326 |
+
metric_value = -1
|
327 |
+
best_metric_global_step = -1
|
328 |
+
global_step = -1
|
329 |
+
tf.logging.info(
|
330 |
+
"Best trial info: Step: %s, Best Value Step: %s, "
|
331 |
+
"Best Value: %s", global_step, best_metric_global_step, metric_value)
|
332 |
+
return global_step, best_metric_global_step, metric_value
|
333 |
+
|
334 |
+
def _remove_checkpoint(checkpoint_path):
|
335 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
336 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
337 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
338 |
+
tf.gfile.Remove(src_ckpt)
|
339 |
+
|
340 |
+
def _find_valid_cands(curr_step):
|
341 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
342 |
+
candidates = []
|
343 |
+
for filename in filenames:
|
344 |
+
if filename.endswith(".index"):
|
345 |
+
ckpt_name = filename[:-6]
|
346 |
+
idx = ckpt_name.split("-")[-1]
|
347 |
+
if int(idx) > curr_step:
|
348 |
+
candidates.append(filename)
|
349 |
+
return candidates
|
350 |
+
|
351 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
352 |
+
|
353 |
+
if task_name == "sts-b":
|
354 |
+
key_name = "pearson"
|
355 |
+
elif task_name == "cola":
|
356 |
+
key_name = "matthew_corr"
|
357 |
+
else:
|
358 |
+
key_name = "eval_accuracy"
|
359 |
+
|
360 |
+
global_step, best_perf_global_step, best_perf = _best_trial_info()
|
361 |
+
writer = tf.gfile.GFile(output_eval_file, "w")
|
362 |
+
while global_step < FLAGS.train_step:
|
363 |
+
steps_and_files = {}
|
364 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
365 |
+
for filename in filenames:
|
366 |
+
if filename.endswith(".index"):
|
367 |
+
ckpt_name = filename[:-6]
|
368 |
+
cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
|
369 |
+
if cur_filename.split("-")[-1] == "best":
|
370 |
+
continue
|
371 |
+
gstep = int(cur_filename.split("-")[-1])
|
372 |
+
if gstep not in steps_and_files:
|
373 |
+
tf.logging.info("Add {} to eval list.".format(cur_filename))
|
374 |
+
steps_and_files[gstep] = cur_filename
|
375 |
+
tf.logging.info("found {} files.".format(len(steps_and_files)))
|
376 |
+
if not steps_and_files:
|
377 |
+
tf.logging.info("found 0 file, global step: {}. Sleeping."
|
378 |
+
.format(global_step))
|
379 |
+
time.sleep(60)
|
380 |
+
else:
|
381 |
+
for checkpoint in sorted(steps_and_files.items()):
|
382 |
+
step, checkpoint_path = checkpoint
|
383 |
+
if global_step >= step:
|
384 |
+
if (best_perf_global_step != step and
|
385 |
+
len(_find_valid_cands(step)) > 1):
|
386 |
+
_remove_checkpoint(checkpoint_path)
|
387 |
+
continue
|
388 |
+
result = estimator.evaluate(
|
389 |
+
input_fn=eval_input_fn,
|
390 |
+
steps=eval_steps,
|
391 |
+
checkpoint_path=checkpoint_path)
|
392 |
+
global_step = result["global_step"]
|
393 |
+
tf.logging.info("***** Eval results *****")
|
394 |
+
for key in sorted(result.keys()):
|
395 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
396 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
397 |
+
writer.write("best = {}\n".format(best_perf))
|
398 |
+
if result[key_name] > best_perf:
|
399 |
+
best_perf = result[key_name]
|
400 |
+
best_perf_global_step = global_step
|
401 |
+
elif len(_find_valid_cands(global_step)) > 1:
|
402 |
+
_remove_checkpoint(checkpoint_path)
|
403 |
+
writer.write("=" * 50 + "\n")
|
404 |
+
writer.flush()
|
405 |
+
with tf.gfile.GFile(best_trial_info_file, "w") as best_info:
|
406 |
+
best_info.write("{}:{}:{}".format(
|
407 |
+
global_step, best_perf_global_step, best_perf))
|
408 |
+
writer.close()
|
409 |
+
|
410 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
411 |
+
src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext)
|
412 |
+
tgt_ckpt = "model.ckpt-best.{}".format(ext)
|
413 |
+
tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
|
414 |
+
tf.io.gfile.rename(
|
415 |
+
os.path.join(FLAGS.output_dir, src_ckpt),
|
416 |
+
os.path.join(FLAGS.output_dir, tgt_ckpt),
|
417 |
+
overwrite=True)
|
418 |
+
|
419 |
+
if FLAGS.do_predict:
|
420 |
+
predict_examples = processor.get_test_examples(FLAGS.data_dir)
|
421 |
+
num_actual_predict_examples = len(predict_examples)
|
422 |
+
if FLAGS.use_tpu:
|
423 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
424 |
+
# of examples must be a multiple of the batch size, or else examples
|
425 |
+
# will get dropped. So we pad with fake examples which are ignored
|
426 |
+
# later on.
|
427 |
+
while len(predict_examples) % FLAGS.predict_batch_size != 0:
|
428 |
+
predict_examples.append(classifier_utils.PaddingInputExample())
|
429 |
+
|
430 |
+
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
431 |
+
classifier_utils.file_based_convert_examples_to_features(
|
432 |
+
predict_examples, label_list,
|
433 |
+
FLAGS.max_seq_length, tokenizer,
|
434 |
+
predict_file, task_name)
|
435 |
+
|
436 |
+
tf.logging.info("***** Running prediction*****")
|
437 |
+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
438 |
+
len(predict_examples), num_actual_predict_examples,
|
439 |
+
len(predict_examples) - num_actual_predict_examples)
|
440 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
441 |
+
|
442 |
+
predict_drop_remainder = True if FLAGS.use_tpu else False
|
443 |
+
predict_input_fn = classifier_utils.file_based_input_fn_builder(
|
444 |
+
input_file=predict_file,
|
445 |
+
seq_length=FLAGS.max_seq_length,
|
446 |
+
is_training=False,
|
447 |
+
drop_remainder=predict_drop_remainder,
|
448 |
+
task_name=task_name,
|
449 |
+
use_tpu=FLAGS.use_tpu,
|
450 |
+
bsz=FLAGS.predict_batch_size)
|
451 |
+
|
452 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
453 |
+
result = estimator.predict(
|
454 |
+
input_fn=predict_input_fn,
|
455 |
+
checkpoint_path=checkpoint_path)
|
456 |
+
|
457 |
+
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
|
458 |
+
output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv")
|
459 |
+
with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\
|
460 |
+
tf.gfile.GFile(output_submit_file, "w") as sub_writer:
|
461 |
+
sub_writer.write("index" + "\t" + "prediction\n")
|
462 |
+
num_written_lines = 0
|
463 |
+
tf.logging.info("***** Predict results *****")
|
464 |
+
for (i, (example, prediction)) in\
|
465 |
+
enumerate(zip(predict_examples, result)):
|
466 |
+
probabilities = prediction["probabilities"]
|
467 |
+
if i >= num_actual_predict_examples:
|
468 |
+
break
|
469 |
+
output_line = "\t".join(
|
470 |
+
str(class_probability)
|
471 |
+
for class_probability in probabilities) + "\n"
|
472 |
+
pred_writer.write(output_line)
|
473 |
+
|
474 |
+
if task_name != "sts-b":
|
475 |
+
actual_label = label_list[int(prediction["predictions"])]
|
476 |
+
else:
|
477 |
+
actual_label = str(prediction["predictions"])
|
478 |
+
sub_writer.write(example.guid + "\t" + actual_label + "\n")
|
479 |
+
num_written_lines += 1
|
480 |
+
assert num_written_lines == num_actual_predict_examples
|
481 |
+
|
482 |
+
|
483 |
+
if __name__ == "__main__":
|
484 |
+
flags.mark_flag_as_required("data_dir")
|
485 |
+
flags.mark_flag_as_required("task_name")
|
486 |
+
flags.mark_flag_as_required("spm_model_file")
|
487 |
+
flags.mark_flag_as_required("output_dir")
|
488 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/run_glue.sh
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# This is a convenience script for evaluating ALBERT on the GLUE benchmark.
|
3 |
+
#
|
4 |
+
# By default, this script uses a pretrained ALBERT v1 BASE model, but you may
|
5 |
+
# use a custom checkpoint or any compatible TF-Hub checkpoint with minimal
|
6 |
+
# edits to environment variables (see ALBERT_HUB_MODULE_HANDLE below).
|
7 |
+
#
|
8 |
+
# This script does fine-tuning and evaluation on 8 tasks, so it may take a
|
9 |
+
# while to complete if you do not have a hardware accelerator.
|
10 |
+
|
11 |
+
set -ex
|
12 |
+
|
13 |
+
python3 -m venv $HOME/albertenv
|
14 |
+
. $HOME/albertenv/bin/activate
|
15 |
+
|
16 |
+
OUTPUT_DIR_BASE="$(mktemp -d)"
|
17 |
+
OUTPUT_DIR="${OUTPUT_DIR_BASE}/output"
|
18 |
+
|
19 |
+
# To start from a custom pretrained checkpoint, set ALBERT_HUB_MODULE_HANDLE
|
20 |
+
# below to an empty string and set INIT_CHECKPOINT to your checkpoint path.
|
21 |
+
ALBERT_HUB_MODULE_HANDLE="https://tfhub.dev/google/albert_base/1"
|
22 |
+
INIT_CHECKPOINT=""
|
23 |
+
|
24 |
+
pip3 install --upgrade pip
|
25 |
+
pip3 install numpy
|
26 |
+
pip3 install -r requirements.txt
|
27 |
+
|
28 |
+
function run_task() {
|
29 |
+
COMMON_ARGS="--output_dir="${OUTPUT_DIR}/$1" --data_dir="${ALBERT_ROOT}/glue" --vocab_file="${ALBERT_ROOT}/vocab.txt" --spm_model_file="${ALBERT_ROOT}/30k-clean.model" --do_lower_case --max_seq_length=512 --optimizer=adamw --task_name=$1 --warmup_step=$2 --learning_rate=$3 --train_step=$4 --save_checkpoints_steps=$5 --train_batch_size=$6"
|
30 |
+
python3 -m run_classifier \
|
31 |
+
${COMMON_ARGS} \
|
32 |
+
--do_train \
|
33 |
+
--nodo_eval \
|
34 |
+
--nodo_predict \
|
35 |
+
--albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" \
|
36 |
+
--init_checkpoint="${INIT_CHECKPOINT}"
|
37 |
+
python3 -m run_classifier \
|
38 |
+
${COMMON_ARGS} \
|
39 |
+
--nodo_train \
|
40 |
+
--do_eval \
|
41 |
+
--do_predict \
|
42 |
+
--albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}"
|
43 |
+
}
|
44 |
+
|
45 |
+
run_task SST-2 1256 1e-5 20935 100 32
|
46 |
+
run_task MNLI 1000 3e-5 10000 100 128
|
47 |
+
run_task CoLA 320 1e-5 5336 100 16
|
48 |
+
run_task QNLI 1986 1e-5 33112 200 32
|
49 |
+
run_task QQP 1000 5e-5 14000 100 128
|
50 |
+
run_task RTE 200 3e-5 800 100 32
|
51 |
+
run_task STS-B 214 2e-5 3598 100 16
|
52 |
+
run_task MRPC 200 2e-5 800 100 32
|
Indic-BERT-v1-master/albert/run_pretraining.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Run masked LM/next sentence masked_lm pre-training for ALBERT."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
from albert import modeling
|
24 |
+
from albert import optimization
|
25 |
+
from six.moves import range
|
26 |
+
import tensorflow.compat.v1 as tf
|
27 |
+
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
|
28 |
+
from tensorflow.contrib import data as contrib_data
|
29 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
30 |
+
|
31 |
+
flags = tf.flags
|
32 |
+
|
33 |
+
FLAGS = flags.FLAGS
|
34 |
+
|
35 |
+
## Required parameters
|
36 |
+
flags.DEFINE_string(
|
37 |
+
"albert_config_file", None,
|
38 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
39 |
+
"This specifies the model architecture.")
|
40 |
+
|
41 |
+
flags.DEFINE_string(
|
42 |
+
"input_file", None,
|
43 |
+
"Input TF example files (can be a glob or comma separated).")
|
44 |
+
|
45 |
+
flags.DEFINE_string(
|
46 |
+
"output_dir", None,
|
47 |
+
"The output directory where the model checkpoints will be written.")
|
48 |
+
|
49 |
+
## Other parameters
|
50 |
+
flags.DEFINE_string(
|
51 |
+
"init_checkpoint", None,
|
52 |
+
"Initial checkpoint (usually from a pre-trained ALBERT model).")
|
53 |
+
|
54 |
+
flags.DEFINE_integer(
|
55 |
+
"max_seq_length", 512,
|
56 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
57 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
58 |
+
"than this will be padded. Must match data generation.")
|
59 |
+
|
60 |
+
flags.DEFINE_integer(
|
61 |
+
"max_predictions_per_seq", 20,
|
62 |
+
"Maximum number of masked LM predictions per sequence. "
|
63 |
+
"Must match data generation.")
|
64 |
+
|
65 |
+
flags.DEFINE_bool("do_train", True, "Whether to run training.")
|
66 |
+
|
67 |
+
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
68 |
+
|
69 |
+
flags.DEFINE_integer("train_batch_size", 4096, "Total batch size for training.")
|
70 |
+
|
71 |
+
flags.DEFINE_integer("eval_batch_size", 64, "Total batch size for eval.")
|
72 |
+
|
73 |
+
flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
|
74 |
+
"The optimizer for training.")
|
75 |
+
|
76 |
+
flags.DEFINE_float("learning_rate", 0.00176, "The initial learning rate.")
|
77 |
+
|
78 |
+
flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
|
79 |
+
|
80 |
+
flags.DEFINE_integer("num_train_steps", 125000, "Number of training steps.")
|
81 |
+
|
82 |
+
flags.DEFINE_integer("num_warmup_steps", 3125, "Number of warmup steps.")
|
83 |
+
|
84 |
+
flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
|
85 |
+
|
86 |
+
flags.DEFINE_integer("save_checkpoints_steps", 5000,
|
87 |
+
"How often to save the model checkpoint.")
|
88 |
+
|
89 |
+
flags.DEFINE_integer("keep_checkpoint_max", 5,
|
90 |
+
"How many checkpoints to keep.")
|
91 |
+
|
92 |
+
flags.DEFINE_integer("iterations_per_loop", 1000,
|
93 |
+
"How many steps to make in each estimator call.")
|
94 |
+
|
95 |
+
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
96 |
+
|
97 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
98 |
+
|
99 |
+
flags.DEFINE_bool("init_from_group0", False, "Whether to initialize"
|
100 |
+
"parameters of other groups from group 0")
|
101 |
+
|
102 |
+
tf.flags.DEFINE_string(
|
103 |
+
"tpu_name", None,
|
104 |
+
"The Cloud TPU to use for training. This should be either the name "
|
105 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
106 |
+
"url.")
|
107 |
+
|
108 |
+
tf.flags.DEFINE_string(
|
109 |
+
"tpu_zone", None,
|
110 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
111 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
112 |
+
"metadata.")
|
113 |
+
|
114 |
+
tf.flags.DEFINE_string(
|
115 |
+
"gcp_project", None,
|
116 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
117 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
118 |
+
"metadata.")
|
119 |
+
|
120 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
121 |
+
|
122 |
+
flags.DEFINE_integer(
|
123 |
+
"num_tpu_cores", 8,
|
124 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
125 |
+
|
126 |
+
flags.DEFINE_float(
|
127 |
+
"masked_lm_budget", 0,
|
128 |
+
"If >0, the ratio of masked ngrams to unmasked ngrams. Default 0,"
|
129 |
+
"for offline masking")
|
130 |
+
|
131 |
+
|
132 |
+
def model_fn_builder(albert_config, init_checkpoint, learning_rate,
|
133 |
+
num_train_steps, num_warmup_steps, use_tpu,
|
134 |
+
use_one_hot_embeddings, optimizer, poly_power,
|
135 |
+
start_warmup_step):
|
136 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
137 |
+
|
138 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
139 |
+
"""The `model_fn` for TPUEstimator."""
|
140 |
+
|
141 |
+
tf.logging.info("*** Features ***")
|
142 |
+
for name in sorted(features.keys()):
|
143 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
144 |
+
|
145 |
+
input_ids = features["input_ids"]
|
146 |
+
input_mask = features["input_mask"]
|
147 |
+
segment_ids = features["segment_ids"]
|
148 |
+
masked_lm_positions = features["masked_lm_positions"]
|
149 |
+
masked_lm_ids = features["masked_lm_ids"]
|
150 |
+
masked_lm_weights = features["masked_lm_weights"]
|
151 |
+
# Note: We keep this feature name `next_sentence_labels` to be compatible
|
152 |
+
# with the original data created by lanzhzh@. However, in the ALBERT case
|
153 |
+
# it does represent sentence_order_labels.
|
154 |
+
sentence_order_labels = features["next_sentence_labels"]
|
155 |
+
|
156 |
+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
157 |
+
|
158 |
+
model = modeling.AlbertModel(
|
159 |
+
config=albert_config,
|
160 |
+
is_training=is_training,
|
161 |
+
input_ids=input_ids,
|
162 |
+
input_mask=input_mask,
|
163 |
+
token_type_ids=segment_ids,
|
164 |
+
use_one_hot_embeddings=use_one_hot_embeddings)
|
165 |
+
|
166 |
+
(masked_lm_loss, masked_lm_example_loss,
|
167 |
+
masked_lm_log_probs) = get_masked_lm_output(albert_config,
|
168 |
+
model.get_sequence_output(),
|
169 |
+
model.get_embedding_table(),
|
170 |
+
masked_lm_positions,
|
171 |
+
masked_lm_ids,
|
172 |
+
masked_lm_weights)
|
173 |
+
|
174 |
+
# (sentence_order_loss, sentence_order_example_loss,
|
175 |
+
# sentence_order_log_probs) = get_sentence_order_output(
|
176 |
+
# albert_config, model.get_pooled_output(), sentence_order_labels)
|
177 |
+
|
178 |
+
total_loss = masked_lm_loss # + sentence_order_loss
|
179 |
+
|
180 |
+
tvars = tf.trainable_variables()
|
181 |
+
|
182 |
+
initialized_variable_names = {}
|
183 |
+
scaffold_fn = None
|
184 |
+
if init_checkpoint:
|
185 |
+
tf.logging.info("number of hidden group %d to initialize",
|
186 |
+
albert_config.num_hidden_groups)
|
187 |
+
num_of_initialize_group = 1
|
188 |
+
if FLAGS.init_from_group0:
|
189 |
+
num_of_initialize_group = albert_config.num_hidden_groups
|
190 |
+
if albert_config.net_structure_type > 0:
|
191 |
+
num_of_initialize_group = albert_config.num_hidden_layers
|
192 |
+
(assignment_map, initialized_variable_names
|
193 |
+
) = modeling.get_assignment_map_from_checkpoint(
|
194 |
+
tvars, init_checkpoint, num_of_initialize_group)
|
195 |
+
if use_tpu:
|
196 |
+
|
197 |
+
def tpu_scaffold():
|
198 |
+
for gid in range(num_of_initialize_group):
|
199 |
+
tf.logging.info("initialize the %dth layer", gid)
|
200 |
+
tf.logging.info(assignment_map[gid])
|
201 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])
|
202 |
+
return tf.train.Scaffold()
|
203 |
+
|
204 |
+
scaffold_fn = tpu_scaffold
|
205 |
+
else:
|
206 |
+
for gid in range(num_of_initialize_group):
|
207 |
+
tf.logging.info("initialize the %dth layer", gid)
|
208 |
+
tf.logging.info(assignment_map[gid])
|
209 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map[gid])
|
210 |
+
|
211 |
+
tf.logging.info("**** Trainable Variables ****")
|
212 |
+
for var in tvars:
|
213 |
+
init_string = ""
|
214 |
+
if var.name in initialized_variable_names:
|
215 |
+
init_string = ", *INIT_FROM_CKPT*"
|
216 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
217 |
+
init_string)
|
218 |
+
|
219 |
+
output_spec = None
|
220 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
221 |
+
train_op = optimization.create_optimizer(
|
222 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps,
|
223 |
+
use_tpu, optimizer, poly_power, start_warmup_step)
|
224 |
+
|
225 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
226 |
+
mode=mode,
|
227 |
+
loss=total_loss,
|
228 |
+
train_op=train_op,
|
229 |
+
scaffold_fn=scaffold_fn)
|
230 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
231 |
+
|
232 |
+
def metric_fn(*args):
|
233 |
+
"""Computes the loss and accuracy of the model."""
|
234 |
+
(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
235 |
+
masked_lm_weights, sentence_order_example_loss,
|
236 |
+
sentence_order_log_probs, sentence_order_labels) = args[:7]
|
237 |
+
|
238 |
+
|
239 |
+
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
|
240 |
+
[-1, masked_lm_log_probs.shape[-1]])
|
241 |
+
masked_lm_predictions = tf.argmax(
|
242 |
+
masked_lm_log_probs, axis=-1, output_type=tf.int32)
|
243 |
+
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
244 |
+
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
245 |
+
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
246 |
+
masked_lm_accuracy = tf.metrics.accuracy(
|
247 |
+
labels=masked_lm_ids,
|
248 |
+
predictions=masked_lm_predictions,
|
249 |
+
weights=masked_lm_weights)
|
250 |
+
masked_lm_mean_loss = tf.metrics.mean(
|
251 |
+
values=masked_lm_example_loss, weights=masked_lm_weights)
|
252 |
+
|
253 |
+
metrics = {
|
254 |
+
"masked_lm_accuracy": masked_lm_accuracy,
|
255 |
+
"masked_lm_loss": masked_lm_mean_loss,
|
256 |
+
}
|
257 |
+
|
258 |
+
sentence_order_log_probs = tf.reshape(
|
259 |
+
sentence_order_log_probs, [-1, sentence_order_log_probs.shape[-1]])
|
260 |
+
sentence_order_predictions = tf.argmax(
|
261 |
+
sentence_order_log_probs, axis=-1, output_type=tf.int32)
|
262 |
+
sentence_order_labels = tf.reshape(sentence_order_labels, [-1])
|
263 |
+
sentence_order_accuracy = tf.metrics.accuracy(
|
264 |
+
labels=sentence_order_labels,
|
265 |
+
predictions=sentence_order_predictions)
|
266 |
+
sentence_order_mean_loss = tf.metrics.mean(
|
267 |
+
values=sentence_order_example_loss)
|
268 |
+
metrics.update({
|
269 |
+
"sentence_order_accuracy": sentence_order_accuracy,
|
270 |
+
"sentence_order_loss": sentence_order_mean_loss
|
271 |
+
})
|
272 |
+
return metrics
|
273 |
+
|
274 |
+
metric_values = [
|
275 |
+
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
276 |
+
masked_lm_weights, sentence_order_example_loss,
|
277 |
+
sentence_order_log_probs, sentence_order_labels
|
278 |
+
]
|
279 |
+
|
280 |
+
eval_metrics = (metric_fn, metric_values)
|
281 |
+
|
282 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
283 |
+
mode=mode,
|
284 |
+
loss=total_loss,
|
285 |
+
eval_metrics=eval_metrics,
|
286 |
+
scaffold_fn=scaffold_fn)
|
287 |
+
else:
|
288 |
+
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
289 |
+
|
290 |
+
return output_spec
|
291 |
+
|
292 |
+
return model_fn
|
293 |
+
|
294 |
+
|
295 |
+
def get_masked_lm_output(albert_config, input_tensor, output_weights, positions,
|
296 |
+
label_ids, label_weights):
|
297 |
+
"""Get loss and log probs for the masked LM."""
|
298 |
+
input_tensor = gather_indexes(input_tensor, positions)
|
299 |
+
|
300 |
+
|
301 |
+
with tf.variable_scope("cls/predictions"):
|
302 |
+
# We apply one more non-linear transformation before the output layer.
|
303 |
+
# This matrix is not used after pre-training.
|
304 |
+
with tf.variable_scope("transform"):
|
305 |
+
input_tensor = tf.layers.dense(
|
306 |
+
input_tensor,
|
307 |
+
units=albert_config.embedding_size,
|
308 |
+
activation=modeling.get_activation(albert_config.hidden_act),
|
309 |
+
kernel_initializer=modeling.create_initializer(
|
310 |
+
albert_config.initializer_range))
|
311 |
+
input_tensor = modeling.layer_norm(input_tensor)
|
312 |
+
|
313 |
+
# The output weights are the same as the input embeddings, but there is
|
314 |
+
# an output-only bias for each token.
|
315 |
+
output_bias = tf.get_variable(
|
316 |
+
"output_bias",
|
317 |
+
shape=[albert_config.vocab_size],
|
318 |
+
initializer=tf.zeros_initializer())
|
319 |
+
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
320 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
321 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
322 |
+
|
323 |
+
label_ids = tf.reshape(label_ids, [-1])
|
324 |
+
label_weights = tf.reshape(label_weights, [-1])
|
325 |
+
|
326 |
+
one_hot_labels = tf.one_hot(
|
327 |
+
label_ids, depth=albert_config.vocab_size, dtype=tf.float32)
|
328 |
+
|
329 |
+
# The `positions` tensor might be zero-padded (if the sequence is too
|
330 |
+
# short to have the maximum number of predictions). The `label_weights`
|
331 |
+
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
332 |
+
# padding predictions.
|
333 |
+
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
334 |
+
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
335 |
+
denominator = tf.reduce_sum(label_weights) + 1e-5
|
336 |
+
loss = numerator / denominator
|
337 |
+
|
338 |
+
return (loss, per_example_loss, log_probs)
|
339 |
+
|
340 |
+
|
341 |
+
def get_sentence_order_output(albert_config, input_tensor, labels):
|
342 |
+
"""Get loss and log probs for the next sentence prediction."""
|
343 |
+
|
344 |
+
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
345 |
+
# "random sentence". This weight matrix is not used after pre-training.
|
346 |
+
with tf.variable_scope("cls/seq_relationship"):
|
347 |
+
output_weights = tf.get_variable(
|
348 |
+
"output_weights",
|
349 |
+
shape=[2, albert_config.hidden_size],
|
350 |
+
initializer=modeling.create_initializer(
|
351 |
+
albert_config.initializer_range))
|
352 |
+
output_bias = tf.get_variable(
|
353 |
+
"output_bias", shape=[2], initializer=tf.zeros_initializer())
|
354 |
+
|
355 |
+
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
356 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
357 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
358 |
+
labels = tf.reshape(labels, [-1])
|
359 |
+
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
360 |
+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
361 |
+
loss = tf.reduce_mean(per_example_loss)
|
362 |
+
return (loss, per_example_loss, log_probs)
|
363 |
+
|
364 |
+
|
365 |
+
def gather_indexes(sequence_tensor, positions):
|
366 |
+
"""Gathers the vectors at the specific positions over a minibatch."""
|
367 |
+
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
368 |
+
batch_size = sequence_shape[0]
|
369 |
+
seq_length = sequence_shape[1]
|
370 |
+
width = sequence_shape[2]
|
371 |
+
|
372 |
+
flat_offsets = tf.reshape(
|
373 |
+
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
|
374 |
+
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
375 |
+
flat_sequence_tensor = tf.reshape(sequence_tensor,
|
376 |
+
[batch_size * seq_length, width])
|
377 |
+
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
378 |
+
return output_tensor
|
379 |
+
|
380 |
+
|
381 |
+
def input_fn_builder(input_files,
|
382 |
+
max_seq_length,
|
383 |
+
max_predictions_per_seq,
|
384 |
+
is_training,
|
385 |
+
num_cpu_threads=4):
|
386 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
387 |
+
|
388 |
+
def input_fn(params):
|
389 |
+
"""The actual input function."""
|
390 |
+
batch_size = params["batch_size"]
|
391 |
+
|
392 |
+
name_to_features = {
|
393 |
+
"input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
394 |
+
"input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
|
395 |
+
"segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
396 |
+
# Note: We keep this feature name `next_sentence_labels` to be
|
397 |
+
# compatible with the original data created by lanzhzh@. However, in
|
398 |
+
# the ALBERT case it does represent sentence_order_labels.
|
399 |
+
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
|
400 |
+
}
|
401 |
+
|
402 |
+
if FLAGS.masked_lm_budget:
|
403 |
+
name_to_features.update({
|
404 |
+
"token_boundary":
|
405 |
+
tf.FixedLenFeature([max_seq_length], tf.int64)})
|
406 |
+
else:
|
407 |
+
name_to_features.update({
|
408 |
+
"masked_lm_positions":
|
409 |
+
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
410 |
+
"masked_lm_ids":
|
411 |
+
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
412 |
+
"masked_lm_weights":
|
413 |
+
tf.FixedLenFeature([max_predictions_per_seq], tf.float32)})
|
414 |
+
|
415 |
+
# For training, we want a lot of parallel reading and shuffling.
|
416 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
417 |
+
if is_training:
|
418 |
+
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
419 |
+
d = d.repeat()
|
420 |
+
d = d.shuffle(buffer_size=len(input_files))
|
421 |
+
|
422 |
+
# `cycle_length` is the number of parallel files that get read.
|
423 |
+
cycle_length = min(num_cpu_threads, len(input_files))
|
424 |
+
|
425 |
+
# `sloppy` mode means that the interleaving is not exact. This adds
|
426 |
+
# even more randomness to the training pipeline.
|
427 |
+
d = d.apply(
|
428 |
+
contrib_data.parallel_interleave(
|
429 |
+
tf.data.TFRecordDataset,
|
430 |
+
sloppy=is_training,
|
431 |
+
cycle_length=cycle_length))
|
432 |
+
d = d.shuffle(buffer_size=100)
|
433 |
+
else:
|
434 |
+
d = tf.data.TFRecordDataset(input_files)
|
435 |
+
# Since we evaluate for a fixed number of steps we don't want to encounter
|
436 |
+
# out-of-range exceptions.
|
437 |
+
d = d.repeat()
|
438 |
+
|
439 |
+
# We must `drop_remainder` on training because the TPU requires fixed
|
440 |
+
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
441 |
+
# and we *don't* want to drop the remainder, otherwise we wont cover
|
442 |
+
# every sample.
|
443 |
+
d = d.apply(
|
444 |
+
tf.data.experimental.map_and_batch_with_legacy_function(
|
445 |
+
lambda record: _decode_record(record, name_to_features),
|
446 |
+
batch_size=batch_size,
|
447 |
+
num_parallel_batches=num_cpu_threads,
|
448 |
+
drop_remainder=True))
|
449 |
+
tf.logging.info(d)
|
450 |
+
return d
|
451 |
+
|
452 |
+
return input_fn
|
453 |
+
|
454 |
+
|
455 |
+
def _decode_record(record, name_to_features):
|
456 |
+
"""Decodes a record to a TensorFlow example."""
|
457 |
+
example = tf.parse_single_example(record, name_to_features)
|
458 |
+
|
459 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
460 |
+
# So cast all int64 to int32.
|
461 |
+
for name in list(example.keys()):
|
462 |
+
t = example[name]
|
463 |
+
if t.dtype == tf.int64:
|
464 |
+
t = tf.to_int32(t)
|
465 |
+
example[name] = t
|
466 |
+
|
467 |
+
return example
|
468 |
+
|
469 |
+
|
470 |
+
def main(_):
|
471 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
472 |
+
|
473 |
+
if not FLAGS.do_train and not FLAGS.do_eval:
|
474 |
+
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
475 |
+
|
476 |
+
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
477 |
+
|
478 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
479 |
+
|
480 |
+
input_files = []
|
481 |
+
for input_pattern in FLAGS.input_file.split(","):
|
482 |
+
input_files.extend(tf.gfile.Glob(input_pattern))
|
483 |
+
|
484 |
+
tf.logging.info("*** Input Files ***")
|
485 |
+
for input_file in input_files:
|
486 |
+
tf.logging.info(" %s" % input_file)
|
487 |
+
|
488 |
+
tpu_cluster_resolver = None
|
489 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
490 |
+
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
|
491 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
492 |
+
|
493 |
+
is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
|
494 |
+
run_config = contrib_tpu.RunConfig(
|
495 |
+
cluster=tpu_cluster_resolver,
|
496 |
+
master=FLAGS.master,
|
497 |
+
model_dir=FLAGS.output_dir,
|
498 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
499 |
+
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
500 |
+
tpu_config=contrib_tpu.TPUConfig(
|
501 |
+
iterations_per_loop=FLAGS.iterations_per_loop,
|
502 |
+
num_shards=FLAGS.num_tpu_cores,
|
503 |
+
per_host_input_for_training=is_per_host))
|
504 |
+
|
505 |
+
model_fn = model_fn_builder(
|
506 |
+
albert_config=albert_config,
|
507 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
508 |
+
learning_rate=FLAGS.learning_rate,
|
509 |
+
num_train_steps=FLAGS.num_train_steps,
|
510 |
+
num_warmup_steps=FLAGS.num_warmup_steps,
|
511 |
+
use_tpu=FLAGS.use_tpu,
|
512 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
513 |
+
optimizer=FLAGS.optimizer,
|
514 |
+
poly_power=FLAGS.poly_power,
|
515 |
+
start_warmup_step=FLAGS.start_warmup_step)
|
516 |
+
|
517 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
518 |
+
# or GPU.
|
519 |
+
estimator = contrib_tpu.TPUEstimator(
|
520 |
+
use_tpu=FLAGS.use_tpu,
|
521 |
+
model_fn=model_fn,
|
522 |
+
config=run_config,
|
523 |
+
train_batch_size=FLAGS.train_batch_size,
|
524 |
+
eval_batch_size=FLAGS.eval_batch_size)
|
525 |
+
|
526 |
+
if FLAGS.do_train:
|
527 |
+
tf.logging.info("***** Running training *****")
|
528 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
529 |
+
train_input_fn = input_fn_builder(
|
530 |
+
input_files=input_files,
|
531 |
+
max_seq_length=FLAGS.max_seq_length,
|
532 |
+
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
533 |
+
is_training=True)
|
534 |
+
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
535 |
+
|
536 |
+
if FLAGS.do_eval:
|
537 |
+
tf.logging.info("***** Running evaluation *****")
|
538 |
+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
539 |
+
global_step = -1
|
540 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
541 |
+
writer = tf.gfile.GFile(output_eval_file, "w")
|
542 |
+
eval_input_fn = input_fn_builder(
|
543 |
+
input_files=input_files,
|
544 |
+
max_seq_length=FLAGS.max_seq_length,
|
545 |
+
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
546 |
+
is_training=False)
|
547 |
+
best_perf = 0
|
548 |
+
key_name = "masked_lm_accuracy"
|
549 |
+
while global_step < FLAGS.num_train_steps:
|
550 |
+
if estimator.latest_checkpoint() is None:
|
551 |
+
tf.logging.info("No checkpoint found yet. Sleeping.")
|
552 |
+
time.sleep(1)
|
553 |
+
else:
|
554 |
+
result = estimator.evaluate(
|
555 |
+
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
556 |
+
global_step = result["global_step"]
|
557 |
+
tf.logging.info("***** Eval results *****")
|
558 |
+
checkpoint_path = estimator.latest_checkpoint()
|
559 |
+
for key in sorted(result.keys()):
|
560 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
561 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
562 |
+
if result[key_name] > best_perf:
|
563 |
+
best_perf = result[key_name]
|
564 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
565 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
566 |
+
tgt_ckpt = checkpoint_path.rsplit(
|
567 |
+
"-", 1)[0] + "-best.{}".format(ext)
|
568 |
+
tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
|
569 |
+
tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
|
570 |
+
writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
|
571 |
+
|
572 |
+
|
573 |
+
if __name__ == "__main__":
|
574 |
+
flags.mark_flag_as_required("input_file")
|
575 |
+
flags.mark_flag_as_required("albert_config_file")
|
576 |
+
flags.mark_flag_as_required("output_dir")
|
577 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/run_pretraining_test.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Tests for run_pretraining."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
import random
|
24 |
+
import tempfile
|
25 |
+
from absl.testing import flagsaver
|
26 |
+
from albert import modeling
|
27 |
+
from albert import run_pretraining
|
28 |
+
import tensorflow.compat.v1 as tf
|
29 |
+
|
30 |
+
FLAGS = tf.app.flags.FLAGS
|
31 |
+
|
32 |
+
|
33 |
+
def _create_config_file(filename, max_seq_length, vocab_size):
|
34 |
+
"""Creates an AlbertConfig and saves it to file."""
|
35 |
+
albert_config = modeling.AlbertConfig(
|
36 |
+
vocab_size,
|
37 |
+
embedding_size=5,
|
38 |
+
hidden_size=14,
|
39 |
+
num_hidden_layers=3,
|
40 |
+
num_hidden_groups=1,
|
41 |
+
num_attention_heads=2,
|
42 |
+
intermediate_size=19,
|
43 |
+
inner_group_num=1,
|
44 |
+
down_scale_factor=1,
|
45 |
+
hidden_act="gelu",
|
46 |
+
hidden_dropout_prob=0,
|
47 |
+
attention_probs_dropout_prob=0,
|
48 |
+
max_position_embeddings=max_seq_length,
|
49 |
+
type_vocab_size=2,
|
50 |
+
initializer_range=0.02)
|
51 |
+
with tf.gfile.Open(filename, "w") as outfile:
|
52 |
+
outfile.write(albert_config.to_json_string())
|
53 |
+
|
54 |
+
|
55 |
+
def _create_record(max_predictions_per_seq, max_seq_length, vocab_size):
|
56 |
+
"""Returns a tf.train.Example containing random data."""
|
57 |
+
example = tf.train.Example()
|
58 |
+
example.features.feature["input_ids"].int64_list.value.extend(
|
59 |
+
[random.randint(0, vocab_size - 1) for _ in range(max_seq_length)])
|
60 |
+
example.features.feature["input_mask"].int64_list.value.extend(
|
61 |
+
[random.randint(0, 1) for _ in range(max_seq_length)])
|
62 |
+
example.features.feature["masked_lm_positions"].int64_list.value.extend([
|
63 |
+
random.randint(0, max_seq_length - 1)
|
64 |
+
for _ in range(max_predictions_per_seq)
|
65 |
+
])
|
66 |
+
example.features.feature["masked_lm_ids"].int64_list.value.extend([
|
67 |
+
random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq)
|
68 |
+
])
|
69 |
+
example.features.feature["masked_lm_weights"].float_list.value.extend(
|
70 |
+
[1. for _ in range(max_predictions_per_seq)])
|
71 |
+
example.features.feature["segment_ids"].int64_list.value.extend(
|
72 |
+
[0 for _ in range(max_seq_length)])
|
73 |
+
example.features.feature["next_sentence_labels"].int64_list.value.append(
|
74 |
+
random.randint(0, 1))
|
75 |
+
return example
|
76 |
+
|
77 |
+
|
78 |
+
def _create_input_file(filename,
|
79 |
+
max_predictions_per_seq,
|
80 |
+
max_seq_length,
|
81 |
+
vocab_size,
|
82 |
+
size=1000):
|
83 |
+
"""Creates an input TFRecord file of specified size."""
|
84 |
+
with tf.io.TFRecordWriter(filename) as writer:
|
85 |
+
for _ in range(size):
|
86 |
+
ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size)
|
87 |
+
writer.write(ex.SerializeToString())
|
88 |
+
|
89 |
+
|
90 |
+
class RunPretrainingTest(tf.test.TestCase):
|
91 |
+
|
92 |
+
def _verify_output_file(self, basename):
|
93 |
+
self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename)))
|
94 |
+
|
95 |
+
def _verify_checkpoint_files(self, name):
|
96 |
+
self._verify_output_file(name + ".meta")
|
97 |
+
self._verify_output_file(name + ".index")
|
98 |
+
self._verify_output_file(name + ".data-00000-of-00001")
|
99 |
+
|
100 |
+
@flagsaver.flagsaver
|
101 |
+
def test_pretraining(self):
|
102 |
+
# Set up required flags.
|
103 |
+
vocab_size = 97
|
104 |
+
FLAGS.max_predictions_per_seq = 7
|
105 |
+
FLAGS.max_seq_length = 13
|
106 |
+
FLAGS.output_dir = tempfile.mkdtemp("output_dir")
|
107 |
+
FLAGS.albert_config_file = os.path.join(
|
108 |
+
tempfile.mkdtemp("config_dir"), "albert_config.json")
|
109 |
+
FLAGS.input_file = os.path.join(
|
110 |
+
tempfile.mkdtemp("input_dir"), "input_data.tfrecord")
|
111 |
+
FLAGS.do_train = True
|
112 |
+
FLAGS.do_eval = True
|
113 |
+
FLAGS.num_train_steps = 1
|
114 |
+
FLAGS.save_checkpoints_steps = 1
|
115 |
+
|
116 |
+
# Construct requisite input files.
|
117 |
+
_create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length,
|
118 |
+
vocab_size)
|
119 |
+
_create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq,
|
120 |
+
FLAGS.max_seq_length, vocab_size)
|
121 |
+
|
122 |
+
# Run the pretraining.
|
123 |
+
run_pretraining.main(None)
|
124 |
+
|
125 |
+
# Verify output.
|
126 |
+
self._verify_checkpoint_files("model.ckpt-best")
|
127 |
+
self._verify_checkpoint_files("model.ckpt-1")
|
128 |
+
self._verify_output_file("eval_results.txt")
|
129 |
+
self._verify_output_file("checkpoint")
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
tf.test.main()
|
Indic-BERT-v1-master/albert/run_race.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""ALBERT finetuning runner with sentence piece tokenization."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import os
|
22 |
+
import time
|
23 |
+
from albert import classifier_utils
|
24 |
+
from albert import fine_tuning_utils
|
25 |
+
from albert import modeling
|
26 |
+
from albert import race_utils
|
27 |
+
import tensorflow.compat.v1 as tf
|
28 |
+
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
|
29 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
30 |
+
|
31 |
+
flags = tf.flags
|
32 |
+
|
33 |
+
FLAGS = flags.FLAGS
|
34 |
+
|
35 |
+
## Required parameters
|
36 |
+
flags.DEFINE_string(
|
37 |
+
"data_dir", None,
|
38 |
+
"The input data dir. Should contain the .tsv files (or other data files) "
|
39 |
+
"for the task.")
|
40 |
+
|
41 |
+
flags.DEFINE_string(
|
42 |
+
"albert_config_file", None,
|
43 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
44 |
+
"This specifies the model architecture.")
|
45 |
+
|
46 |
+
flags.DEFINE_string("task_name", "race", "The name of the task to train.")
|
47 |
+
|
48 |
+
flags.DEFINE_string("vocab_file", None,
|
49 |
+
"The vocabulary file that the ALBERT model was trained on.")
|
50 |
+
|
51 |
+
flags.DEFINE_string("train_file", None,
|
52 |
+
"path to preprocessed tfrecord file. "
|
53 |
+
"The file will be generated if not exst.")
|
54 |
+
|
55 |
+
flags.DEFINE_string("eval_file", None,
|
56 |
+
"path to preprocessed tfrecord file. "
|
57 |
+
"The file will be generated if not exst.")
|
58 |
+
|
59 |
+
flags.DEFINE_string("predict_file", None,
|
60 |
+
"path to preprocessed tfrecord file. "
|
61 |
+
"The file will be generated if not exst.")
|
62 |
+
|
63 |
+
flags.DEFINE_string("spm_model_file", None,
|
64 |
+
"The model file for sentence piece tokenization.")
|
65 |
+
|
66 |
+
flags.DEFINE_string(
|
67 |
+
"output_dir", None,
|
68 |
+
"The output directory where the model checkpoints will be written.")
|
69 |
+
|
70 |
+
## Other parameters
|
71 |
+
|
72 |
+
flags.DEFINE_string(
|
73 |
+
"init_checkpoint", None,
|
74 |
+
"Initial checkpoint (usually from a pre-trained ALBERT model).")
|
75 |
+
|
76 |
+
flags.DEFINE_string(
|
77 |
+
"albert_hub_module_handle", None,
|
78 |
+
"If set, the ALBERT hub module to use.")
|
79 |
+
|
80 |
+
flags.DEFINE_bool(
|
81 |
+
"do_lower_case", True,
|
82 |
+
"Whether to lower case the input text. Should be True for uncased "
|
83 |
+
"models and False for cased models.")
|
84 |
+
|
85 |
+
flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.")
|
86 |
+
|
87 |
+
flags.DEFINE_integer(
|
88 |
+
"max_seq_length", 512,
|
89 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
90 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
91 |
+
"than this will be padded.")
|
92 |
+
|
93 |
+
flags.DEFINE_integer(
|
94 |
+
"max_qa_length", 128,
|
95 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
96 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
97 |
+
"than this will be padded.")
|
98 |
+
|
99 |
+
flags.DEFINE_integer(
|
100 |
+
"num_keep_checkpoint", 5,
|
101 |
+
"maximum number of keep checkpoints")
|
102 |
+
|
103 |
+
|
104 |
+
flags.DEFINE_bool(
|
105 |
+
"high_only", False,
|
106 |
+
"Whether to only run the model on the high school set.")
|
107 |
+
|
108 |
+
flags.DEFINE_bool(
|
109 |
+
"middle_only", False,
|
110 |
+
"Whether to only run the model on the middle school set.")
|
111 |
+
|
112 |
+
flags.DEFINE_bool("do_train", True, "Whether to run training.")
|
113 |
+
|
114 |
+
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
|
115 |
+
|
116 |
+
flags.DEFINE_bool(
|
117 |
+
"do_predict", False,
|
118 |
+
"Whether to run the model in inference mode on the test set.")
|
119 |
+
|
120 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
121 |
+
|
122 |
+
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
123 |
+
|
124 |
+
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
|
125 |
+
|
126 |
+
flags.DEFINE_float("learning_rate", 1e-5, "The initial learning rate for Adam.")
|
127 |
+
|
128 |
+
flags.DEFINE_integer("train_step", 12000,
|
129 |
+
"Total number of training epochs to perform.")
|
130 |
+
|
131 |
+
flags.DEFINE_integer(
|
132 |
+
"warmup_step", 1000,
|
133 |
+
"number of steps to perform linear learning rate warmup for.")
|
134 |
+
|
135 |
+
flags.DEFINE_integer("save_checkpoints_steps", 100,
|
136 |
+
"How often to save the model checkpoint.")
|
137 |
+
|
138 |
+
flags.DEFINE_integer("iterations_per_loop", 1000,
|
139 |
+
"How many steps to make in each estimator call.")
|
140 |
+
|
141 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
142 |
+
|
143 |
+
tf.flags.DEFINE_string(
|
144 |
+
"tpu_name", None,
|
145 |
+
"The Cloud TPU to use for training. This should be either the name "
|
146 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
147 |
+
"url.")
|
148 |
+
|
149 |
+
tf.flags.DEFINE_string(
|
150 |
+
"tpu_zone", None,
|
151 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
152 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
153 |
+
"metadata.")
|
154 |
+
|
155 |
+
tf.flags.DEFINE_string(
|
156 |
+
"gcp_project", None,
|
157 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
158 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
159 |
+
"metadata.")
|
160 |
+
|
161 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
162 |
+
|
163 |
+
flags.DEFINE_integer(
|
164 |
+
"num_tpu_cores", 8,
|
165 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
166 |
+
|
167 |
+
|
168 |
+
def main(_):
|
169 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
170 |
+
|
171 |
+
processors = {
|
172 |
+
"race": race_utils.RaceProcessor
|
173 |
+
}
|
174 |
+
|
175 |
+
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
|
176 |
+
raise ValueError(
|
177 |
+
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
|
178 |
+
|
179 |
+
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
180 |
+
|
181 |
+
if FLAGS.max_seq_length > albert_config.max_position_embeddings:
|
182 |
+
raise ValueError(
|
183 |
+
"Cannot use sequence length %d because the ALBERT model "
|
184 |
+
"was only trained up to sequence length %d" %
|
185 |
+
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
|
186 |
+
|
187 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
188 |
+
|
189 |
+
task_name = FLAGS.task_name.lower()
|
190 |
+
|
191 |
+
if task_name not in processors:
|
192 |
+
raise ValueError("Task not found: %s" % (task_name))
|
193 |
+
|
194 |
+
processor = processors[task_name](
|
195 |
+
use_spm=True if FLAGS.spm_model_file else False,
|
196 |
+
do_lower_case=FLAGS.do_lower_case,
|
197 |
+
high_only=FLAGS.high_only,
|
198 |
+
middle_only=FLAGS.middle_only)
|
199 |
+
|
200 |
+
label_list = processor.get_labels()
|
201 |
+
|
202 |
+
tokenizer = fine_tuning_utils.create_vocab(
|
203 |
+
vocab_file=FLAGS.vocab_file,
|
204 |
+
do_lower_case=FLAGS.do_lower_case,
|
205 |
+
spm_model_file=FLAGS.spm_model_file,
|
206 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
207 |
+
|
208 |
+
tpu_cluster_resolver = None
|
209 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
210 |
+
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
|
211 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
212 |
+
|
213 |
+
is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
|
214 |
+
if FLAGS.do_train:
|
215 |
+
iterations_per_loop = int(min(FLAGS.iterations_per_loop,
|
216 |
+
FLAGS.save_checkpoints_steps))
|
217 |
+
else:
|
218 |
+
iterations_per_loop = FLAGS.iterations_per_loop
|
219 |
+
run_config = contrib_tpu.RunConfig(
|
220 |
+
cluster=tpu_cluster_resolver,
|
221 |
+
master=FLAGS.master,
|
222 |
+
model_dir=FLAGS.output_dir,
|
223 |
+
save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
|
224 |
+
keep_checkpoint_max=0,
|
225 |
+
tpu_config=contrib_tpu.TPUConfig(
|
226 |
+
iterations_per_loop=iterations_per_loop,
|
227 |
+
num_shards=FLAGS.num_tpu_cores,
|
228 |
+
per_host_input_for_training=is_per_host))
|
229 |
+
|
230 |
+
train_examples = None
|
231 |
+
if FLAGS.do_train:
|
232 |
+
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
233 |
+
|
234 |
+
model_fn = race_utils.model_fn_builder(
|
235 |
+
albert_config=albert_config,
|
236 |
+
num_labels=len(label_list),
|
237 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
238 |
+
learning_rate=FLAGS.learning_rate,
|
239 |
+
num_train_steps=FLAGS.train_step,
|
240 |
+
num_warmup_steps=FLAGS.warmup_step,
|
241 |
+
use_tpu=FLAGS.use_tpu,
|
242 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
243 |
+
max_seq_length=FLAGS.max_seq_length,
|
244 |
+
dropout_prob=FLAGS.dropout_prob,
|
245 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
246 |
+
|
247 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
248 |
+
# or GPU.
|
249 |
+
estimator = contrib_tpu.TPUEstimator(
|
250 |
+
use_tpu=FLAGS.use_tpu,
|
251 |
+
model_fn=model_fn,
|
252 |
+
config=run_config,
|
253 |
+
train_batch_size=FLAGS.train_batch_size,
|
254 |
+
eval_batch_size=FLAGS.eval_batch_size,
|
255 |
+
predict_batch_size=FLAGS.predict_batch_size)
|
256 |
+
|
257 |
+
if FLAGS.do_train:
|
258 |
+
if not tf.gfile.Exists(FLAGS.train_file):
|
259 |
+
race_utils.file_based_convert_examples_to_features(
|
260 |
+
train_examples, label_list, FLAGS.max_seq_length, tokenizer,
|
261 |
+
FLAGS.train_file, FLAGS.max_qa_length)
|
262 |
+
tf.logging.info("***** Running training *****")
|
263 |
+
tf.logging.info(" Num examples = %d", len(train_examples))
|
264 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
265 |
+
tf.logging.info(" Num steps = %d", FLAGS.train_step)
|
266 |
+
train_input_fn = classifier_utils.file_based_input_fn_builder(
|
267 |
+
input_file=FLAGS.train_file,
|
268 |
+
seq_length=FLAGS.max_seq_length,
|
269 |
+
is_training=True,
|
270 |
+
drop_remainder=True,
|
271 |
+
task_name=task_name,
|
272 |
+
use_tpu=FLAGS.use_tpu,
|
273 |
+
bsz=FLAGS.train_batch_size,
|
274 |
+
multiple=len(label_list))
|
275 |
+
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step)
|
276 |
+
|
277 |
+
if FLAGS.do_eval:
|
278 |
+
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
|
279 |
+
num_actual_eval_examples = len(eval_examples)
|
280 |
+
if FLAGS.use_tpu:
|
281 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
282 |
+
# of examples must be a multiple of the batch size, or else examples
|
283 |
+
# will get dropped. So we pad with fake examples which are ignored
|
284 |
+
# later on. These do NOT count towards the metric (all tf.metrics
|
285 |
+
# support a per-instance weight, and these get a weight of 0.0).
|
286 |
+
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
287 |
+
eval_examples.append(classifier_utils.PaddingInputExample())
|
288 |
+
|
289 |
+
if not tf.gfile.Exists(FLAGS.eval_file):
|
290 |
+
race_utils.file_based_convert_examples_to_features(
|
291 |
+
eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
|
292 |
+
FLAGS.eval_file, FLAGS.max_qa_length)
|
293 |
+
|
294 |
+
tf.logging.info("***** Running evaluation *****")
|
295 |
+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
296 |
+
len(eval_examples), num_actual_eval_examples,
|
297 |
+
len(eval_examples) - num_actual_eval_examples)
|
298 |
+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
299 |
+
|
300 |
+
# This tells the estimator to run through the entire set.
|
301 |
+
eval_steps = None
|
302 |
+
# However, if running eval on the TPU, you will need to specify the
|
303 |
+
# number of steps.
|
304 |
+
if FLAGS.use_tpu:
|
305 |
+
assert len(eval_examples) % FLAGS.eval_batch_size == 0
|
306 |
+
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
|
307 |
+
|
308 |
+
eval_drop_remainder = True if FLAGS.use_tpu else False
|
309 |
+
eval_input_fn = classifier_utils.file_based_input_fn_builder(
|
310 |
+
input_file=FLAGS.eval_file,
|
311 |
+
seq_length=FLAGS.max_seq_length,
|
312 |
+
is_training=False,
|
313 |
+
drop_remainder=eval_drop_remainder,
|
314 |
+
task_name=task_name,
|
315 |
+
use_tpu=FLAGS.use_tpu,
|
316 |
+
bsz=FLAGS.eval_batch_size,
|
317 |
+
multiple=len(label_list))
|
318 |
+
|
319 |
+
def _find_valid_cands(curr_step):
|
320 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
321 |
+
candidates = []
|
322 |
+
for filename in filenames:
|
323 |
+
if filename.endswith(".index"):
|
324 |
+
ckpt_name = filename[:-6]
|
325 |
+
idx = ckpt_name.split("-")[-1]
|
326 |
+
if idx != "best" and int(idx) > curr_step:
|
327 |
+
candidates.append(filename)
|
328 |
+
return candidates
|
329 |
+
|
330 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
331 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
332 |
+
key_name = "eval_accuracy"
|
333 |
+
if tf.gfile.Exists(checkpoint_path + ".index"):
|
334 |
+
result = estimator.evaluate(
|
335 |
+
input_fn=eval_input_fn,
|
336 |
+
steps=eval_steps,
|
337 |
+
checkpoint_path=checkpoint_path)
|
338 |
+
best_perf = result[key_name]
|
339 |
+
global_step = result["global_step"]
|
340 |
+
else:
|
341 |
+
global_step = -1
|
342 |
+
best_perf = -1
|
343 |
+
checkpoint_path = None
|
344 |
+
writer = tf.gfile.GFile(output_eval_file, "w")
|
345 |
+
while global_step < FLAGS.train_step:
|
346 |
+
steps_and_files = {}
|
347 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
348 |
+
for filename in filenames:
|
349 |
+
if filename.endswith(".index"):
|
350 |
+
ckpt_name = filename[:-6]
|
351 |
+
cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
|
352 |
+
if cur_filename.split("-")[-1] == "best":
|
353 |
+
continue
|
354 |
+
gstep = int(cur_filename.split("-")[-1])
|
355 |
+
if gstep not in steps_and_files:
|
356 |
+
tf.logging.info("Add {} to eval list.".format(cur_filename))
|
357 |
+
steps_and_files[gstep] = cur_filename
|
358 |
+
tf.logging.info("found {} files.".format(len(steps_and_files)))
|
359 |
+
# steps_and_files = sorted(steps_and_files, key=lambda x: x[0])
|
360 |
+
if not steps_and_files:
|
361 |
+
tf.logging.info("found 0 file, global step: {}. Sleeping."
|
362 |
+
.format(global_step))
|
363 |
+
time.sleep(1)
|
364 |
+
else:
|
365 |
+
for ele in sorted(steps_and_files.items()):
|
366 |
+
step, checkpoint_path = ele
|
367 |
+
if global_step >= step:
|
368 |
+
if len(_find_valid_cands(step)) > 1:
|
369 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
370 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
371 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
372 |
+
tf.gfile.Remove(src_ckpt)
|
373 |
+
continue
|
374 |
+
result = estimator.evaluate(
|
375 |
+
input_fn=eval_input_fn,
|
376 |
+
steps=eval_steps,
|
377 |
+
checkpoint_path=checkpoint_path)
|
378 |
+
global_step = result["global_step"]
|
379 |
+
tf.logging.info("***** Eval results *****")
|
380 |
+
for key in sorted(result.keys()):
|
381 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
382 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
383 |
+
writer.write("best = {}\n".format(best_perf))
|
384 |
+
if result[key_name] > best_perf:
|
385 |
+
best_perf = result[key_name]
|
386 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
387 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
388 |
+
tgt_ckpt = checkpoint_path.rsplit("-", 1)[0] + "-best.{}".format(ext)
|
389 |
+
tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
|
390 |
+
tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
|
391 |
+
writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
|
392 |
+
|
393 |
+
if len(_find_valid_cands(global_step)) > 1:
|
394 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
395 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
396 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
397 |
+
tf.gfile.Remove(src_ckpt)
|
398 |
+
writer.write("=" * 50 + "\n")
|
399 |
+
writer.close()
|
400 |
+
if FLAGS.do_predict:
|
401 |
+
predict_examples = processor.get_test_examples(FLAGS.data_dir)
|
402 |
+
num_actual_predict_examples = len(predict_examples)
|
403 |
+
if FLAGS.use_tpu:
|
404 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
405 |
+
# of examples must be a multiple of the batch size, or else examples
|
406 |
+
# will get dropped. So we pad with fake examples which are ignored
|
407 |
+
# later on.
|
408 |
+
while len(predict_examples) % FLAGS.predict_batch_size != 0:
|
409 |
+
predict_examples.append(classifier_utils.PaddingInputExample())
|
410 |
+
assert len(predict_examples) % FLAGS.predict_batch_size == 0
|
411 |
+
predict_steps = int(len(predict_examples) // FLAGS.predict_batch_size)
|
412 |
+
|
413 |
+
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
414 |
+
race_utils.file_based_convert_examples_to_features(
|
415 |
+
predict_examples, label_list,
|
416 |
+
FLAGS.max_seq_length, tokenizer,
|
417 |
+
predict_file, FLAGS.max_qa_length)
|
418 |
+
|
419 |
+
tf.logging.info("***** Running prediction*****")
|
420 |
+
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
421 |
+
len(predict_examples), num_actual_predict_examples,
|
422 |
+
len(predict_examples) - num_actual_predict_examples)
|
423 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
424 |
+
|
425 |
+
predict_drop_remainder = True if FLAGS.use_tpu else False
|
426 |
+
predict_input_fn = classifier_utils.file_based_input_fn_builder(
|
427 |
+
input_file=predict_file,
|
428 |
+
seq_length=FLAGS.max_seq_length,
|
429 |
+
is_training=False,
|
430 |
+
drop_remainder=predict_drop_remainder,
|
431 |
+
task_name=task_name,
|
432 |
+
use_tpu=FLAGS.use_tpu,
|
433 |
+
bsz=FLAGS.predict_batch_size,
|
434 |
+
multiple=len(label_list))
|
435 |
+
|
436 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
437 |
+
result = estimator.evaluate(
|
438 |
+
input_fn=predict_input_fn,
|
439 |
+
steps=predict_steps,
|
440 |
+
checkpoint_path=checkpoint_path)
|
441 |
+
|
442 |
+
output_predict_file = os.path.join(FLAGS.output_dir, "predict_results.txt")
|
443 |
+
with tf.gfile.GFile(output_predict_file, "w") as pred_writer:
|
444 |
+
# num_written_lines = 0
|
445 |
+
tf.logging.info("***** Predict results *****")
|
446 |
+
pred_writer.write("***** Predict results *****\n")
|
447 |
+
for key in sorted(result.keys()):
|
448 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
449 |
+
pred_writer.write("%s = %s\n" % (key, str(result[key])))
|
450 |
+
pred_writer.write("best = {}\n".format(best_perf))
|
451 |
+
|
452 |
+
|
453 |
+
if __name__ == "__main__":
|
454 |
+
flags.mark_flag_as_required("data_dir")
|
455 |
+
flags.mark_flag_as_required("spm_model_file")
|
456 |
+
flags.mark_flag_as_required("albert_config_file")
|
457 |
+
flags.mark_flag_as_required("output_dir")
|
458 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/run_squad_v1.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Run ALBERT on SQuAD v1.1 using sentence piece tokenization."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
|
23 |
+
import json
|
24 |
+
import os
|
25 |
+
import random
|
26 |
+
import time
|
27 |
+
from albert import fine_tuning_utils
|
28 |
+
from albert import modeling
|
29 |
+
from albert import squad_utils
|
30 |
+
import six
|
31 |
+
import tensorflow.compat.v1 as tf
|
32 |
+
|
33 |
+
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
|
34 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
35 |
+
|
36 |
+
|
37 |
+
# pylint: disable=g-import-not-at-top
|
38 |
+
if six.PY2:
|
39 |
+
import six.moves.cPickle as pickle
|
40 |
+
else:
|
41 |
+
import pickle
|
42 |
+
# pylint: enable=g-import-not-at-top
|
43 |
+
|
44 |
+
flags = tf.flags
|
45 |
+
|
46 |
+
FLAGS = flags.FLAGS
|
47 |
+
|
48 |
+
## Required parameters
|
49 |
+
flags.DEFINE_string(
|
50 |
+
"albert_config_file", None,
|
51 |
+
"The config json file corresponding to the pre-trained BERT model. "
|
52 |
+
"This specifies the model architecture.")
|
53 |
+
|
54 |
+
flags.DEFINE_string("vocab_file", None,
|
55 |
+
"The vocabulary file that the BERT model was trained on.")
|
56 |
+
|
57 |
+
flags.DEFINE_string("spm_model_file", None,
|
58 |
+
"The model file for sentence piece tokenization.")
|
59 |
+
|
60 |
+
flags.DEFINE_string(
|
61 |
+
"output_dir", None,
|
62 |
+
"The output directory where the model checkpoints will be written.")
|
63 |
+
|
64 |
+
## Other parameters
|
65 |
+
flags.DEFINE_string("train_file", None,
|
66 |
+
"SQuAD json for training. E.g., train-v1.1.json")
|
67 |
+
|
68 |
+
flags.DEFINE_string(
|
69 |
+
"predict_file", None,
|
70 |
+
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
71 |
+
|
72 |
+
flags.DEFINE_string("train_feature_file", None,
|
73 |
+
"training feature file.")
|
74 |
+
|
75 |
+
flags.DEFINE_string(
|
76 |
+
"predict_feature_file", None,
|
77 |
+
"Location of predict features. If it doesn't exist, it will be written. "
|
78 |
+
"If it does exist, it will be read.")
|
79 |
+
|
80 |
+
flags.DEFINE_string(
|
81 |
+
"predict_feature_left_file", None,
|
82 |
+
"Location of predict features not passed to TPU. If it doesn't exist, it "
|
83 |
+
"will be written. If it does exist, it will be read.")
|
84 |
+
|
85 |
+
flags.DEFINE_string(
|
86 |
+
"init_checkpoint", None,
|
87 |
+
"Initial checkpoint (usually from a pre-trained BERT model).")
|
88 |
+
|
89 |
+
flags.DEFINE_string(
|
90 |
+
"albert_hub_module_handle", None,
|
91 |
+
"If set, the ALBERT hub module to use.")
|
92 |
+
|
93 |
+
flags.DEFINE_bool(
|
94 |
+
"do_lower_case", True,
|
95 |
+
"Whether to lower case the input text. Should be True for uncased "
|
96 |
+
"models and False for cased models.")
|
97 |
+
|
98 |
+
flags.DEFINE_integer(
|
99 |
+
"max_seq_length", 384,
|
100 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
101 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
102 |
+
"than this will be padded.")
|
103 |
+
|
104 |
+
flags.DEFINE_integer(
|
105 |
+
"doc_stride", 128,
|
106 |
+
"When splitting up a long document into chunks, how much stride to "
|
107 |
+
"take between chunks.")
|
108 |
+
|
109 |
+
flags.DEFINE_integer(
|
110 |
+
"max_query_length", 64,
|
111 |
+
"The maximum number of tokens for the question. Questions longer than "
|
112 |
+
"this will be truncated to this length.")
|
113 |
+
|
114 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
115 |
+
|
116 |
+
flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
|
117 |
+
|
118 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
119 |
+
|
120 |
+
flags.DEFINE_integer("predict_batch_size", 8,
|
121 |
+
"Total batch size for predictions.")
|
122 |
+
|
123 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
124 |
+
|
125 |
+
flags.DEFINE_float("num_train_epochs", 3.0,
|
126 |
+
"Total number of training epochs to perform.")
|
127 |
+
|
128 |
+
flags.DEFINE_float(
|
129 |
+
"warmup_proportion", 0.1,
|
130 |
+
"Proportion of training to perform linear learning rate warmup for. "
|
131 |
+
"E.g., 0.1 = 10% of training.")
|
132 |
+
|
133 |
+
flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
134 |
+
"How often to save the model checkpoint.")
|
135 |
+
|
136 |
+
flags.DEFINE_integer("iterations_per_loop", 1000,
|
137 |
+
"How many steps to make in each estimator call.")
|
138 |
+
|
139 |
+
flags.DEFINE_integer(
|
140 |
+
"n_best_size", 20,
|
141 |
+
"The total number of n-best predictions to generate in the "
|
142 |
+
"nbest_predictions.json output file.")
|
143 |
+
|
144 |
+
flags.DEFINE_integer(
|
145 |
+
"max_answer_length", 30,
|
146 |
+
"The maximum length of an answer that can be generated. This is needed "
|
147 |
+
"because the start and end predictions are not conditioned on one another.")
|
148 |
+
|
149 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
150 |
+
|
151 |
+
tf.flags.DEFINE_string(
|
152 |
+
"tpu_name", None,
|
153 |
+
"The Cloud TPU to use for training. This should be either the name "
|
154 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
155 |
+
"url.")
|
156 |
+
|
157 |
+
tf.flags.DEFINE_string(
|
158 |
+
"tpu_zone", None,
|
159 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
160 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
161 |
+
"metadata.")
|
162 |
+
|
163 |
+
tf.flags.DEFINE_string(
|
164 |
+
"gcp_project", None,
|
165 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
166 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
167 |
+
"metadata.")
|
168 |
+
|
169 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
170 |
+
|
171 |
+
flags.DEFINE_integer(
|
172 |
+
"num_tpu_cores", 8,
|
173 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
174 |
+
|
175 |
+
flags.DEFINE_bool(
|
176 |
+
"use_einsum", True,
|
177 |
+
"Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
|
178 |
+
"be set to False for TFLite compatibility.")
|
179 |
+
|
180 |
+
flags.DEFINE_string(
|
181 |
+
"export_dir",
|
182 |
+
default=None,
|
183 |
+
help=("The directory where the exported SavedModel will be stored."))
|
184 |
+
|
185 |
+
|
186 |
+
def validate_flags_or_throw(albert_config):
|
187 |
+
"""Validate the input FLAGS or throw an exception."""
|
188 |
+
|
189 |
+
if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_dir:
|
190 |
+
err_msg = "At least one of `do_train` or `do_predict` or `export_dir`" + "must be True."
|
191 |
+
raise ValueError(err_msg)
|
192 |
+
|
193 |
+
if FLAGS.do_train:
|
194 |
+
if not FLAGS.train_file:
|
195 |
+
raise ValueError(
|
196 |
+
"If `do_train` is True, then `train_file` must be specified.")
|
197 |
+
if FLAGS.do_predict:
|
198 |
+
if not FLAGS.predict_file:
|
199 |
+
raise ValueError(
|
200 |
+
"If `do_predict` is True, then `predict_file` must be specified.")
|
201 |
+
if not FLAGS.predict_feature_file:
|
202 |
+
raise ValueError(
|
203 |
+
"If `do_predict` is True, then `predict_feature_file` must be "
|
204 |
+
"specified.")
|
205 |
+
if not FLAGS.predict_feature_left_file:
|
206 |
+
raise ValueError(
|
207 |
+
"If `do_predict` is True, then `predict_feature_left_file` must be "
|
208 |
+
"specified.")
|
209 |
+
|
210 |
+
if FLAGS.max_seq_length > albert_config.max_position_embeddings:
|
211 |
+
raise ValueError(
|
212 |
+
"Cannot use sequence length %d because the ALBERT model "
|
213 |
+
"was only trained up to sequence length %d" %
|
214 |
+
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
|
215 |
+
|
216 |
+
if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
|
217 |
+
raise ValueError(
|
218 |
+
"The max_seq_length (%d) must be greater than max_query_length "
|
219 |
+
"(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
|
220 |
+
|
221 |
+
|
222 |
+
def build_squad_serving_input_fn(seq_length):
|
223 |
+
"""Builds a serving input fn for raw input."""
|
224 |
+
|
225 |
+
def _seq_serving_input_fn():
|
226 |
+
"""Serving input fn for raw images."""
|
227 |
+
input_ids = tf.placeholder(
|
228 |
+
shape=[1, seq_length], name="input_ids", dtype=tf.int32)
|
229 |
+
input_mask = tf.placeholder(
|
230 |
+
shape=[1, seq_length], name="input_mask", dtype=tf.int32)
|
231 |
+
segment_ids = tf.placeholder(
|
232 |
+
shape=[1, seq_length], name="segment_ids", dtype=tf.int32)
|
233 |
+
|
234 |
+
inputs = {
|
235 |
+
"input_ids": input_ids,
|
236 |
+
"input_mask": input_mask,
|
237 |
+
"segment_ids": segment_ids
|
238 |
+
}
|
239 |
+
return tf.estimator.export.ServingInputReceiver(features=inputs,
|
240 |
+
receiver_tensors=inputs)
|
241 |
+
|
242 |
+
return _seq_serving_input_fn
|
243 |
+
|
244 |
+
|
245 |
+
def main(_):
|
246 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
247 |
+
|
248 |
+
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
249 |
+
|
250 |
+
validate_flags_or_throw(albert_config)
|
251 |
+
|
252 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
253 |
+
|
254 |
+
tokenizer = fine_tuning_utils.create_vocab(
|
255 |
+
vocab_file=FLAGS.vocab_file,
|
256 |
+
do_lower_case=FLAGS.do_lower_case,
|
257 |
+
spm_model_file=FLAGS.spm_model_file,
|
258 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
259 |
+
|
260 |
+
tpu_cluster_resolver = None
|
261 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
262 |
+
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
|
263 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
264 |
+
|
265 |
+
is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
|
266 |
+
if FLAGS.do_train:
|
267 |
+
iterations_per_loop = int(min(FLAGS.iterations_per_loop,
|
268 |
+
FLAGS.save_checkpoints_steps))
|
269 |
+
else:
|
270 |
+
iterations_per_loop = FLAGS.iterations_per_loop
|
271 |
+
run_config = contrib_tpu.RunConfig(
|
272 |
+
cluster=tpu_cluster_resolver,
|
273 |
+
master=FLAGS.master,
|
274 |
+
model_dir=FLAGS.output_dir,
|
275 |
+
keep_checkpoint_max=0,
|
276 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
277 |
+
tpu_config=contrib_tpu.TPUConfig(
|
278 |
+
iterations_per_loop=iterations_per_loop,
|
279 |
+
num_shards=FLAGS.num_tpu_cores,
|
280 |
+
per_host_input_for_training=is_per_host))
|
281 |
+
|
282 |
+
train_examples = None
|
283 |
+
num_train_steps = None
|
284 |
+
num_warmup_steps = None
|
285 |
+
if FLAGS.do_train:
|
286 |
+
train_examples = squad_utils.read_squad_examples(
|
287 |
+
input_file=FLAGS.train_file, is_training=True)
|
288 |
+
num_train_steps = int(
|
289 |
+
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
|
290 |
+
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
291 |
+
|
292 |
+
# Pre-shuffle the input to avoid having to make a very large shuffle
|
293 |
+
# buffer in in the `input_fn`.
|
294 |
+
rng = random.Random(12345)
|
295 |
+
rng.shuffle(train_examples)
|
296 |
+
|
297 |
+
model_fn = squad_utils.v1_model_fn_builder(
|
298 |
+
albert_config=albert_config,
|
299 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
300 |
+
learning_rate=FLAGS.learning_rate,
|
301 |
+
num_train_steps=num_train_steps,
|
302 |
+
num_warmup_steps=num_warmup_steps,
|
303 |
+
use_tpu=FLAGS.use_tpu,
|
304 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
305 |
+
use_einsum=FLAGS.use_einsum,
|
306 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
307 |
+
|
308 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
309 |
+
# or GPU.
|
310 |
+
estimator = contrib_tpu.TPUEstimator(
|
311 |
+
use_tpu=FLAGS.use_tpu,
|
312 |
+
model_fn=model_fn,
|
313 |
+
config=run_config,
|
314 |
+
train_batch_size=FLAGS.train_batch_size,
|
315 |
+
predict_batch_size=FLAGS.predict_batch_size)
|
316 |
+
|
317 |
+
if FLAGS.do_train:
|
318 |
+
# We write to a temporary file to avoid storing very large constant tensors
|
319 |
+
# in memory.
|
320 |
+
|
321 |
+
if not tf.gfile.Exists(FLAGS.train_feature_file):
|
322 |
+
train_writer = squad_utils.FeatureWriter(
|
323 |
+
filename=os.path.join(FLAGS.train_feature_file), is_training=True)
|
324 |
+
squad_utils.convert_examples_to_features(
|
325 |
+
examples=train_examples,
|
326 |
+
tokenizer=tokenizer,
|
327 |
+
max_seq_length=FLAGS.max_seq_length,
|
328 |
+
doc_stride=FLAGS.doc_stride,
|
329 |
+
max_query_length=FLAGS.max_query_length,
|
330 |
+
is_training=True,
|
331 |
+
output_fn=train_writer.process_feature,
|
332 |
+
do_lower_case=FLAGS.do_lower_case)
|
333 |
+
train_writer.close()
|
334 |
+
|
335 |
+
tf.logging.info("***** Running training *****")
|
336 |
+
tf.logging.info(" Num orig examples = %d", len(train_examples))
|
337 |
+
# tf.logging.info(" Num split examples = %d", train_writer.num_features)
|
338 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
339 |
+
tf.logging.info(" Num steps = %d", num_train_steps)
|
340 |
+
del train_examples
|
341 |
+
|
342 |
+
train_input_fn = squad_utils.input_fn_builder(
|
343 |
+
input_file=FLAGS.train_feature_file,
|
344 |
+
seq_length=FLAGS.max_seq_length,
|
345 |
+
is_training=True,
|
346 |
+
drop_remainder=True,
|
347 |
+
use_tpu=FLAGS.use_tpu,
|
348 |
+
bsz=FLAGS.train_batch_size,
|
349 |
+
is_v2=False)
|
350 |
+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
351 |
+
|
352 |
+
if FLAGS.do_predict:
|
353 |
+
with tf.gfile.Open(FLAGS.predict_file) as predict_file:
|
354 |
+
prediction_json = json.load(predict_file)["data"]
|
355 |
+
|
356 |
+
eval_examples = squad_utils.read_squad_examples(
|
357 |
+
input_file=FLAGS.predict_file, is_training=False)
|
358 |
+
|
359 |
+
if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists(
|
360 |
+
FLAGS.predict_feature_left_file)):
|
361 |
+
tf.logging.info("Loading eval features from {}".format(
|
362 |
+
FLAGS.predict_feature_left_file))
|
363 |
+
with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
|
364 |
+
eval_features = pickle.load(fin)
|
365 |
+
else:
|
366 |
+
eval_writer = squad_utils.FeatureWriter(
|
367 |
+
filename=FLAGS.predict_feature_file, is_training=False)
|
368 |
+
eval_features = []
|
369 |
+
|
370 |
+
def append_feature(feature):
|
371 |
+
eval_features.append(feature)
|
372 |
+
eval_writer.process_feature(feature)
|
373 |
+
|
374 |
+
squad_utils.convert_examples_to_features(
|
375 |
+
examples=eval_examples,
|
376 |
+
tokenizer=tokenizer,
|
377 |
+
max_seq_length=FLAGS.max_seq_length,
|
378 |
+
doc_stride=FLAGS.doc_stride,
|
379 |
+
max_query_length=FLAGS.max_query_length,
|
380 |
+
is_training=False,
|
381 |
+
output_fn=append_feature,
|
382 |
+
do_lower_case=FLAGS.do_lower_case)
|
383 |
+
eval_writer.close()
|
384 |
+
|
385 |
+
with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
|
386 |
+
pickle.dump(eval_features, fout)
|
387 |
+
|
388 |
+
tf.logging.info("***** Running predictions *****")
|
389 |
+
tf.logging.info(" Num orig examples = %d", len(eval_examples))
|
390 |
+
tf.logging.info(" Num split examples = %d", len(eval_features))
|
391 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
392 |
+
|
393 |
+
predict_input_fn = squad_utils.input_fn_builder(
|
394 |
+
input_file=FLAGS.predict_feature_file,
|
395 |
+
seq_length=FLAGS.max_seq_length,
|
396 |
+
is_training=False,
|
397 |
+
drop_remainder=False,
|
398 |
+
use_tpu=FLAGS.use_tpu,
|
399 |
+
bsz=FLAGS.predict_batch_size,
|
400 |
+
is_v2=False)
|
401 |
+
|
402 |
+
def get_result(checkpoint):
|
403 |
+
"""Evaluate the checkpoint on SQuAD 1.0."""
|
404 |
+
# If running eval on the TPU, you will need to specify the number of
|
405 |
+
# steps.
|
406 |
+
reader = tf.train.NewCheckpointReader(checkpoint)
|
407 |
+
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
|
408 |
+
all_results = []
|
409 |
+
for result in estimator.predict(
|
410 |
+
predict_input_fn, yield_single_examples=True,
|
411 |
+
checkpoint_path=checkpoint):
|
412 |
+
if len(all_results) % 1000 == 0:
|
413 |
+
tf.logging.info("Processing example: %d" % (len(all_results)))
|
414 |
+
unique_id = int(result["unique_ids"])
|
415 |
+
start_log_prob = [float(x) for x in result["start_log_prob"].flat]
|
416 |
+
end_log_prob = [float(x) for x in result["end_log_prob"].flat]
|
417 |
+
all_results.append(
|
418 |
+
squad_utils.RawResult(
|
419 |
+
unique_id=unique_id,
|
420 |
+
start_log_prob=start_log_prob,
|
421 |
+
end_log_prob=end_log_prob))
|
422 |
+
|
423 |
+
output_prediction_file = os.path.join(
|
424 |
+
FLAGS.output_dir, "predictions.json")
|
425 |
+
output_nbest_file = os.path.join(
|
426 |
+
FLAGS.output_dir, "nbest_predictions.json")
|
427 |
+
|
428 |
+
result_dict = {}
|
429 |
+
squad_utils.accumulate_predictions_v1(
|
430 |
+
result_dict, eval_examples, eval_features,
|
431 |
+
all_results, FLAGS.n_best_size, FLAGS.max_answer_length)
|
432 |
+
predictions = squad_utils.write_predictions_v1(
|
433 |
+
result_dict, eval_examples, eval_features, all_results,
|
434 |
+
FLAGS.n_best_size, FLAGS.max_answer_length,
|
435 |
+
output_prediction_file, output_nbest_file)
|
436 |
+
|
437 |
+
return squad_utils.evaluate_v1(
|
438 |
+
prediction_json, predictions), int(global_step)
|
439 |
+
|
440 |
+
def _find_valid_cands(curr_step):
|
441 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
442 |
+
candidates = []
|
443 |
+
for filename in filenames:
|
444 |
+
if filename.endswith(".index"):
|
445 |
+
ckpt_name = filename[:-6]
|
446 |
+
idx = ckpt_name.split("-")[-1]
|
447 |
+
if idx != "best" and int(idx) > curr_step:
|
448 |
+
candidates.append(filename)
|
449 |
+
return candidates
|
450 |
+
|
451 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
452 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
453 |
+
key_name = "f1"
|
454 |
+
writer = tf.gfile.GFile(output_eval_file, "w")
|
455 |
+
if tf.gfile.Exists(checkpoint_path + ".index"):
|
456 |
+
result = get_result(checkpoint_path)
|
457 |
+
best_perf = result[0][key_name]
|
458 |
+
global_step = result[1]
|
459 |
+
else:
|
460 |
+
global_step = -1
|
461 |
+
best_perf = -1
|
462 |
+
checkpoint_path = None
|
463 |
+
while global_step < num_train_steps:
|
464 |
+
steps_and_files = {}
|
465 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
466 |
+
for filename in filenames:
|
467 |
+
if filename.endswith(".index"):
|
468 |
+
ckpt_name = filename[:-6]
|
469 |
+
cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
|
470 |
+
if cur_filename.split("-")[-1] == "best":
|
471 |
+
continue
|
472 |
+
gstep = int(cur_filename.split("-")[-1])
|
473 |
+
if gstep not in steps_and_files:
|
474 |
+
tf.logging.info("Add {} to eval list.".format(cur_filename))
|
475 |
+
steps_and_files[gstep] = cur_filename
|
476 |
+
tf.logging.info("found {} files.".format(len(steps_and_files)))
|
477 |
+
if not steps_and_files:
|
478 |
+
tf.logging.info("found 0 file, global step: {}. Sleeping."
|
479 |
+
.format(global_step))
|
480 |
+
time.sleep(60)
|
481 |
+
else:
|
482 |
+
for ele in sorted(steps_and_files.items()):
|
483 |
+
step, checkpoint_path = ele
|
484 |
+
if global_step >= step:
|
485 |
+
if len(_find_valid_cands(step)) > 1:
|
486 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
487 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
488 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
489 |
+
tf.gfile.Remove(src_ckpt)
|
490 |
+
continue
|
491 |
+
result, global_step = get_result(checkpoint_path)
|
492 |
+
tf.logging.info("***** Eval results *****")
|
493 |
+
for key in sorted(result.keys()):
|
494 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
495 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
496 |
+
if result[key_name] > best_perf:
|
497 |
+
best_perf = result[key_name]
|
498 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
499 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
500 |
+
tgt_ckpt = checkpoint_path.rsplit(
|
501 |
+
"-", 1)[0] + "-best.{}".format(ext)
|
502 |
+
tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
|
503 |
+
tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
|
504 |
+
writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
|
505 |
+
writer.write("best {} = {}\n".format(key_name, best_perf))
|
506 |
+
tf.logging.info(" best {} = {}\n".format(key_name, best_perf))
|
507 |
+
|
508 |
+
if len(_find_valid_cands(global_step)) > 2:
|
509 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
510 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
511 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
512 |
+
tf.gfile.Remove(src_ckpt)
|
513 |
+
writer.write("=" * 50 + "\n")
|
514 |
+
|
515 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
516 |
+
result, global_step = get_result(checkpoint_path)
|
517 |
+
tf.logging.info("***** Final Eval results *****")
|
518 |
+
for key in sorted(result.keys()):
|
519 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
520 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
521 |
+
writer.write("best perf happened at step: {}".format(global_step))
|
522 |
+
|
523 |
+
if FLAGS.export_dir:
|
524 |
+
tf.gfile.MakeDirs(FLAGS.export_dir)
|
525 |
+
squad_serving_input_fn = (
|
526 |
+
build_squad_serving_input_fn(FLAGS.max_seq_length))
|
527 |
+
tf.logging.info("Starting to export model.")
|
528 |
+
subfolder = estimator.export_saved_model(
|
529 |
+
export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"),
|
530 |
+
serving_input_receiver_fn=squad_serving_input_fn)
|
531 |
+
|
532 |
+
tf.logging.info("Starting to export TFLite.")
|
533 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(
|
534 |
+
subfolder,
|
535 |
+
input_arrays=["input_ids", "input_mask", "segment_ids"],
|
536 |
+
output_arrays=["start_logits", "end_logits"])
|
537 |
+
float_model = converter.convert()
|
538 |
+
tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite")
|
539 |
+
with tf.gfile.GFile(tflite_file, "wb") as f:
|
540 |
+
f.write(float_model)
|
541 |
+
|
542 |
+
|
543 |
+
if __name__ == "__main__":
|
544 |
+
flags.mark_flag_as_required("spm_model_file")
|
545 |
+
flags.mark_flag_as_required("albert_config_file")
|
546 |
+
flags.mark_flag_as_required("output_dir")
|
547 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/run_squad_v2.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Run ALBERT on SQuAD v2.0 using sentence piece tokenization."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
|
23 |
+
import json
|
24 |
+
import os
|
25 |
+
import random
|
26 |
+
import time
|
27 |
+
|
28 |
+
from albert import fine_tuning_utils
|
29 |
+
from albert import modeling
|
30 |
+
from albert import squad_utils
|
31 |
+
import six
|
32 |
+
import tensorflow.compat.v1 as tf
|
33 |
+
|
34 |
+
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
|
35 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
36 |
+
|
37 |
+
|
38 |
+
# pylint: disable=g-import-not-at-top
|
39 |
+
if six.PY2:
|
40 |
+
import six.moves.cPickle as pickle
|
41 |
+
else:
|
42 |
+
import pickle
|
43 |
+
# pylint: enable=g-import-not-at-top
|
44 |
+
|
45 |
+
flags = tf.flags
|
46 |
+
|
47 |
+
FLAGS = flags.FLAGS
|
48 |
+
|
49 |
+
## Required parameters
|
50 |
+
flags.DEFINE_string(
|
51 |
+
"albert_config_file", None,
|
52 |
+
"The config json file corresponding to the pre-trained ALBERT model. "
|
53 |
+
"This specifies the model architecture.")
|
54 |
+
|
55 |
+
flags.DEFINE_string("vocab_file", None,
|
56 |
+
"The vocabulary file that the ALBERT model was trained on.")
|
57 |
+
|
58 |
+
flags.DEFINE_string("spm_model_file", None,
|
59 |
+
"The model file for sentence piece tokenization.")
|
60 |
+
|
61 |
+
flags.DEFINE_string(
|
62 |
+
"output_dir", None,
|
63 |
+
"The output directory where the model checkpoints will be written.")
|
64 |
+
|
65 |
+
## Other parameters
|
66 |
+
flags.DEFINE_string("train_file", None,
|
67 |
+
"SQuAD json for training. E.g., train-v1.1.json")
|
68 |
+
|
69 |
+
flags.DEFINE_string(
|
70 |
+
"predict_file", None,
|
71 |
+
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
|
72 |
+
|
73 |
+
flags.DEFINE_string("train_feature_file", None,
|
74 |
+
"training feature file.")
|
75 |
+
|
76 |
+
flags.DEFINE_string(
|
77 |
+
"predict_feature_file", None,
|
78 |
+
"Location of predict features. If it doesn't exist, it will be written. "
|
79 |
+
"If it does exist, it will be read.")
|
80 |
+
|
81 |
+
flags.DEFINE_string(
|
82 |
+
"predict_feature_left_file", None,
|
83 |
+
"Location of predict features not passed to TPU. If it doesn't exist, it "
|
84 |
+
"will be written. If it does exist, it will be read.")
|
85 |
+
|
86 |
+
flags.DEFINE_string(
|
87 |
+
"init_checkpoint", None,
|
88 |
+
"Initial checkpoint (usually from a pre-trained BERT model).")
|
89 |
+
|
90 |
+
flags.DEFINE_string(
|
91 |
+
"albert_hub_module_handle", None,
|
92 |
+
"If set, the ALBERT hub module to use.")
|
93 |
+
|
94 |
+
flags.DEFINE_bool(
|
95 |
+
"do_lower_case", True,
|
96 |
+
"Whether to lower case the input text. Should be True for uncased "
|
97 |
+
"models and False for cased models.")
|
98 |
+
|
99 |
+
flags.DEFINE_integer(
|
100 |
+
"max_seq_length", 384,
|
101 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
102 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
103 |
+
"than this will be padded.")
|
104 |
+
|
105 |
+
flags.DEFINE_integer(
|
106 |
+
"doc_stride", 128,
|
107 |
+
"When splitting up a long document into chunks, how much stride to "
|
108 |
+
"take between chunks.")
|
109 |
+
|
110 |
+
flags.DEFINE_integer(
|
111 |
+
"max_query_length", 64,
|
112 |
+
"The maximum number of tokens for the question. Questions longer than "
|
113 |
+
"this will be truncated to this length.")
|
114 |
+
|
115 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
116 |
+
|
117 |
+
flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
|
118 |
+
|
119 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
120 |
+
|
121 |
+
flags.DEFINE_integer("predict_batch_size", 8,
|
122 |
+
"Total batch size for predictions.")
|
123 |
+
|
124 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
125 |
+
|
126 |
+
flags.DEFINE_float("num_train_epochs", 3.0,
|
127 |
+
"Total number of training epochs to perform.")
|
128 |
+
|
129 |
+
flags.DEFINE_float(
|
130 |
+
"warmup_proportion", 0.1,
|
131 |
+
"Proportion of training to perform linear learning rate warmup for. "
|
132 |
+
"E.g., 0.1 = 10% of training.")
|
133 |
+
|
134 |
+
flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
135 |
+
"How often to save the model checkpoint.")
|
136 |
+
|
137 |
+
flags.DEFINE_integer("iterations_per_loop", 1000,
|
138 |
+
"How many steps to make in each estimator call.")
|
139 |
+
|
140 |
+
flags.DEFINE_integer(
|
141 |
+
"n_best_size", 20,
|
142 |
+
"The total number of n-best predictions to generate in the "
|
143 |
+
"nbest_predictions.json output file.")
|
144 |
+
|
145 |
+
flags.DEFINE_integer(
|
146 |
+
"max_answer_length", 30,
|
147 |
+
"The maximum length of an answer that can be generated. This is needed "
|
148 |
+
"because the start and end predictions are not conditioned on one another.")
|
149 |
+
|
150 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
151 |
+
|
152 |
+
tf.flags.DEFINE_string(
|
153 |
+
"tpu_name", None,
|
154 |
+
"The Cloud TPU to use for training. This should be either the name "
|
155 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
156 |
+
"url.")
|
157 |
+
|
158 |
+
tf.flags.DEFINE_string(
|
159 |
+
"tpu_zone", None,
|
160 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
161 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
162 |
+
"metadata.")
|
163 |
+
|
164 |
+
tf.flags.DEFINE_string(
|
165 |
+
"gcp_project", None,
|
166 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
167 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
168 |
+
"metadata.")
|
169 |
+
|
170 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
171 |
+
|
172 |
+
flags.DEFINE_integer(
|
173 |
+
"num_tpu_cores", 8,
|
174 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
175 |
+
|
176 |
+
|
177 |
+
flags.DEFINE_integer("start_n_top", 5, "beam size for the start positions.")
|
178 |
+
|
179 |
+
flags.DEFINE_integer("end_n_top", 5, "beam size for the end positions.")
|
180 |
+
|
181 |
+
flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.")
|
182 |
+
|
183 |
+
|
184 |
+
def validate_flags_or_throw(albert_config):
|
185 |
+
"""Validate the input FLAGS or throw an exception."""
|
186 |
+
|
187 |
+
if not FLAGS.do_train and not FLAGS.do_predict:
|
188 |
+
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
189 |
+
|
190 |
+
if FLAGS.do_train:
|
191 |
+
if not FLAGS.train_file:
|
192 |
+
raise ValueError(
|
193 |
+
"If `do_train` is True, then `train_file` must be specified.")
|
194 |
+
if FLAGS.do_predict:
|
195 |
+
if not FLAGS.predict_file:
|
196 |
+
raise ValueError(
|
197 |
+
"If `do_predict` is True, then `predict_file` must be specified.")
|
198 |
+
if not FLAGS.predict_feature_file:
|
199 |
+
raise ValueError(
|
200 |
+
"If `do_predict` is True, then `predict_feature_file` must be "
|
201 |
+
"specified.")
|
202 |
+
if not FLAGS.predict_feature_left_file:
|
203 |
+
raise ValueError(
|
204 |
+
"If `do_predict` is True, then `predict_feature_left_file` must be "
|
205 |
+
"specified.")
|
206 |
+
|
207 |
+
if FLAGS.max_seq_length > albert_config.max_position_embeddings:
|
208 |
+
raise ValueError(
|
209 |
+
"Cannot use sequence length %d because the ALBERT model "
|
210 |
+
"was only trained up to sequence length %d" %
|
211 |
+
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
|
212 |
+
|
213 |
+
if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
|
214 |
+
raise ValueError(
|
215 |
+
"The max_seq_length (%d) must be greater than max_query_length "
|
216 |
+
"(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
|
217 |
+
|
218 |
+
|
219 |
+
def main(_):
|
220 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
221 |
+
|
222 |
+
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
223 |
+
|
224 |
+
validate_flags_or_throw(albert_config)
|
225 |
+
|
226 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
227 |
+
|
228 |
+
tokenizer = fine_tuning_utils.create_vocab(
|
229 |
+
vocab_file=FLAGS.vocab_file,
|
230 |
+
do_lower_case=FLAGS.do_lower_case,
|
231 |
+
spm_model_file=FLAGS.spm_model_file,
|
232 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
233 |
+
|
234 |
+
tpu_cluster_resolver = None
|
235 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
236 |
+
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
|
237 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
238 |
+
|
239 |
+
is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
|
240 |
+
if FLAGS.do_train:
|
241 |
+
iterations_per_loop = int(min(FLAGS.iterations_per_loop,
|
242 |
+
FLAGS.save_checkpoints_steps))
|
243 |
+
else:
|
244 |
+
iterations_per_loop = FLAGS.iterations_per_loop
|
245 |
+
run_config = contrib_tpu.RunConfig(
|
246 |
+
cluster=tpu_cluster_resolver,
|
247 |
+
master=FLAGS.master,
|
248 |
+
model_dir=FLAGS.output_dir,
|
249 |
+
keep_checkpoint_max=0,
|
250 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
251 |
+
tpu_config=contrib_tpu.TPUConfig(
|
252 |
+
iterations_per_loop=iterations_per_loop,
|
253 |
+
num_shards=FLAGS.num_tpu_cores,
|
254 |
+
per_host_input_for_training=is_per_host))
|
255 |
+
|
256 |
+
train_examples = None
|
257 |
+
num_train_steps = None
|
258 |
+
num_warmup_steps = None
|
259 |
+
train_examples = squad_utils.read_squad_examples(
|
260 |
+
input_file=FLAGS.train_file, is_training=True)
|
261 |
+
num_train_steps = int(
|
262 |
+
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
|
263 |
+
if FLAGS.do_train:
|
264 |
+
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
265 |
+
|
266 |
+
# Pre-shuffle the input to avoid having to make a very large shuffle
|
267 |
+
# buffer in in the `input_fn`.
|
268 |
+
rng = random.Random(12345)
|
269 |
+
rng.shuffle(train_examples)
|
270 |
+
|
271 |
+
model_fn = squad_utils.v2_model_fn_builder(
|
272 |
+
albert_config=albert_config,
|
273 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
274 |
+
learning_rate=FLAGS.learning_rate,
|
275 |
+
num_train_steps=num_train_steps,
|
276 |
+
num_warmup_steps=num_warmup_steps,
|
277 |
+
use_tpu=FLAGS.use_tpu,
|
278 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
279 |
+
max_seq_length=FLAGS.max_seq_length,
|
280 |
+
start_n_top=FLAGS.start_n_top,
|
281 |
+
end_n_top=FLAGS.end_n_top,
|
282 |
+
dropout_prob=FLAGS.dropout_prob,
|
283 |
+
hub_module=FLAGS.albert_hub_module_handle)
|
284 |
+
|
285 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
286 |
+
# or GPU.
|
287 |
+
estimator = contrib_tpu.TPUEstimator(
|
288 |
+
use_tpu=FLAGS.use_tpu,
|
289 |
+
model_fn=model_fn,
|
290 |
+
config=run_config,
|
291 |
+
train_batch_size=FLAGS.train_batch_size,
|
292 |
+
predict_batch_size=FLAGS.predict_batch_size)
|
293 |
+
|
294 |
+
if FLAGS.do_train:
|
295 |
+
# We write to a temporary file to avoid storing very large constant tensors
|
296 |
+
# in memory.
|
297 |
+
|
298 |
+
if not tf.gfile.Exists(FLAGS.train_feature_file):
|
299 |
+
train_writer = squad_utils.FeatureWriter(
|
300 |
+
filename=os.path.join(FLAGS.train_feature_file), is_training=True)
|
301 |
+
squad_utils.convert_examples_to_features(
|
302 |
+
examples=train_examples,
|
303 |
+
tokenizer=tokenizer,
|
304 |
+
max_seq_length=FLAGS.max_seq_length,
|
305 |
+
doc_stride=FLAGS.doc_stride,
|
306 |
+
max_query_length=FLAGS.max_query_length,
|
307 |
+
is_training=True,
|
308 |
+
output_fn=train_writer.process_feature,
|
309 |
+
do_lower_case=FLAGS.do_lower_case)
|
310 |
+
train_writer.close()
|
311 |
+
|
312 |
+
tf.logging.info("***** Running training *****")
|
313 |
+
tf.logging.info(" Num orig examples = %d", len(train_examples))
|
314 |
+
# tf.logging.info(" Num split examples = %d", train_writer.num_features)
|
315 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
316 |
+
tf.logging.info(" Num steps = %d", num_train_steps)
|
317 |
+
del train_examples
|
318 |
+
|
319 |
+
train_input_fn = squad_utils.input_fn_builder(
|
320 |
+
input_file=FLAGS.train_feature_file,
|
321 |
+
seq_length=FLAGS.max_seq_length,
|
322 |
+
is_training=True,
|
323 |
+
drop_remainder=True,
|
324 |
+
use_tpu=FLAGS.use_tpu,
|
325 |
+
bsz=FLAGS.train_batch_size,
|
326 |
+
is_v2=True)
|
327 |
+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
328 |
+
|
329 |
+
if FLAGS.do_predict:
|
330 |
+
with tf.gfile.Open(FLAGS.predict_file) as predict_file:
|
331 |
+
prediction_json = json.load(predict_file)["data"]
|
332 |
+
eval_examples = squad_utils.read_squad_examples(
|
333 |
+
input_file=FLAGS.predict_file, is_training=False)
|
334 |
+
|
335 |
+
if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists(
|
336 |
+
FLAGS.predict_feature_left_file)):
|
337 |
+
tf.logging.info("Loading eval features from {}".format(
|
338 |
+
FLAGS.predict_feature_left_file))
|
339 |
+
with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
|
340 |
+
eval_features = pickle.load(fin)
|
341 |
+
else:
|
342 |
+
eval_writer = squad_utils.FeatureWriter(
|
343 |
+
filename=FLAGS.predict_feature_file, is_training=False)
|
344 |
+
eval_features = []
|
345 |
+
|
346 |
+
def append_feature(feature):
|
347 |
+
eval_features.append(feature)
|
348 |
+
eval_writer.process_feature(feature)
|
349 |
+
|
350 |
+
squad_utils.convert_examples_to_features(
|
351 |
+
examples=eval_examples,
|
352 |
+
tokenizer=tokenizer,
|
353 |
+
max_seq_length=FLAGS.max_seq_length,
|
354 |
+
doc_stride=FLAGS.doc_stride,
|
355 |
+
max_query_length=FLAGS.max_query_length,
|
356 |
+
is_training=False,
|
357 |
+
output_fn=append_feature,
|
358 |
+
do_lower_case=FLAGS.do_lower_case)
|
359 |
+
eval_writer.close()
|
360 |
+
|
361 |
+
with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
|
362 |
+
pickle.dump(eval_features, fout)
|
363 |
+
|
364 |
+
tf.logging.info("***** Running predictions *****")
|
365 |
+
tf.logging.info(" Num orig examples = %d", len(eval_examples))
|
366 |
+
tf.logging.info(" Num split examples = %d", len(eval_features))
|
367 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
368 |
+
|
369 |
+
predict_input_fn = squad_utils.input_fn_builder(
|
370 |
+
input_file=FLAGS.predict_feature_file,
|
371 |
+
seq_length=FLAGS.max_seq_length,
|
372 |
+
is_training=False,
|
373 |
+
drop_remainder=False,
|
374 |
+
use_tpu=FLAGS.use_tpu,
|
375 |
+
bsz=FLAGS.predict_batch_size,
|
376 |
+
is_v2=True)
|
377 |
+
|
378 |
+
def get_result(checkpoint):
|
379 |
+
"""Evaluate the checkpoint on SQuAD v2.0."""
|
380 |
+
# If running eval on the TPU, you will need to specify the number of
|
381 |
+
# steps.
|
382 |
+
reader = tf.train.NewCheckpointReader(checkpoint)
|
383 |
+
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
|
384 |
+
all_results = []
|
385 |
+
for result in estimator.predict(
|
386 |
+
predict_input_fn, yield_single_examples=True,
|
387 |
+
checkpoint_path=checkpoint):
|
388 |
+
if len(all_results) % 1000 == 0:
|
389 |
+
tf.logging.info("Processing example: %d" % (len(all_results)))
|
390 |
+
unique_id = int(result["unique_ids"])
|
391 |
+
start_top_log_probs = (
|
392 |
+
[float(x) for x in result["start_top_log_probs"].flat])
|
393 |
+
start_top_index = [int(x) for x in result["start_top_index"].flat]
|
394 |
+
end_top_log_probs = (
|
395 |
+
[float(x) for x in result["end_top_log_probs"].flat])
|
396 |
+
end_top_index = [int(x) for x in result["end_top_index"].flat]
|
397 |
+
|
398 |
+
cls_logits = float(result["cls_logits"].flat[0])
|
399 |
+
all_results.append(
|
400 |
+
squad_utils.RawResultV2(
|
401 |
+
unique_id=unique_id,
|
402 |
+
start_top_log_probs=start_top_log_probs,
|
403 |
+
start_top_index=start_top_index,
|
404 |
+
end_top_log_probs=end_top_log_probs,
|
405 |
+
end_top_index=end_top_index,
|
406 |
+
cls_logits=cls_logits))
|
407 |
+
|
408 |
+
output_prediction_file = os.path.join(
|
409 |
+
FLAGS.output_dir, "predictions.json")
|
410 |
+
output_nbest_file = os.path.join(
|
411 |
+
FLAGS.output_dir, "nbest_predictions.json")
|
412 |
+
output_null_log_odds_file = os.path.join(
|
413 |
+
FLAGS.output_dir, "null_odds.json")
|
414 |
+
|
415 |
+
result_dict = {}
|
416 |
+
cls_dict = {}
|
417 |
+
squad_utils.accumulate_predictions_v2(
|
418 |
+
result_dict, cls_dict, eval_examples, eval_features,
|
419 |
+
all_results, FLAGS.n_best_size, FLAGS.max_answer_length,
|
420 |
+
FLAGS.start_n_top, FLAGS.end_n_top)
|
421 |
+
|
422 |
+
return squad_utils.evaluate_v2(
|
423 |
+
result_dict, cls_dict, prediction_json, eval_examples,
|
424 |
+
eval_features, all_results, FLAGS.n_best_size,
|
425 |
+
FLAGS.max_answer_length, output_prediction_file, output_nbest_file,
|
426 |
+
output_null_log_odds_file), int(global_step)
|
427 |
+
|
428 |
+
def _find_valid_cands(curr_step):
|
429 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
430 |
+
candidates = []
|
431 |
+
for filename in filenames:
|
432 |
+
if filename.endswith(".index"):
|
433 |
+
ckpt_name = filename[:-6]
|
434 |
+
idx = ckpt_name.split("-")[-1]
|
435 |
+
if idx != "best" and int(idx) > curr_step:
|
436 |
+
candidates.append(filename)
|
437 |
+
return candidates
|
438 |
+
|
439 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
440 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
441 |
+
key_name = "f1"
|
442 |
+
writer = tf.gfile.GFile(output_eval_file, "w")
|
443 |
+
if tf.gfile.Exists(checkpoint_path + ".index"):
|
444 |
+
result = get_result(checkpoint_path)
|
445 |
+
best_perf = result[0][key_name]
|
446 |
+
global_step = result[1]
|
447 |
+
else:
|
448 |
+
global_step = -1
|
449 |
+
best_perf = -1
|
450 |
+
checkpoint_path = None
|
451 |
+
while global_step < num_train_steps:
|
452 |
+
steps_and_files = {}
|
453 |
+
filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
|
454 |
+
for filename in filenames:
|
455 |
+
if filename.endswith(".index"):
|
456 |
+
ckpt_name = filename[:-6]
|
457 |
+
cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
|
458 |
+
if cur_filename.split("-")[-1] == "best":
|
459 |
+
continue
|
460 |
+
gstep = int(cur_filename.split("-")[-1])
|
461 |
+
if gstep not in steps_and_files:
|
462 |
+
tf.logging.info("Add {} to eval list.".format(cur_filename))
|
463 |
+
steps_and_files[gstep] = cur_filename
|
464 |
+
tf.logging.info("found {} files.".format(len(steps_and_files)))
|
465 |
+
if not steps_and_files:
|
466 |
+
tf.logging.info("found 0 file, global step: {}. Sleeping."
|
467 |
+
.format(global_step))
|
468 |
+
time.sleep(60)
|
469 |
+
else:
|
470 |
+
for ele in sorted(steps_and_files.items()):
|
471 |
+
step, checkpoint_path = ele
|
472 |
+
if global_step >= step:
|
473 |
+
if len(_find_valid_cands(step)) > 1:
|
474 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
475 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
476 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
477 |
+
tf.gfile.Remove(src_ckpt)
|
478 |
+
continue
|
479 |
+
result, global_step = get_result(checkpoint_path)
|
480 |
+
tf.logging.info("***** Eval results *****")
|
481 |
+
for key in sorted(result.keys()):
|
482 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
483 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
484 |
+
if result[key_name] > best_perf:
|
485 |
+
best_perf = result[key_name]
|
486 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
487 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
488 |
+
tgt_ckpt = checkpoint_path.rsplit(
|
489 |
+
"-", 1)[0] + "-best.{}".format(ext)
|
490 |
+
tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
|
491 |
+
tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
|
492 |
+
writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
|
493 |
+
writer.write("best {} = {}\n".format(key_name, best_perf))
|
494 |
+
tf.logging.info(" best {} = {}\n".format(key_name, best_perf))
|
495 |
+
|
496 |
+
if len(_find_valid_cands(global_step)) > 2:
|
497 |
+
for ext in ["meta", "data-00000-of-00001", "index"]:
|
498 |
+
src_ckpt = checkpoint_path + ".{}".format(ext)
|
499 |
+
tf.logging.info("removing {}".format(src_ckpt))
|
500 |
+
tf.gfile.Remove(src_ckpt)
|
501 |
+
writer.write("=" * 50 + "\n")
|
502 |
+
|
503 |
+
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
|
504 |
+
result, global_step = get_result(checkpoint_path)
|
505 |
+
tf.logging.info("***** Final Eval results *****")
|
506 |
+
for key in sorted(result.keys()):
|
507 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
508 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
509 |
+
writer.write("best perf happened at step: {}".format(global_step))
|
510 |
+
|
511 |
+
|
512 |
+
if __name__ == "__main__":
|
513 |
+
flags.mark_flag_as_required("spm_model_file")
|
514 |
+
flags.mark_flag_as_required("albert_config_file")
|
515 |
+
flags.mark_flag_as_required("output_dir")
|
516 |
+
tf.app.run()
|
Indic-BERT-v1-master/albert/run_trivial_model_test.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Small integration test script.
|
3 |
+
# The values in this file are **not** meant for reproducing actual results.
|
4 |
+
|
5 |
+
set -e
|
6 |
+
set -x
|
7 |
+
|
8 |
+
virtualenv -p python3 .
|
9 |
+
source ./bin/activate
|
10 |
+
|
11 |
+
OUTPUT_DIR_BASE="$(mktemp -d)"
|
12 |
+
OUTPUT_DIR="${OUTPUT_DIR_BASE}/output"
|
13 |
+
|
14 |
+
pip install numpy
|
15 |
+
pip install -r requirements.txt
|
16 |
+
python -m run_pretraining_test \
|
17 |
+
--output_dir="${OUTPUT_DIR}" \
|
18 |
+
--do_train \
|
19 |
+
--do_eval \
|
20 |
+
--nouse_tpu \
|
21 |
+
--train_batch_size=2 \
|
22 |
+
--eval_batch_size=1 \
|
23 |
+
--max_seq_length=4 \
|
24 |
+
--num_train_steps=2 \
|
25 |
+
--max_eval_steps=3
|
26 |
+
|
27 |
+
|
Indic-BERT-v1-master/albert/squad_utils.py
ADDED
@@ -0,0 +1,1735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
"""Utility functions for SQuAD v1.1/v2.0 datasets."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
# from __future__ import google_type_annotations
|
21 |
+
from __future__ import print_function
|
22 |
+
import collections
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
import re
|
26 |
+
import string
|
27 |
+
import sys
|
28 |
+
from albert import fine_tuning_utils
|
29 |
+
from albert import modeling
|
30 |
+
from albert import optimization
|
31 |
+
from albert import tokenization
|
32 |
+
import numpy as np
|
33 |
+
import six
|
34 |
+
from six.moves import map
|
35 |
+
from six.moves import range
|
36 |
+
import tensorflow.compat.v1 as tf
|
37 |
+
from tensorflow.contrib import data as contrib_data
|
38 |
+
from tensorflow.contrib import layers as contrib_layers
|
39 |
+
from tensorflow.contrib import tpu as contrib_tpu
|
40 |
+
|
41 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
42 |
+
"PrelimPrediction",
|
43 |
+
["feature_index", "start_index", "end_index",
|
44 |
+
"start_log_prob", "end_log_prob"])
|
45 |
+
|
46 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
47 |
+
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
48 |
+
|
49 |
+
RawResult = collections.namedtuple("RawResult",
|
50 |
+
["unique_id",
|
51 |
+
"start_log_prob",
|
52 |
+
"end_log_prob"])
|
53 |
+
|
54 |
+
RawResultV2 = collections.namedtuple(
|
55 |
+
"RawResultV2",
|
56 |
+
["unique_id", "start_top_log_probs", "start_top_index",
|
57 |
+
"end_top_log_probs", "end_top_index", "cls_logits"])
|
58 |
+
|
59 |
+
|
60 |
+
class SquadExample(object):
|
61 |
+
"""A single training/test example for simple sequence classification.
|
62 |
+
|
63 |
+
For examples without an answer, the start and end position are -1.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self,
|
67 |
+
qas_id,
|
68 |
+
question_text,
|
69 |
+
paragraph_text,
|
70 |
+
orig_answer_text=None,
|
71 |
+
start_position=None,
|
72 |
+
end_position=None,
|
73 |
+
is_impossible=False):
|
74 |
+
self.qas_id = qas_id
|
75 |
+
self.question_text = question_text
|
76 |
+
self.paragraph_text = paragraph_text
|
77 |
+
self.orig_answer_text = orig_answer_text
|
78 |
+
self.start_position = start_position
|
79 |
+
self.end_position = end_position
|
80 |
+
self.is_impossible = is_impossible
|
81 |
+
|
82 |
+
def __str__(self):
|
83 |
+
return self.__repr__()
|
84 |
+
|
85 |
+
def __repr__(self):
|
86 |
+
s = ""
|
87 |
+
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
88 |
+
s += ", question_text: %s" % (
|
89 |
+
tokenization.printable_text(self.question_text))
|
90 |
+
s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
|
91 |
+
if self.start_position:
|
92 |
+
s += ", start_position: %d" % (self.start_position)
|
93 |
+
if self.start_position:
|
94 |
+
s += ", end_position: %d" % (self.end_position)
|
95 |
+
if self.start_position:
|
96 |
+
s += ", is_impossible: %r" % (self.is_impossible)
|
97 |
+
return s
|
98 |
+
|
99 |
+
|
100 |
+
class InputFeatures(object):
|
101 |
+
"""A single set of features of data."""
|
102 |
+
|
103 |
+
def __init__(self,
|
104 |
+
unique_id,
|
105 |
+
example_index,
|
106 |
+
doc_span_index,
|
107 |
+
tok_start_to_orig_index,
|
108 |
+
tok_end_to_orig_index,
|
109 |
+
token_is_max_context,
|
110 |
+
tokens,
|
111 |
+
input_ids,
|
112 |
+
input_mask,
|
113 |
+
segment_ids,
|
114 |
+
paragraph_len,
|
115 |
+
p_mask=None,
|
116 |
+
start_position=None,
|
117 |
+
end_position=None,
|
118 |
+
is_impossible=None):
|
119 |
+
self.unique_id = unique_id
|
120 |
+
self.example_index = example_index
|
121 |
+
self.doc_span_index = doc_span_index
|
122 |
+
self.tok_start_to_orig_index = tok_start_to_orig_index
|
123 |
+
self.tok_end_to_orig_index = tok_end_to_orig_index
|
124 |
+
self.token_is_max_context = token_is_max_context
|
125 |
+
self.tokens = tokens
|
126 |
+
self.input_ids = input_ids
|
127 |
+
self.input_mask = input_mask
|
128 |
+
self.segment_ids = segment_ids
|
129 |
+
self.paragraph_len = paragraph_len
|
130 |
+
self.start_position = start_position
|
131 |
+
self.end_position = end_position
|
132 |
+
self.is_impossible = is_impossible
|
133 |
+
self.p_mask = p_mask
|
134 |
+
|
135 |
+
|
136 |
+
def read_squad_examples(input_file, is_training):
|
137 |
+
"""Read a SQuAD json file into a list of SquadExample."""
|
138 |
+
with tf.gfile.Open(input_file, "r") as reader:
|
139 |
+
input_data = json.load(reader)["data"]
|
140 |
+
|
141 |
+
examples = []
|
142 |
+
for entry in input_data:
|
143 |
+
for paragraph in entry["paragraphs"]:
|
144 |
+
paragraph_text = paragraph["context"]
|
145 |
+
|
146 |
+
for qa in paragraph["qas"]:
|
147 |
+
qas_id = qa["id"]
|
148 |
+
question_text = qa["question"]
|
149 |
+
start_position = None
|
150 |
+
orig_answer_text = None
|
151 |
+
is_impossible = False
|
152 |
+
|
153 |
+
if is_training:
|
154 |
+
is_impossible = qa.get("is_impossible", False)
|
155 |
+
if (len(qa["answers"]) != 1) and (not is_impossible):
|
156 |
+
raise ValueError(
|
157 |
+
"For training, each question should have exactly 1 answer.")
|
158 |
+
if not is_impossible:
|
159 |
+
answer = qa["answers"][0]
|
160 |
+
orig_answer_text = answer["text"]
|
161 |
+
start_position = answer["answer_start"]
|
162 |
+
else:
|
163 |
+
start_position = -1
|
164 |
+
orig_answer_text = ""
|
165 |
+
|
166 |
+
example = SquadExample(
|
167 |
+
qas_id=qas_id,
|
168 |
+
question_text=question_text,
|
169 |
+
paragraph_text=paragraph_text,
|
170 |
+
orig_answer_text=orig_answer_text,
|
171 |
+
start_position=start_position,
|
172 |
+
is_impossible=is_impossible)
|
173 |
+
examples.append(example)
|
174 |
+
|
175 |
+
return examples
|
176 |
+
|
177 |
+
|
178 |
+
def _convert_index(index, pos, m=None, is_start=True):
|
179 |
+
"""Converts index."""
|
180 |
+
if index[pos] is not None:
|
181 |
+
return index[pos]
|
182 |
+
n = len(index)
|
183 |
+
rear = pos
|
184 |
+
while rear < n - 1 and index[rear] is None:
|
185 |
+
rear += 1
|
186 |
+
front = pos
|
187 |
+
while front > 0 and index[front] is None:
|
188 |
+
front -= 1
|
189 |
+
assert index[front] is not None or index[rear] is not None
|
190 |
+
if index[front] is None:
|
191 |
+
if index[rear] >= 1:
|
192 |
+
if is_start:
|
193 |
+
return 0
|
194 |
+
else:
|
195 |
+
return index[rear] - 1
|
196 |
+
return index[rear]
|
197 |
+
if index[rear] is None:
|
198 |
+
if m is not None and index[front] < m - 1:
|
199 |
+
if is_start:
|
200 |
+
return index[front] + 1
|
201 |
+
else:
|
202 |
+
return m - 1
|
203 |
+
return index[front]
|
204 |
+
if is_start:
|
205 |
+
if index[rear] > index[front] + 1:
|
206 |
+
return index[front] + 1
|
207 |
+
else:
|
208 |
+
return index[rear]
|
209 |
+
else:
|
210 |
+
if index[rear] > index[front] + 1:
|
211 |
+
return index[rear] - 1
|
212 |
+
else:
|
213 |
+
return index[front]
|
214 |
+
|
215 |
+
|
216 |
+
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
217 |
+
doc_stride, max_query_length, is_training,
|
218 |
+
output_fn, do_lower_case):
|
219 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
220 |
+
|
221 |
+
cnt_pos, cnt_neg = 0, 0
|
222 |
+
unique_id = 1000000000
|
223 |
+
max_n, max_m = 1024, 1024
|
224 |
+
f = np.zeros((max_n, max_m), dtype=np.float32)
|
225 |
+
|
226 |
+
for (example_index, example) in enumerate(examples):
|
227 |
+
|
228 |
+
if example_index % 100 == 0:
|
229 |
+
tf.logging.info("Converting {}/{} pos {} neg {}".format(
|
230 |
+
example_index, len(examples), cnt_pos, cnt_neg))
|
231 |
+
|
232 |
+
query_tokens = tokenization.encode_ids(
|
233 |
+
tokenizer.sp_model,
|
234 |
+
tokenization.preprocess_text(
|
235 |
+
example.question_text, lower=do_lower_case))
|
236 |
+
|
237 |
+
if len(query_tokens) > max_query_length:
|
238 |
+
query_tokens = query_tokens[0:max_query_length]
|
239 |
+
|
240 |
+
paragraph_text = example.paragraph_text
|
241 |
+
para_tokens = tokenization.encode_pieces(
|
242 |
+
tokenizer.sp_model,
|
243 |
+
tokenization.preprocess_text(
|
244 |
+
example.paragraph_text, lower=do_lower_case),
|
245 |
+
return_unicode=False)
|
246 |
+
|
247 |
+
chartok_to_tok_index = []
|
248 |
+
tok_start_to_chartok_index = []
|
249 |
+
tok_end_to_chartok_index = []
|
250 |
+
char_cnt = 0
|
251 |
+
para_tokens = [six.ensure_text(token, "utf-8") for token in para_tokens]
|
252 |
+
for i, token in enumerate(para_tokens):
|
253 |
+
new_token = six.ensure_text(token).replace(
|
254 |
+
tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ")
|
255 |
+
chartok_to_tok_index.extend([i] * len(new_token))
|
256 |
+
tok_start_to_chartok_index.append(char_cnt)
|
257 |
+
char_cnt += len(new_token)
|
258 |
+
tok_end_to_chartok_index.append(char_cnt - 1)
|
259 |
+
|
260 |
+
tok_cat_text = "".join(para_tokens).replace(
|
261 |
+
tokenization.SPIECE_UNDERLINE.decode("utf-8"), " ")
|
262 |
+
n, m = len(paragraph_text), len(tok_cat_text)
|
263 |
+
|
264 |
+
if n > max_n or m > max_m:
|
265 |
+
max_n = max(n, max_n)
|
266 |
+
max_m = max(m, max_m)
|
267 |
+
f = np.zeros((max_n, max_m), dtype=np.float32)
|
268 |
+
|
269 |
+
g = {}
|
270 |
+
|
271 |
+
def _lcs_match(max_dist, n=n, m=m):
|
272 |
+
"""Longest-common-substring algorithm."""
|
273 |
+
f.fill(0)
|
274 |
+
g.clear()
|
275 |
+
|
276 |
+
### longest common sub sequence
|
277 |
+
# f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
|
278 |
+
for i in range(n):
|
279 |
+
|
280 |
+
# note(zhiliny):
|
281 |
+
# unlike standard LCS, this is specifically optimized for the setting
|
282 |
+
# because the mismatch between sentence pieces and original text will
|
283 |
+
# be small
|
284 |
+
for j in range(i - max_dist, i + max_dist):
|
285 |
+
if j >= m or j < 0: continue
|
286 |
+
|
287 |
+
if i > 0:
|
288 |
+
g[(i, j)] = 0
|
289 |
+
f[i, j] = f[i - 1, j]
|
290 |
+
|
291 |
+
if j > 0 and f[i, j - 1] > f[i, j]:
|
292 |
+
g[(i, j)] = 1
|
293 |
+
f[i, j] = f[i, j - 1]
|
294 |
+
|
295 |
+
f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
|
296 |
+
if (tokenization.preprocess_text(
|
297 |
+
paragraph_text[i], lower=do_lower_case,
|
298 |
+
remove_space=False) == tok_cat_text[j]
|
299 |
+
and f_prev + 1 > f[i, j]):
|
300 |
+
g[(i, j)] = 2
|
301 |
+
f[i, j] = f_prev + 1
|
302 |
+
|
303 |
+
max_dist = abs(n - m) + 5
|
304 |
+
for _ in range(2):
|
305 |
+
_lcs_match(max_dist)
|
306 |
+
if f[n - 1, m - 1] > 0.8 * n: break
|
307 |
+
max_dist *= 2
|
308 |
+
|
309 |
+
orig_to_chartok_index = [None] * n
|
310 |
+
chartok_to_orig_index = [None] * m
|
311 |
+
i, j = n - 1, m - 1
|
312 |
+
while i >= 0 and j >= 0:
|
313 |
+
if (i, j) not in g: break
|
314 |
+
if g[(i, j)] == 2:
|
315 |
+
orig_to_chartok_index[i] = j
|
316 |
+
chartok_to_orig_index[j] = i
|
317 |
+
i, j = i - 1, j - 1
|
318 |
+
elif g[(i, j)] == 1:
|
319 |
+
j = j - 1
|
320 |
+
else:
|
321 |
+
i = i - 1
|
322 |
+
|
323 |
+
if (all(v is None for v in orig_to_chartok_index) or
|
324 |
+
f[n - 1, m - 1] < 0.8 * n):
|
325 |
+
tf.logging.info("MISMATCH DETECTED!")
|
326 |
+
continue
|
327 |
+
|
328 |
+
tok_start_to_orig_index = []
|
329 |
+
tok_end_to_orig_index = []
|
330 |
+
for i in range(len(para_tokens)):
|
331 |
+
start_chartok_pos = tok_start_to_chartok_index[i]
|
332 |
+
end_chartok_pos = tok_end_to_chartok_index[i]
|
333 |
+
start_orig_pos = _convert_index(chartok_to_orig_index, start_chartok_pos,
|
334 |
+
n, is_start=True)
|
335 |
+
end_orig_pos = _convert_index(chartok_to_orig_index, end_chartok_pos,
|
336 |
+
n, is_start=False)
|
337 |
+
|
338 |
+
tok_start_to_orig_index.append(start_orig_pos)
|
339 |
+
tok_end_to_orig_index.append(end_orig_pos)
|
340 |
+
|
341 |
+
if not is_training:
|
342 |
+
tok_start_position = tok_end_position = None
|
343 |
+
|
344 |
+
if is_training and example.is_impossible:
|
345 |
+
tok_start_position = 0
|
346 |
+
tok_end_position = 0
|
347 |
+
|
348 |
+
if is_training and not example.is_impossible:
|
349 |
+
start_position = example.start_position
|
350 |
+
end_position = start_position + len(example.orig_answer_text) - 1
|
351 |
+
|
352 |
+
start_chartok_pos = _convert_index(orig_to_chartok_index, start_position,
|
353 |
+
is_start=True)
|
354 |
+
tok_start_position = chartok_to_tok_index[start_chartok_pos]
|
355 |
+
|
356 |
+
end_chartok_pos = _convert_index(orig_to_chartok_index, end_position,
|
357 |
+
is_start=False)
|
358 |
+
tok_end_position = chartok_to_tok_index[end_chartok_pos]
|
359 |
+
assert tok_start_position <= tok_end_position
|
360 |
+
|
361 |
+
def _piece_to_id(x):
|
362 |
+
if six.PY2 and isinstance(x, six.text_type):
|
363 |
+
x = six.ensure_binary(x, "utf-8")
|
364 |
+
return tokenizer.sp_model.PieceToId(x)
|
365 |
+
|
366 |
+
all_doc_tokens = list(map(_piece_to_id, para_tokens))
|
367 |
+
|
368 |
+
# The -3 accounts for [CLS], [SEP] and [SEP]
|
369 |
+
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
370 |
+
|
371 |
+
# We can have documents that are longer than the maximum sequence length.
|
372 |
+
# To deal with this we do a sliding window approach, where we take chunks
|
373 |
+
# of the up to our max length with a stride of `doc_stride`.
|
374 |
+
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
375 |
+
"DocSpan", ["start", "length"])
|
376 |
+
doc_spans = []
|
377 |
+
start_offset = 0
|
378 |
+
while start_offset < len(all_doc_tokens):
|
379 |
+
length = len(all_doc_tokens) - start_offset
|
380 |
+
if length > max_tokens_for_doc:
|
381 |
+
length = max_tokens_for_doc
|
382 |
+
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
383 |
+
if start_offset + length == len(all_doc_tokens):
|
384 |
+
break
|
385 |
+
start_offset += min(length, doc_stride)
|
386 |
+
|
387 |
+
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
388 |
+
tokens = []
|
389 |
+
token_is_max_context = {}
|
390 |
+
segment_ids = []
|
391 |
+
p_mask = []
|
392 |
+
|
393 |
+
cur_tok_start_to_orig_index = []
|
394 |
+
cur_tok_end_to_orig_index = []
|
395 |
+
|
396 |
+
tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
|
397 |
+
segment_ids.append(0)
|
398 |
+
p_mask.append(0)
|
399 |
+
for token in query_tokens:
|
400 |
+
tokens.append(token)
|
401 |
+
segment_ids.append(0)
|
402 |
+
p_mask.append(1)
|
403 |
+
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
404 |
+
segment_ids.append(0)
|
405 |
+
p_mask.append(1)
|
406 |
+
|
407 |
+
for i in range(doc_span.length):
|
408 |
+
split_token_index = doc_span.start + i
|
409 |
+
|
410 |
+
cur_tok_start_to_orig_index.append(
|
411 |
+
tok_start_to_orig_index[split_token_index])
|
412 |
+
cur_tok_end_to_orig_index.append(
|
413 |
+
tok_end_to_orig_index[split_token_index])
|
414 |
+
|
415 |
+
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
416 |
+
split_token_index)
|
417 |
+
token_is_max_context[len(tokens)] = is_max_context
|
418 |
+
tokens.append(all_doc_tokens[split_token_index])
|
419 |
+
segment_ids.append(1)
|
420 |
+
p_mask.append(0)
|
421 |
+
tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
|
422 |
+
segment_ids.append(1)
|
423 |
+
p_mask.append(1)
|
424 |
+
|
425 |
+
paragraph_len = len(tokens)
|
426 |
+
input_ids = tokens
|
427 |
+
|
428 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
429 |
+
# tokens are attended to.
|
430 |
+
input_mask = [1] * len(input_ids)
|
431 |
+
|
432 |
+
# Zero-pad up to the sequence length.
|
433 |
+
while len(input_ids) < max_seq_length:
|
434 |
+
input_ids.append(0)
|
435 |
+
input_mask.append(0)
|
436 |
+
segment_ids.append(0)
|
437 |
+
p_mask.append(1)
|
438 |
+
|
439 |
+
assert len(input_ids) == max_seq_length
|
440 |
+
assert len(input_mask) == max_seq_length
|
441 |
+
assert len(segment_ids) == max_seq_length
|
442 |
+
|
443 |
+
span_is_impossible = example.is_impossible
|
444 |
+
start_position = None
|
445 |
+
end_position = None
|
446 |
+
if is_training and not span_is_impossible:
|
447 |
+
# For training, if our document chunk does not contain an annotation
|
448 |
+
# we throw it out, since there is nothing to predict.
|
449 |
+
doc_start = doc_span.start
|
450 |
+
doc_end = doc_span.start + doc_span.length - 1
|
451 |
+
out_of_span = False
|
452 |
+
if not (tok_start_position >= doc_start and
|
453 |
+
tok_end_position <= doc_end):
|
454 |
+
out_of_span = True
|
455 |
+
if out_of_span:
|
456 |
+
# continue
|
457 |
+
start_position = 0
|
458 |
+
end_position = 0
|
459 |
+
span_is_impossible = True
|
460 |
+
else:
|
461 |
+
doc_offset = len(query_tokens) + 2
|
462 |
+
start_position = tok_start_position - doc_start + doc_offset
|
463 |
+
end_position = tok_end_position - doc_start + doc_offset
|
464 |
+
|
465 |
+
if is_training and span_is_impossible:
|
466 |
+
start_position = 0
|
467 |
+
end_position = 0
|
468 |
+
|
469 |
+
if example_index < 20:
|
470 |
+
tf.logging.info("*** Example ***")
|
471 |
+
tf.logging.info("unique_id: %s" % (unique_id))
|
472 |
+
tf.logging.info("example_index: %s" % (example_index))
|
473 |
+
tf.logging.info("doc_span_index: %s" % (doc_span_index))
|
474 |
+
tf.logging.info("tok_start_to_orig_index: %s" % " ".join(
|
475 |
+
[str(x) for x in cur_tok_start_to_orig_index]))
|
476 |
+
tf.logging.info("tok_end_to_orig_index: %s" % " ".join(
|
477 |
+
[str(x) for x in cur_tok_end_to_orig_index]))
|
478 |
+
tf.logging.info("token_is_max_context: %s" % " ".join([
|
479 |
+
"%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
|
480 |
+
]))
|
481 |
+
tf.logging.info("input_pieces: %s" % " ".join(
|
482 |
+
[tokenizer.sp_model.IdToPiece(x) for x in tokens]))
|
483 |
+
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
484 |
+
tf.logging.info(
|
485 |
+
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
486 |
+
tf.logging.info(
|
487 |
+
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
488 |
+
|
489 |
+
if is_training and span_is_impossible:
|
490 |
+
tf.logging.info("impossible example span")
|
491 |
+
|
492 |
+
if is_training and not span_is_impossible:
|
493 |
+
pieces = [tokenizer.sp_model.IdToPiece(token) for token in
|
494 |
+
tokens[start_position: (end_position + 1)]]
|
495 |
+
answer_text = tokenizer.sp_model.DecodePieces(pieces)
|
496 |
+
tf.logging.info("start_position: %d" % (start_position))
|
497 |
+
tf.logging.info("end_position: %d" % (end_position))
|
498 |
+
tf.logging.info(
|
499 |
+
"answer: %s" % (tokenization.printable_text(answer_text)))
|
500 |
+
|
501 |
+
# note(zhiliny): With multi processing,
|
502 |
+
# the example_index is actually the index within the current process
|
503 |
+
# therefore we use example_index=None to avoid being used in the future.
|
504 |
+
# The current code does not use example_index of training data.
|
505 |
+
if is_training:
|
506 |
+
feat_example_index = None
|
507 |
+
else:
|
508 |
+
feat_example_index = example_index
|
509 |
+
|
510 |
+
feature = InputFeatures(
|
511 |
+
unique_id=unique_id,
|
512 |
+
example_index=feat_example_index,
|
513 |
+
doc_span_index=doc_span_index,
|
514 |
+
tok_start_to_orig_index=cur_tok_start_to_orig_index,
|
515 |
+
tok_end_to_orig_index=cur_tok_end_to_orig_index,
|
516 |
+
token_is_max_context=token_is_max_context,
|
517 |
+
tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
|
518 |
+
input_ids=input_ids,
|
519 |
+
input_mask=input_mask,
|
520 |
+
segment_ids=segment_ids,
|
521 |
+
paragraph_len=paragraph_len,
|
522 |
+
start_position=start_position,
|
523 |
+
end_position=end_position,
|
524 |
+
is_impossible=span_is_impossible,
|
525 |
+
p_mask=p_mask)
|
526 |
+
|
527 |
+
# Run callback
|
528 |
+
output_fn(feature)
|
529 |
+
|
530 |
+
unique_id += 1
|
531 |
+
if span_is_impossible:
|
532 |
+
cnt_neg += 1
|
533 |
+
else:
|
534 |
+
cnt_pos += 1
|
535 |
+
|
536 |
+
tf.logging.info("Total number of instances: {} = pos {} neg {}".format(
|
537 |
+
cnt_pos + cnt_neg, cnt_pos, cnt_neg))
|
538 |
+
|
539 |
+
|
540 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
541 |
+
"""Check if this is the 'max context' doc span for the token."""
|
542 |
+
|
543 |
+
# Because of the sliding window approach taken to scoring documents, a single
|
544 |
+
# token can appear in multiple documents. E.g.
|
545 |
+
# Doc: the man went to the store and bought a gallon of milk
|
546 |
+
# Span A: the man went to the
|
547 |
+
# Span B: to the store and bought
|
548 |
+
# Span C: and bought a gallon of
|
549 |
+
# ...
|
550 |
+
#
|
551 |
+
# Now the word 'bought' will have two scores from spans B and C. We only
|
552 |
+
# want to consider the score with "maximum context", which we define as
|
553 |
+
# the *minimum* of its left and right context (the *sum* of left and
|
554 |
+
# right context will always be the same, of course).
|
555 |
+
#
|
556 |
+
# In the example the maximum context for 'bought' would be span C since
|
557 |
+
# it has 1 left context and 3 right context, while span B has 4 left context
|
558 |
+
# and 0 right context.
|
559 |
+
best_score = None
|
560 |
+
best_span_index = None
|
561 |
+
for (span_index, doc_span) in enumerate(doc_spans):
|
562 |
+
end = doc_span.start + doc_span.length - 1
|
563 |
+
if position < doc_span.start:
|
564 |
+
continue
|
565 |
+
if position > end:
|
566 |
+
continue
|
567 |
+
num_left_context = position - doc_span.start
|
568 |
+
num_right_context = end - position
|
569 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
570 |
+
if best_score is None or score > best_score:
|
571 |
+
best_score = score
|
572 |
+
best_span_index = span_index
|
573 |
+
|
574 |
+
return cur_span_index == best_span_index
|
575 |
+
|
576 |
+
|
577 |
+
def _get_best_indexes(logits, n_best_size):
|
578 |
+
"""Get the n-best logits from a list."""
|
579 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
580 |
+
|
581 |
+
best_indexes = []
|
582 |
+
for i in range(len(index_and_score)):
|
583 |
+
if i >= n_best_size:
|
584 |
+
break
|
585 |
+
best_indexes.append(index_and_score[i][0])
|
586 |
+
return best_indexes
|
587 |
+
|
588 |
+
|
589 |
+
def _compute_softmax(scores):
|
590 |
+
"""Compute softmax probability over raw logits."""
|
591 |
+
if not scores:
|
592 |
+
return []
|
593 |
+
|
594 |
+
max_score = None
|
595 |
+
for score in scores:
|
596 |
+
if max_score is None or score > max_score:
|
597 |
+
max_score = score
|
598 |
+
|
599 |
+
exp_scores = []
|
600 |
+
total_sum = 0.0
|
601 |
+
for score in scores:
|
602 |
+
x = math.exp(score - max_score)
|
603 |
+
exp_scores.append(x)
|
604 |
+
total_sum += x
|
605 |
+
|
606 |
+
probs = []
|
607 |
+
for score in exp_scores:
|
608 |
+
probs.append(score / total_sum)
|
609 |
+
return probs
|
610 |
+
|
611 |
+
|
612 |
+
class FeatureWriter(object):
|
613 |
+
"""Writes InputFeature to TF example file."""
|
614 |
+
|
615 |
+
def __init__(self, filename, is_training):
|
616 |
+
self.filename = filename
|
617 |
+
self.is_training = is_training
|
618 |
+
self.num_features = 0
|
619 |
+
self._writer = tf.python_io.TFRecordWriter(filename)
|
620 |
+
|
621 |
+
def process_feature(self, feature):
|
622 |
+
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
623 |
+
self.num_features += 1
|
624 |
+
|
625 |
+
def create_int_feature(values):
|
626 |
+
feature = tf.train.Feature(
|
627 |
+
int64_list=tf.train.Int64List(value=list(values)))
|
628 |
+
return feature
|
629 |
+
|
630 |
+
features = collections.OrderedDict()
|
631 |
+
features["unique_ids"] = create_int_feature([feature.unique_id])
|
632 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
633 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
634 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
635 |
+
features["p_mask"] = create_int_feature(feature.p_mask)
|
636 |
+
|
637 |
+
if self.is_training:
|
638 |
+
features["start_positions"] = create_int_feature([feature.start_position])
|
639 |
+
features["end_positions"] = create_int_feature([feature.end_position])
|
640 |
+
impossible = 0
|
641 |
+
if feature.is_impossible:
|
642 |
+
impossible = 1
|
643 |
+
features["is_impossible"] = create_int_feature([impossible])
|
644 |
+
|
645 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
646 |
+
self._writer.write(tf_example.SerializeToString())
|
647 |
+
|
648 |
+
def close(self):
|
649 |
+
self._writer.close()
|
650 |
+
|
651 |
+
|
652 |
+
def input_fn_builder(input_file, seq_length, is_training,
|
653 |
+
drop_remainder, use_tpu, bsz, is_v2):
|
654 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
655 |
+
|
656 |
+
name_to_features = {
|
657 |
+
"unique_ids": tf.FixedLenFeature([], tf.int64),
|
658 |
+
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
659 |
+
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
|
660 |
+
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
661 |
+
}
|
662 |
+
# p_mask is not required for SQuAD v1.1
|
663 |
+
if is_v2:
|
664 |
+
name_to_features["p_mask"] = tf.FixedLenFeature([seq_length], tf.int64)
|
665 |
+
|
666 |
+
if is_training:
|
667 |
+
name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
|
668 |
+
name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
|
669 |
+
name_to_features["is_impossible"] = tf.FixedLenFeature([], tf.int64)
|
670 |
+
|
671 |
+
def _decode_record(record, name_to_features):
|
672 |
+
"""Decodes a record to a TensorFlow example."""
|
673 |
+
example = tf.parse_single_example(record, name_to_features)
|
674 |
+
|
675 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
676 |
+
# So cast all int64 to int32.
|
677 |
+
for name in list(example.keys()):
|
678 |
+
t = example[name]
|
679 |
+
if t.dtype == tf.int64:
|
680 |
+
t = tf.to_int32(t)
|
681 |
+
example[name] = t
|
682 |
+
|
683 |
+
return example
|
684 |
+
|
685 |
+
def input_fn(params):
|
686 |
+
"""The actual input function."""
|
687 |
+
if use_tpu:
|
688 |
+
batch_size = params["batch_size"]
|
689 |
+
else:
|
690 |
+
batch_size = bsz
|
691 |
+
|
692 |
+
# For training, we want a lot of parallel reading and shuffling.
|
693 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
694 |
+
d = tf.data.TFRecordDataset(input_file)
|
695 |
+
if is_training:
|
696 |
+
d = d.repeat()
|
697 |
+
d = d.shuffle(buffer_size=100)
|
698 |
+
|
699 |
+
d = d.apply(
|
700 |
+
contrib_data.map_and_batch(
|
701 |
+
lambda record: _decode_record(record, name_to_features),
|
702 |
+
batch_size=batch_size,
|
703 |
+
drop_remainder=drop_remainder))
|
704 |
+
|
705 |
+
return d
|
706 |
+
|
707 |
+
return input_fn
|
708 |
+
|
709 |
+
|
710 |
+
def create_v1_model(albert_config, is_training, input_ids, input_mask,
|
711 |
+
segment_ids, use_one_hot_embeddings, use_einsum,
|
712 |
+
hub_module):
|
713 |
+
"""Creates a classification model."""
|
714 |
+
(_, final_hidden) = fine_tuning_utils.create_albert(
|
715 |
+
albert_config=albert_config,
|
716 |
+
is_training=is_training,
|
717 |
+
input_ids=input_ids,
|
718 |
+
input_mask=input_mask,
|
719 |
+
segment_ids=segment_ids,
|
720 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
721 |
+
use_einsum=use_einsum,
|
722 |
+
hub_module=hub_module)
|
723 |
+
|
724 |
+
final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
|
725 |
+
batch_size = final_hidden_shape[0]
|
726 |
+
seq_length = final_hidden_shape[1]
|
727 |
+
hidden_size = final_hidden_shape[2]
|
728 |
+
|
729 |
+
output_weights = tf.get_variable(
|
730 |
+
"cls/squad/output_weights", [2, hidden_size],
|
731 |
+
initializer=tf.truncated_normal_initializer(stddev=0.02))
|
732 |
+
|
733 |
+
output_bias = tf.get_variable(
|
734 |
+
"cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
|
735 |
+
|
736 |
+
final_hidden_matrix = tf.reshape(final_hidden,
|
737 |
+
[batch_size * seq_length, hidden_size])
|
738 |
+
logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
|
739 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
740 |
+
|
741 |
+
logits = tf.reshape(logits, [batch_size, seq_length, 2])
|
742 |
+
logits = tf.transpose(logits, [2, 0, 1])
|
743 |
+
|
744 |
+
unstacked_logits = tf.unstack(logits, axis=0)
|
745 |
+
|
746 |
+
(start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
|
747 |
+
|
748 |
+
return (start_logits, end_logits)
|
749 |
+
|
750 |
+
|
751 |
+
def v1_model_fn_builder(albert_config, init_checkpoint, learning_rate,
|
752 |
+
num_train_steps, num_warmup_steps, use_tpu,
|
753 |
+
use_one_hot_embeddings, use_einsum, hub_module):
|
754 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
755 |
+
|
756 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
757 |
+
"""The `model_fn` for TPUEstimator."""
|
758 |
+
|
759 |
+
tf.logging.info("*** Features ***")
|
760 |
+
for name in sorted(features.keys()):
|
761 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
762 |
+
|
763 |
+
if "unique_ids" in features:
|
764 |
+
unique_ids = features["unique_ids"]
|
765 |
+
else:
|
766 |
+
unique_ids = None
|
767 |
+
input_ids = features["input_ids"]
|
768 |
+
input_mask = features["input_mask"]
|
769 |
+
segment_ids = features["segment_ids"]
|
770 |
+
|
771 |
+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
772 |
+
|
773 |
+
(start_logits, end_logits) = create_v1_model(
|
774 |
+
albert_config=albert_config,
|
775 |
+
is_training=is_training,
|
776 |
+
input_ids=input_ids,
|
777 |
+
input_mask=input_mask,
|
778 |
+
segment_ids=segment_ids,
|
779 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
780 |
+
use_einsum=use_einsum,
|
781 |
+
hub_module=hub_module)
|
782 |
+
|
783 |
+
# Assign names to the logits so that we can refer to them as output tensors.
|
784 |
+
start_logits = tf.identity(start_logits, name="start_logits")
|
785 |
+
end_logits = tf.identity(end_logits, name="end_logits")
|
786 |
+
|
787 |
+
tvars = tf.trainable_variables()
|
788 |
+
|
789 |
+
initialized_variable_names = {}
|
790 |
+
scaffold_fn = None
|
791 |
+
if init_checkpoint:
|
792 |
+
(assignment_map, initialized_variable_names
|
793 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
794 |
+
if use_tpu:
|
795 |
+
|
796 |
+
def tpu_scaffold():
|
797 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
798 |
+
return tf.train.Scaffold()
|
799 |
+
|
800 |
+
scaffold_fn = tpu_scaffold
|
801 |
+
else:
|
802 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
803 |
+
|
804 |
+
tf.logging.info("**** Trainable Variables ****")
|
805 |
+
for var in tvars:
|
806 |
+
init_string = ""
|
807 |
+
if var.name in initialized_variable_names:
|
808 |
+
init_string = ", *INIT_FROM_CKPT*"
|
809 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
810 |
+
init_string)
|
811 |
+
|
812 |
+
output_spec = None
|
813 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
814 |
+
seq_length = modeling.get_shape_list(input_ids)[1]
|
815 |
+
|
816 |
+
def compute_loss(logits, positions):
|
817 |
+
one_hot_positions = tf.one_hot(
|
818 |
+
positions, depth=seq_length, dtype=tf.float32)
|
819 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
820 |
+
loss = -tf.reduce_mean(
|
821 |
+
tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
|
822 |
+
return loss
|
823 |
+
|
824 |
+
start_positions = features["start_positions"]
|
825 |
+
end_positions = features["end_positions"]
|
826 |
+
|
827 |
+
start_loss = compute_loss(start_logits, start_positions)
|
828 |
+
end_loss = compute_loss(end_logits, end_positions)
|
829 |
+
|
830 |
+
total_loss = (start_loss + end_loss) / 2.0
|
831 |
+
|
832 |
+
train_op = optimization.create_optimizer(
|
833 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
834 |
+
|
835 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
836 |
+
mode=mode,
|
837 |
+
loss=total_loss,
|
838 |
+
train_op=train_op,
|
839 |
+
scaffold_fn=scaffold_fn)
|
840 |
+
elif mode == tf.estimator.ModeKeys.PREDICT:
|
841 |
+
predictions = {
|
842 |
+
"start_log_prob": start_logits,
|
843 |
+
"end_log_prob": end_logits,
|
844 |
+
}
|
845 |
+
if unique_ids is not None:
|
846 |
+
predictions["unique_ids"] = unique_ids
|
847 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
848 |
+
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
849 |
+
else:
|
850 |
+
raise ValueError(
|
851 |
+
"Only TRAIN and PREDICT modes are supported: %s" % (mode))
|
852 |
+
return output_spec
|
853 |
+
|
854 |
+
return model_fn
|
855 |
+
|
856 |
+
|
857 |
+
def accumulate_predictions_v1(result_dict, all_examples, all_features,
|
858 |
+
all_results, n_best_size, max_answer_length):
|
859 |
+
"""accumulate predictions for each positions in a dictionary."""
|
860 |
+
example_index_to_features = collections.defaultdict(list)
|
861 |
+
for feature in all_features:
|
862 |
+
example_index_to_features[feature.example_index].append(feature)
|
863 |
+
|
864 |
+
unique_id_to_result = {}
|
865 |
+
for result in all_results:
|
866 |
+
unique_id_to_result[result.unique_id] = result
|
867 |
+
|
868 |
+
all_predictions = collections.OrderedDict()
|
869 |
+
all_nbest_json = collections.OrderedDict()
|
870 |
+
scores_diff_json = collections.OrderedDict()
|
871 |
+
|
872 |
+
for (example_index, example) in enumerate(all_examples):
|
873 |
+
if example_index not in result_dict:
|
874 |
+
result_dict[example_index] = {}
|
875 |
+
features = example_index_to_features[example_index]
|
876 |
+
|
877 |
+
prelim_predictions = []
|
878 |
+
min_null_feature_index = 0 # the paragraph slice with min mull score
|
879 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
880 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
881 |
+
for (feature_index, feature) in enumerate(features):
|
882 |
+
if feature.unique_id not in result_dict[example_index]:
|
883 |
+
result_dict[example_index][feature.unique_id] = {}
|
884 |
+
result = unique_id_to_result[feature.unique_id]
|
885 |
+
start_indexes = _get_best_indexes(result.start_log_prob, n_best_size)
|
886 |
+
end_indexes = _get_best_indexes(result.end_log_prob, n_best_size)
|
887 |
+
for start_index in start_indexes:
|
888 |
+
for end_index in end_indexes:
|
889 |
+
doc_offset = feature.tokens.index("[SEP]") + 1
|
890 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
891 |
+
# that the start of the span is in the question. We throw out all
|
892 |
+
# invalid predictions.
|
893 |
+
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
|
894 |
+
continue
|
895 |
+
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
|
896 |
+
continue
|
897 |
+
if not feature.token_is_max_context.get(start_index, False):
|
898 |
+
continue
|
899 |
+
if end_index < start_index:
|
900 |
+
continue
|
901 |
+
length = end_index - start_index + 1
|
902 |
+
if length > max_answer_length:
|
903 |
+
continue
|
904 |
+
start_log_prob = result.start_log_prob[start_index]
|
905 |
+
end_log_prob = result.end_log_prob[end_index]
|
906 |
+
start_idx = start_index - doc_offset
|
907 |
+
end_idx = end_index - doc_offset
|
908 |
+
if (start_idx, end_idx) not in result_dict[example_index][feature.unique_id]:
|
909 |
+
result_dict[example_index][feature.unique_id][(start_idx, end_idx)] = []
|
910 |
+
result_dict[example_index][feature.unique_id][(start_idx, end_idx)].append((start_log_prob, end_log_prob))
|
911 |
+
|
912 |
+
|
913 |
+
def write_predictions_v1(result_dict, all_examples, all_features,
|
914 |
+
all_results, n_best_size, max_answer_length,
|
915 |
+
output_prediction_file, output_nbest_file):
|
916 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
917 |
+
tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
|
918 |
+
tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
|
919 |
+
|
920 |
+
example_index_to_features = collections.defaultdict(list)
|
921 |
+
for feature in all_features:
|
922 |
+
example_index_to_features[feature.example_index].append(feature)
|
923 |
+
|
924 |
+
unique_id_to_result = {}
|
925 |
+
for result in all_results:
|
926 |
+
unique_id_to_result[result.unique_id] = result
|
927 |
+
|
928 |
+
all_predictions = collections.OrderedDict()
|
929 |
+
all_nbest_json = collections.OrderedDict()
|
930 |
+
scores_diff_json = collections.OrderedDict()
|
931 |
+
|
932 |
+
for (example_index, example) in enumerate(all_examples):
|
933 |
+
features = example_index_to_features[example_index]
|
934 |
+
|
935 |
+
prelim_predictions = []
|
936 |
+
# keep track of the minimum score of null start+end of position 0
|
937 |
+
score_null = 1000000 # large and positive
|
938 |
+
min_null_feature_index = 0 # the paragraph slice with min mull score
|
939 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
940 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
941 |
+
for (feature_index, feature) in enumerate(features):
|
942 |
+
for ((start_idx, end_idx), logprobs) in \
|
943 |
+
result_dict[example_index][feature.unique_id].items():
|
944 |
+
start_log_prob = 0
|
945 |
+
end_log_prob = 0
|
946 |
+
for logprob in logprobs:
|
947 |
+
start_log_prob += logprob[0]
|
948 |
+
end_log_prob += logprob[1]
|
949 |
+
prelim_predictions.append(
|
950 |
+
_PrelimPrediction(
|
951 |
+
feature_index=feature_index,
|
952 |
+
start_index=start_idx,
|
953 |
+
end_index=end_idx,
|
954 |
+
start_log_prob=start_log_prob / len(logprobs),
|
955 |
+
end_log_prob=end_log_prob / len(logprobs)))
|
956 |
+
|
957 |
+
prelim_predictions = sorted(
|
958 |
+
prelim_predictions,
|
959 |
+
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
960 |
+
reverse=True)
|
961 |
+
|
962 |
+
seen_predictions = {}
|
963 |
+
nbest = []
|
964 |
+
for pred in prelim_predictions:
|
965 |
+
if len(nbest) >= n_best_size:
|
966 |
+
break
|
967 |
+
feature = features[pred.feature_index]
|
968 |
+
if pred.start_index >= 0: # this is a non-null prediction
|
969 |
+
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
970 |
+
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
971 |
+
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
972 |
+
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
973 |
+
|
974 |
+
paragraph_text = example.paragraph_text
|
975 |
+
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
976 |
+
if final_text in seen_predictions:
|
977 |
+
continue
|
978 |
+
|
979 |
+
seen_predictions[final_text] = True
|
980 |
+
else:
|
981 |
+
final_text = ""
|
982 |
+
seen_predictions[final_text] = True
|
983 |
+
|
984 |
+
nbest.append(
|
985 |
+
_NbestPrediction(
|
986 |
+
text=final_text,
|
987 |
+
start_log_prob=pred.start_log_prob,
|
988 |
+
end_log_prob=pred.end_log_prob))
|
989 |
+
|
990 |
+
# In very rare edge cases we could have no valid predictions. So we
|
991 |
+
# just create a nonce prediction in this case to avoid failure.
|
992 |
+
if not nbest:
|
993 |
+
nbest.append(
|
994 |
+
_NbestPrediction(text="empty", start_log_prob=0.0, end_log_prob=0.0))
|
995 |
+
|
996 |
+
assert len(nbest) >= 1
|
997 |
+
|
998 |
+
total_scores = []
|
999 |
+
best_non_null_entry = None
|
1000 |
+
for entry in nbest:
|
1001 |
+
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
1002 |
+
if not best_non_null_entry:
|
1003 |
+
if entry.text:
|
1004 |
+
best_non_null_entry = entry
|
1005 |
+
|
1006 |
+
probs = _compute_softmax(total_scores)
|
1007 |
+
|
1008 |
+
nbest_json = []
|
1009 |
+
for (i, entry) in enumerate(nbest):
|
1010 |
+
output = collections.OrderedDict()
|
1011 |
+
output["text"] = entry.text
|
1012 |
+
output["probability"] = probs[i]
|
1013 |
+
output["start_log_prob"] = entry.start_log_prob
|
1014 |
+
output["end_log_prob"] = entry.end_log_prob
|
1015 |
+
nbest_json.append(output)
|
1016 |
+
|
1017 |
+
assert len(nbest_json) >= 1
|
1018 |
+
|
1019 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
1020 |
+
all_nbest_json[example.qas_id] = nbest_json
|
1021 |
+
|
1022 |
+
with tf.gfile.GFile(output_prediction_file, "w") as writer:
|
1023 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
1024 |
+
|
1025 |
+
with tf.gfile.GFile(output_nbest_file, "w") as writer:
|
1026 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
1027 |
+
|
1028 |
+
return all_predictions
|
1029 |
+
|
1030 |
+
|
1031 |
+
####### following are from official SQuAD v1.1 evaluation scripts
|
1032 |
+
def normalize_answer_v1(s):
|
1033 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
1034 |
+
|
1035 |
+
def remove_articles(text):
|
1036 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
1037 |
+
|
1038 |
+
def white_space_fix(text):
|
1039 |
+
return " ".join(text.split())
|
1040 |
+
|
1041 |
+
def remove_punc(text):
|
1042 |
+
exclude = set(string.punctuation)
|
1043 |
+
return "".join(ch for ch in text if ch not in exclude)
|
1044 |
+
|
1045 |
+
def lower(text):
|
1046 |
+
return text.lower()
|
1047 |
+
|
1048 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
1049 |
+
|
1050 |
+
|
1051 |
+
def f1_score(prediction, ground_truth):
|
1052 |
+
prediction_tokens = normalize_answer_v1(prediction).split()
|
1053 |
+
ground_truth_tokens = normalize_answer_v1(ground_truth).split()
|
1054 |
+
common = (
|
1055 |
+
collections.Counter(prediction_tokens)
|
1056 |
+
& collections.Counter(ground_truth_tokens))
|
1057 |
+
num_same = sum(common.values())
|
1058 |
+
if num_same == 0:
|
1059 |
+
return 0
|
1060 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
1061 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
1062 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
1063 |
+
return f1
|
1064 |
+
|
1065 |
+
|
1066 |
+
def exact_match_score(prediction, ground_truth):
|
1067 |
+
return (normalize_answer_v1(prediction) == normalize_answer_v1(ground_truth))
|
1068 |
+
|
1069 |
+
|
1070 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
1071 |
+
scores_for_ground_truths = []
|
1072 |
+
for ground_truth in ground_truths:
|
1073 |
+
score = metric_fn(prediction, ground_truth)
|
1074 |
+
scores_for_ground_truths.append(score)
|
1075 |
+
return max(scores_for_ground_truths)
|
1076 |
+
|
1077 |
+
|
1078 |
+
def evaluate_v1(dataset, predictions):
|
1079 |
+
f1 = exact_match = total = 0
|
1080 |
+
for article in dataset:
|
1081 |
+
for paragraph in article["paragraphs"]:
|
1082 |
+
for qa in paragraph["qas"]:
|
1083 |
+
total += 1
|
1084 |
+
if qa["id"] not in predictions:
|
1085 |
+
message = ("Unanswered question " + six.ensure_str(qa["id"]) +
|
1086 |
+
" will receive score 0.")
|
1087 |
+
print(message, file=sys.stderr)
|
1088 |
+
continue
|
1089 |
+
ground_truths = [x["text"] for x in qa["answers"]]
|
1090 |
+
# ground_truths = list(map(lambda x: x["text"], qa["answers"]))
|
1091 |
+
prediction = predictions[qa["id"]]
|
1092 |
+
exact_match += metric_max_over_ground_truths(exact_match_score,
|
1093 |
+
prediction, ground_truths)
|
1094 |
+
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
1095 |
+
|
1096 |
+
exact_match = 100.0 * exact_match / total
|
1097 |
+
f1 = 100.0 * f1 / total
|
1098 |
+
|
1099 |
+
return {"exact_match": exact_match, "f1": f1}
|
1100 |
+
|
1101 |
+
####### above are from official SQuAD v1.1 evaluation scripts
|
1102 |
+
####### following are from official SQuAD v2.0 evaluation scripts
|
1103 |
+
def make_qid_to_has_ans(dataset):
|
1104 |
+
qid_to_has_ans = {}
|
1105 |
+
for article in dataset:
|
1106 |
+
for p in article['paragraphs']:
|
1107 |
+
for qa in p['qas']:
|
1108 |
+
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
1109 |
+
return qid_to_has_ans
|
1110 |
+
|
1111 |
+
def normalize_answer_v2(s):
|
1112 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
1113 |
+
def remove_articles(text):
|
1114 |
+
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
1115 |
+
return re.sub(regex, ' ', text)
|
1116 |
+
def white_space_fix(text):
|
1117 |
+
return ' '.join(text.split())
|
1118 |
+
def remove_punc(text):
|
1119 |
+
exclude = set(string.punctuation)
|
1120 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
1121 |
+
def lower(text):
|
1122 |
+
return text.lower()
|
1123 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
1124 |
+
|
1125 |
+
def get_tokens(s):
|
1126 |
+
if not s: return []
|
1127 |
+
return normalize_answer_v2(s).split()
|
1128 |
+
|
1129 |
+
def compute_exact(a_gold, a_pred):
|
1130 |
+
return int(normalize_answer_v2(a_gold) == normalize_answer_v2(a_pred))
|
1131 |
+
|
1132 |
+
def compute_f1(a_gold, a_pred):
|
1133 |
+
gold_toks = get_tokens(a_gold)
|
1134 |
+
pred_toks = get_tokens(a_pred)
|
1135 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
1136 |
+
num_same = sum(common.values())
|
1137 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
1138 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
1139 |
+
return int(gold_toks == pred_toks)
|
1140 |
+
if num_same == 0:
|
1141 |
+
return 0
|
1142 |
+
precision = 1.0 * num_same / len(pred_toks)
|
1143 |
+
recall = 1.0 * num_same / len(gold_toks)
|
1144 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
1145 |
+
return f1
|
1146 |
+
|
1147 |
+
def get_raw_scores(dataset, preds):
|
1148 |
+
exact_scores = {}
|
1149 |
+
f1_scores = {}
|
1150 |
+
for article in dataset:
|
1151 |
+
for p in article['paragraphs']:
|
1152 |
+
for qa in p['qas']:
|
1153 |
+
qid = qa['id']
|
1154 |
+
gold_answers = [a['text'] for a in qa['answers']
|
1155 |
+
if normalize_answer_v2(a['text'])]
|
1156 |
+
if not gold_answers:
|
1157 |
+
# For unanswerable questions, only correct answer is empty string
|
1158 |
+
gold_answers = ['']
|
1159 |
+
if qid not in preds:
|
1160 |
+
print('Missing prediction for %s' % qid)
|
1161 |
+
continue
|
1162 |
+
a_pred = preds[qid]
|
1163 |
+
# Take max over all gold answers
|
1164 |
+
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
1165 |
+
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
1166 |
+
return exact_scores, f1_scores
|
1167 |
+
|
1168 |
+
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
1169 |
+
new_scores = {}
|
1170 |
+
for qid, s in scores.items():
|
1171 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
1172 |
+
if pred_na:
|
1173 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
1174 |
+
else:
|
1175 |
+
new_scores[qid] = s
|
1176 |
+
return new_scores
|
1177 |
+
|
1178 |
+
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
1179 |
+
if not qid_list:
|
1180 |
+
total = len(exact_scores)
|
1181 |
+
return collections.OrderedDict([
|
1182 |
+
('exact', 100.0 * sum(exact_scores.values()) / total),
|
1183 |
+
('f1', 100.0 * sum(f1_scores.values()) / total),
|
1184 |
+
('total', total),
|
1185 |
+
])
|
1186 |
+
else:
|
1187 |
+
total = len(qid_list)
|
1188 |
+
return collections.OrderedDict([
|
1189 |
+
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
1190 |
+
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
1191 |
+
('total', total),
|
1192 |
+
])
|
1193 |
+
|
1194 |
+
|
1195 |
+
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
1196 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
1197 |
+
cur_score = num_no_ans
|
1198 |
+
best_score = cur_score
|
1199 |
+
best_thresh = 0.0
|
1200 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
1201 |
+
for i, qid in enumerate(qid_list):
|
1202 |
+
if qid not in scores: continue
|
1203 |
+
if qid_to_has_ans[qid]:
|
1204 |
+
diff = scores[qid]
|
1205 |
+
else:
|
1206 |
+
if preds[qid]:
|
1207 |
+
diff = -1
|
1208 |
+
else:
|
1209 |
+
diff = 0
|
1210 |
+
cur_score += diff
|
1211 |
+
if cur_score > best_score:
|
1212 |
+
best_score = cur_score
|
1213 |
+
best_thresh = na_probs[qid]
|
1214 |
+
return 100.0 * best_score / len(scores), best_thresh
|
1215 |
+
|
1216 |
+
|
1217 |
+
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
1218 |
+
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
1219 |
+
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
1220 |
+
main_eval['best_exact'] = best_exact
|
1221 |
+
main_eval['best_exact_thresh'] = exact_thresh
|
1222 |
+
main_eval['best_f1'] = best_f1
|
1223 |
+
main_eval['best_f1_thresh'] = f1_thresh
|
1224 |
+
|
1225 |
+
|
1226 |
+
def merge_eval(main_eval, new_eval, prefix):
|
1227 |
+
for k in new_eval:
|
1228 |
+
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
1229 |
+
|
1230 |
+
####### above are from official SQuAD v2.0 evaluation scripts
|
1231 |
+
|
1232 |
+
def accumulate_predictions_v2(result_dict, cls_dict, all_examples,
|
1233 |
+
all_features, all_results, n_best_size,
|
1234 |
+
max_answer_length, start_n_top, end_n_top):
|
1235 |
+
"""accumulate predictions for each positions in a dictionary."""
|
1236 |
+
|
1237 |
+
example_index_to_features = collections.defaultdict(list)
|
1238 |
+
for feature in all_features:
|
1239 |
+
example_index_to_features[feature.example_index].append(feature)
|
1240 |
+
|
1241 |
+
unique_id_to_result = {}
|
1242 |
+
for result in all_results:
|
1243 |
+
unique_id_to_result[result.unique_id] = result
|
1244 |
+
|
1245 |
+
all_predictions = collections.OrderedDict()
|
1246 |
+
all_nbest_json = collections.OrderedDict()
|
1247 |
+
scores_diff_json = collections.OrderedDict()
|
1248 |
+
|
1249 |
+
for (example_index, example) in enumerate(all_examples):
|
1250 |
+
if example_index not in result_dict:
|
1251 |
+
result_dict[example_index] = {}
|
1252 |
+
features = example_index_to_features[example_index]
|
1253 |
+
|
1254 |
+
prelim_predictions = []
|
1255 |
+
# keep track of the minimum score of null start+end of position 0
|
1256 |
+
score_null = 1000000 # large and positive
|
1257 |
+
|
1258 |
+
for (feature_index, feature) in enumerate(features):
|
1259 |
+
if feature.unique_id not in result_dict[example_index]:
|
1260 |
+
result_dict[example_index][feature.unique_id] = {}
|
1261 |
+
result = unique_id_to_result[feature.unique_id]
|
1262 |
+
cur_null_score = result.cls_logits
|
1263 |
+
|
1264 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
1265 |
+
score_null = min(score_null, cur_null_score)
|
1266 |
+
|
1267 |
+
doc_offset = feature.tokens.index("[SEP]") + 1
|
1268 |
+
for i in range(start_n_top):
|
1269 |
+
for j in range(end_n_top):
|
1270 |
+
start_log_prob = result.start_top_log_probs[i]
|
1271 |
+
start_index = result.start_top_index[i]
|
1272 |
+
|
1273 |
+
j_index = i * end_n_top + j
|
1274 |
+
|
1275 |
+
end_log_prob = result.end_top_log_probs[j_index]
|
1276 |
+
end_index = result.end_top_index[j_index]
|
1277 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
1278 |
+
# that the start of the span is in the question. We throw out all
|
1279 |
+
# invalid predictions.
|
1280 |
+
if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
|
1281 |
+
continue
|
1282 |
+
if start_index - doc_offset < 0:
|
1283 |
+
continue
|
1284 |
+
if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
|
1285 |
+
continue
|
1286 |
+
if not feature.token_is_max_context.get(start_index, False):
|
1287 |
+
continue
|
1288 |
+
if end_index < start_index:
|
1289 |
+
continue
|
1290 |
+
length = end_index - start_index + 1
|
1291 |
+
if length > max_answer_length:
|
1292 |
+
continue
|
1293 |
+
start_idx = start_index - doc_offset
|
1294 |
+
end_idx = end_index - doc_offset
|
1295 |
+
if (start_idx, end_idx) not in result_dict[example_index][feature.unique_id]:
|
1296 |
+
result_dict[example_index][feature.unique_id][(start_idx, end_idx)] = []
|
1297 |
+
result_dict[example_index][feature.unique_id][(start_idx, end_idx)].append((start_log_prob, end_log_prob))
|
1298 |
+
if example_index not in cls_dict:
|
1299 |
+
cls_dict[example_index] = []
|
1300 |
+
cls_dict[example_index].append(score_null)
|
1301 |
+
|
1302 |
+
|
1303 |
+
def write_predictions_v2(result_dict, cls_dict, all_examples, all_features,
|
1304 |
+
all_results, n_best_size, max_answer_length,
|
1305 |
+
output_prediction_file,
|
1306 |
+
output_nbest_file, output_null_log_odds_file,
|
1307 |
+
null_score_diff_threshold):
|
1308 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
1309 |
+
tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
|
1310 |
+
tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
|
1311 |
+
|
1312 |
+
example_index_to_features = collections.defaultdict(list)
|
1313 |
+
for feature in all_features:
|
1314 |
+
example_index_to_features[feature.example_index].append(feature)
|
1315 |
+
|
1316 |
+
unique_id_to_result = {}
|
1317 |
+
for result in all_results:
|
1318 |
+
unique_id_to_result[result.unique_id] = result
|
1319 |
+
|
1320 |
+
all_predictions = collections.OrderedDict()
|
1321 |
+
all_nbest_json = collections.OrderedDict()
|
1322 |
+
scores_diff_json = collections.OrderedDict()
|
1323 |
+
|
1324 |
+
for (example_index, example) in enumerate(all_examples):
|
1325 |
+
features = example_index_to_features[example_index]
|
1326 |
+
|
1327 |
+
prelim_predictions = []
|
1328 |
+
# keep track of the minimum score of null start+end of position 0
|
1329 |
+
# score_null = 1000000 # large and positive
|
1330 |
+
|
1331 |
+
for (feature_index, feature) in enumerate(features):
|
1332 |
+
for ((start_idx, end_idx), logprobs) in \
|
1333 |
+
result_dict[example_index][feature.unique_id].items():
|
1334 |
+
start_log_prob = 0
|
1335 |
+
end_log_prob = 0
|
1336 |
+
for logprob in logprobs:
|
1337 |
+
start_log_prob += logprob[0]
|
1338 |
+
end_log_prob += logprob[1]
|
1339 |
+
prelim_predictions.append(
|
1340 |
+
_PrelimPrediction(
|
1341 |
+
feature_index=feature_index,
|
1342 |
+
start_index=start_idx,
|
1343 |
+
end_index=end_idx,
|
1344 |
+
start_log_prob=start_log_prob / len(logprobs),
|
1345 |
+
end_log_prob=end_log_prob / len(logprobs)))
|
1346 |
+
|
1347 |
+
prelim_predictions = sorted(
|
1348 |
+
prelim_predictions,
|
1349 |
+
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
1350 |
+
reverse=True)
|
1351 |
+
|
1352 |
+
seen_predictions = {}
|
1353 |
+
nbest = []
|
1354 |
+
for pred in prelim_predictions:
|
1355 |
+
if len(nbest) >= n_best_size:
|
1356 |
+
break
|
1357 |
+
feature = features[pred.feature_index]
|
1358 |
+
|
1359 |
+
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
1360 |
+
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
1361 |
+
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
1362 |
+
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
1363 |
+
|
1364 |
+
paragraph_text = example.paragraph_text
|
1365 |
+
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
1366 |
+
|
1367 |
+
if final_text in seen_predictions:
|
1368 |
+
continue
|
1369 |
+
|
1370 |
+
seen_predictions[final_text] = True
|
1371 |
+
|
1372 |
+
nbest.append(
|
1373 |
+
_NbestPrediction(
|
1374 |
+
text=final_text,
|
1375 |
+
start_log_prob=pred.start_log_prob,
|
1376 |
+
end_log_prob=pred.end_log_prob))
|
1377 |
+
|
1378 |
+
# In very rare edge cases we could have no valid predictions. So we
|
1379 |
+
# just create a nonce prediction in this case to avoid failure.
|
1380 |
+
if not nbest:
|
1381 |
+
nbest.append(
|
1382 |
+
_NbestPrediction(
|
1383 |
+
text="",
|
1384 |
+
start_log_prob=-1e6,
|
1385 |
+
end_log_prob=-1e6))
|
1386 |
+
|
1387 |
+
total_scores = []
|
1388 |
+
best_non_null_entry = None
|
1389 |
+
for entry in nbest:
|
1390 |
+
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
1391 |
+
if not best_non_null_entry:
|
1392 |
+
best_non_null_entry = entry
|
1393 |
+
|
1394 |
+
probs = _compute_softmax(total_scores)
|
1395 |
+
|
1396 |
+
nbest_json = []
|
1397 |
+
for (i, entry) in enumerate(nbest):
|
1398 |
+
output = collections.OrderedDict()
|
1399 |
+
output["text"] = entry.text
|
1400 |
+
output["probability"] = probs[i]
|
1401 |
+
output["start_log_prob"] = entry.start_log_prob
|
1402 |
+
output["end_log_prob"] = entry.end_log_prob
|
1403 |
+
nbest_json.append(output)
|
1404 |
+
|
1405 |
+
assert len(nbest_json) >= 1
|
1406 |
+
assert best_non_null_entry is not None
|
1407 |
+
|
1408 |
+
score_diff = sum(cls_dict[example_index]) / len(cls_dict[example_index])
|
1409 |
+
scores_diff_json[example.qas_id] = score_diff
|
1410 |
+
# predict null answers when null threshold is provided
|
1411 |
+
if null_score_diff_threshold is None or score_diff < null_score_diff_threshold:
|
1412 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
1413 |
+
else:
|
1414 |
+
all_predictions[example.qas_id] = ""
|
1415 |
+
|
1416 |
+
all_nbest_json[example.qas_id] = nbest_json
|
1417 |
+
assert len(nbest_json) >= 1
|
1418 |
+
|
1419 |
+
with tf.gfile.GFile(output_prediction_file, "w") as writer:
|
1420 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
1421 |
+
|
1422 |
+
with tf.gfile.GFile(output_nbest_file, "w") as writer:
|
1423 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
1424 |
+
|
1425 |
+
with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
|
1426 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
1427 |
+
return all_predictions, scores_diff_json
|
1428 |
+
|
1429 |
+
|
1430 |
+
def create_v2_model(albert_config, is_training, input_ids, input_mask,
|
1431 |
+
segment_ids, use_one_hot_embeddings, features,
|
1432 |
+
max_seq_length, start_n_top, end_n_top, dropout_prob,
|
1433 |
+
hub_module):
|
1434 |
+
"""Creates a classification model."""
|
1435 |
+
(_, output) = fine_tuning_utils.create_albert(
|
1436 |
+
albert_config=albert_config,
|
1437 |
+
is_training=is_training,
|
1438 |
+
input_ids=input_ids,
|
1439 |
+
input_mask=input_mask,
|
1440 |
+
segment_ids=segment_ids,
|
1441 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
1442 |
+
use_einsum=True,
|
1443 |
+
hub_module=hub_module)
|
1444 |
+
|
1445 |
+
bsz = tf.shape(output)[0]
|
1446 |
+
return_dict = {}
|
1447 |
+
output = tf.transpose(output, [1, 0, 2])
|
1448 |
+
|
1449 |
+
# invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
1450 |
+
p_mask = tf.cast(features["p_mask"], dtype=tf.float32)
|
1451 |
+
|
1452 |
+
# logit of the start position
|
1453 |
+
with tf.variable_scope("start_logits"):
|
1454 |
+
start_logits = tf.layers.dense(
|
1455 |
+
output,
|
1456 |
+
1,
|
1457 |
+
kernel_initializer=modeling.create_initializer(
|
1458 |
+
albert_config.initializer_range))
|
1459 |
+
start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
|
1460 |
+
start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
|
1461 |
+
start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
|
1462 |
+
|
1463 |
+
# logit of the end position
|
1464 |
+
with tf.variable_scope("end_logits"):
|
1465 |
+
if is_training:
|
1466 |
+
# during training, compute the end logits based on the
|
1467 |
+
# ground truth of the start position
|
1468 |
+
start_positions = tf.reshape(features["start_positions"], [-1])
|
1469 |
+
start_index = tf.one_hot(start_positions, depth=max_seq_length, axis=-1,
|
1470 |
+
dtype=tf.float32)
|
1471 |
+
start_features = tf.einsum("lbh,bl->bh", output, start_index)
|
1472 |
+
start_features = tf.tile(start_features[None], [max_seq_length, 1, 1])
|
1473 |
+
end_logits = tf.layers.dense(
|
1474 |
+
tf.concat([output, start_features], axis=-1),
|
1475 |
+
albert_config.hidden_size,
|
1476 |
+
kernel_initializer=modeling.create_initializer(
|
1477 |
+
albert_config.initializer_range),
|
1478 |
+
activation=tf.tanh,
|
1479 |
+
name="dense_0")
|
1480 |
+
end_logits = contrib_layers.layer_norm(end_logits, begin_norm_axis=-1)
|
1481 |
+
|
1482 |
+
end_logits = tf.layers.dense(
|
1483 |
+
end_logits,
|
1484 |
+
1,
|
1485 |
+
kernel_initializer=modeling.create_initializer(
|
1486 |
+
albert_config.initializer_range),
|
1487 |
+
name="dense_1")
|
1488 |
+
end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
|
1489 |
+
end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
|
1490 |
+
end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
|
1491 |
+
else:
|
1492 |
+
# during inference, compute the end logits based on beam search
|
1493 |
+
|
1494 |
+
start_top_log_probs, start_top_index = tf.nn.top_k(
|
1495 |
+
start_log_probs, k=start_n_top)
|
1496 |
+
start_index = tf.one_hot(start_top_index,
|
1497 |
+
depth=max_seq_length, axis=-1, dtype=tf.float32)
|
1498 |
+
start_features = tf.einsum("lbh,bkl->bkh", output, start_index)
|
1499 |
+
end_input = tf.tile(output[:, :, None],
|
1500 |
+
[1, 1, start_n_top, 1])
|
1501 |
+
start_features = tf.tile(start_features[None],
|
1502 |
+
[max_seq_length, 1, 1, 1])
|
1503 |
+
end_input = tf.concat([end_input, start_features], axis=-1)
|
1504 |
+
end_logits = tf.layers.dense(
|
1505 |
+
end_input,
|
1506 |
+
albert_config.hidden_size,
|
1507 |
+
kernel_initializer=modeling.create_initializer(
|
1508 |
+
albert_config.initializer_range),
|
1509 |
+
activation=tf.tanh,
|
1510 |
+
name="dense_0")
|
1511 |
+
end_logits = contrib_layers.layer_norm(end_logits, begin_norm_axis=-1)
|
1512 |
+
end_logits = tf.layers.dense(
|
1513 |
+
end_logits,
|
1514 |
+
1,
|
1515 |
+
kernel_initializer=modeling.create_initializer(
|
1516 |
+
albert_config.initializer_range),
|
1517 |
+
name="dense_1")
|
1518 |
+
end_logits = tf.reshape(end_logits, [max_seq_length, -1, start_n_top])
|
1519 |
+
end_logits = tf.transpose(end_logits, [1, 2, 0])
|
1520 |
+
end_logits_masked = end_logits * (
|
1521 |
+
1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
|
1522 |
+
end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
|
1523 |
+
end_top_log_probs, end_top_index = tf.nn.top_k(
|
1524 |
+
end_log_probs, k=end_n_top)
|
1525 |
+
end_top_log_probs = tf.reshape(
|
1526 |
+
end_top_log_probs,
|
1527 |
+
[-1, start_n_top * end_n_top])
|
1528 |
+
end_top_index = tf.reshape(
|
1529 |
+
end_top_index,
|
1530 |
+
[-1, start_n_top * end_n_top])
|
1531 |
+
|
1532 |
+
if is_training:
|
1533 |
+
return_dict["start_log_probs"] = start_log_probs
|
1534 |
+
return_dict["end_log_probs"] = end_log_probs
|
1535 |
+
else:
|
1536 |
+
return_dict["start_top_log_probs"] = start_top_log_probs
|
1537 |
+
return_dict["start_top_index"] = start_top_index
|
1538 |
+
return_dict["end_top_log_probs"] = end_top_log_probs
|
1539 |
+
return_dict["end_top_index"] = end_top_index
|
1540 |
+
|
1541 |
+
# an additional layer to predict answerability
|
1542 |
+
with tf.variable_scope("answer_class"):
|
1543 |
+
# get the representation of CLS
|
1544 |
+
cls_index = tf.one_hot(tf.zeros([bsz], dtype=tf.int32),
|
1545 |
+
max_seq_length,
|
1546 |
+
axis=-1, dtype=tf.float32)
|
1547 |
+
cls_feature = tf.einsum("lbh,bl->bh", output, cls_index)
|
1548 |
+
|
1549 |
+
# get the representation of START
|
1550 |
+
start_p = tf.nn.softmax(start_logits_masked, axis=-1,
|
1551 |
+
name="softmax_start")
|
1552 |
+
start_feature = tf.einsum("lbh,bl->bh", output, start_p)
|
1553 |
+
|
1554 |
+
# note(zhiliny): no dependency on end_feature so that we can obtain
|
1555 |
+
# one single `cls_logits` for each sample
|
1556 |
+
ans_feature = tf.concat([start_feature, cls_feature], -1)
|
1557 |
+
ans_feature = tf.layers.dense(
|
1558 |
+
ans_feature,
|
1559 |
+
albert_config.hidden_size,
|
1560 |
+
activation=tf.tanh,
|
1561 |
+
kernel_initializer=modeling.create_initializer(
|
1562 |
+
albert_config.initializer_range),
|
1563 |
+
name="dense_0")
|
1564 |
+
ans_feature = tf.layers.dropout(ans_feature, dropout_prob,
|
1565 |
+
training=is_training)
|
1566 |
+
cls_logits = tf.layers.dense(
|
1567 |
+
ans_feature,
|
1568 |
+
1,
|
1569 |
+
kernel_initializer=modeling.create_initializer(
|
1570 |
+
albert_config.initializer_range),
|
1571 |
+
name="dense_1",
|
1572 |
+
use_bias=False)
|
1573 |
+
cls_logits = tf.squeeze(cls_logits, -1)
|
1574 |
+
|
1575 |
+
return_dict["cls_logits"] = cls_logits
|
1576 |
+
|
1577 |
+
return return_dict
|
1578 |
+
|
1579 |
+
|
1580 |
+
def v2_model_fn_builder(albert_config, init_checkpoint, learning_rate,
|
1581 |
+
num_train_steps, num_warmup_steps, use_tpu,
|
1582 |
+
use_one_hot_embeddings, max_seq_length, start_n_top,
|
1583 |
+
end_n_top, dropout_prob, hub_module):
|
1584 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
1585 |
+
|
1586 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
1587 |
+
"""The `model_fn` for TPUEstimator."""
|
1588 |
+
|
1589 |
+
tf.logging.info("*** Features ***")
|
1590 |
+
for name in sorted(features.keys()):
|
1591 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
1592 |
+
|
1593 |
+
# unique_ids = features["unique_ids"]
|
1594 |
+
input_ids = features["input_ids"]
|
1595 |
+
input_mask = features["input_mask"]
|
1596 |
+
segment_ids = features["segment_ids"]
|
1597 |
+
|
1598 |
+
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
1599 |
+
|
1600 |
+
outputs = create_v2_model(
|
1601 |
+
albert_config=albert_config,
|
1602 |
+
is_training=is_training,
|
1603 |
+
input_ids=input_ids,
|
1604 |
+
input_mask=input_mask,
|
1605 |
+
segment_ids=segment_ids,
|
1606 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
1607 |
+
features=features,
|
1608 |
+
max_seq_length=max_seq_length,
|
1609 |
+
start_n_top=start_n_top,
|
1610 |
+
end_n_top=end_n_top,
|
1611 |
+
dropout_prob=dropout_prob,
|
1612 |
+
hub_module=hub_module)
|
1613 |
+
|
1614 |
+
tvars = tf.trainable_variables()
|
1615 |
+
|
1616 |
+
initialized_variable_names = {}
|
1617 |
+
scaffold_fn = None
|
1618 |
+
if init_checkpoint:
|
1619 |
+
(assignment_map, initialized_variable_names
|
1620 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
1621 |
+
if use_tpu:
|
1622 |
+
|
1623 |
+
def tpu_scaffold():
|
1624 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
1625 |
+
return tf.train.Scaffold()
|
1626 |
+
|
1627 |
+
scaffold_fn = tpu_scaffold
|
1628 |
+
else:
|
1629 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
1630 |
+
|
1631 |
+
tf.logging.info("**** Trainable Variables ****")
|
1632 |
+
for var in tvars:
|
1633 |
+
init_string = ""
|
1634 |
+
if var.name in initialized_variable_names:
|
1635 |
+
init_string = ", *INIT_FROM_CKPT*"
|
1636 |
+
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
|
1637 |
+
init_string)
|
1638 |
+
|
1639 |
+
output_spec = None
|
1640 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
1641 |
+
seq_length = modeling.get_shape_list(input_ids)[1]
|
1642 |
+
|
1643 |
+
def compute_loss(log_probs, positions):
|
1644 |
+
one_hot_positions = tf.one_hot(
|
1645 |
+
positions, depth=seq_length, dtype=tf.float32)
|
1646 |
+
|
1647 |
+
loss = - tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
|
1648 |
+
loss = tf.reduce_mean(loss)
|
1649 |
+
return loss
|
1650 |
+
|
1651 |
+
start_loss = compute_loss(
|
1652 |
+
outputs["start_log_probs"], features["start_positions"])
|
1653 |
+
end_loss = compute_loss(
|
1654 |
+
outputs["end_log_probs"], features["end_positions"])
|
1655 |
+
|
1656 |
+
total_loss = (start_loss + end_loss) * 0.5
|
1657 |
+
|
1658 |
+
cls_logits = outputs["cls_logits"]
|
1659 |
+
is_impossible = tf.reshape(features["is_impossible"], [-1])
|
1660 |
+
regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
1661 |
+
labels=tf.cast(is_impossible, dtype=tf.float32), logits=cls_logits)
|
1662 |
+
regression_loss = tf.reduce_mean(regression_loss)
|
1663 |
+
|
1664 |
+
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
|
1665 |
+
# comparable to start_loss and end_loss
|
1666 |
+
total_loss += regression_loss * 0.5
|
1667 |
+
train_op = optimization.create_optimizer(
|
1668 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
1669 |
+
|
1670 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
1671 |
+
mode=mode,
|
1672 |
+
loss=total_loss,
|
1673 |
+
train_op=train_op,
|
1674 |
+
scaffold_fn=scaffold_fn)
|
1675 |
+
elif mode == tf.estimator.ModeKeys.PREDICT:
|
1676 |
+
predictions = {
|
1677 |
+
"unique_ids": features["unique_ids"],
|
1678 |
+
"start_top_index": outputs["start_top_index"],
|
1679 |
+
"start_top_log_probs": outputs["start_top_log_probs"],
|
1680 |
+
"end_top_index": outputs["end_top_index"],
|
1681 |
+
"end_top_log_probs": outputs["end_top_log_probs"],
|
1682 |
+
"cls_logits": outputs["cls_logits"]
|
1683 |
+
}
|
1684 |
+
output_spec = contrib_tpu.TPUEstimatorSpec(
|
1685 |
+
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
|
1686 |
+
else:
|
1687 |
+
raise ValueError(
|
1688 |
+
"Only TRAIN and PREDICT modes are supported: %s" % (mode))
|
1689 |
+
|
1690 |
+
return output_spec
|
1691 |
+
|
1692 |
+
return model_fn
|
1693 |
+
|
1694 |
+
|
1695 |
+
def evaluate_v2(result_dict, cls_dict, prediction_json, eval_examples,
|
1696 |
+
eval_features, all_results, n_best_size, max_answer_length,
|
1697 |
+
output_prediction_file, output_nbest_file,
|
1698 |
+
output_null_log_odds_file):
|
1699 |
+
null_score_diff_threshold = None
|
1700 |
+
predictions, na_probs = write_predictions_v2(
|
1701 |
+
result_dict, cls_dict, eval_examples, eval_features,
|
1702 |
+
all_results, n_best_size, max_answer_length,
|
1703 |
+
output_prediction_file, output_nbest_file,
|
1704 |
+
output_null_log_odds_file, null_score_diff_threshold)
|
1705 |
+
|
1706 |
+
na_prob_thresh = 1.0 # default value taken from the eval script
|
1707 |
+
qid_to_has_ans = make_qid_to_has_ans(prediction_json) # maps qid to True/False
|
1708 |
+
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
1709 |
+
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
1710 |
+
exact_raw, f1_raw = get_raw_scores(prediction_json, predictions)
|
1711 |
+
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
|
1712 |
+
na_prob_thresh)
|
1713 |
+
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
|
1714 |
+
na_prob_thresh)
|
1715 |
+
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
1716 |
+
find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
1717 |
+
null_score_diff_threshold = out_eval["best_f1_thresh"]
|
1718 |
+
|
1719 |
+
predictions, na_probs = write_predictions_v2(
|
1720 |
+
result_dict, cls_dict,eval_examples, eval_features,
|
1721 |
+
all_results, n_best_size, max_answer_length,
|
1722 |
+
output_prediction_file, output_nbest_file,
|
1723 |
+
output_null_log_odds_file, null_score_diff_threshold)
|
1724 |
+
|
1725 |
+
qid_to_has_ans = make_qid_to_has_ans(prediction_json) # maps qid to True/False
|
1726 |
+
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
1727 |
+
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
1728 |
+
exact_raw, f1_raw = get_raw_scores(prediction_json, predictions)
|
1729 |
+
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
|
1730 |
+
na_prob_thresh)
|
1731 |
+
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
|
1732 |
+
na_prob_thresh)
|
1733 |
+
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
1734 |
+
out_eval["null_score_diff_threshold"] = null_score_diff_threshold
|
1735 |
+
return out_eval
|
Indic-BERT-v1-master/albert/tokenization.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
# coding=utf-8
|
17 |
+
"""Tokenization classes."""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
|
23 |
+
import collections
|
24 |
+
import unicodedata
|
25 |
+
import six
|
26 |
+
from six.moves import range
|
27 |
+
import tensorflow.compat.v1 as tf
|
28 |
+
import tensorflow_hub as hub
|
29 |
+
import sentencepiece as spm
|
30 |
+
|
31 |
+
SPIECE_UNDERLINE = u"▁".encode("utf-8")
|
32 |
+
|
33 |
+
|
34 |
+
def preprocess_text(inputs, remove_space=True, lower=False):
|
35 |
+
"""preprocess data by removing extra space and normalize data."""
|
36 |
+
outputs = inputs
|
37 |
+
if remove_space:
|
38 |
+
outputs = " ".join(inputs.strip().split())
|
39 |
+
|
40 |
+
if six.PY2 and isinstance(outputs, str):
|
41 |
+
try:
|
42 |
+
outputs = six.ensure_text(outputs, "utf-8")
|
43 |
+
except UnicodeDecodeError:
|
44 |
+
outputs = six.ensure_text(outputs, "latin-1")
|
45 |
+
|
46 |
+
outputs = unicodedata.normalize("NFKD", outputs)
|
47 |
+
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
48 |
+
if lower:
|
49 |
+
outputs = outputs.lower()
|
50 |
+
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
|
54 |
+
def encode_pieces(sp_model, text, return_unicode=True, sample=False):
|
55 |
+
"""turn sentences into word pieces."""
|
56 |
+
|
57 |
+
if six.PY2 and isinstance(text, six.text_type):
|
58 |
+
text = six.ensure_binary(text, "utf-8")
|
59 |
+
|
60 |
+
if not sample:
|
61 |
+
pieces = sp_model.EncodeAsPieces(text)
|
62 |
+
else:
|
63 |
+
pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
64 |
+
new_pieces = []
|
65 |
+
for piece in pieces:
|
66 |
+
piece = printable_text(piece)
|
67 |
+
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
|
68 |
+
cur_pieces = sp_model.EncodeAsPieces(
|
69 |
+
six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b""))
|
70 |
+
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
71 |
+
if len(cur_pieces[0]) == 1:
|
72 |
+
cur_pieces = cur_pieces[1:]
|
73 |
+
else:
|
74 |
+
cur_pieces[0] = cur_pieces[0][1:]
|
75 |
+
cur_pieces.append(piece[-1])
|
76 |
+
new_pieces.extend(cur_pieces)
|
77 |
+
else:
|
78 |
+
new_pieces.append(piece)
|
79 |
+
|
80 |
+
# note(zhiliny): convert back to unicode for py2
|
81 |
+
if six.PY2 and return_unicode:
|
82 |
+
ret_pieces = []
|
83 |
+
for piece in new_pieces:
|
84 |
+
if isinstance(piece, str):
|
85 |
+
piece = six.ensure_text(piece, "utf-8")
|
86 |
+
ret_pieces.append(piece)
|
87 |
+
new_pieces = ret_pieces
|
88 |
+
|
89 |
+
return new_pieces
|
90 |
+
|
91 |
+
|
92 |
+
def encode_ids(sp_model, text, sample=False):
|
93 |
+
pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
|
94 |
+
ids = [sp_model.PieceToId(piece) for piece in pieces]
|
95 |
+
return ids
|
96 |
+
|
97 |
+
|
98 |
+
def convert_to_unicode(text):
|
99 |
+
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
100 |
+
if six.PY3:
|
101 |
+
if isinstance(text, str):
|
102 |
+
return text
|
103 |
+
elif isinstance(text, bytes):
|
104 |
+
return six.ensure_text(text, "utf-8", "ignore")
|
105 |
+
else:
|
106 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
107 |
+
elif six.PY2:
|
108 |
+
if isinstance(text, str):
|
109 |
+
return six.ensure_text(text, "utf-8", "ignore")
|
110 |
+
elif isinstance(text, six.text_type):
|
111 |
+
return text
|
112 |
+
else:
|
113 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
114 |
+
else:
|
115 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
116 |
+
|
117 |
+
|
118 |
+
def printable_text(text):
|
119 |
+
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
120 |
+
|
121 |
+
# These functions want `str` for both Python2 and Python3, but in one case
|
122 |
+
# it's a Unicode string and in the other it's a byte string.
|
123 |
+
if six.PY3:
|
124 |
+
if isinstance(text, str):
|
125 |
+
return text
|
126 |
+
elif isinstance(text, bytes):
|
127 |
+
return six.ensure_text(text, "utf-8", "ignore")
|
128 |
+
else:
|
129 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
130 |
+
elif six.PY2:
|
131 |
+
if isinstance(text, str):
|
132 |
+
return text
|
133 |
+
elif isinstance(text, six.text_type):
|
134 |
+
return six.ensure_binary(text, "utf-8")
|
135 |
+
else:
|
136 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
137 |
+
else:
|
138 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
139 |
+
|
140 |
+
|
141 |
+
def load_vocab(vocab_file):
|
142 |
+
"""Loads a vocabulary file into a dictionary."""
|
143 |
+
vocab = collections.OrderedDict()
|
144 |
+
with tf.gfile.GFile(vocab_file, "r") as reader:
|
145 |
+
while True:
|
146 |
+
token = convert_to_unicode(reader.readline())
|
147 |
+
if not token:
|
148 |
+
break
|
149 |
+
token = token.strip().split()[0] if token.strip() else " "
|
150 |
+
if token not in vocab:
|
151 |
+
vocab[token] = len(vocab)
|
152 |
+
return vocab
|
153 |
+
|
154 |
+
|
155 |
+
def convert_by_vocab(vocab, items):
|
156 |
+
"""Converts a sequence of [tokens|ids] using the vocab."""
|
157 |
+
output = []
|
158 |
+
for item in items:
|
159 |
+
output.append(vocab[item])
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
def convert_tokens_to_ids(vocab, tokens):
|
164 |
+
return convert_by_vocab(vocab, tokens)
|
165 |
+
|
166 |
+
|
167 |
+
def convert_ids_to_tokens(inv_vocab, ids):
|
168 |
+
return convert_by_vocab(inv_vocab, ids)
|
169 |
+
|
170 |
+
|
171 |
+
def whitespace_tokenize(text):
|
172 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
173 |
+
text = text.strip()
|
174 |
+
if not text:
|
175 |
+
return []
|
176 |
+
tokens = text.split()
|
177 |
+
return tokens
|
178 |
+
|
179 |
+
|
180 |
+
class FullTokenizer(object):
|
181 |
+
"""Runs end-to-end tokenziation."""
|
182 |
+
|
183 |
+
def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None):
|
184 |
+
self.vocab = None
|
185 |
+
self.sp_model = None
|
186 |
+
if spm_model_file:
|
187 |
+
self.sp_model = spm.SentencePieceProcessor()
|
188 |
+
tf.logging.info("loading sentence piece model")
|
189 |
+
# Handle cases where SP can't load the file, but gfile can.
|
190 |
+
sp_model_ = tf.gfile.GFile(spm_model_file, "rb").read()
|
191 |
+
self.sp_model.LoadFromSerializedProto(sp_model_)
|
192 |
+
# Note(mingdachen): For the purpose of consisent API, we are
|
193 |
+
# generating a vocabulary for the sentence piece tokenizer.
|
194 |
+
self.vocab = {self.sp_model.IdToPiece(i): i for i
|
195 |
+
in range(self.sp_model.GetPieceSize())}
|
196 |
+
else:
|
197 |
+
self.vocab = load_vocab(vocab_file)
|
198 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
199 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
200 |
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
201 |
+
|
202 |
+
@classmethod
|
203 |
+
def from_scratch(cls, vocab_file, do_lower_case, spm_model_file):
|
204 |
+
return FullTokenizer(vocab_file, do_lower_case, spm_model_file)
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def from_hub_module(cls, hub_module, use_spm=True):
|
208 |
+
"""Get the vocab file and casing info from the Hub module."""
|
209 |
+
with tf.Graph().as_default():
|
210 |
+
albert_module = hub.Module(hub_module)
|
211 |
+
tokenization_info = albert_module(signature="tokenization_info",
|
212 |
+
as_dict=True)
|
213 |
+
with tf.Session() as sess:
|
214 |
+
vocab_file, do_lower_case = sess.run(
|
215 |
+
[tokenization_info["vocab_file"],
|
216 |
+
tokenization_info["do_lower_case"]])
|
217 |
+
if use_spm:
|
218 |
+
spm_model_file = vocab_file
|
219 |
+
vocab_file = None
|
220 |
+
return FullTokenizer(
|
221 |
+
vocab_file=vocab_file, do_lower_case=do_lower_case,
|
222 |
+
spm_model_file=spm_model_file)
|
223 |
+
|
224 |
+
def tokenize(self, text):
|
225 |
+
if self.sp_model:
|
226 |
+
split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)
|
227 |
+
else:
|
228 |
+
split_tokens = []
|
229 |
+
for token in self.basic_tokenizer.tokenize(text):
|
230 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
231 |
+
split_tokens.append(sub_token)
|
232 |
+
|
233 |
+
return split_tokens
|
234 |
+
|
235 |
+
def convert_tokens_to_ids(self, tokens):
|
236 |
+
if self.sp_model:
|
237 |
+
tf.logging.info("using sentence piece tokenzier.")
|
238 |
+
return [self.sp_model.PieceToId(
|
239 |
+
printable_text(token)) for token in tokens]
|
240 |
+
else:
|
241 |
+
return convert_by_vocab(self.vocab, tokens)
|
242 |
+
|
243 |
+
def convert_ids_to_tokens(self, ids):
|
244 |
+
if self.sp_model:
|
245 |
+
tf.logging.info("using sentence piece tokenzier.")
|
246 |
+
return [self.sp_model.IdToPiece(id_) for id_ in ids]
|
247 |
+
else:
|
248 |
+
return convert_by_vocab(self.inv_vocab, ids)
|
249 |
+
|
250 |
+
|
251 |
+
class BasicTokenizer(object):
|
252 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
253 |
+
|
254 |
+
def __init__(self, do_lower_case=True):
|
255 |
+
"""Constructs a BasicTokenizer.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
do_lower_case: Whether to lower case the input.
|
259 |
+
"""
|
260 |
+
self.do_lower_case = do_lower_case
|
261 |
+
|
262 |
+
def tokenize(self, text):
|
263 |
+
"""Tokenizes a piece of text."""
|
264 |
+
text = convert_to_unicode(text)
|
265 |
+
text = self._clean_text(text)
|
266 |
+
|
267 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
268 |
+
# models. This is also applied to the English models now, but it doesn't
|
269 |
+
# matter since the English models were not trained on any Chinese data
|
270 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
271 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
272 |
+
# words in the English Wikipedia.).
|
273 |
+
text = self._tokenize_chinese_chars(text)
|
274 |
+
|
275 |
+
orig_tokens = whitespace_tokenize(text)
|
276 |
+
split_tokens = []
|
277 |
+
for token in orig_tokens:
|
278 |
+
if self.do_lower_case:
|
279 |
+
token = token.lower()
|
280 |
+
token = self._run_strip_accents(token)
|
281 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
282 |
+
|
283 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
284 |
+
return output_tokens
|
285 |
+
|
286 |
+
def _run_strip_accents(self, text):
|
287 |
+
"""Strips accents from a piece of text."""
|
288 |
+
text = unicodedata.normalize("NFD", text)
|
289 |
+
output = []
|
290 |
+
for char in text:
|
291 |
+
cat = unicodedata.category(char)
|
292 |
+
if cat == "Mn":
|
293 |
+
continue
|
294 |
+
output.append(char)
|
295 |
+
return "".join(output)
|
296 |
+
|
297 |
+
def _run_split_on_punc(self, text):
|
298 |
+
"""Splits punctuation on a piece of text."""
|
299 |
+
chars = list(text)
|
300 |
+
i = 0
|
301 |
+
start_new_word = True
|
302 |
+
output = []
|
303 |
+
while i < len(chars):
|
304 |
+
char = chars[i]
|
305 |
+
if _is_punctuation(char):
|
306 |
+
output.append([char])
|
307 |
+
start_new_word = True
|
308 |
+
else:
|
309 |
+
if start_new_word:
|
310 |
+
output.append([])
|
311 |
+
start_new_word = False
|
312 |
+
output[-1].append(char)
|
313 |
+
i += 1
|
314 |
+
|
315 |
+
return ["".join(x) for x in output]
|
316 |
+
|
317 |
+
def _tokenize_chinese_chars(self, text):
|
318 |
+
"""Adds whitespace around any CJK character."""
|
319 |
+
output = []
|
320 |
+
for char in text:
|
321 |
+
cp = ord(char)
|
322 |
+
if self._is_chinese_char(cp):
|
323 |
+
output.append(" ")
|
324 |
+
output.append(char)
|
325 |
+
output.append(" ")
|
326 |
+
else:
|
327 |
+
output.append(char)
|
328 |
+
return "".join(output)
|
329 |
+
|
330 |
+
def _is_chinese_char(self, cp):
|
331 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
332 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
333 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
334 |
+
#
|
335 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
336 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
337 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
338 |
+
# space-separated words, so they are not treated specially and handled
|
339 |
+
# like the all of the other languages.
|
340 |
+
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
341 |
+
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
342 |
+
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
343 |
+
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
344 |
+
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
345 |
+
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
346 |
+
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
347 |
+
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
348 |
+
return True
|
349 |
+
|
350 |
+
return False
|
351 |
+
|
352 |
+
def _clean_text(self, text):
|
353 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
354 |
+
output = []
|
355 |
+
for char in text:
|
356 |
+
cp = ord(char)
|
357 |
+
if cp == 0 or cp == 0xfffd or _is_control(char):
|
358 |
+
continue
|
359 |
+
if _is_whitespace(char):
|
360 |
+
output.append(" ")
|
361 |
+
else:
|
362 |
+
output.append(char)
|
363 |
+
return "".join(output)
|
364 |
+
|
365 |
+
|
366 |
+
class WordpieceTokenizer(object):
|
367 |
+
"""Runs WordPiece tokenziation."""
|
368 |
+
|
369 |
+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
370 |
+
self.vocab = vocab
|
371 |
+
self.unk_token = unk_token
|
372 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
373 |
+
|
374 |
+
def tokenize(self, text):
|
375 |
+
"""Tokenizes a piece of text into its word pieces.
|
376 |
+
|
377 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
378 |
+
using the given vocabulary.
|
379 |
+
|
380 |
+
For example:
|
381 |
+
input = "unaffable"
|
382 |
+
output = ["un", "##aff", "##able"]
|
383 |
+
|
384 |
+
Args:
|
385 |
+
text: A single token or whitespace separated tokens. This should have
|
386 |
+
already been passed through `BasicTokenizer.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
A list of wordpiece tokens.
|
390 |
+
"""
|
391 |
+
|
392 |
+
text = convert_to_unicode(text)
|
393 |
+
|
394 |
+
output_tokens = []
|
395 |
+
for token in whitespace_tokenize(text):
|
396 |
+
chars = list(token)
|
397 |
+
if len(chars) > self.max_input_chars_per_word:
|
398 |
+
output_tokens.append(self.unk_token)
|
399 |
+
continue
|
400 |
+
|
401 |
+
is_bad = False
|
402 |
+
start = 0
|
403 |
+
sub_tokens = []
|
404 |
+
while start < len(chars):
|
405 |
+
end = len(chars)
|
406 |
+
cur_substr = None
|
407 |
+
while start < end:
|
408 |
+
substr = "".join(chars[start:end])
|
409 |
+
if start > 0:
|
410 |
+
substr = "##" + six.ensure_str(substr)
|
411 |
+
if substr in self.vocab:
|
412 |
+
cur_substr = substr
|
413 |
+
break
|
414 |
+
end -= 1
|
415 |
+
if cur_substr is None:
|
416 |
+
is_bad = True
|
417 |
+
break
|
418 |
+
sub_tokens.append(cur_substr)
|
419 |
+
start = end
|
420 |
+
|
421 |
+
if is_bad:
|
422 |
+
output_tokens.append(self.unk_token)
|
423 |
+
else:
|
424 |
+
output_tokens.extend(sub_tokens)
|
425 |
+
return output_tokens
|
426 |
+
|
427 |
+
|
428 |
+
def _is_whitespace(char):
|
429 |
+
"""Checks whether `chars` is a whitespace character."""
|
430 |
+
# \t, \n, and \r are technically control characters but we treat them
|
431 |
+
# as whitespace since they are generally considered as such.
|
432 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
433 |
+
return True
|
434 |
+
cat = unicodedata.category(char)
|
435 |
+
if cat == "Zs":
|
436 |
+
return True
|
437 |
+
return False
|
438 |
+
|
439 |
+
|
440 |
+
def _is_control(char):
|
441 |
+
"""Checks whether `chars` is a control character."""
|
442 |
+
# These are technically control characters but we count them as whitespace
|
443 |
+
# characters.
|
444 |
+
if char == "\t" or char == "\n" or char == "\r":
|
445 |
+
return False
|
446 |
+
cat = unicodedata.category(char)
|
447 |
+
if cat in ("Cc", "Cf"):
|
448 |
+
return True
|
449 |
+
return False
|
450 |
+
|
451 |
+
|
452 |
+
def _is_punctuation(char):
|
453 |
+
"""Checks whether `chars` is a punctuation character."""
|
454 |
+
cp = ord(char)
|
455 |
+
# We treat all non-letter/number ASCII as punctuation.
|
456 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
457 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
458 |
+
# consistency.
|
459 |
+
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
460 |
+
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
461 |
+
return True
|
462 |
+
cat = unicodedata.category(char)
|
463 |
+
if cat.startswith("P"):
|
464 |
+
return True
|
465 |
+
return False
|
Indic-BERT-v1-master/albert/tokenization_test.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Lint as: python2, python3
|
16 |
+
from __future__ import absolute_import
|
17 |
+
from __future__ import division
|
18 |
+
from __future__ import print_function
|
19 |
+
import os
|
20 |
+
import tempfile
|
21 |
+
from albert import tokenization
|
22 |
+
import six
|
23 |
+
import tensorflow.compat.v1 as tf
|
24 |
+
|
25 |
+
|
26 |
+
class TokenizationTest(tf.test.TestCase):
|
27 |
+
|
28 |
+
def test_full_tokenizer(self):
|
29 |
+
vocab_tokens = [
|
30 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
31 |
+
"##ing", ","
|
32 |
+
]
|
33 |
+
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
34 |
+
if six.PY2:
|
35 |
+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
36 |
+
else:
|
37 |
+
contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens])
|
38 |
+
vocab_writer.write(six.ensure_binary(contents, "utf-8"))
|
39 |
+
|
40 |
+
vocab_file = vocab_writer.name
|
41 |
+
|
42 |
+
tokenizer = tokenization.FullTokenizer(vocab_file)
|
43 |
+
os.unlink(vocab_file)
|
44 |
+
|
45 |
+
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
46 |
+
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
47 |
+
|
48 |
+
self.assertAllEqual(
|
49 |
+
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
50 |
+
|
51 |
+
def test_chinese(self):
|
52 |
+
tokenizer = tokenization.BasicTokenizer()
|
53 |
+
|
54 |
+
self.assertAllEqual(
|
55 |
+
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
56 |
+
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
57 |
+
|
58 |
+
def test_basic_tokenizer_lower(self):
|
59 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
60 |
+
|
61 |
+
self.assertAllEqual(
|
62 |
+
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
63 |
+
["hello", "!", "how", "are", "you", "?"])
|
64 |
+
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
65 |
+
|
66 |
+
def test_basic_tokenizer_no_lower(self):
|
67 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
68 |
+
|
69 |
+
self.assertAllEqual(
|
70 |
+
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
71 |
+
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
72 |
+
|
73 |
+
def test_wordpiece_tokenizer(self):
|
74 |
+
vocab_tokens = [
|
75 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
76 |
+
"##ing"
|
77 |
+
]
|
78 |
+
|
79 |
+
vocab = {}
|
80 |
+
for (i, token) in enumerate(vocab_tokens):
|
81 |
+
vocab[token] = i
|
82 |
+
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
83 |
+
|
84 |
+
self.assertAllEqual(tokenizer.tokenize(""), [])
|
85 |
+
|
86 |
+
self.assertAllEqual(
|
87 |
+
tokenizer.tokenize("unwanted running"),
|
88 |
+
["un", "##want", "##ed", "runn", "##ing"])
|
89 |
+
|
90 |
+
self.assertAllEqual(
|
91 |
+
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
92 |
+
|
93 |
+
def test_convert_tokens_to_ids(self):
|
94 |
+
vocab_tokens = [
|
95 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
96 |
+
"##ing"
|
97 |
+
]
|
98 |
+
|
99 |
+
vocab = {}
|
100 |
+
for (i, token) in enumerate(vocab_tokens):
|
101 |
+
vocab[token] = i
|
102 |
+
|
103 |
+
self.assertAllEqual(
|
104 |
+
tokenization.convert_tokens_to_ids(
|
105 |
+
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
106 |
+
|
107 |
+
def test_is_whitespace(self):
|
108 |
+
self.assertTrue(tokenization._is_whitespace(u" "))
|
109 |
+
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
110 |
+
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
111 |
+
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
112 |
+
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
113 |
+
|
114 |
+
self.assertFalse(tokenization._is_whitespace(u"A"))
|
115 |
+
self.assertFalse(tokenization._is_whitespace(u"-"))
|
116 |
+
|
117 |
+
def test_is_control(self):
|
118 |
+
self.assertTrue(tokenization._is_control(u"\u0005"))
|
119 |
+
|
120 |
+
self.assertFalse(tokenization._is_control(u"A"))
|
121 |
+
self.assertFalse(tokenization._is_control(u" "))
|
122 |
+
self.assertFalse(tokenization._is_control(u"\t"))
|
123 |
+
self.assertFalse(tokenization._is_control(u"\r"))
|
124 |
+
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
|
125 |
+
|
126 |
+
def test_is_punctuation(self):
|
127 |
+
self.assertTrue(tokenization._is_punctuation(u"-"))
|
128 |
+
self.assertTrue(tokenization._is_punctuation(u"$"))
|
129 |
+
self.assertTrue(tokenization._is_punctuation(u"`"))
|
130 |
+
self.assertTrue(tokenization._is_punctuation(u"."))
|
131 |
+
|
132 |
+
self.assertFalse(tokenization._is_punctuation(u"A"))
|
133 |
+
self.assertFalse(tokenization._is_punctuation(u" "))
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
tf.test.main()
|
Indic-BERT-v1-master/albert/train.py
ADDED
File without changes
|
Indic-BERT-v1-master/configs/albert_base_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "albert",
|
3 |
+
"attention_probs_dropout_prob": 0,
|
4 |
+
"hidden_act": "gelu",
|
5 |
+
"hidden_dropout_prob": 0,
|
6 |
+
"embedding_size": 128,
|
7 |
+
"hidden_size": 768,
|
8 |
+
"initializer_range": 0.02,
|
9 |
+
"intermediate_size": 3072,
|
10 |
+
"max_position_embeddings": 512,
|
11 |
+
"num_attention_heads": 12,
|
12 |
+
"num_hidden_layers": 12,
|
13 |
+
"num_hidden_groups": 1,
|
14 |
+
"net_structure_type": 0,
|
15 |
+
"gap_size": 0,
|
16 |
+
"num_memory_blocks": 0,
|
17 |
+
"inner_group_num": 1,
|
18 |
+
"down_scale_factor": 1,
|
19 |
+
"type_vocab_size": 2,
|
20 |
+
"vocab_size": 200000
|
21 |
+
}
|
Indic-BERT-v1-master/configs/albert_large_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "albert",
|
3 |
+
"attention_probs_dropout_prob": 0,
|
4 |
+
"hidden_act": "gelu",
|
5 |
+
"hidden_dropout_prob": 0,
|
6 |
+
"embedding_size": 128,
|
7 |
+
"hidden_size": 1024,
|
8 |
+
"initializer_range": 0.02,
|
9 |
+
"intermediate_size": 4096,
|
10 |
+
"max_position_embeddings": 512,
|
11 |
+
"num_attention_heads": 16,
|
12 |
+
"num_hidden_layers": 24,
|
13 |
+
"num_hidden_groups": 1,
|
14 |
+
"net_structure_type": 0,
|
15 |
+
"gap_size": 0,
|
16 |
+
"num_memory_blocks": 0,
|
17 |
+
"inner_group_num": 1,
|
18 |
+
"down_scale_factor": 1,
|
19 |
+
"type_vocab_size": 2,
|
20 |
+
"vocab_size": 200000
|
21 |
+
}
|
Indic-BERT-v1-master/docs/advanced-usage.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Advanced Usage
|
2 |
+
|
3 |
+
Note that the following sections describe how to use the fine-tuning CLI for advanced purposes. To do this on Colab, simply use the arguments mentioned here in the `argvec` list in our [Colab notebook](https://colab.research.google.com/github/ai4bharat/indic-bert/blob/master/notebooks/finetuning.ipynb)
|
4 |
+
|
5 |
+
#### Using any Huggingface Model
|
6 |
+
|
7 |
+
```python
|
8 |
+
python3 -m fine_tune.cli --model <HF name*> --dataset <dataset name> --lang <iso lang code> --iglue_dir <base path to indic glue dir> --output_dir <output dir>
|
9 |
+
```
|
10 |
+
|
11 |
+
where HF name refers to the Huggingface shortcut name for the model. For the list of all shortcut names, refer the official docs [https://huggingface.co/transformers/pretrained_models.html](https://huggingface.co/transformers/pretrained_models.html)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
#### Loading Model from Local File
|
16 |
+
|
17 |
+
All models in the code are loaded through HF transformers library. For any model, you need the following three files:
|
18 |
+
|
19 |
+
* `config.json`: config file in HF format; check config files used by transformers, for example [here](https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_bert.py).
|
20 |
+
* `tok.model`: the tokenizer (spm, wordpiece etc.) model file.
|
21 |
+
* `pytorch_model.bin`: pytorch binary of the transformer model which stores parameters.
|
22 |
+
|
23 |
+
If you have tensorflow checkpoints instead of pytorch binary, then use the following command to first generate the pytorch binary file:
|
24 |
+
|
25 |
+
```bash
|
26 |
+
MODEL_DIR=$1
|
27 |
+
|
28 |
+
# modify model_type and filenames accordingly
|
29 |
+
transformers-cli convert --model_type albert \
|
30 |
+
--tf_checkpoint $MODEL_DIR/tf_model \
|
31 |
+
--config $MODEL_DIR/config.json \
|
32 |
+
--pytorch_dump_output $MODEL_DIR/pytorch_model.bin
|
33 |
+
```
|
34 |
+
|
35 |
+
Finally, run the evaluation using the following command:
|
36 |
+
|
37 |
+
```bash
|
38 |
+
python3 -m fine_tune.cli --model <path to the directory containing pytorch_model.bin> --tokenizer_name <path to the tokenizer file> --config_name <path to the config file> --dataset <dataset name> --lang <iso lang code> --iglue_dir <base path to indic glue dir> --output_dir <output dir>
|
39 |
+
```
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
#### Running Cross-lingual Experiments
|
44 |
+
|
45 |
+
_Add later_
|
Indic-BERT-v1-master/docs/arxiv2020_indicnlp_corpus.pdf
ADDED
Binary file (200 kB). View file
|
|
Indic-BERT-v1-master/fine_tune/__init__.py
ADDED
File without changes
|
Indic-BERT-v1-master/fine_tune/cli.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from .modules import get_modules
|
6 |
+
|
7 |
+
|
8 |
+
# For every dataset, add an entry:
|
9 |
+
# [<transformer module>, <do train?>]
|
10 |
+
ALL_DATASETS = {
|
11 |
+
'indicnlp-articles': ['text_classification', True],
|
12 |
+
'wikiann-ner': ['token_classification', True],
|
13 |
+
'wiki-cloze': ['masked_lm', False],
|
14 |
+
'wiki-section-titles': ['multiple_choice', True],
|
15 |
+
'indicnlp-articles-headlines': ['multiple_choice', True],
|
16 |
+
'cvit-mkb': ['xsent_retrieval', False],
|
17 |
+
'bbc-articles': ['text_classification', True],
|
18 |
+
'iitp-movie-reviews': ['text_classification', True],
|
19 |
+
'iitp-product-reviews': ['text_classification', True],
|
20 |
+
'soham-articles': ['text_classification', True],
|
21 |
+
'inltk-headlines': ['text_classification', True],
|
22 |
+
'actsa': ['text_classification', True],
|
23 |
+
'midas-discourse': ['text_classification', True],
|
24 |
+
'wnli-translated': ['text_classification', True],
|
25 |
+
'copa-translated': ['multiple_choice', True],
|
26 |
+
'amrita-paraphrase-exact': ['text_classification', True],
|
27 |
+
'amrita-paraphrase-fuzzy': ['text_classification', True],
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def add_generic_args(parser, root_dir):
|
32 |
+
# task-specific args START
|
33 |
+
parser.add_argument(
|
34 |
+
'--dataset',
|
35 |
+
type=str,
|
36 |
+
required=True,
|
37 |
+
help='The evaluation dataset to use'
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
'--lang',
|
42 |
+
default=None,
|
43 |
+
type=str,
|
44 |
+
required=True,
|
45 |
+
help='ISO code of test language',
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
'--train_lang',
|
49 |
+
default=None,
|
50 |
+
type=str,
|
51 |
+
help='ISO code of train language. If not specified, it is assumed to be the same as the test langauges',
|
52 |
+
)
|
53 |
+
# task-specific args END
|
54 |
+
|
55 |
+
# model structural parameters START
|
56 |
+
parser.add_argument(
|
57 |
+
'--model',
|
58 |
+
default=None,
|
59 |
+
type=str,
|
60 |
+
required=True,
|
61 |
+
help='Path to pretrained model or model identifier from huggingface.co/models',
|
62 |
+
)
|
63 |
+
|
64 |
+
parser.add_argument(
|
65 |
+
'--config_name', default='', type=str, help='Pretrained config name or path if not the same as model_name'
|
66 |
+
)
|
67 |
+
|
68 |
+
parser.add_argument(
|
69 |
+
'--tokenizer_name',
|
70 |
+
default='',
|
71 |
+
type=str,
|
72 |
+
help='Pretrained tokenizer name or path if not the same as model_name',
|
73 |
+
)
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
'--max_seq_length',
|
77 |
+
default=128,
|
78 |
+
type=int,
|
79 |
+
help='The maximum total input sequence length after tokenization. Sequences longer '
|
80 |
+
'than this will be truncated, sequences shorter will be padded.',
|
81 |
+
)
|
82 |
+
# model structural parameters END
|
83 |
+
|
84 |
+
# data I/O args START
|
85 |
+
parser.add_argument(
|
86 |
+
'--iglue_dir',
|
87 |
+
default=None,
|
88 |
+
type=str,
|
89 |
+
required=True,
|
90 |
+
help='The input data dir',
|
91 |
+
)
|
92 |
+
|
93 |
+
parser.add_argument(
|
94 |
+
'--overwrite_cache', action='store_true', help='Overwrite the cached training and evaluation sets'
|
95 |
+
)
|
96 |
+
|
97 |
+
parser.add_argument(
|
98 |
+
'--output_dir',
|
99 |
+
default=None,
|
100 |
+
type=str,
|
101 |
+
required=True,
|
102 |
+
help='The output directory where the model predictions and checkpoints will be written.',
|
103 |
+
)
|
104 |
+
|
105 |
+
parser.add_argument(
|
106 |
+
'--cache_dir',
|
107 |
+
default=None,
|
108 |
+
type=str,
|
109 |
+
help='Where do you want to store the pre-trained models downloaded from s3',
|
110 |
+
)
|
111 |
+
# data I/O args END
|
112 |
+
|
113 |
+
# model training and inference parameters START
|
114 |
+
parser.add_argument(
|
115 |
+
'--fp16',
|
116 |
+
action='store_true',
|
117 |
+
help='Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit',
|
118 |
+
)
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
'--fp16_opt_level',
|
122 |
+
type=str,
|
123 |
+
default='O1',
|
124 |
+
help='For fp16: Apex AMP optimization level selected in ["O0", "O1", "O2", and "O3"].'
|
125 |
+
'See details at https://nvidia.github.io/apex/amp.html',
|
126 |
+
)
|
127 |
+
|
128 |
+
parser.add_argument('--n_gpu', type=int, default=1)
|
129 |
+
parser.add_argument('--n_tpu_cores', type=int, default=0)
|
130 |
+
parser.add_argument('--max_grad_norm', default=1.0, type=float, help='Max gradient norm.')
|
131 |
+
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
|
132 |
+
parser.add_argument('--do_predict', action='store_true', help='Whether to run predictions on the test set.')
|
133 |
+
parser.add_argument(
|
134 |
+
'--gradient_accumulation_steps',
|
135 |
+
type=int,
|
136 |
+
default=1,
|
137 |
+
help='Number of updates steps to accumulate before performing a backward/update pass.',
|
138 |
+
)
|
139 |
+
|
140 |
+
parser.add_argument('--seed', type=int, default=2, help='random seed for initialization')
|
141 |
+
parser.add_argument('--learning_rate', default=2e-5, type=float, help='The initial learning rate for Adam.')
|
142 |
+
parser.add_argument('--weight_decay', default=0.0, type=float, help='Weight decay if we apply some.')
|
143 |
+
parser.add_argument('--adam_epsilon', default=1e-8, type=float, help='Epsilon for Adam optimizer.')
|
144 |
+
parser.add_argument('--warmup_steps', default=0, type=int, help='Linear warmup over warmup_steps.')
|
145 |
+
parser.add_argument(
|
146 |
+
'--num_train_epochs', default=3, type=int, help='Total number of training epochs to perform.'
|
147 |
+
)
|
148 |
+
parser.add_argument('--train_batch_size', default=32, type=int)
|
149 |
+
parser.add_argument('--eval_batch_size', default=32, type=int)
|
150 |
+
# model training and inference parameters END
|
151 |
+
|
152 |
+
|
153 |
+
def main(argvec=None):
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
add_generic_args(parser, os.getcwd())
|
156 |
+
for module in get_modules():
|
157 |
+
module.add_model_specific_args(parser, os.getcwd())
|
158 |
+
args = parser.parse_args(argvec)
|
159 |
+
hparams = vars(args)
|
160 |
+
|
161 |
+
# high-level command line parameters
|
162 |
+
dataset = hparams['dataset']
|
163 |
+
train_lang = hparams.get('train_lang', hparams['lang'])
|
164 |
+
test_lang = hparams['lang']
|
165 |
+
model = hparams['model']
|
166 |
+
iglue_dir = hparams['iglue_dir']
|
167 |
+
|
168 |
+
data_dir = os.path.join(iglue_dir, dataset)
|
169 |
+
output_dir = os.path.join(hparams['output_dir'], dataset,
|
170 |
+
'train-{}'.format(train_lang),
|
171 |
+
'model-{}'.format(model.replace('/', '-')))
|
172 |
+
|
173 |
+
hparams['model_name_or_path'] = hparams['model']
|
174 |
+
hparams['train_lang'] = train_lang
|
175 |
+
hparams['test_lang'] = test_lang
|
176 |
+
hparams['data_dir'] = data_dir
|
177 |
+
hparams['output_dir'] = output_dir
|
178 |
+
hparams['do_train'] = ALL_DATASETS[dataset][1]
|
179 |
+
hparams['do_predict'] = True
|
180 |
+
|
181 |
+
if dataset not in ALL_DATASETS:
|
182 |
+
print('Unrecognized dataset')
|
183 |
+
sys.exit()
|
184 |
+
|
185 |
+
os.makedirs(output_dir, exist_ok=True)
|
186 |
+
|
187 |
+
module_name = ALL_DATASETS[dataset][0]
|
188 |
+
module_class = get_modules(module_name)
|
189 |
+
module = module_class(hparams)
|
190 |
+
module.run_module()
|
191 |
+
|
192 |
+
return module
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
main()
|
Indic-BERT-v1-master/fine_tune/data/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .processors import *
|
3 |
+
|
4 |
+
|
5 |
+
PROCESSORS_TABLE = {
|
6 |
+
'indicnlp-articles-headlines': IndicNLPHeadlines,
|
7 |
+
'wiki-cloze': WikiCloze,
|
8 |
+
'indicnlp-articles': IndicNLPGenre,
|
9 |
+
'wikiann-ner': WikiNER,
|
10 |
+
'wiki-section-titles': WikiSectionTitles,
|
11 |
+
'cvit-mkb': ManKiBaat,
|
12 |
+
'actsa': ACTSA,
|
13 |
+
'bbc-articles': BBCNews,
|
14 |
+
'iitp-movie-reviews': IITPMovies,
|
15 |
+
'iitp-product-reviews': IITProducts,
|
16 |
+
'inltk-headlines': INLTKHeadlines,
|
17 |
+
'soham-articles': SohamArticles,
|
18 |
+
'midas-discourse': MidasDiscourse,
|
19 |
+
'wnli-translated': WNLI,
|
20 |
+
'copa-translated': COPA,
|
21 |
+
'amrita-paraphrase-exact': AmritaParaphraseExact,
|
22 |
+
'amrita-paraphrase-fuzzy': AmritaParaphraseFuzzy
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def load_dataset(dataset_name, data_dir):
|
27 |
+
return PROCESSORS_TABLE[dataset_name](data_dir)
|
Indic-BERT-v1-master/fine_tune/data/examples.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Optional, List, Any, Union
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class TextExample:
|
14 |
+
"""
|
15 |
+
A single training/test example for simple sequence classification.
|
16 |
+
Args:
|
17 |
+
guid: Unique id for the example.
|
18 |
+
text_a: string. The untokenized text of the first sequence. For single
|
19 |
+
sequence tasks, only this sequence must be specified.
|
20 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
21 |
+
Only must be specified for sequence pair tasks.
|
22 |
+
label: (Optional) string. The label of the example. This should be
|
23 |
+
specified for train and dev examples, but not for test examples.
|
24 |
+
"""
|
25 |
+
guid: str
|
26 |
+
text_a: str
|
27 |
+
text_b: Optional[str] = None
|
28 |
+
label: Optional[str] = None
|
29 |
+
|
30 |
+
def to_json_string(self):
|
31 |
+
"""Serializes this instance to a JSON string."""
|
32 |
+
return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass(frozen=True)
|
36 |
+
class MultipleChoiceExample:
|
37 |
+
"""
|
38 |
+
A single training/test example for multiple choice
|
39 |
+
|
40 |
+
Args:
|
41 |
+
example_id: Unique id for the example.
|
42 |
+
question: string. The untokenized text of the second sequence
|
43 |
+
(question).
|
44 |
+
contexts: list of str. The untokenized text of the first sequence
|
45 |
+
(context of corresponding question).
|
46 |
+
endings: list of str. multiple choice's options. Its length must be
|
47 |
+
equal to contexts' length.
|
48 |
+
label: (Optional) string. The label of the example. This should be
|
49 |
+
specified for train and dev examples, but not for test examples.
|
50 |
+
"""
|
51 |
+
example_id: str
|
52 |
+
question: str
|
53 |
+
contexts: List[str]
|
54 |
+
endings: List[str]
|
55 |
+
label: Optional[str]
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class TokensExample:
|
60 |
+
"""
|
61 |
+
A single training/test example for token classification.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
guid: Unique id for the example.
|
65 |
+
words: list. The words of the sequence.
|
66 |
+
labels: (Optional) list. The labels for each word of the sequence. This
|
67 |
+
should be specified for train and dev examples, but not for test
|
68 |
+
examples.
|
69 |
+
"""
|
70 |
+
guid: str
|
71 |
+
words: List[str]
|
72 |
+
labels: Optional[List[str]]
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class InputFeatures:
|
77 |
+
"""
|
78 |
+
A single set of features of data.
|
79 |
+
Property names are the same names as the corresponding inputs to a model.
|
80 |
+
"""
|
81 |
+
input_ids: Any
|
82 |
+
attention_mask: Any
|
83 |
+
token_type_ids: Any = None
|
84 |
+
label: Any = None
|
85 |
+
candidates: Any = None
|
86 |
+
example_id: str = None
|
87 |
+
|
88 |
+
|
89 |
+
def convert_multiple_choice_examples_to_features(
|
90 |
+
examples: List[MultipleChoiceExample],
|
91 |
+
tokenizer: PreTrainedTokenizer,
|
92 |
+
max_length: int,
|
93 |
+
label_list: List[str],
|
94 |
+
pad_token_segment_id=0,
|
95 |
+
pad_on_left=False,
|
96 |
+
pad_token=0,
|
97 |
+
mask_padding_with_zero=True,
|
98 |
+
) -> List[InputFeatures]:
|
99 |
+
"""
|
100 |
+
Loads a data file into a list of `InputFeatures`
|
101 |
+
"""
|
102 |
+
|
103 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
104 |
+
|
105 |
+
features = []
|
106 |
+
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
107 |
+
if ex_index % 10000 == 0:
|
108 |
+
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
109 |
+
choices_inputs = []
|
110 |
+
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
111 |
+
text_a = context
|
112 |
+
if example.question.find("_") != -1:
|
113 |
+
# this is for cloze question
|
114 |
+
text_b = example.question.replace("_", ending)
|
115 |
+
else:
|
116 |
+
text_b = example.question + " " + ending
|
117 |
+
|
118 |
+
inputs = tokenizer(
|
119 |
+
text_a,
|
120 |
+
text_b,
|
121 |
+
add_special_tokens=True,
|
122 |
+
max_length=max_length,
|
123 |
+
truncation='longest_first',
|
124 |
+
pad_to_max_length=True,
|
125 |
+
)
|
126 |
+
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
127 |
+
logger.info(
|
128 |
+
"Attention! you are cropping tokens (swag task is ok). "
|
129 |
+
"If you are training ARC and RACE and you are poping question + options,"
|
130 |
+
"you need to try to use a bigger max seq length!"
|
131 |
+
)
|
132 |
+
|
133 |
+
choices_inputs.append(inputs)
|
134 |
+
|
135 |
+
label = label_map[example.label]
|
136 |
+
|
137 |
+
input_ids = [x["input_ids"] for x in choices_inputs]
|
138 |
+
attention_mask = (
|
139 |
+
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
|
140 |
+
)
|
141 |
+
token_type_ids = (
|
142 |
+
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
|
143 |
+
)
|
144 |
+
|
145 |
+
features.append(
|
146 |
+
InputFeatures(
|
147 |
+
example_id=example.example_id,
|
148 |
+
input_ids=input_ids,
|
149 |
+
attention_mask=attention_mask,
|
150 |
+
token_type_ids=token_type_ids,
|
151 |
+
label=label,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
|
155 |
+
for f in features[:2]:
|
156 |
+
logger.info("*** Example ***")
|
157 |
+
logger.info("feature: %s" % f)
|
158 |
+
|
159 |
+
return features
|
160 |
+
|
161 |
+
|
162 |
+
def convert_tokens_examples_to_features(
|
163 |
+
examples: List[TokensExample],
|
164 |
+
label_list: List[str],
|
165 |
+
max_seq_length: int,
|
166 |
+
tokenizer: PreTrainedTokenizer,
|
167 |
+
cls_token_at_end=False,
|
168 |
+
cls_token='[CLS]',
|
169 |
+
cls_token_segment_id=1,
|
170 |
+
sep_token='[SEP]',
|
171 |
+
sep_token_extra=False,
|
172 |
+
pad_on_left=False,
|
173 |
+
pad_token=0,
|
174 |
+
pad_token_segment_id=0,
|
175 |
+
pad_token_label_id=-100,
|
176 |
+
sequence_a_segment_id=0,
|
177 |
+
mask_padding_with_zero=True,
|
178 |
+
) -> List[InputFeatures]:
|
179 |
+
""" Loads a data file into a list of `InputFeatures`
|
180 |
+
`cls_token_at_end` define the location of the CLS token:
|
181 |
+
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
182 |
+
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
183 |
+
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
184 |
+
"""
|
185 |
+
# TODO clean up all this to leverage built-in features of tokenizers
|
186 |
+
|
187 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
188 |
+
|
189 |
+
features = []
|
190 |
+
for (ex_index, example) in enumerate(examples):
|
191 |
+
if ex_index % 10_000 == 0:
|
192 |
+
logger.info("Writing example %d of %d", ex_index, len(examples))
|
193 |
+
|
194 |
+
tokens = []
|
195 |
+
label_ids = []
|
196 |
+
for word, label in zip(example.words, example.labels):
|
197 |
+
word_tokens = tokenizer.tokenize(word)
|
198 |
+
|
199 |
+
# bert-base-multilingual-cased sometimes output "nothing ([]) when calling tokenize with just a space.
|
200 |
+
if len(word_tokens) > 0:
|
201 |
+
tokens.extend(word_tokens)
|
202 |
+
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
203 |
+
label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
|
204 |
+
|
205 |
+
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
206 |
+
special_tokens_count = tokenizer.num_special_tokens_to_add()
|
207 |
+
if len(tokens) > max_seq_length - special_tokens_count:
|
208 |
+
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
209 |
+
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
|
210 |
+
|
211 |
+
# The convention in BERT is:
|
212 |
+
# (a) For sequence pairs:
|
213 |
+
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
214 |
+
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
215 |
+
# (b) For single sequences:
|
216 |
+
# tokens: [CLS] the dog is hairy . [SEP]
|
217 |
+
# type_ids: 0 0 0 0 0 0 0
|
218 |
+
#
|
219 |
+
# Where "type_ids" are used to indicate whether this is the first
|
220 |
+
# sequence or the second sequence. The embedding vectors for `type=0` and
|
221 |
+
# `type=1` were learned during pre-training and are added to the wordpiece
|
222 |
+
# embedding vector (and position vector). This is not *strictly* necessary
|
223 |
+
# since the [SEP] token unambiguously separates the sequences, but it makes
|
224 |
+
# it easier for the model to learn the concept of sequences.
|
225 |
+
#
|
226 |
+
# For classification tasks, the first vector (corresponding to [CLS]) is
|
227 |
+
# used as as the "sentence vector". Note that this only makes sense because
|
228 |
+
# the entire model is fine-tuned.
|
229 |
+
tokens += [sep_token]
|
230 |
+
label_ids += [pad_token_label_id]
|
231 |
+
if sep_token_extra:
|
232 |
+
# roberta uses an extra separator b/w pairs of sentences
|
233 |
+
tokens += [sep_token]
|
234 |
+
label_ids += [pad_token_label_id]
|
235 |
+
segment_ids = [sequence_a_segment_id] * len(tokens)
|
236 |
+
|
237 |
+
if cls_token_at_end:
|
238 |
+
tokens += [cls_token]
|
239 |
+
label_ids += [pad_token_label_id]
|
240 |
+
segment_ids += [cls_token_segment_id]
|
241 |
+
else:
|
242 |
+
tokens = [cls_token] + tokens
|
243 |
+
label_ids = [pad_token_label_id] + label_ids
|
244 |
+
segment_ids = [cls_token_segment_id] + segment_ids
|
245 |
+
|
246 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
247 |
+
|
248 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
249 |
+
# tokens are attended to.
|
250 |
+
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
251 |
+
|
252 |
+
# Zero-pad up to the sequence length.
|
253 |
+
padding_length = max_seq_length - len(input_ids)
|
254 |
+
if pad_on_left:
|
255 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
256 |
+
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
257 |
+
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
258 |
+
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
259 |
+
else:
|
260 |
+
input_ids += [pad_token] * padding_length
|
261 |
+
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
262 |
+
segment_ids += [pad_token_segment_id] * padding_length
|
263 |
+
label_ids += [pad_token_label_id] * padding_length
|
264 |
+
|
265 |
+
assert len(input_ids) == max_seq_length
|
266 |
+
assert len(input_mask) == max_seq_length
|
267 |
+
assert len(segment_ids) == max_seq_length
|
268 |
+
assert len(label_ids) == max_seq_length
|
269 |
+
|
270 |
+
if ex_index < 5:
|
271 |
+
logger.info("*** Example ***")
|
272 |
+
logger.info("guid: %s", example.guid)
|
273 |
+
logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
|
274 |
+
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
275 |
+
logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
276 |
+
logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
277 |
+
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
278 |
+
|
279 |
+
if "token_type_ids" not in tokenizer.model_input_names:
|
280 |
+
segment_ids = None
|
281 |
+
|
282 |
+
features.append(
|
283 |
+
InputFeatures(
|
284 |
+
input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label=label_ids
|
285 |
+
)
|
286 |
+
)
|
287 |
+
return features
|
288 |
+
|
289 |
+
|
290 |
+
def convert_text_examples_to_features(
|
291 |
+
examples: List[TextExample],
|
292 |
+
tokenizer: PreTrainedTokenizer,
|
293 |
+
max_length: Optional[int] = None,
|
294 |
+
label_list=None,
|
295 |
+
output_mode=None,
|
296 |
+
):
|
297 |
+
if max_length is None:
|
298 |
+
max_length = tokenizer.model_max_length
|
299 |
+
|
300 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
301 |
+
|
302 |
+
def label_from_example(example: TextExample) -> Union[int, float, None]:
|
303 |
+
if example.label is None:
|
304 |
+
return None
|
305 |
+
if output_mode == "classification":
|
306 |
+
return label_map[example.label]
|
307 |
+
elif output_mode == "regression":
|
308 |
+
return float(example.label)
|
309 |
+
raise KeyError(output_mode)
|
310 |
+
|
311 |
+
labels = [label_from_example(example) for example in examples]
|
312 |
+
|
313 |
+
batch_encoding = tokenizer(
|
314 |
+
[example.text_a if example.text_b is None else (example.text_a, example.text_b) for example in examples],
|
315 |
+
max_length=max_length,
|
316 |
+
padding="max_length",
|
317 |
+
truncation=True,
|
318 |
+
)
|
319 |
+
|
320 |
+
features = []
|
321 |
+
for i in range(len(examples)):
|
322 |
+
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
|
323 |
+
|
324 |
+
feature = InputFeatures(**inputs, label=labels[i])
|
325 |
+
features.append(feature)
|
326 |
+
|
327 |
+
return features
|
Indic-BERT-v1-master/fine_tune/data/processors.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import csv
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from .examples import MultipleChoiceExample, TextExample, TokensExample
|
7 |
+
|
8 |
+
|
9 |
+
class DataProcessor:
|
10 |
+
"""Base class for data converters for sequence classification data sets."""
|
11 |
+
|
12 |
+
def __init__(self, data_dir):
|
13 |
+
self.data_dir = data_dir
|
14 |
+
|
15 |
+
def get_examples(self, lang, mode):
|
16 |
+
if mode == 'train':
|
17 |
+
return self.get_train_examples(lang)
|
18 |
+
elif mode == 'dev':
|
19 |
+
return self.get_dev_examples(lang)
|
20 |
+
elif mode == 'test':
|
21 |
+
return self.get_test_examples(lang)
|
22 |
+
|
23 |
+
def modes(self):
|
24 |
+
return ['train', 'dev', 'test']
|
25 |
+
|
26 |
+
def get_train_examples(self, lang):
|
27 |
+
"""Gets a collection of :class:`InputExample` for the train set."""
|
28 |
+
raise NotImplementedError()
|
29 |
+
|
30 |
+
def get_dev_examples(self, lang):
|
31 |
+
"""Gets a collection of :class:`InputExample` for the dev set."""
|
32 |
+
raise NotImplementedError()
|
33 |
+
|
34 |
+
def get_test_examples(self, lang):
|
35 |
+
"""Gets a collection of :class:`InputExample` for the test set."""
|
36 |
+
raise NotImplementedError()
|
37 |
+
|
38 |
+
def get_labels(self, lang):
|
39 |
+
"""Gets the list of labels for this data set."""
|
40 |
+
raise NotImplementedError()
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def read_csv(cls, input_file, quotechar=None):
|
44 |
+
"""Reads a tab separated value file."""
|
45 |
+
with open(input_file, encoding='utf-8') as fp:
|
46 |
+
return list(csv.reader(fp, delimiter=','))
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def read_json(cls, input_file):
|
50 |
+
"""Reads a json file file."""
|
51 |
+
with open(input_file, encoding='utf-8') as fp:
|
52 |
+
return json.load(fp)
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def readlines(cls, filepath):
|
56 |
+
with open(filepath, encoding='utf-8') as fp:
|
57 |
+
return fp.readlines()
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def read_jsonl(cls, filepath):
|
61 |
+
with open(filepath, 'r', encoding='utf-8') as fp:
|
62 |
+
data = fp.readlines()
|
63 |
+
data = list(map(lambda l: json.loads(l), data))
|
64 |
+
return data
|
65 |
+
|
66 |
+
|
67 |
+
class IndicNLPHeadlines(DataProcessor):
|
68 |
+
"""Processor for the Headline Predction dataset"""
|
69 |
+
|
70 |
+
def __init__(self, data_dir):
|
71 |
+
self.data_dir = data_dir
|
72 |
+
|
73 |
+
def get_train_examples(self, lang):
|
74 |
+
"""See base class."""
|
75 |
+
fname = '{}/{}-train.json'.format(lang, lang)
|
76 |
+
fpath = os.path.join(self.data_dir, fname)
|
77 |
+
return self._create_examples(self.read_json(fpath), 'train')
|
78 |
+
|
79 |
+
def get_dev_examples(self, lang):
|
80 |
+
'''See base class.'''
|
81 |
+
fname = '{}/{}-valid.json'.format(lang, lang)
|
82 |
+
fpath = os.path.join(self.data_dir, fname)
|
83 |
+
return self._create_examples(self.read_json(fpath), 'dev')
|
84 |
+
|
85 |
+
def get_test_examples(self, lang):
|
86 |
+
'''See base class.'''
|
87 |
+
fname = '{}/{}-test.json'.format(lang, lang)
|
88 |
+
fpath = os.path.join(self.data_dir, fname)
|
89 |
+
return self._create_examples(self.read_json(fpath), 'test')
|
90 |
+
|
91 |
+
def get_labels(self, lang):
|
92 |
+
"""See base class."""
|
93 |
+
return ['A', 'B', 'C', 'D']
|
94 |
+
|
95 |
+
def _create_examples(self, items, set_type):
|
96 |
+
"""Creates examples for the training and dev sets."""
|
97 |
+
examples = [
|
98 |
+
MultipleChoiceExample(
|
99 |
+
example_id=idx,
|
100 |
+
question='',
|
101 |
+
contexts=[item['content'], item['content'], item['content'],
|
102 |
+
item['content']],
|
103 |
+
endings=[item['optionA'], item['optionB'], item['optionC'],
|
104 |
+
item['optionD']],
|
105 |
+
label=item['correctOption'],
|
106 |
+
)
|
107 |
+
for idx, item in enumerate(items)
|
108 |
+
]
|
109 |
+
return examples
|
110 |
+
|
111 |
+
|
112 |
+
class WikiCloze(DataProcessor):
|
113 |
+
"""Processor for Wiki Cloze QA dataset"""
|
114 |
+
|
115 |
+
def __init__(self, data_dir):
|
116 |
+
self.data_dir = data_dir
|
117 |
+
|
118 |
+
def modes(self):
|
119 |
+
return ['test']
|
120 |
+
|
121 |
+
def get_test_examples(self, lang):
|
122 |
+
"""See base class."""
|
123 |
+
fname = '{}.json'.format(lang, lang)
|
124 |
+
fpath = os.path.join(self.data_dir, fname)
|
125 |
+
return self._create_examples(self.read_json(fpath)['cloze_data'], 'test')
|
126 |
+
|
127 |
+
def get_labels(self, lang):
|
128 |
+
"""See base class."""
|
129 |
+
return list(range(4))
|
130 |
+
|
131 |
+
def _create_examples(self, items, set_type):
|
132 |
+
"""Creates examples for the training and dev sets."""
|
133 |
+
examples = []
|
134 |
+
for (i, item) in enumerate(items):
|
135 |
+
if '' in [option.strip() for option in item['options']]:
|
136 |
+
continue
|
137 |
+
example = MultipleChoiceExample(
|
138 |
+
example_id=i,
|
139 |
+
question=item['question'].replace('<MASK>', '[MASK]'),
|
140 |
+
contexts=[],
|
141 |
+
endings=item['options'],
|
142 |
+
label=item['options'].index(item['answer'])
|
143 |
+
)
|
144 |
+
examples.append(example)
|
145 |
+
return examples
|
146 |
+
|
147 |
+
|
148 |
+
class IndicNLPGenre(DataProcessor):
|
149 |
+
"""Processor for the Article Genre Classification data set"""
|
150 |
+
|
151 |
+
def __init__(self, data_dir):
|
152 |
+
self.data_dir = data_dir
|
153 |
+
|
154 |
+
def get_train_examples(self, lang):
|
155 |
+
"""See base class."""
|
156 |
+
fname = '{}/{}-train.csv'.format(lang, lang)
|
157 |
+
fpath = os.path.join(self.data_dir, fname)
|
158 |
+
return self._create_examples(self.read_csv(fpath), 'train')
|
159 |
+
|
160 |
+
def get_dev_examples(self, lang):
|
161 |
+
"""See base class."""
|
162 |
+
fname = '{}/{}-valid.csv'.format(lang, lang)
|
163 |
+
fpath = os.path.join(self.data_dir, fname)
|
164 |
+
return self._create_examples(self.read_csv(fpath), 'dev')
|
165 |
+
|
166 |
+
def get_test_examples(self, lang):
|
167 |
+
fname = '{}/{}-test.csv'.format(lang, lang)
|
168 |
+
fpath = os.path.join(self.data_dir, fname)
|
169 |
+
return self._create_examples(self.read_csv(fpath), 'test')
|
170 |
+
|
171 |
+
def get_labels(self, lang):
|
172 |
+
"""See base class."""
|
173 |
+
filename = '{}/{}-train.csv'.format(lang, lang)
|
174 |
+
lines = self.read_csv(os.path.join(self.data_dir, filename))
|
175 |
+
labels = map(lambda l: l[0], lines)
|
176 |
+
labels = list(set(labels))
|
177 |
+
return labels
|
178 |
+
|
179 |
+
def _create_examples(self, lines, set_type):
|
180 |
+
"""Creates examples for the training and dev sets."""
|
181 |
+
examples = []
|
182 |
+
for (i, line) in enumerate(lines):
|
183 |
+
example = TextExample(
|
184 |
+
guid=('%s-%s' % (set_type, i)),
|
185 |
+
text_a=line[1],
|
186 |
+
label=line[0]
|
187 |
+
)
|
188 |
+
examples.append(example)
|
189 |
+
return examples
|
190 |
+
|
191 |
+
|
192 |
+
class WikiNER(DataProcessor):
|
193 |
+
|
194 |
+
def __init__(self, data_dir):
|
195 |
+
self.data_dir = data_dir
|
196 |
+
|
197 |
+
def get_examples(self, lang, mode):
|
198 |
+
mode = 'valid' if mode == 'dev' else mode
|
199 |
+
file_path = os.path.join(self.data_dir, lang, f'{mode}.txt')
|
200 |
+
guid_index = 1
|
201 |
+
examples = []
|
202 |
+
with open(file_path, encoding='utf-8') as f:
|
203 |
+
words = []
|
204 |
+
labels = []
|
205 |
+
for line in f:
|
206 |
+
if line.startswith('-DOCSTART-') or line == '' or line == '\n':
|
207 |
+
if words:
|
208 |
+
example = TokensExample(
|
209 |
+
guid=f'{mode}-{guid_index}',
|
210 |
+
words=words,
|
211 |
+
labels=labels
|
212 |
+
)
|
213 |
+
examples.append(example)
|
214 |
+
guid_index += 1
|
215 |
+
words = []
|
216 |
+
labels = []
|
217 |
+
else:
|
218 |
+
splits = line.split(' ')
|
219 |
+
words.append(splits[0])
|
220 |
+
if len(splits) > 1:
|
221 |
+
labels.append(splits[-1].replace('\n', ''))
|
222 |
+
else:
|
223 |
+
# Examples could have no label for mode = 'test'
|
224 |
+
labels.append('O')
|
225 |
+
if words:
|
226 |
+
example = TokensExample(
|
227 |
+
guid=f'{mode}-{guid_index}',
|
228 |
+
words=words,
|
229 |
+
labels=labels
|
230 |
+
)
|
231 |
+
examples.append(example)
|
232 |
+
return examples
|
233 |
+
|
234 |
+
def get_labels(self, lang):
|
235 |
+
path = os.path.join(self.data_dir, lang, 'labels.txt')
|
236 |
+
with open(path, 'r') as f:
|
237 |
+
labels = f.read().splitlines()
|
238 |
+
if 'O' not in labels:
|
239 |
+
labels = ['O'] + labels
|
240 |
+
return labels
|
241 |
+
|
242 |
+
|
243 |
+
class WikiSectionTitles(DataProcessor):
|
244 |
+
"""Processor for the Wikipedia Section Title Prediction dataset"""
|
245 |
+
|
246 |
+
def __init__(self, data_dir):
|
247 |
+
self.data_dir = data_dir
|
248 |
+
|
249 |
+
def get_train_examples(self, lang):
|
250 |
+
"""See base class."""
|
251 |
+
fname = '{}/{}-train.json'.format(lang, lang)
|
252 |
+
fpath = os.path.join(self.data_dir, fname)
|
253 |
+
return self._create_examples(self.read_json(fpath), 'train')
|
254 |
+
|
255 |
+
def get_dev_examples(self, lang):
|
256 |
+
"""See base class."""
|
257 |
+
fname = '{}/{}-valid.json'.format(lang, lang)
|
258 |
+
fpath = os.path.join(self.data_dir, fname)
|
259 |
+
return self._create_examples(self.read_json(fpath), 'dev')
|
260 |
+
|
261 |
+
def get_test_examples(self, lang):
|
262 |
+
"""See base class."""
|
263 |
+
fname = '{}/{}-test.json'.format(lang, lang)
|
264 |
+
fpath = os.path.join(self.data_dir, fname)
|
265 |
+
return self._create_examples(self.read_json(fpath), 'test')
|
266 |
+
|
267 |
+
def get_labels(self, lang):
|
268 |
+
"""See base class."""
|
269 |
+
return ['titleA', 'titleB', 'titleC', 'titleD']
|
270 |
+
|
271 |
+
def _create_examples(self, items, set_type):
|
272 |
+
"""Creates examples for the training and dev sets."""
|
273 |
+
examples = [
|
274 |
+
MultipleChoiceExample(
|
275 |
+
example_id=idx,
|
276 |
+
question='',
|
277 |
+
contexts=[item['sectionText'], item['sectionText'],
|
278 |
+
item['sectionText'], item['sectionText']],
|
279 |
+
endings=[item['titleA'], item['titleB'], item['titleC'],
|
280 |
+
item['titleD']],
|
281 |
+
label=item['correctTitle'],
|
282 |
+
)
|
283 |
+
for idx, item in enumerate(items)
|
284 |
+
]
|
285 |
+
return examples
|
286 |
+
|
287 |
+
|
288 |
+
class ManKiBaat(DataProcessor):
|
289 |
+
"""Processor for Man ki Baat dataset"""
|
290 |
+
|
291 |
+
def __init__(self, data_dir):
|
292 |
+
self.data_dir = data_dir
|
293 |
+
|
294 |
+
def modes(self):
|
295 |
+
return ['en', 'in']
|
296 |
+
|
297 |
+
def get_examples(self, lang, mode):
|
298 |
+
if mode == 'en':
|
299 |
+
return self.get_examples_en(lang)
|
300 |
+
elif mode == 'in':
|
301 |
+
return self.get_examples_in(lang)
|
302 |
+
|
303 |
+
def get_examples_en(self, lang):
|
304 |
+
"""Get examples of English language"""
|
305 |
+
fname = 'en-{}/mkb.en'.format(lang)
|
306 |
+
fpath = os.path.join(self.data_dir, fname)
|
307 |
+
return self._create_examples(self.readlines(fpath), 'en')
|
308 |
+
|
309 |
+
def get_examples_in(self, lang):
|
310 |
+
"""Get examples of the Indian language"""
|
311 |
+
fname = 'en-{}/mkb.{}'.format(lang, lang)
|
312 |
+
fpath = os.path.join(self.data_dir, fname)
|
313 |
+
return self._create_examples(self.readlines(fpath), 'in')
|
314 |
+
|
315 |
+
def _create_examples(self, lines, set_type):
|
316 |
+
"""Creates examples for the training and dev sets."""
|
317 |
+
examples = []
|
318 |
+
for (i, line) in enumerate(lines):
|
319 |
+
example = TextExample(
|
320 |
+
guid=('%s-%s' % (set_type, i)),
|
321 |
+
text_a=line,
|
322 |
+
label=i
|
323 |
+
)
|
324 |
+
examples.append(example)
|
325 |
+
return examples
|
326 |
+
|
327 |
+
def get_labels(self, lang):
|
328 |
+
# return dummy value greater than number of examples
|
329 |
+
return list(range(10000))
|
330 |
+
|
331 |
+
|
332 |
+
class ACTSA(IndicNLPGenre):
|
333 |
+
pass
|
334 |
+
|
335 |
+
|
336 |
+
class BBCNews(IndicNLPGenre):
|
337 |
+
|
338 |
+
def get_dev_examples(self, lang):
|
339 |
+
"""See base class."""
|
340 |
+
fname = '{}/{}-test.csv'.format(lang, lang)
|
341 |
+
fpath = os.path.join(self.data_dir, fname)
|
342 |
+
return self._create_examples(self.read_csv(fpath), 'dev')
|
343 |
+
|
344 |
+
|
345 |
+
class INLTKHeadlines(IndicNLPGenre):
|
346 |
+
pass
|
347 |
+
|
348 |
+
|
349 |
+
class SohamArticles(IndicNLPGenre):
|
350 |
+
pass
|
351 |
+
|
352 |
+
|
353 |
+
class IITPMovies(IndicNLPGenre):
|
354 |
+
pass
|
355 |
+
|
356 |
+
|
357 |
+
class IITProducts(IndicNLPGenre):
|
358 |
+
pass
|
359 |
+
|
360 |
+
|
361 |
+
class AmritaParaphraseExact(IndicNLPGenre):
|
362 |
+
|
363 |
+
def get_dev_examples(self, lang):
|
364 |
+
"""See base class."""
|
365 |
+
fname = '{}/{}-test.csv'.format(lang, lang)
|
366 |
+
fpath = os.path.join(self.data_dir, fname)
|
367 |
+
return self._create_examples(self.read_csv(fpath), 'dev')
|
368 |
+
|
369 |
+
def get_labels(self, lang):
|
370 |
+
"""See base class."""
|
371 |
+
filename = '{}/{}-train.csv'.format(lang, lang)
|
372 |
+
lines = self.read_csv(os.path.join(self.data_dir, filename))
|
373 |
+
labels = map(lambda l: l[2], lines)
|
374 |
+
labels = list(set(labels))
|
375 |
+
return labels
|
376 |
+
|
377 |
+
def _create_examples(self, lines, set_type):
|
378 |
+
"""Creates examples for the training and dev sets."""
|
379 |
+
examples = []
|
380 |
+
for (i, line) in enumerate(lines):
|
381 |
+
example = TextExample(
|
382 |
+
guid=('%s-%s' % (set_type, i)),
|
383 |
+
text_a=line[0],
|
384 |
+
text_b=line[1],
|
385 |
+
label=line[2]
|
386 |
+
)
|
387 |
+
examples.append(example)
|
388 |
+
return examples
|
389 |
+
|
390 |
+
|
391 |
+
class AmritaParaphraseFuzzy(AmritaParaphraseExact):
|
392 |
+
pass
|
393 |
+
|
394 |
+
|
395 |
+
class MidasDiscourse(DataProcessor):
|
396 |
+
"""Processor for the Article Genre Classification data set"""
|
397 |
+
|
398 |
+
def __init__(self, data_dir):
|
399 |
+
self.data_dir = data_dir
|
400 |
+
|
401 |
+
def get_train_examples(self, lang):
|
402 |
+
"""See base class."""
|
403 |
+
fname = '{}/train.json'.format(lang, lang)
|
404 |
+
fpath = os.path.join(self.data_dir, fname)
|
405 |
+
return self._create_examples(self.read_json(fpath), 'train')
|
406 |
+
|
407 |
+
def get_dev_examples(self, lang):
|
408 |
+
"""See base class."""
|
409 |
+
fname = '{}/val.json'.format(lang, lang)
|
410 |
+
fpath = os.path.join(self.data_dir, fname)
|
411 |
+
return self._create_examples(self.read_json(fpath), 'dev')
|
412 |
+
|
413 |
+
def get_test_examples(self, lang):
|
414 |
+
fname = '{}/test.json'.format(lang, lang)
|
415 |
+
fpath = os.path.join(self.data_dir, fname)
|
416 |
+
return self._create_examples(self.read_json(fpath), 'test')
|
417 |
+
|
418 |
+
def get_labels(self, lang):
|
419 |
+
"""See base class."""
|
420 |
+
filename = '{}/train.json'.format(lang, lang)
|
421 |
+
lines = self.read_json(os.path.join(self.data_dir, filename))
|
422 |
+
labels = map(lambda l: l['Discourse Mode'], lines)
|
423 |
+
labels = list(set(labels))
|
424 |
+
return labels
|
425 |
+
|
426 |
+
def _create_examples(self, lines, set_type):
|
427 |
+
"""Creates examples for the training and dev sets."""
|
428 |
+
examples = []
|
429 |
+
for (i, line) in enumerate(lines):
|
430 |
+
example = TextExample(
|
431 |
+
guid=('%s-%s' % (set_type, i)),
|
432 |
+
text_a=line['Sentence'],
|
433 |
+
label=line['Discourse Mode']
|
434 |
+
)
|
435 |
+
examples.append(example)
|
436 |
+
return examples
|
437 |
+
|
438 |
+
|
439 |
+
class WNLI(DataProcessor):
|
440 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
441 |
+
|
442 |
+
def __init__(self, data_dir):
|
443 |
+
self.data_dir = data_dir
|
444 |
+
|
445 |
+
def get_train_examples(self, lang):
|
446 |
+
"""See base class."""
|
447 |
+
fname = '{}/train.csv'.format(lang)
|
448 |
+
fpath = os.path.join(self.data_dir, fname)
|
449 |
+
return self._create_examples(self.read_csv(fpath), 'train')
|
450 |
+
|
451 |
+
def get_dev_examples(self, lang):
|
452 |
+
"""See base class."""
|
453 |
+
fname = '{}/dev.csv'.format(lang)
|
454 |
+
fpath = os.path.join(self.data_dir, fname)
|
455 |
+
return self._create_examples(self.read_csv(fpath), 'dev')
|
456 |
+
|
457 |
+
def get_test_examples(self, lang):
|
458 |
+
"""See base class."""
|
459 |
+
fname = '{}/dev.csv'.format(lang)
|
460 |
+
fpath = os.path.join(self.data_dir, fname)
|
461 |
+
return self._create_examples(self.read_csv(fpath), 'test')
|
462 |
+
|
463 |
+
def get_labels(self, lang):
|
464 |
+
"""See base class."""
|
465 |
+
return ['0', '1']
|
466 |
+
|
467 |
+
def _create_examples(self, lines, set_type):
|
468 |
+
"""Creates examples for the training, dev and test sets."""
|
469 |
+
examples = []
|
470 |
+
for (i, line) in enumerate(lines):
|
471 |
+
if i == 0:
|
472 |
+
continue
|
473 |
+
guid = "%s-%s" % (set_type, line[0])
|
474 |
+
text_a = line[1]
|
475 |
+
text_b = line[2]
|
476 |
+
label = line[-1]
|
477 |
+
examples.append(TextExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
478 |
+
return examples
|
479 |
+
|
480 |
+
|
481 |
+
class COPA(DataProcessor):
|
482 |
+
"""Processor for the Wikipedia Section Title Prediction dataset"""
|
483 |
+
|
484 |
+
def __init__(self, data_dir):
|
485 |
+
self.data_dir = data_dir
|
486 |
+
|
487 |
+
def get_train_examples(self, lang):
|
488 |
+
"""See base class."""
|
489 |
+
fname = '{}/train.jsonl'.format(lang)
|
490 |
+
fpath = os.path.join(self.data_dir, fname)
|
491 |
+
return self._create_examples(self.read_jsonl(fpath), 'train')
|
492 |
+
|
493 |
+
def get_dev_examples(self, lang):
|
494 |
+
"""See base class."""
|
495 |
+
fname = '{}/val.jsonl'.format(lang)
|
496 |
+
fpath = os.path.join(self.data_dir, fname)
|
497 |
+
return self._create_examples(self.read_jsonl(fpath), 'dev')
|
498 |
+
|
499 |
+
def get_test_examples(self, lang):
|
500 |
+
"""See base class."""
|
501 |
+
fname = '{}/val.jsonl'.format(lang, lang)
|
502 |
+
fpath = os.path.join(self.data_dir, fname)
|
503 |
+
return self._create_examples(self.read_jsonl(fpath), 'test')
|
504 |
+
|
505 |
+
def get_labels(self, lang):
|
506 |
+
"""See base class."""
|
507 |
+
return [0, 1]
|
508 |
+
|
509 |
+
def _create_examples(self, items, set_type):
|
510 |
+
"""Creates examples for the training and dev sets."""
|
511 |
+
examples = [
|
512 |
+
MultipleChoiceExample(
|
513 |
+
example_id=idx,
|
514 |
+
question='',
|
515 |
+
contexts=[item['premise'], item['premise']],
|
516 |
+
endings=[item['choice1'], item['choice2']],
|
517 |
+
label=item['label'],
|
518 |
+
)
|
519 |
+
for idx, item in enumerate(items)
|
520 |
+
]
|
521 |
+
return examples
|
Indic-BERT-v1-master/fine_tune/modules/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from .masked_lm import MaskedLM
|
4 |
+
from .multiple_choice import MultipleChoice
|
5 |
+
from .text_classification import TextClassification
|
6 |
+
from .token_classification import TokenClassification
|
7 |
+
from .xsent_retrieval import XSentRetrieval
|
8 |
+
|
9 |
+
|
10 |
+
modules = {
|
11 |
+
'masked_lm': MaskedLM,
|
12 |
+
'multiple_choice': MultipleChoice,
|
13 |
+
'text_classification': TextClassification,
|
14 |
+
'token_classification': TokenClassification,
|
15 |
+
'xsent_retrieval': XSentRetrieval
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def get_modules(name=None):
|
20 |
+
if name:
|
21 |
+
return modules[name]
|
22 |
+
return modules.values()
|
Indic-BERT-v1-master/fine_tune/modules/base.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import random
|
7 |
+
import copy
|
8 |
+
import numpy as np
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
14 |
+
from torch.utils.data import DataLoader, TensorDataset
|
15 |
+
from transformers import (
|
16 |
+
AdamW,
|
17 |
+
AutoConfig,
|
18 |
+
AutoModel,
|
19 |
+
AutoModelForPreTraining,
|
20 |
+
AutoModelForQuestionAnswering,
|
21 |
+
AutoModelForSequenceClassification,
|
22 |
+
AutoModelForTokenClassification,
|
23 |
+
AutoModelWithLMHead,
|
24 |
+
AutoModelForMultipleChoice,
|
25 |
+
AutoTokenizer,
|
26 |
+
get_linear_schedule_with_warmup,
|
27 |
+
)
|
28 |
+
|
29 |
+
from ..data import load_dataset
|
30 |
+
from ..data.examples import *
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
MODEL_MODES = {
|
37 |
+
'base': AutoModel,
|
38 |
+
'sequence-classification': AutoModelForSequenceClassification,
|
39 |
+
'question-answering': AutoModelForQuestionAnswering,
|
40 |
+
'pretraining': AutoModelForPreTraining,
|
41 |
+
'token-classification': AutoModelForTokenClassification,
|
42 |
+
'language-modeling': AutoModelWithLMHead,
|
43 |
+
'multiple-choice': AutoModelForMultipleChoice,
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
def get_model_class(model_type, mode):
|
48 |
+
return MODEL_MODES[mode]
|
49 |
+
|
50 |
+
|
51 |
+
def set_seed(hparams):
|
52 |
+
random.seed(hparams['seed'])
|
53 |
+
np.random.seed(hparams['seed'])
|
54 |
+
torch.manual_seed(hparams['seed'])
|
55 |
+
if hparams['n_gpu'] > 0:
|
56 |
+
torch.cuda.manual_seed_all(hparams['seed'])
|
57 |
+
|
58 |
+
|
59 |
+
class BaseModule(pl.LightningModule):
|
60 |
+
"""
|
61 |
+
The base module has 4 components: config, tokenizer, transformer model,
|
62 |
+
and dataset
|
63 |
+
|
64 |
+
Loading of a dataset:
|
65 |
+
1. Load instances of a dataset in the form of `Examples`
|
66 |
+
2. Convert all examples into features - may require tokenizer
|
67 |
+
3. Create a tensor dataset and loader given all the converted features
|
68 |
+
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, hparams):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
hparams['mode'] = self.mode
|
75 |
+
hparams['output_mode'] = self.output_mode
|
76 |
+
hparams['example_type'] = self.example_type
|
77 |
+
hparams['dev_lang'] = hparams['train_lang']
|
78 |
+
self.hparams = hparams # must come after super
|
79 |
+
self.dataset = load_dataset(hparams['dataset'], hparams['data_dir'])
|
80 |
+
if self.output_mode == 'classification':
|
81 |
+
self.labels = self.dataset.get_labels(hparams['train_lang'])
|
82 |
+
|
83 |
+
# setup config object
|
84 |
+
config_name = hparams['config_name'] or hparams['model_name_or_path']
|
85 |
+
args = {}
|
86 |
+
if self.output_mode == 'classification':
|
87 |
+
hparams['num_labels'] = len(self.dataset.get_labels(hparams['train_lang']))
|
88 |
+
args = {'num_labels': hparams['num_labels']}
|
89 |
+
|
90 |
+
self.config = AutoConfig.from_pretrained(
|
91 |
+
config_name,
|
92 |
+
**args,
|
93 |
+
cache_dir=hparams['cache_dir']
|
94 |
+
)
|
95 |
+
|
96 |
+
# setup tokenizer object
|
97 |
+
tok_name = hparams['tokenizer_name'] or hparams['model_name_or_path']
|
98 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
99 |
+
tok_name,
|
100 |
+
config=self.config,
|
101 |
+
cache_dir=hparams['cache_dir'],
|
102 |
+
)
|
103 |
+
|
104 |
+
# setup transformer model
|
105 |
+
model_class = get_model_class(self.config.model_type, hparams['mode'])
|
106 |
+
self.model = model_class.from_pretrained(
|
107 |
+
hparams['model_name_or_path'],
|
108 |
+
config=self.config,
|
109 |
+
cache_dir=hparams['cache_dir'],
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, **inputs):
|
113 |
+
return self.model(**inputs)
|
114 |
+
|
115 |
+
def prepare_data(self):
|
116 |
+
"""Cache feature files on disk for every mode at the onset"""
|
117 |
+
modes = self.dataset.modes()
|
118 |
+
for mode in modes:
|
119 |
+
cached_features_file = self._feature_file(mode)
|
120 |
+
if not os.path.exists(cached_features_file)\
|
121 |
+
or self.hparams['overwrite_cache']:
|
122 |
+
self.load_features(mode)
|
123 |
+
|
124 |
+
def load_features(self, mode):
|
125 |
+
"""Load examples and convert them into features"""
|
126 |
+
if mode in ('train', 'dev', 'test'):
|
127 |
+
lang = self.hparams['{}_lang'.format(mode)]
|
128 |
+
else:
|
129 |
+
lang = self.hparams['test_lang']
|
130 |
+
examples = self.dataset.get_examples(lang, mode)
|
131 |
+
|
132 |
+
cached_features_file = self._feature_file(mode)
|
133 |
+
if os.path.exists(cached_features_file)\
|
134 |
+
and not self.hparams['overwrite_cache']:
|
135 |
+
features = torch.load(cached_features_file)
|
136 |
+
else:
|
137 |
+
features = self.convert_examples_to_features(examples)
|
138 |
+
torch.save(features, cached_features_file)
|
139 |
+
|
140 |
+
return features
|
141 |
+
|
142 |
+
def convert_examples_to_features(self, examples):
|
143 |
+
if self.hparams['example_type'] == 'multiple-choice':
|
144 |
+
features = convert_multiple_choice_examples_to_features(
|
145 |
+
examples,
|
146 |
+
self.tokenizer,
|
147 |
+
max_length=self.hparams['max_seq_length'],
|
148 |
+
label_list=self.labels
|
149 |
+
)
|
150 |
+
elif self.hparams['example_type'] == 'text':
|
151 |
+
features = convert_text_examples_to_features(
|
152 |
+
examples,
|
153 |
+
self.tokenizer,
|
154 |
+
max_length=self.hparams['max_seq_length'],
|
155 |
+
label_list=self.labels,
|
156 |
+
output_mode=self.output_mode,
|
157 |
+
)
|
158 |
+
elif self.hparams['example_type'] == 'tokens':
|
159 |
+
features = convert_tokens_examples_to_features(
|
160 |
+
examples,
|
161 |
+
self.labels,
|
162 |
+
self.hparams['max_seq_length'],
|
163 |
+
self.tokenizer,
|
164 |
+
cls_token_at_end=bool(self.config.model_type in ["xlnet"]),
|
165 |
+
cls_token=self.tokenizer.cls_token,
|
166 |
+
cls_token_segment_id=2 if self.config.model_type in ["xlnet"] else 0,
|
167 |
+
sep_token=self.tokenizer.sep_token,
|
168 |
+
sep_token_extra=bool(self.config.model_type in ["roberta"]),
|
169 |
+
pad_on_left=bool(self.config.model_type in ["xlnet"]),
|
170 |
+
pad_token=self.tokenizer.pad_token_id,
|
171 |
+
pad_token_segment_id=self.tokenizer.pad_token_type_id,
|
172 |
+
pad_token_label_id=self.pad_token_label_id,
|
173 |
+
)
|
174 |
+
return features
|
175 |
+
|
176 |
+
def make_loader(self, features, batch_size):
|
177 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
178 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
179 |
+
all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long)
|
180 |
+
# all_candidates = torch.tensor([f.candidates for f in features], dtype=torch.long)
|
181 |
+
if self.hparams['output_mode'] == 'classification':
|
182 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
183 |
+
elif self.hparams['output_mode'] == 'regression':
|
184 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
185 |
+
|
186 |
+
return DataLoader(
|
187 |
+
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels),
|
188 |
+
# TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_candidates),
|
189 |
+
batch_size=batch_size,
|
190 |
+
)
|
191 |
+
|
192 |
+
def train_dataloader(self):
|
193 |
+
train_batch_size = self.hparams['train_batch_size']
|
194 |
+
train_features = self.load_features('train')
|
195 |
+
dataloader = self.make_loader(train_features, train_batch_size)
|
196 |
+
|
197 |
+
t_total = (
|
198 |
+
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams['n_gpu'])))
|
199 |
+
// self.hparams['gradient_accumulation_steps']
|
200 |
+
* float(self.hparams['num_train_epochs'])
|
201 |
+
)
|
202 |
+
scheduler = get_linear_schedule_with_warmup(
|
203 |
+
self.opt, num_warmup_steps=self.hparams['warmup_steps'], num_training_steps=t_total
|
204 |
+
)
|
205 |
+
self.lr_scheduler = scheduler
|
206 |
+
return dataloader
|
207 |
+
|
208 |
+
def val_dataloader(self):
|
209 |
+
dev_features = self.load_features('dev')
|
210 |
+
dataloader = self.make_loader(dev_features, self.hparams['eval_batch_size'])
|
211 |
+
return dataloader
|
212 |
+
|
213 |
+
def test_dataloader(self):
|
214 |
+
test_features = self.load_features('test')
|
215 |
+
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
|
216 |
+
return dataloader
|
217 |
+
|
218 |
+
def training_step(self, batch, batch_idx):
|
219 |
+
inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3]}
|
220 |
+
if self.config.model_type != 'distilbert':
|
221 |
+
inputs['token_type_ids'] = (
|
222 |
+
batch[2] if self.config.model_type in ['bert', 'xlnet', 'albert'] else None
|
223 |
+
) # XLM and RoBERTa don't use token_type_ids
|
224 |
+
|
225 |
+
outputs = self(**inputs)
|
226 |
+
loss = outputs[0]
|
227 |
+
|
228 |
+
tensorboard_logs = {'loss': loss, 'rate': self.lr_scheduler.get_last_lr()[-1]}
|
229 |
+
return {'loss': loss, 'log': tensorboard_logs}
|
230 |
+
|
231 |
+
def validation_step(self, batch, batch_nb):
|
232 |
+
inputs = {'input_ids': batch[0],
|
233 |
+
'attention_mask': batch[1],
|
234 |
+
'labels': batch[3]}
|
235 |
+
|
236 |
+
# XLM and RoBERTa don't use token_type_ids
|
237 |
+
inputs['token_type_ids'] = None
|
238 |
+
if self.config.model_type in ['bert', 'xlnet', 'albert']:
|
239 |
+
inputs['token_type_ids'] = batch[2]
|
240 |
+
|
241 |
+
outputs = self(**inputs)
|
242 |
+
tmp_eval_loss, logits = outputs[:2]
|
243 |
+
preds = logits.detach().cpu().numpy()
|
244 |
+
out_label_ids = inputs['labels'].detach().cpu().numpy()
|
245 |
+
|
246 |
+
return {'val_loss': tmp_eval_loss.detach().cpu(),
|
247 |
+
'pred': preds,
|
248 |
+
'target': out_label_ids}
|
249 |
+
|
250 |
+
def test_step(self, batch, batch_nb):
|
251 |
+
return self.validation_step(batch, batch_nb)
|
252 |
+
|
253 |
+
def _feature_file(self, mode):
|
254 |
+
if mode in ('train', 'dev', 'test'):
|
255 |
+
lang = self.hparams['{}_lang'.format(mode)]
|
256 |
+
else:
|
257 |
+
lang = self.hparams['test_lang']
|
258 |
+
return os.path.join(
|
259 |
+
self.hparams['data_dir'],
|
260 |
+
'cached_{}_{}_{}_{}'.format(
|
261 |
+
lang,
|
262 |
+
mode,
|
263 |
+
list(filter(None, self.hparams['model_name_or_path'].split('/'))).pop(),
|
264 |
+
str(self.hparams['max_seq_length']),
|
265 |
+
),
|
266 |
+
)
|
267 |
+
|
268 |
+
def is_logger(self):
|
269 |
+
return self.trainer.global_rank <= 0
|
270 |
+
|
271 |
+
def configure_optimizers(self):
|
272 |
+
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
273 |
+
|
274 |
+
model = self.model
|
275 |
+
no_decay = ['bias', 'LayerNorm.weight']
|
276 |
+
optimizer_grouped_parameters = [
|
277 |
+
{
|
278 |
+
'params': [p for n, p in model.named_parameters()
|
279 |
+
if not any(nd in n for nd in no_decay)],
|
280 |
+
'weight_decay': self.hparams['weight_decay'],
|
281 |
+
},
|
282 |
+
{
|
283 |
+
'params': [p for n, p in model.named_parameters()
|
284 |
+
if any(nd in n for nd in no_decay)],
|
285 |
+
'weight_decay': 0.0,
|
286 |
+
},
|
287 |
+
]
|
288 |
+
optimizer = AdamW(optimizer_grouped_parameters,
|
289 |
+
lr=self.hparams['learning_rate'],
|
290 |
+
eps=self.hparams['adam_epsilon'])
|
291 |
+
self.opt = optimizer
|
292 |
+
return [optimizer]
|
293 |
+
|
294 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
295 |
+
second_order_closure=None):
|
296 |
+
if self.trainer.use_tpu:
|
297 |
+
import torch_xla.core.xla_model as xm
|
298 |
+
xm.optimizer_step(optimizer)
|
299 |
+
else:
|
300 |
+
optimizer.step()
|
301 |
+
optimizer.zero_grad()
|
302 |
+
self.lr_scheduler.step()
|
303 |
+
|
304 |
+
def get_tqdm_dict(self):
|
305 |
+
avg_loss = getattr(self.trainer, 'avg_loss', 0.0)
|
306 |
+
tqdm_dict = {'loss': '{:.3f}'.format(avg_loss), 'lr': self.lr_scheduler.get_last_lr()[-1]}
|
307 |
+
return tqdm_dict
|
308 |
+
|
309 |
+
def run_module(self):
|
310 |
+
trainer = create_trainer(self, self.hparams)
|
311 |
+
hparams_copy = copy.deepcopy(self.hparams)
|
312 |
+
|
313 |
+
if self.hparams['do_train']:
|
314 |
+
checkpoints = list(sorted(glob.glob(os.path.join(self.hparams['output_dir'], 'checkpointepoch=*.ckpt'), recursive=True)))
|
315 |
+
if len(checkpoints) == 0:
|
316 |
+
trainer.fit(self)
|
317 |
+
checkpoints = list(sorted(glob.glob(os.path.join(self.hparams['output_dir'], 'checkpointepoch=*.ckpt'), recursive=True)))
|
318 |
+
self.trained_model = self.load_from_checkpoint(checkpoints[-1])
|
319 |
+
self.trained_model.hparams = hparams_copy
|
320 |
+
|
321 |
+
# Optionally, predict on dev set and write to output_dir
|
322 |
+
if self.hparams['do_predict']:
|
323 |
+
trainer.test(self.trained_model)
|
324 |
+
|
325 |
+
|
326 |
+
# Fixes __temp_weight_ddp_end.ckpt bug
|
327 |
+
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/1142
|
328 |
+
class MonkeyPatchedTrainer(pl.Trainer):
|
329 |
+
def load_spawn_weights(self, original_model):
|
330 |
+
pass
|
331 |
+
|
332 |
+
|
333 |
+
pl.Trainer = MonkeyPatchedTrainer
|
334 |
+
|
335 |
+
|
336 |
+
class LoggingCallback(pl.Callback):
|
337 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
338 |
+
logger.info("***** Validation results *****")
|
339 |
+
if pl_module.is_logger():
|
340 |
+
metrics = trainer.callback_metrics
|
341 |
+
# Log results
|
342 |
+
for key in sorted(metrics):
|
343 |
+
if key not in ["log", "progress_bar"]:
|
344 |
+
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
345 |
+
|
346 |
+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
347 |
+
logger.info("***** Test results *****")
|
348 |
+
print(trainer.callback_metrics)
|
349 |
+
|
350 |
+
if pl_module.is_logger():
|
351 |
+
metrics = trainer.callback_metrics
|
352 |
+
|
353 |
+
# Log and save results to file
|
354 |
+
output_dir = pl_module.hparams['output_dir']
|
355 |
+
test_lang = pl_module.hparams['test_lang']
|
356 |
+
output_test_results_file = os.path.join(output_dir, 'test_results_{}.txt'.format(test_lang))
|
357 |
+
with open(output_test_results_file, "w") as writer:
|
358 |
+
for key in sorted(metrics):
|
359 |
+
if key not in ["log", "progress_bar"]:
|
360 |
+
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
361 |
+
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
362 |
+
|
363 |
+
|
364 |
+
def create_trainer(model, hparams):
|
365 |
+
# init model
|
366 |
+
set_seed(hparams)
|
367 |
+
|
368 |
+
# if os.path.exists(hparams['output_dir']) and os.listdir(hparams['output_dir']) and hparams['do_train']:
|
369 |
+
# raise ValueError('Output directory ({}) already exists and is not empty.'.format(hparams['output_dir']))
|
370 |
+
|
371 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
372 |
+
filepath=hparams['output_dir'], prefix='checkpoint', monitor='val_loss', mode='min', save_top_k=5
|
373 |
+
)
|
374 |
+
|
375 |
+
train_params = dict(
|
376 |
+
accumulate_grad_batches=hparams['gradient_accumulation_steps'],
|
377 |
+
gpus=hparams['n_gpu'],
|
378 |
+
max_epochs=hparams['num_train_epochs'],
|
379 |
+
early_stop_callback=False,
|
380 |
+
gradient_clip_val=hparams['max_grad_norm'],
|
381 |
+
checkpoint_callback=checkpoint_callback,
|
382 |
+
callbacks=[LoggingCallback()],
|
383 |
+
)
|
384 |
+
|
385 |
+
if hparams['fp16']:
|
386 |
+
train_params['use_amp'] = hparams['fp16']
|
387 |
+
train_params['amp_level'] = hparams['fp16_opt_level']
|
388 |
+
|
389 |
+
if hparams['n_tpu_cores'] > 0:
|
390 |
+
train_params['tpu_cores'] = hparams['n_tpu_cores']
|
391 |
+
train_params['gpus'] = 0
|
392 |
+
|
393 |
+
if hparams['n_gpu'] > 1:
|
394 |
+
train_params['distributed_backend'] = 'ddp'
|
395 |
+
|
396 |
+
trainer = pl.Trainer(**train_params)
|
397 |
+
return trainer
|
Indic-BERT-v1-master/fine_tune/modules/masked_lm.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on https://github.com/huggingface/transformers/issues/80
|
3 |
+
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import argparse
|
8 |
+
import glob
|
9 |
+
import sys
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import string
|
14 |
+
from filelock import FileLock
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import pickle
|
18 |
+
import torch
|
19 |
+
from torch.utils.data import DataLoader, TensorDataset
|
20 |
+
|
21 |
+
from .base import BaseModule, create_trainer
|
22 |
+
from ..data.examples import InputFeatures
|
23 |
+
from collections import ChainMap
|
24 |
+
from torch.utils.data import DataLoader, TensorDataset
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class MaskedLM(BaseModule):
|
31 |
+
|
32 |
+
mode = 'language-modeling'
|
33 |
+
output_mode = 'classification'
|
34 |
+
example_type = 'multiple-choice'
|
35 |
+
|
36 |
+
def __init__(self, hparams):
|
37 |
+
super().__init__(hparams)
|
38 |
+
|
39 |
+
self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]')
|
40 |
+
self.test_results_fpath = 'test_results'
|
41 |
+
if os.path.exists(self.test_results_fpath):
|
42 |
+
os.remove(self.test_results_fpath)
|
43 |
+
|
44 |
+
def convert_examples_to_features(self, examples):
|
45 |
+
|
46 |
+
batch_encoding = self.tokenizer(
|
47 |
+
[example.question for example in examples],
|
48 |
+
max_length=self.hparams['max_seq_length'],
|
49 |
+
padding='max_length',
|
50 |
+
truncation=True,
|
51 |
+
)
|
52 |
+
|
53 |
+
features = []
|
54 |
+
for i in range(len(examples)):
|
55 |
+
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
|
56 |
+
candidates = examples[i].endings
|
57 |
+
tokens = [self.tokenizer.tokenize(cand) for cand in candidates]
|
58 |
+
token_candidates = []
|
59 |
+
|
60 |
+
for toks in tokens:
|
61 |
+
if len(toks) == 0:
|
62 |
+
token_candidates.append(self.tokenizer.unk_token)
|
63 |
+
else:
|
64 |
+
token_candidates.append(max(toks, key=lambda t: len(t.strip(string.punctuation))))
|
65 |
+
candidate_ids = self.tokenizer.convert_tokens_to_ids(token_candidates)
|
66 |
+
|
67 |
+
feature = InputFeatures(**inputs, candidates=candidate_ids, label=examples[i].label)
|
68 |
+
features.append(feature)
|
69 |
+
|
70 |
+
return features
|
71 |
+
|
72 |
+
def test_dataloader(self):
|
73 |
+
mode = 'test'
|
74 |
+
cached_features_file = self._feature_file(mode)
|
75 |
+
if os.path.exists(cached_features_file) and not self.hparams['overwrite_cache']:
|
76 |
+
features = torch.load(cached_features_file)
|
77 |
+
else:
|
78 |
+
features = self.load_features(mode)
|
79 |
+
torch.save(features, cached_features_file)
|
80 |
+
|
81 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
82 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
83 |
+
all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long)
|
84 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
85 |
+
all_cands = torch.tensor([f.candidates for f in features], dtype=torch.long)
|
86 |
+
all_answers = torch.tensor([f.label for f in features], dtype=torch.long)
|
87 |
+
|
88 |
+
return DataLoader(
|
89 |
+
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_cands, all_answers),
|
90 |
+
batch_size=self.hparams['eval_batch_size'],
|
91 |
+
)
|
92 |
+
|
93 |
+
def test_step(self, batch, batch_idx):
|
94 |
+
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
|
95 |
+
'attention_mask': batch[1]}
|
96 |
+
|
97 |
+
answers = batch[3].detach().cpu().numpy()
|
98 |
+
candidates = batch[4].detach().cpu().numpy()
|
99 |
+
|
100 |
+
# get first mask location
|
101 |
+
input_ids = batch[0].detach().cpu().numpy()
|
102 |
+
mask_ids = (input_ids == self.mask_id).argmax(axis=1)
|
103 |
+
mask_ids = torch.from_numpy(mask_ids)
|
104 |
+
|
105 |
+
predictions = self(**inputs)[0]
|
106 |
+
|
107 |
+
i = torch.arange(0, predictions.shape[0], dtype=torch.int64)
|
108 |
+
predictions = predictions[i, mask_ids]
|
109 |
+
predictions = predictions.detach().cpu().numpy()
|
110 |
+
|
111 |
+
right, wrong = 0, 0
|
112 |
+
|
113 |
+
for i, pred in enumerate(predictions):
|
114 |
+
prob = pred[candidates[i]]
|
115 |
+
pred_answer = int(np.argmax(prob))
|
116 |
+
if answers[i] == pred_answer:
|
117 |
+
right += 1
|
118 |
+
else:
|
119 |
+
wrong += 1
|
120 |
+
|
121 |
+
return {"right": right, "wrong": wrong}
|
122 |
+
|
123 |
+
def test_epoch_end(self, outputs):
|
124 |
+
right = sum(output['right'] for output in outputs)
|
125 |
+
wrong = sum(output['wrong'] for output in outputs)
|
126 |
+
merged = {'right': right, 'wrong': wrong}
|
127 |
+
|
128 |
+
with FileLock(self.test_results_fpath + '.lock'):
|
129 |
+
if os.path.exists(self.test_results_fpath):
|
130 |
+
with open(self.test_results_fpath, 'rb') as fp:
|
131 |
+
data = pickle.load(fp)
|
132 |
+
data = {'right': data['right'] + merged['right'], 'wrong': data['wrong'] + merged['wrong']}
|
133 |
+
else:
|
134 |
+
data = merged
|
135 |
+
with open(self.test_results_fpath, 'wb') as fp:
|
136 |
+
pickle.dump(data, fp)
|
137 |
+
|
138 |
+
return data
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def add_model_specific_args(parser, root_dir):
|
142 |
+
return parser
|
143 |
+
|
144 |
+
def run_module(self):
|
145 |
+
self.eval()
|
146 |
+
self.freeze()
|
147 |
+
torch.no_grad()
|
148 |
+
|
149 |
+
trainer = create_trainer(self, self.hparams)
|
150 |
+
|
151 |
+
trainer.test(self)
|
152 |
+
preds = pickle.load(open(self.test_results_fpath, 'rb'))
|
153 |
+
correct, wrong = preds['right'], preds['wrong']
|
154 |
+
with open(os.path.join(self.hparams['output_dir'], 'test_results.txt'), 'w') as fp:
|
155 |
+
json.dump({'test_acc': correct/(correct + wrong)}, fp)
|
Indic-BERT-v1-master/fine_tune/modules/multiple_choice.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from .base import BaseModule
|
5 |
+
from .utils import mean_accuracy
|
6 |
+
|
7 |
+
|
8 |
+
class MultipleChoice(BaseModule):
|
9 |
+
|
10 |
+
mode = 'multiple-choice'
|
11 |
+
output_mode = 'classification'
|
12 |
+
example_type = 'multiple-choice'
|
13 |
+
|
14 |
+
def __init__(self, hparams):
|
15 |
+
super().__init__(hparams)
|
16 |
+
|
17 |
+
def _eval_end(self, outputs):
|
18 |
+
val_loss_mean = torch.stack([x['val_loss'] for x in outputs])\
|
19 |
+
.mean().detach().cpu().item()
|
20 |
+
preds = np.concatenate([x['pred'] for x in outputs], axis=0)
|
21 |
+
preds = np.argmax(preds, axis=1)
|
22 |
+
|
23 |
+
out_label_ids = np.concatenate([x['target'] for x in outputs], axis=0)
|
24 |
+
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
25 |
+
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
26 |
+
|
27 |
+
results = {**{'val_loss': val_loss_mean},
|
28 |
+
**mean_accuracy(preds, out_label_ids)}
|
29 |
+
|
30 |
+
ret = {k: v for k, v in results.items()}
|
31 |
+
ret['log'] = results
|
32 |
+
return ret, preds_list, out_label_list
|
33 |
+
|
34 |
+
def validation_epoch_end(self, outputs: list) -> dict:
|
35 |
+
ret, preds, targets = self._eval_end(outputs)
|
36 |
+
logs = ret['log']
|
37 |
+
return {'val_loss': logs['val_loss'], 'log': logs, 'progress_bar': logs}
|
38 |
+
|
39 |
+
def test_epoch_end(self, outputs):
|
40 |
+
ret, predictions, targets = self._eval_end(outputs)
|
41 |
+
|
42 |
+
# Converting to the dic required by pl
|
43 |
+
logs = ret['log']
|
44 |
+
# `val_loss` is the key returned by `self._eval_end()`
|
45 |
+
# but actually refers to `test_loss`
|
46 |
+
return {'avg_test_loss': logs['val_loss'],
|
47 |
+
'log': logs, 'progress_bar': logs}
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def add_model_specific_args(parser, root_dir):
|
51 |
+
return parser
|
Indic-BERT-v1-master/fine_tune/modules/question_answering.py
ADDED
File without changes
|
Indic-BERT-v1-master/fine_tune/modules/text_classification.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code inspired from the Huggingface's transformer library:
|
3 |
+
File path: transformers/examples/text-classification/run_pl_glue.py
|
4 |
+
|
5 |
+
To handle large documents, we use head-truncation. Check the following
|
6 |
+
paper for a detailed analysis of text classification techniques using
|
7 |
+
bert-like models: https://arxiv.org/pdf/1905.05583.pdf
|
8 |
+
"""
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import glob
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import DataLoader, TensorDataset
|
19 |
+
|
20 |
+
from .base import BaseModule, create_trainer
|
21 |
+
from .utils import mean_accuracy
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class TextClassification(BaseModule):
|
28 |
+
|
29 |
+
mode = 'sequence-classification'
|
30 |
+
output_mode = 'classification'
|
31 |
+
example_type = 'text'
|
32 |
+
|
33 |
+
def __init__(self, hparams):
|
34 |
+
super().__init__(hparams)
|
35 |
+
|
36 |
+
def _eval_end(self, outputs):
|
37 |
+
val_loss_mean = torch.stack([x['val_loss'] for x in outputs])\
|
38 |
+
.mean().detach().cpu().item()
|
39 |
+
preds = np.concatenate([x['pred'] for x in outputs], axis=0)
|
40 |
+
preds = np.argmax(preds, axis=1)
|
41 |
+
|
42 |
+
out_label_ids = np.concatenate([x['target'] for x in outputs], axis=0)
|
43 |
+
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
44 |
+
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
45 |
+
|
46 |
+
results = {**{'val_loss': val_loss_mean},
|
47 |
+
**mean_accuracy(preds, out_label_ids)}
|
48 |
+
|
49 |
+
ret = {k: v for k, v in results.items()}
|
50 |
+
ret['log'] = results
|
51 |
+
return ret, preds_list, out_label_list
|
52 |
+
|
53 |
+
def validation_epoch_end(self, outputs: list) -> dict:
|
54 |
+
ret, preds, targets = self._eval_end(outputs)
|
55 |
+
logs = ret['log']
|
56 |
+
return {'val_loss': logs['val_loss'], 'log': logs, 'progress_bar': logs}
|
57 |
+
|
58 |
+
def test_epoch_end(self, outputs):
|
59 |
+
ret, predictions, targets = self._eval_end(outputs)
|
60 |
+
|
61 |
+
# Converting to the dic required by pl
|
62 |
+
logs = ret['log']
|
63 |
+
# `val_loss` is the key returned by `self._eval_end()`
|
64 |
+
# but actually refers to `test_loss`
|
65 |
+
return {'avg_test_loss': logs['val_loss'],
|
66 |
+
'log': logs, 'progress_bar': logs}
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def add_model_specific_args(parser, root_dir):
|
70 |
+
return parser
|
Indic-BERT-v1-master/fine_tune/modules/token_classification.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from seqeval.metrics import f1_score, precision_score, recall_score
|
10 |
+
from torch.nn import CrossEntropyLoss
|
11 |
+
|
12 |
+
from .base import BaseModule
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class TokenClassification(BaseModule):
|
19 |
+
|
20 |
+
mode = 'token-classification'
|
21 |
+
output_mode = 'classification'
|
22 |
+
example_type = 'tokens'
|
23 |
+
|
24 |
+
def __init__(self, hyparams):
|
25 |
+
self.pad_token_label_id = CrossEntropyLoss().ignore_index
|
26 |
+
|
27 |
+
script_path = os.path.join(os.path.dirname(__file__), '../..', 'scripts/ner_preprocess.sh')
|
28 |
+
cmd = f"bash {script_path} {hyparams['data_dir']} {hyparams['train_lang']} "\
|
29 |
+
f"{hyparams['test_lang']} {hyparams['model_name_or_path']} {hyparams['max_seq_length']}"
|
30 |
+
subprocess.call(cmd, shell=True)
|
31 |
+
|
32 |
+
super().__init__(hyparams)
|
33 |
+
|
34 |
+
def _eval_end(self, outputs):
|
35 |
+
"""Evaluation called for both Val and Test"""
|
36 |
+
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
|
37 |
+
preds = np.concatenate([x['pred'] for x in outputs], axis=0)
|
38 |
+
preds = np.argmax(preds, axis=2)
|
39 |
+
out_label_ids = np.concatenate([x['target'] for x in outputs], axis=0)
|
40 |
+
|
41 |
+
label_map = {i: label for i, label in enumerate(self.labels)}
|
42 |
+
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
|
43 |
+
preds_list = [[] for _ in range(out_label_ids.shape[0])]
|
44 |
+
|
45 |
+
for i in range(out_label_ids.shape[0]):
|
46 |
+
for j in range(out_label_ids.shape[1]):
|
47 |
+
if out_label_ids[i, j] != self.pad_token_label_id:
|
48 |
+
out_label_list[i].append(label_map[out_label_ids[i][j]])
|
49 |
+
preds_list[i].append(label_map[preds[i][j]])
|
50 |
+
|
51 |
+
results = {
|
52 |
+
'val_loss': val_loss_mean,
|
53 |
+
'precision': precision_score(out_label_list, preds_list),
|
54 |
+
'recall': recall_score(out_label_list, preds_list),
|
55 |
+
'f1': f1_score(out_label_list, preds_list),
|
56 |
+
}
|
57 |
+
|
58 |
+
ret = {k: v for k, v in results.items()}
|
59 |
+
ret['log'] = results
|
60 |
+
return ret, preds_list, out_label_list
|
61 |
+
|
62 |
+
def validation_epoch_end(self, outputs):
|
63 |
+
# when stable
|
64 |
+
ret, preds, targets = self._eval_end(outputs)
|
65 |
+
logs = ret['log']
|
66 |
+
return {'val_loss': logs['val_loss'], 'log': logs, 'progress_bar': logs}
|
67 |
+
|
68 |
+
def test_epoch_end(self, outputs):
|
69 |
+
# updating to test_epoch_end instead of deprecated test_end
|
70 |
+
ret, predictions, targets = self._eval_end(outputs)
|
71 |
+
|
72 |
+
# Converting to the dict required by pl
|
73 |
+
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
|
74 |
+
# pytorch_lightning/trainer/logging.py#L139
|
75 |
+
logs = ret['log']
|
76 |
+
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
|
77 |
+
return {'avg_test_loss': logs['val_loss'], 'log': logs, 'progress_bar': logs}
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def add_model_specific_args(parser, root_dir):
|
81 |
+
parser.add_argument(
|
82 |
+
'--labels',
|
83 |
+
default='',
|
84 |
+
type=str,
|
85 |
+
help='Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.',
|
86 |
+
)
|
87 |
+
return parser
|
Indic-BERT-v1-master/fine_tune/modules/utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
def mean_accuracy(preds, labels):
|
4 |
+
return {'acc': (preds == labels).mean()}
|
Indic-BERT-v1-master/fine_tune/modules/xsent_retrieval.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
import logging
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
import scipy.spatial as sp
|
8 |
+
from filelock import FileLock
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .base import BaseModule, create_trainer
|
14 |
+
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class XSentRetrieval(BaseModule):
|
20 |
+
|
21 |
+
mode = 'base'
|
22 |
+
output_mode = 'classification'
|
23 |
+
example_type = 'text'
|
24 |
+
|
25 |
+
def __init__(self, hparams):
|
26 |
+
self.test_results_fpath = 'test_results'
|
27 |
+
if os.path.exists(self.test_results_fpath):
|
28 |
+
os.remove(self.test_results_fpath)
|
29 |
+
|
30 |
+
super().__init__(hparams)
|
31 |
+
|
32 |
+
def forward(self, **inputs):
|
33 |
+
outputs = self.model(**inputs)
|
34 |
+
last_hidden = outputs[0]
|
35 |
+
mean_pooled = torch.mean(last_hidden, 1)
|
36 |
+
return mean_pooled
|
37 |
+
|
38 |
+
def test_dataloader_en(self):
|
39 |
+
test_features = self.load_features('en')
|
40 |
+
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
|
41 |
+
return dataloader
|
42 |
+
|
43 |
+
def test_dataloader_in(self):
|
44 |
+
test_features = self.load_features('in')
|
45 |
+
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
|
46 |
+
return dataloader
|
47 |
+
|
48 |
+
def test_step(self, batch, batch_idx):
|
49 |
+
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
|
50 |
+
'attention_mask': batch[1]}
|
51 |
+
labels = batch[3].detach().cpu().numpy()
|
52 |
+
sentvecs = self(**inputs)
|
53 |
+
sentvecs = sentvecs.detach().cpu().numpy()
|
54 |
+
sentvecs = np.hstack([labels[:, None], sentvecs])
|
55 |
+
|
56 |
+
return {'sentvecs': sentvecs}
|
57 |
+
|
58 |
+
def test_epoch_end(self, outputs):
|
59 |
+
all_sentvecs = np.vstack([x['sentvecs'] for x in outputs])
|
60 |
+
|
61 |
+
with FileLock(self.test_results_fpath + '.lock'):
|
62 |
+
if os.path.exists(self.test_results_fpath):
|
63 |
+
with open(self.test_results_fpath, 'rb') as fp:
|
64 |
+
data = pickle.load(fp)
|
65 |
+
data = np.vstack([data, all_sentvecs])
|
66 |
+
else:
|
67 |
+
data = all_sentvecs
|
68 |
+
with open(self.test_results_fpath, 'wb') as fp:
|
69 |
+
pickle.dump(data, fp)
|
70 |
+
|
71 |
+
return {'sentvecs': all_sentvecs}
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def add_model_specific_args(parser, root_dir):
|
75 |
+
return parser
|
76 |
+
|
77 |
+
def run_module(self):
|
78 |
+
self.eval()
|
79 |
+
self.freeze()
|
80 |
+
|
81 |
+
trainer = create_trainer(self, self.hparams)
|
82 |
+
|
83 |
+
trainer.test(self, self.test_dataloader_en())
|
84 |
+
sentvecs1 = pickle.load(open(self.test_results_fpath, 'rb'))
|
85 |
+
os.remove(self.test_results_fpath)
|
86 |
+
|
87 |
+
trainer.test(self, self.test_dataloader_in())
|
88 |
+
sentvecs2 = pickle.load(open(self.test_results_fpath, 'rb'))
|
89 |
+
os.remove(self.test_results_fpath)
|
90 |
+
|
91 |
+
sentvecs1 = sentvecs1[sentvecs1[:, 0].argsort()][:, 1:]
|
92 |
+
sentvecs2 = sentvecs2[sentvecs2[:, 0].argsort()][:, 1:]
|
93 |
+
|
94 |
+
result_path = os.path.join(self.hparams['output_dir'], 'test_results.txt')
|
95 |
+
with open(result_path, 'w') as fp:
|
96 |
+
metrics = {'test_acc': precision_at_10(sentvecs1, sentvecs2)}
|
97 |
+
json.dump(metrics, fp)
|
98 |
+
|
99 |
+
|
100 |
+
def precision_at_10(sentvecs1, sentvecs2):
|
101 |
+
n = sentvecs1.shape[0]
|
102 |
+
|
103 |
+
# mean centering
|
104 |
+
sentvecs1 = sentvecs1 - np.mean(sentvecs1, axis=0)
|
105 |
+
sentvecs2 = sentvecs2 - np.mean(sentvecs2, axis=0)
|
106 |
+
|
107 |
+
sim = sp.distance.cdist(sentvecs1, sentvecs2, 'cosine')
|
108 |
+
actual = np.array(range(n))
|
109 |
+
preds = sim.argsort(axis=1)[:, :10]
|
110 |
+
matches = np.any(preds == actual[:, None], axis=1)
|
111 |
+
return matches.mean()
|