TymaaHammouda commited on
Commit
ceed500
·
verified ·
1 Parent(s): b60df2a

Upload 106 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. arabert/.gitignore +142 -0
  2. arabert/AJGT.xlsx +0 -0
  3. arabert/README.md +227 -0
  4. arabert/__init__.py +1 -0
  5. arabert/__pycache__/__init__.cpython-310.pyc +0 -0
  6. arabert/__pycache__/__init__.cpython-38.pyc +0 -0
  7. arabert/__pycache__/__init__.cpython-39.pyc +0 -0
  8. arabert/__pycache__/preprocess.cpython-310.pyc +0 -0
  9. arabert/__pycache__/preprocess.cpython-38.pyc +0 -0
  10. arabert/__pycache__/preprocess.cpython-39.pyc +0 -0
  11. arabert/arabert/LICENSE +75 -0
  12. arabert/arabert/Readme.md +75 -0
  13. arabert/arabert/__init__.py +14 -0
  14. arabert/arabert/create_classification_data.py +260 -0
  15. arabert/arabert/create_pretraining_data.py +534 -0
  16. arabert/arabert/extract_features.py +444 -0
  17. arabert/arabert/lamb_optimizer.py +158 -0
  18. arabert/arabert/modeling.py +1027 -0
  19. arabert/arabert/optimization.py +202 -0
  20. arabert/arabert/run_classifier.py +1078 -0
  21. arabert/arabert/run_pretraining.py +593 -0
  22. arabert/arabert/run_squad.py +1440 -0
  23. arabert/arabert/sample_text.txt +38 -0
  24. arabert/arabert/tokenization.py +414 -0
  25. arabert/arabert_logo.png +0 -0
  26. arabert/araelectra/.gitignore +4 -0
  27. arabert/araelectra/LICENSE +76 -0
  28. arabert/araelectra/README.md +144 -0
  29. arabert/araelectra/__init__.py +1 -0
  30. arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
  31. arabert/araelectra/build_pretraining_dataset.py +230 -0
  32. arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
  33. arabert/araelectra/configure_finetuning.py +172 -0
  34. arabert/araelectra/configure_pretraining.py +143 -0
  35. arabert/araelectra/finetune/__init__.py +14 -0
  36. arabert/araelectra/finetune/classification/classification_metrics.py +116 -0
  37. arabert/araelectra/finetune/classification/classification_tasks.py +439 -0
  38. arabert/araelectra/finetune/feature_spec.py +56 -0
  39. arabert/araelectra/finetune/preprocessing.py +173 -0
  40. arabert/araelectra/finetune/qa/mrqa_official_eval.py +120 -0
  41. arabert/araelectra/finetune/qa/qa_metrics.py +401 -0
  42. arabert/araelectra/finetune/qa/qa_tasks.py +628 -0
  43. arabert/araelectra/finetune/qa/squad_official_eval.py +317 -0
  44. arabert/araelectra/finetune/qa/squad_official_eval_v1.py +126 -0
  45. arabert/araelectra/finetune/scorer.py +54 -0
  46. arabert/araelectra/finetune/tagging/tagging_metrics.py +116 -0
  47. arabert/araelectra/finetune/tagging/tagging_tasks.py +253 -0
  48. arabert/araelectra/finetune/tagging/tagging_utils.py +58 -0
  49. arabert/araelectra/finetune/task.py +74 -0
  50. arabert/araelectra/finetune/task_builder.py +70 -0
arabert/.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ #vscode stuff
129
+ .vscode/
130
+
131
+ # Local History for Visual Studio Code
132
+ .history/
133
+ # Pyre type checker
134
+ .pyre/
135
+ testing_squad.py
136
+ FarasaSegmenterJar.jar
137
+ data/
138
+ testing/
139
+ *.tsv
140
+ *.zip
141
+ model_cards/
142
+ optuna/
arabert/AJGT.xlsx ADDED
Binary file (107 kB). View file
 
arabert/README.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AraBERTv2 / AraGPT2 / AraELECTRA
2
+
3
+ <img src="https://github.com/aub-mind/arabert/blob/master/arabert_logo.png" width="100" align="right"/>
4
+
5
+ This repository now contains code and implementation for:
6
+ - **AraBERT v0.1/v1**: Original
7
+ - **AraBERT v0.2/v2**: Base and large versions with better vocabulary, more data, more training [Read More...](#AraBERT)
8
+ - **AraGPT2**: base, medium, large and MEGA. Trained from scratch on Arabic [Read More...](#AraGPT2)
9
+ - **AraELECTRA**: Trained from scratch on Arabic [Read More...](#AraELECTRA)
10
+
11
+ If you want to clone the old repository:
12
+ ```bash
13
+ git clone https://github.com/aub-mind/arabert/
14
+ cd arabert && git checkout 6a58ca118911ef311cbe8cdcdcc1d03601123291
15
+ ```
16
+ # Update
17
+
18
+ - **02-Apr-2021:** AraELECTRA powered Arabic Wikipedia QA system [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/wissamantoun/arabic-wikipedia-qa-streamlit/main)
19
+
20
+ # AraBERTv2
21
+
22
+ ## What's New!
23
+
24
+ AraBERT now comes in 4 new variants to replace the old v1 versions:
25
+
26
+ More Detail in the AraBERT folder and in the [README](https://github.com/aub-mind/arabert/tree/master/arabert) and in the [AraBERT Paper](https://arxiv.org/abs/2003.00104)
27
+
28
+ Model | HuggingFace Model Name | Size (MB/Params)| Pre-Segmentation | DataSet (Sentences/Size/nWords) |
29
+ ---|:---:|:---:|:---:|:---:
30
+ AraBERTv0.2-base | [bert-base-arabertv02](https://huggingface.co/aubmindlab/bert-base-arabertv02) | 543MB / 136M | No | 200M / 77GB / 8.6B |
31
+ AraBERTv0.2-large| [bert-large-arabertv02](https://huggingface.co/aubmindlab/bert-large-arabertv02) | 1.38G / 371M | No | 200M / 77GB / 8.6B |
32
+ AraBERTv2-base| [bert-base-arabertv2](https://huggingface.co/aubmindlab/bert-base-arabertv2) | 543MB / 136M | Yes | 200M / 77GB / 8.6B |
33
+ AraBERTv2-large| [bert-large-arabertv2](https://huggingface.co/aubmindlab/bert-large-arabertv2) | 1.38G / 371M | Yes | 200M / 77GB / 8.6B |
34
+ AraBERTv0.1-base| [bert-base-arabertv01](https://huggingface.co/aubmindlab/bert-base-arabertv01) | 543MB / 136M | No | 77M / 23GB / 2.7B |
35
+ AraBERTv1-base| [bert-base-arabert](https://huggingface.co/aubmindlab/bert-base-arabert) | 543MB / 136M | Yes | 77M / 23GB / 2.7B |
36
+
37
+ All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
38
+
39
+ ## Better Pre-Processing and New Vocab
40
+
41
+ We identified an issue with AraBERTv1's wordpiece vocabulary. The issue came from punctuations and numbers that were still attached to words when learned the wordpiece vocab. We now insert a space between numbers and characters and around punctuation characters.
42
+
43
+ The new vocabulary was learnt using the `BertWordpieceTokenizer` from the `tokenizers` library, and should now support the Fast tokenizer implementation from the `transformers` library.
44
+
45
+ **P.S.**: All the old BERT codes should work with the new BERT, just change the model name and check the new preprocessing function
46
+
47
+ **Please read the section on how to use the [preprocessing function](#Preprocessing)**
48
+
49
+ ## Bigger Dataset and More Compute
50
+
51
+ We used ~3.5 times more data, and trained for longer.
52
+ For Dataset Sources see the [Dataset Section](#Dataset)
53
+
54
+ Model | Hardware | num of examples with seq len (128 / 512) |128 (Batch Size/ Num of Steps) | 512 (Batch Size/ Num of Steps) | Total Steps | Total Time (in Days) |
55
+ ---|:---:|:---:|:---:|:---:|:---:|:---:
56
+ AraBERTv0.2-base | TPUv3-8 | 420M / 207M | 2560 / 1M | 384/ 2M | 3M | 36
57
+ AraBERTv0.2-large | TPUv3-128 | 420M / 207M | 13440 / 250K | 2056 / 300K | 550K | 7
58
+ AraBERTv2-base | TPUv3-8 | 420M / 207M | 2560 / 1M | 384/ 2M | 3M | 36
59
+ AraBERTv2-large | TPUv3-128 | 520M / 245M | 13440 / 250K | 2056 / 300K | 550K | 7
60
+ AraBERT-base (v1/v0.1) | TPUv2-8 | - |512 / 900K | 128 / 300K| 1.2M | 4
61
+
62
+ # AraGPT2
63
+
64
+ More details and code are available in the AraGPT2 folder and [README](https://github.com/aub-mind/arabert/blob/master/aragpt2/README.md)
65
+
66
+ ## Model
67
+
68
+ Model | HuggingFace Model Name | Size / Params|
69
+ ---|:---:|:---:
70
+ AraGPT2-base | [aragpt2-base](https://huggingface.co/aubmindlab/aragpt2-base) | 527MB/135M |
71
+ AraGPT2-medium | [aragpt2-medium](https://huggingface.co/aubmindlab/aragpt2-medium) | 1.38G/370M |
72
+ AraGPT2-large | [aragpt2-large](https://huggingface.co/aubmindlab/aragpt2-large) | 2.98GB/792M |
73
+ AraGPT2-mega | [aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) | 5.5GB/1.46B |
74
+ AraGPT2-mega-detector-long | [aragpt2-mega-detector-long](https://huggingface.co/aubmindlab/aragpt2-mega-detector-long) | 516MB/135M |
75
+
76
+ All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
77
+
78
+ ## Dataset and Compute
79
+
80
+ For Dataset Source see the [Dataset Section](#Dataset)
81
+
82
+ Model | Hardware | num of examples (seq len = 1024) | Batch Size | Num of Steps | Time (in days)
83
+ ---|:---:|:---:|:---:|:---:|:---:
84
+ AraGPT2-base | TPUv3-128 | 9.7M | 1792 | 125K | 1.5
85
+ AraGPT2-medium | TPUv3-128 | 9.7M | 1152 | 85K | 1.5
86
+ AraGPT2-large | TPUv3-128 | 9.7M | 256 | 220k | 3
87
+ AraGPT2-mega | TPUv3-128 | 9.7M | 256 | 800K | 9
88
+
89
+ # AraELECTRA
90
+
91
+ More details and code are available in the AraELECTRA folder and [README](https://github.com/aub-mind/arabert/blob/master/araelectra/README.md)
92
+
93
+ ## Model
94
+
95
+ Model | HuggingFace Model Name | Size (MB/Params)|
96
+ ---|:---:|:---:
97
+ AraELECTRA-base-generator | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator) | 227MB/60M |
98
+ AraELECTRA-base-discriminator | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) | 516MB/135M |
99
+
100
+ ## Dataset and Compute
101
+ Model | Hardware | num of examples (seq len = 512) | Batch Size | Num of Steps | Time (in days)
102
+ ---|:---:|:---:|:---:|:---:|:---:
103
+ ELECTRA-base | TPUv3-8 | - | 256 | 2M | 24
104
+
105
+ # Dataset
106
+
107
+ The pretraining data used for the new AraBERT model is also used for **AraGPT2 and AraELECTRA**.
108
+
109
+ The dataset consists of 77GB or 200,095,961 lines or 8,655,948,860 words or 82,232,988,358 chars (before applying Farasa Segmentation)
110
+
111
+ For the new dataset we added the unshuffled OSCAR corpus, after we thoroughly filter it, to the previous dataset used in AraBERTv1 but with out the websites that we previously crawled:
112
+ - OSCAR unshuffled and filtered.
113
+ - [Arabic Wikipedia dump](https://archive.org/details/arwiki-20190201) from 2020/09/01
114
+ - [The 1.5B words Arabic Corpus](https://www.semanticscholar.org/paper/1.5-billion-words-Arabic-Corpus-El-Khair/f3eeef4afb81223df96575adadf808fe7fe440b4)
115
+ - [The OSIAN Corpus](https://www.aclweb.org/anthology/W19-4619)
116
+ - Assafir news articles. Huge thank you for Assafir for the data
117
+
118
+ # Preprocessing
119
+
120
+ It is recommended to apply our preprocessing function before training/testing on any dataset.
121
+ **Install farasapy to segment text for AraBERT v1 & v2 `pip install farasapy`**
122
+
123
+ ```python
124
+ from arabert.preprocess import ArabertPreprocessor
125
+
126
+ model_name = "aubmindlab/bert-base-arabertv2"
127
+ arabert_prep = ArabertPreprocessor(model_name=model_name)
128
+
129
+ text = "ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
130
+ arabert_prep.preprocess(text)
131
+ >>>"و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
132
+ ```
133
+
134
+ You can also use the `unpreprocess()` function to reverse the preprocessing changes, by fixing the spacing around non alphabetical characters, and also de-segmenting if the model selected need pre-segmentation. We highly recommend unprocessing generated content of `AraGPT2` model, to make it look more natural.
135
+ ```python
136
+ output_text = "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
137
+ arabert_prep.unpreprocess(output_text)
138
+ >>>"ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
139
+ ```
140
+
141
+ ### Accepted Model Names:
142
+ The `ArabertPreprocessor` class expects one of the following model names:
143
+
144
+ Note: You can also use the same model name from the `HuggingFace` model repository without removing `aubmindlab/`. Defaults to `bert-base-arabertv02` with no pre-segmentation
145
+
146
+ ```
147
+ bert-base-arabertv01
148
+ bert-base-arabert
149
+ bert-base-arabertv02
150
+ bert-base-arabertv2
151
+ bert-large-arabertv02
152
+ bert-large-arabertv2
153
+ araelectra-base-discriminator
154
+ araelectra-base-generator
155
+ aragpt2-base
156
+ aragpt2-medium
157
+ aragpt2-large
158
+ aragpt2-mega
159
+ ```
160
+ # Examples Notebooks
161
+
162
+ - You can find the old examples that work with AraBERTv1 in the `examples/old` folder
163
+ - Check the [Readme.md](https://github.com/aub-mind/arabert/tree/master/examples) file in the examples folder for new links to colab notebooks
164
+
165
+ # TensorFlow 1.x models
166
+
167
+ **You can find the PyTorch, TF2 and TF1 models in HuggingFace's Transformer Library under the ```aubmindlab``` username**
168
+
169
+ - `wget https://huggingface.co/aubmindlab/MODEL_NAME/resolve/main/tf1_model.tar.gz` where `MODEL_NAME` is any model under the `aubmindlab` name
170
+
171
+
172
+ # If you used this model please cite us as :
173
+ ## AraBERT
174
+ Google Scholar has our Bibtex wrong (missing name), use this instead
175
+ ```
176
+ @inproceedings{antoun2020arabert,
177
+ title={AraBERT: Transformer-based Model for Arabic Language Understanding},
178
+ author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
179
+ booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
180
+ pages={9}
181
+ }
182
+ ```
183
+ ## AraGPT2
184
+ ```
185
+ @inproceedings{antoun-etal-2021-aragpt2,
186
+ title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation",
187
+ author = "Antoun, Wissam and
188
+ Baly, Fady and
189
+ Hajj, Hazem",
190
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
191
+ month = apr,
192
+ year = "2021",
193
+ address = "Kyiv, Ukraine (Virtual)",
194
+ publisher = "Association for Computational Linguistics",
195
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.21",
196
+ pages = "196--207",
197
+ }
198
+ ```
199
+
200
+ ## AraELECTRA
201
+ ```
202
+ @inproceedings{antoun-etal-2021-araelectra,
203
+ title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
204
+ author = "Antoun, Wissam and
205
+ Baly, Fady and
206
+ Hajj, Hazem",
207
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
208
+ month = apr,
209
+ year = "2021",
210
+ address = "Kyiv, Ukraine (Virtual)",
211
+ publisher = "Association for Computational Linguistics",
212
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
213
+ pages = "191--195",
214
+ }
215
+ ```
216
+
217
+
218
+ # Acknowledgments
219
+ Thanks to TensorFlow Research Cloud (TFRC) for the free access to Cloud TPUs, couldn't have done it without this program, and to the [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) Members for the continous support. Also thanks to [Yakshof](https://www.yakshof.com/#/) and Assafir for data and storage access. Another thanks for Habib Rahal (https://www.behance.net/rahalhabib), for putting a face to AraBERT.
220
+
221
+ # Contacts
222
+ **Wissam Antoun**: [Linkedin](https://www.linkedin.com/in/wissam-antoun-622142b4/) | [Twitter](https://twitter.com/wissam_antoun) | [Github](https://github.com/WissamAntoun) | wfa07 (AT) mail (DOT) aub (DOT) edu | wissam.antoun (AT) gmail (DOT) com
223
+
224
+ **Fady Baly**: [Linkedin](https://www.linkedin.com/in/fadybaly/) | [Twitter](https://twitter.com/fadybaly) | [Github](https://github.com/fadybaly) | fgb06 (AT) mail (DOT) aub (DOT) edu | baly.fady (AT) gmail (DOT) com
225
+
226
+
227
+
arabert/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
arabert/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (117 Bytes). View file
 
arabert/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (143 Bytes). View file
 
arabert/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (139 Bytes). View file
 
arabert/__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (21.8 kB). View file
 
arabert/__pycache__/preprocess.cpython-38.pyc ADDED
Binary file (21.9 kB). View file
 
arabert/__pycache__/preprocess.cpython-39.pyc ADDED
Binary file (21.8 kB). View file
 
arabert/arabert/LICENSE ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==========================================
2
+ SOFTWARE LICENSE AGREEMENT - AraBERT
3
+ ==========================================
4
+
5
+ * NAME: AraBERT : Arabic Bidirectional Encoder Representations from Transformers
6
+
7
+ * ACKNOWLEDGMENTS
8
+
9
+ This [software] was generated by [American
10
+ University of Beirut] (“Owners”). The statements
11
+ made herein are solely the responsibility of the author[s].
12
+
13
+ The following software programs and programs have been used in the
14
+ generation of [AraBERT]:
15
+
16
+ + Farasa Segmenter
17
+ - Abdelali, Ahmed, Kareem Darwish, Nadir Durrani, and Hamdy Mubarak.
18
+ "Farasa: A fast and furious segmenter for arabic." In Proceedings of
19
+ the 2016 Conference of the North American Chapter of the Association
20
+ for Computational Linguistics: Demonstrations, pp. 11-16. 2016.
21
+ - License and link : http://alt.qcri.org/farasa/segmenter.html
22
+
23
+ + BERT
24
+ - Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.
25
+ "Bert: Pre-training of deep bidirectional transformers for language
26
+ understanding." arXiv preprint arXiv:1810.04805 (2018).
27
+ - License and link : https://github.com/google-research/bert
28
+
29
+ + PyArabic
30
+ - T. Zerrouki, Pyarabic, An Arabic language library for Python,
31
+ https://pypi.python.org/pypi/pyarabic/, 2010
32
+ - License and link: https://github.com/linuxscout/pyarabic/
33
+
34
+ * LICENSE
35
+
36
+ This software and database is being provided to you, the LICENSEE,
37
+ by the Owners under the following license. By obtaining, using and/or
38
+ copying this software and database, you agree that you have read,
39
+ understood, and will comply with these terms and conditions. You
40
+ further agree that you have read and you will abide by the license
41
+ agreements provided in the above links under “acknowledgements”:
42
+ Permission to use, copy, modify and distribute this software and
43
+ database and its documentation for any purpose and without fee or
44
+ royalty is hereby granted, provided that you agree to comply with the
45
+ following copyright notice and statements, including the disclaimer,
46
+ and that the same appear on ALL copies of the software, database and
47
+ documentation, including modifications that you make for internal use
48
+ or for distribution. [AraBERT] Copyright 2020 by [American University
49
+ of Beirut]. All rights reserved. If you remix, transform, or build
50
+ upon the material, you must distribute your contributions under the
51
+ same license as this one. You may not apply legal terms or technological
52
+ measures that legally restrict others from doing anything this license
53
+ permits. THIS SOFTWARE IS PROVIDED "AS IS" AND THE OWNERS MAKE NO
54
+ REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE,
55
+ BUT NOT LIMITATION, THE OWNERS MAKE NO REPRESENTATIONS OR WARRANTIES OF
56
+ MERCHANT-ABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
57
+ THE LICENSED SOFTWARE, DATABASE OR DOCUMENTATION WILL NOT INFRINGE ANY THIRD
58
+ PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS. The name of the
59
+ Owners may not be used in advertising or publicity pertaining to
60
+ distribution of the software and/or database. Title to copyright in
61
+ this software, database and any associated documentation shall at all
62
+ times remain with the Owners and LICENSEE agrees to preserve same.
63
+
64
+ The use of AraBERT should be cited as follows:
65
+ @inproceedings{antoun2020arabert,
66
+ title={AraBERT: Transformer-based Model for Arabic Language Understanding},
67
+ author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
68
+ booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
69
+ pages={9}
70
+ }
71
+
72
+ [AraBERT] Copyright 2020 by [American University of Beirut].
73
+ All rights reserved.
74
+ ==========================================
75
+
arabert/arabert/Readme.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AraBERT v1 & v2 : Pre-training BERT for Arabic Language Understanding
2
+ <img src="https://github.com/aub-mind/arabert/blob/master/arabert_logo.png" width="100" align="left"/>
3
+
4
+ **AraBERT** is an Arabic pretrained lanaguage model based on [Google's BERT architechture](https://github.com/google-research/bert). AraBERT uses the same BERT-Base config. More details are available in the [AraBERT Paper](https://arxiv.org/abs/2003.00104v2) and in the [AraBERT Meetup](https://github.com/WissamAntoun/pydata_khobar_meetup)
5
+
6
+ There are two versions of the model, AraBERTv0.1 and AraBERTv1, with the difference being that AraBERTv1 uses pre-segmented text where prefixes and suffixes were splitted using the [Farasa Segmenter](http://alt.qcri.org/farasa/segmenter.html).
7
+
8
+
9
+ We evalaute AraBERT models on different downstream tasks and compare them to [mBERT]((https://github.com/google-research/bert/blob/master/multilingual.md)), and other state of the art models (*To the extent of our knowledge*). The Tasks were Sentiment Analysis on 6 different datasets ([HARD](https://github.com/elnagara/HARD-Arabic-Dataset), [ASTD-Balanced](https://www.aclweb.org/anthology/D15-1299), [ArsenTD-Lev](https://staff.aub.edu.lb/~we07/Publications/ArSentD-LEV_Sentiment_Corpus.pdf), [LABR](https://github.com/mohamedadaly/LABR)), Named Entity Recognition with the [ANERcorp](http://curtis.ml.cmu.edu/w/courses/index.php/ANERcorp), and Arabic Question Answering on [Arabic-SQuAD and ARCD](https://github.com/husseinmozannar/SOQAL)
10
+
11
+
12
+ ## Results
13
+ Task | Metric | AraBERTv0.1 | AraBERTv1 | AraBERTv0.2-base | AraBERTv2-Base | AraBERTv0.2-large | AraBERTv2-large| AraELECTRA-Base
14
+ :---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
15
+ HARD |Acc.|**96.2**|96.1|-|-|-|-|-
16
+ ASTD |Acc.|92.2|**92.6**|-|-|-|-|-
17
+ ArsenTD-Lev|macro-f1|53.56|-|55.71|-|56.94|-|**57.20**
18
+ AJGT|Acc.|93.1|**93.8**|-|-|-|-|-
19
+ LABR|Acc.|85.9|**86.7**|-|-|-|-|-
20
+ ANERcorp|macro-F1|83.1|82.4|83.70|-|83.08|-|**83.95**
21
+ ARCD|EM - F1|31.62 - 67.45|31.7 - 67.8|32.76 - 66.53|31.34 - 67.23|36.89 - **71.32**|34.19 - 68.12|**37.03** - 71.22
22
+ TyDiQA-ar|EM - F1|68.51 - 82.86|- |73.07 - 85.41|-|73.72 - 86.03|-|**74.91 - 86.68**
23
+
24
+
25
+ ## How to use
26
+
27
+ You can easily use AraBERT since it is almost fully compatible with existing codebases (Use this repo instead of the official BERT one, the only difference is in the ```tokenization.py``` file where we modify the _is_punctuation function to make it compatible with the "+" symbol and the "[" and "]" characters)
28
+
29
+
30
+ **AraBERTv1 an v2 always needs pre-segmentation**
31
+ ```python
32
+ from transformers import AutoTokenizer, AutoModel
33
+ from arabert.preprocess import ArabertPreprocessor
34
+
35
+ model_name = "aubmindlab/bert-base-arabertv2"
36
+ arabert_tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ arabert_model = AutoModel.from_pretrained(model_name)
38
+
39
+ arabert_prep = ArabertPreprocessor(model_name=model_name)
40
+
41
+ text = "ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا ضروري"
42
+ arabert_prep.preprocess(text)
43
+ >>>"و+ لن نبالغ إذا قل +نا إن هاتف أو كمبيوتر ال+ مكتب في زمن +نا هذا ضروري"
44
+
45
+ arabert_tokenizer.tokenize(text_preprocessed)
46
+
47
+ >>> ['و+', 'لن', 'نبال', '##غ', 'إذا', 'قل', '+نا', 'إن', 'هاتف', 'أو', 'كمبيوتر', 'ال+', 'مكتب', 'في', 'زمن', '+نا', 'هذا', 'ضروري']
48
+ ```
49
+
50
+ **AraBERTv0.1 and v0.2 needs no pre-segmentation.**
51
+ ```python
52
+ from transformers import AutoTokenizer, AutoModel
53
+ from arabert.preprocess import ArabertPreprocessor
54
+
55
+ arabert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv01",do_lower_case=False)
56
+ arabert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv01")
57
+
58
+ text = "ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا ضروري"
59
+
60
+ model_name = "aubmindlab/bert-base-arabertv01"
61
+ arabert_tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+ arabert_model = AutoModel.from_pretrained(model_name)
63
+
64
+ arabert_prep = ArabertPreprocessor(model_name=model_name)
65
+
66
+ arabert_tokenizer.tokenize(text_preprocessed)
67
+
68
+ >>> ['ولن', 'ن', '##بالغ', 'إذا', 'قلنا', 'إن', 'هاتف', 'أو', 'كمبيوتر', 'المكتب', 'في', 'زمن', '##ن', '##ا', 'هذا', 'ضروري']
69
+ ```
70
+
71
+ ## Model Weights and Vocab Download
72
+
73
+ **You can find the PyTorch, TF2 and TF1 models in HuggingFace's Transformer Library under the ```aubmindlab``` username**
74
+
75
+ - `wget https://huggingface.co/aubmindlab/MODEL_NAME/resolve/main/tf1_model.tar.gz` where `MODEL_NAME` is any model under the `aubmindlab` name
arabert/arabert/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
arabert/arabert/create_classification_data.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts used to pre_process and create the data for classifier evaluation
2
+ #%%
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+
6
+ import sys
7
+ sys.path.append("..")
8
+
9
+ from arabert.preprocess import ArabertPreprocessor
10
+
11
+
12
+ from tqdm import tqdm
13
+
14
+ tqdm.pandas()
15
+
16
+ from tokenization import FullTokenizer
17
+ from run_classifier import input_fn_builder, model_fn_builder
18
+
19
+
20
+ model_name = "bert-base-arabert"
21
+ arabert_prep = ArabertPreprocessor(model_name=model_name, keep_emojis=False)
22
+
23
+
24
+ class Dataset:
25
+ def __init__(
26
+ self,
27
+ name,
28
+ train,
29
+ test,
30
+ label_list,
31
+ train_InputExamples=None,
32
+ test_InputExamples=None,
33
+ train_features=None,
34
+ test_features=None,
35
+ ):
36
+ self.name = name
37
+ self.train = train
38
+ self.test = test
39
+ self.label_list = label_list
40
+ self.train_InputExamples = train_InputExamples
41
+ self.test_InputExamples = test_InputExamples
42
+ self.train_features = train_features
43
+ self.test_features = test_features
44
+
45
+
46
+ all_datasets = []
47
+ #%%
48
+ # *************HARD************
49
+ df_HARD = pd.read_csv("Datasets\\HARD\\balanced-reviews-utf8.tsv", sep="\t", header=0)
50
+
51
+ df_HARD = df_HARD[["rating", "review"]] # we are interested in rating and review only
52
+ # code rating as +ve if > 3, -ve if less, no 3s in dataset
53
+ df_HARD["rating"] = df_HARD["rating"].apply(lambda x: 0 if x < 3 else 1)
54
+ # rename columns to fit default constructor in fastai
55
+ df_HARD.columns = ["label", "text"]
56
+ df_HARD["text"] = df_HARD["text"].progress_apply(
57
+ lambda x: arabert_prep.preprocess(
58
+ x
59
+ )
60
+ )
61
+ train_HARD, test_HARD = train_test_split(df_HARD, test_size=0.2, random_state=42)
62
+ label_list_HARD = [0, 1]
63
+
64
+ data_Hard = Dataset("HARD", train_HARD, test_HARD, label_list_HARD)
65
+ all_datasets.append(data_Hard)
66
+
67
+ #%%
68
+ # *************ASTD-Unbalanced************
69
+ df_ASTD_UN = pd.read_csv(
70
+ "Datasets\\ASTD-master\\data\\Tweets.txt", sep="\t", header=None
71
+ )
72
+
73
+ DATA_COLUMN = "text"
74
+ LABEL_COLUMN = "label"
75
+ df_ASTD_UN.columns = [DATA_COLUMN, LABEL_COLUMN]
76
+
77
+ df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
78
+ lambda x: 0 if (x == "NEG") else x
79
+ )
80
+ df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
81
+ lambda x: 1 if (x == "POS") else x
82
+ )
83
+ df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
84
+ lambda x: 2 if (x == "NEUTRAL") else x
85
+ )
86
+ df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
87
+ lambda x: 3 if (x == "OBJ") else x
88
+ )
89
+ df_ASTD_UN["text"] = df_ASTD_UN["text"].progress_apply(
90
+ lambda x: arabert_prep.preprocess(
91
+ x
92
+ )
93
+ )
94
+ train_ASTD_UN, test_ASTD_UN = train_test_split(
95
+ df_ASTD_UN, test_size=0.2, random_state=42
96
+ )
97
+ label_list_ASTD_UN = [0, 1, 2, 3]
98
+
99
+ data_ASTD_UN = Dataset(
100
+ "ASTD-Unbalanced", train_ASTD_UN, test_ASTD_UN, label_list_ASTD_UN
101
+ )
102
+ all_datasets.append(data_ASTD_UN)
103
+ #%%
104
+ # *************ASTD-Dahou-Balanced************
105
+
106
+ df_ASTD_B = pd.read_csv(
107
+ "Datasets\\Dahou\\data_csv_balanced\\ASTD-balanced-not-linked.csv",
108
+ sep=",",
109
+ header=0,
110
+ )
111
+
112
+ df_ASTD_B.columns = [DATA_COLUMN, LABEL_COLUMN]
113
+
114
+ df_ASTD_B[LABEL_COLUMN] = df_ASTD_B[LABEL_COLUMN].apply(lambda x: 0 if (x == -1) else x)
115
+ df_ASTD_B["text"] = df_ASTD_B["text"].progress_apply(
116
+ lambda x: arabert_prep.preprocess(
117
+ x
118
+ )
119
+ )
120
+ train_ASTD_B, test_ASTD_B = train_test_split(df_ASTD_B, test_size=0.2, random_state=42)
121
+ label_list_ASTD_B = [0, 1]
122
+
123
+ data_ASTD_B = Dataset(
124
+ "ASTD-Dahou-Balanced", train_ASTD_B, test_ASTD_B, label_list_ASTD_B
125
+ )
126
+ all_datasets.append(data_ASTD_B)
127
+
128
+ #%%
129
+ # *************ArSenTD-LEV************
130
+ df_ArSenTD = pd.read_csv(
131
+ "Datasets\\ArSenTD-LEV\\ArSenTD-LEV-processed-no-emojis2.csv", sep=",", header=0
132
+ )
133
+
134
+ df_ArSenTD.columns = [DATA_COLUMN, LABEL_COLUMN]
135
+
136
+ df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
137
+ lambda x: 0 if (x == "very_negative") else x
138
+ )
139
+ df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
140
+ lambda x: 1 if (x == "negative") else x
141
+ )
142
+ df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
143
+ lambda x: 2 if (x == "neutral") else x
144
+ )
145
+ df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
146
+ lambda x: 3 if (x == "positive") else x
147
+ )
148
+ df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
149
+ lambda x: 4 if (x == "very_positive") else x
150
+ )
151
+ df_ArSenTD["text"] = df_ArSenTD["text"].progress_apply(
152
+ lambda x: arabert_prep.preprocess(
153
+ x
154
+ )
155
+ )
156
+ label_list_ArSenTD = [0, 1, 2, 3, 4]
157
+
158
+ train_ArSenTD, test_ArSenTD = train_test_split(
159
+ df_ArSenTD, test_size=0.2, random_state=42
160
+ )
161
+
162
+ data_ArSenTD = Dataset("ArSenTD-LEV", train_ArSenTD, test_ArSenTD, label_list_ArSenTD)
163
+ all_datasets.append(data_ArSenTD)
164
+
165
+ #%%
166
+ # *************AJGT************
167
+ df_AJGT = pd.read_excel("Datasets\\Ajgt\\AJGT.xlsx", header=0)
168
+
169
+ df_AJGT = df_AJGT[["Feed", "Sentiment"]]
170
+ df_AJGT.columns = [DATA_COLUMN, LABEL_COLUMN]
171
+
172
+ df_AJGT[LABEL_COLUMN] = df_AJGT[LABEL_COLUMN].apply(
173
+ lambda x: 0 if (x == "Negative") else x
174
+ )
175
+ df_AJGT[LABEL_COLUMN] = df_AJGT[LABEL_COLUMN].apply(
176
+ lambda x: 1 if (x == "Positive") else x
177
+ )
178
+ df_AJGT["text"] = df_AJGT["text"].progress_apply(
179
+ lambda x: arabert_prep.preprocess(
180
+ x
181
+ )
182
+ )
183
+ train_AJGT, test_AJGT = train_test_split(df_AJGT, test_size=0.2, random_state=42)
184
+ label_list_AJGT = [0, 1]
185
+
186
+ data_AJGT = Dataset("AJGT", train_AJGT, test_AJGT, label_list_AJGT)
187
+ all_datasets.append(data_AJGT)
188
+ #%%
189
+ # *************LABR-UN-Binary************
190
+ from labr import LABR
191
+
192
+ labr_helper = LABR()
193
+
194
+ (d_train, y_train, d_test, y_test) = labr_helper.get_train_test(
195
+ klass="2", balanced="unbalanced"
196
+ )
197
+
198
+ train_LABR_B_U = pd.DataFrame({"text": d_train, "label": y_train})
199
+ test_LABR_B_U = pd.DataFrame({"text": d_test, "label": y_test})
200
+
201
+ train_LABR_B_U["text"] = train_LABR_B_U["text"].progress_apply(
202
+ lambda x: arabert_prep.preprocess(
203
+ x
204
+ )
205
+ )
206
+ test_LABR_B_U["text"] = test_LABR_B_U["text"].progress_apply(
207
+ lambda x: arabert_prep.preprocess(
208
+ x
209
+ )
210
+ )
211
+ label_list_LABR_B_U = [0, 1]
212
+
213
+ data_LABR_B_U = Dataset(
214
+ "LABR-UN-Binary", train_LABR_B_U, test_LABR_B_U, label_list_LABR_B_U
215
+ )
216
+ # all_datasets.append(data_LABR_B_U)
217
+
218
+ #%%
219
+ for data in tqdm(all_datasets):
220
+ # Use the InputExample class from BERT's run_classifier code to create examples from the data
221
+ data.train_InputExamples = data.train.apply(
222
+ lambda x: run_classifier.InputExample(
223
+ guid=None, # Globally unique ID for bookkeeping, unused in this example
224
+ text_a=x[DATA_COLUMN],
225
+ text_b=None,
226
+ label=x[LABEL_COLUMN],
227
+ ),
228
+ axis=1,
229
+ )
230
+
231
+ data.test_InputExamples = data.test.apply(
232
+ lambda x: run_classifier.InputExample(
233
+ guid=None, text_a=x[DATA_COLUMN], text_b=None, label=x[LABEL_COLUMN]
234
+ ),
235
+ axis=1,
236
+ )
237
+ #%%
238
+ # We'll set sequences to be at most 128 tokens long.
239
+ MAX_SEQ_LENGTH = 256
240
+
241
+ VOC_FNAME = "./64000_vocab_sp_70m.txt"
242
+ tokenizer = FullTokenizer(VOC_FNAME)
243
+
244
+ for data in tqdm(all_datasets):
245
+ # Convert our train and test features to InputFeatures that BERT understands.
246
+ data.train_features = run_classifier.convert_examples_to_features(
247
+ data.train_InputExamples, data.label_list, MAX_SEQ_LENGTH, tokenizer
248
+ )
249
+ data.test_features = run_classifier.convert_examples_to_features(
250
+ data.test_InputExamples, data.label_list, MAX_SEQ_LENGTH, tokenizer
251
+ )
252
+
253
+ # %%
254
+ import pickle
255
+
256
+ with open("all_datasets_64k_farasa_256.pickle", "wb") as fp: # Pickling
257
+ pickle.dump(all_datasets, fp)
258
+
259
+
260
+ # %%
arabert/arabert/create_pretraining_data.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Create masked LM/next sentence masked_lm TF examples for BERT."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import collections
22
+ import random
23
+ import tokenization
24
+ import tensorflow as tf
25
+
26
+ flags = tf.flags
27
+
28
+ FLAGS = flags.FLAGS
29
+
30
+ flags.DEFINE_string(
31
+ "input_file", None, "Input raw text file (or comma-separated list of files)."
32
+ )
33
+
34
+ flags.DEFINE_string(
35
+ "output_file", None, "Output TF example file (or comma-separated list of files)."
36
+ )
37
+
38
+ flags.DEFINE_string(
39
+ "vocab_file", None, "The vocabulary file that the BERT model was trained on."
40
+ )
41
+
42
+ flags.DEFINE_bool(
43
+ "do_lower_case",
44
+ True,
45
+ "Whether to lower case the input text. Should be True for uncased "
46
+ "models and False for cased models.",
47
+ )
48
+
49
+ flags.DEFINE_bool(
50
+ "do_whole_word_mask",
51
+ False,
52
+ "Whether to use whole word masking rather than per-WordPiece masking.",
53
+ )
54
+
55
+ flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
56
+
57
+ flags.DEFINE_integer(
58
+ "max_predictions_per_seq",
59
+ 20,
60
+ "Maximum number of masked LM predictions per sequence.",
61
+ )
62
+
63
+ flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
64
+
65
+ flags.DEFINE_integer(
66
+ "dupe_factor",
67
+ 10,
68
+ "Number of times to duplicate the input data (with different masks).",
69
+ )
70
+
71
+ flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
72
+
73
+ flags.DEFINE_float(
74
+ "short_seq_prob",
75
+ 0.1,
76
+ "Probability of creating sequences which are shorter than the " "maximum length.",
77
+ )
78
+
79
+
80
+ class TrainingInstance(object):
81
+ """A single training instance (sentence pair)."""
82
+
83
+ def __init__(
84
+ self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next
85
+ ):
86
+ self.tokens = tokens
87
+ self.segment_ids = segment_ids
88
+ self.is_random_next = is_random_next
89
+ self.masked_lm_positions = masked_lm_positions
90
+ self.masked_lm_labels = masked_lm_labels
91
+
92
+ def __str__(self):
93
+ s = ""
94
+ s += "tokens: %s\n" % (
95
+ " ".join([tokenization.printable_text(x) for x in self.tokens])
96
+ )
97
+ s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
98
+ s += "is_random_next: %s\n" % self.is_random_next
99
+ s += "masked_lm_positions: %s\n" % (
100
+ " ".join([str(x) for x in self.masked_lm_positions])
101
+ )
102
+ s += "masked_lm_labels: %s\n" % (
103
+ " ".join([tokenization.printable_text(x) for x in self.masked_lm_labels])
104
+ )
105
+ s += "\n"
106
+ return s
107
+
108
+ def __repr__(self):
109
+ return self.__str__()
110
+
111
+
112
+ def write_instance_to_example_files(
113
+ instances, tokenizer, max_seq_length, max_predictions_per_seq, output_files
114
+ ):
115
+ """Create TF example files from `TrainingInstance`s."""
116
+ writers = []
117
+ for output_file in output_files:
118
+ writers.append(tf.python_io.TFRecordWriter(output_file))
119
+
120
+ writer_index = 0
121
+
122
+ total_written = 0
123
+ for (inst_index, instance) in enumerate(instances):
124
+ input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
125
+ input_mask = [1] * len(input_ids)
126
+ segment_ids = list(instance.segment_ids)
127
+ assert len(input_ids) <= max_seq_length
128
+
129
+ while len(input_ids) < max_seq_length:
130
+ input_ids.append(0)
131
+ input_mask.append(0)
132
+ segment_ids.append(0)
133
+
134
+ assert len(input_ids) == max_seq_length
135
+ assert len(input_mask) == max_seq_length
136
+ assert len(segment_ids) == max_seq_length
137
+
138
+ masked_lm_positions = list(instance.masked_lm_positions)
139
+ masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
140
+ masked_lm_weights = [1.0] * len(masked_lm_ids)
141
+
142
+ while len(masked_lm_positions) < max_predictions_per_seq:
143
+ masked_lm_positions.append(0)
144
+ masked_lm_ids.append(0)
145
+ masked_lm_weights.append(0.0)
146
+
147
+ next_sentence_label = 1 if instance.is_random_next else 0
148
+
149
+ features = collections.OrderedDict()
150
+ features["input_ids"] = create_int_feature(input_ids)
151
+ features["input_mask"] = create_int_feature(input_mask)
152
+ features["segment_ids"] = create_int_feature(segment_ids)
153
+ features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
154
+ features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
155
+ features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
156
+ features["next_sentence_labels"] = create_int_feature([next_sentence_label])
157
+
158
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
159
+
160
+ writers[writer_index].write(tf_example.SerializeToString())
161
+ writer_index = (writer_index + 1) % len(writers)
162
+
163
+ total_written += 1
164
+
165
+ if inst_index < 20:
166
+ tf.logging.info("*** Example ***")
167
+ tf.logging.info(
168
+ "tokens: %s"
169
+ % " ".join([tokenization.printable_text(x) for x in instance.tokens])
170
+ )
171
+
172
+ for feature_name in features.keys():
173
+ feature = features[feature_name]
174
+ values = []
175
+ if feature.int64_list.value:
176
+ values = feature.int64_list.value
177
+ elif feature.float_list.value:
178
+ values = feature.float_list.value
179
+ tf.logging.info(
180
+ "%s: %s" % (feature_name, " ".join([str(x) for x in values]))
181
+ )
182
+
183
+ for writer in writers:
184
+ writer.close()
185
+
186
+ tf.logging.info("Wrote %d total instances", total_written)
187
+
188
+
189
+ def create_int_feature(values):
190
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
191
+ return feature
192
+
193
+
194
+ def create_float_feature(values):
195
+ feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
196
+ return feature
197
+
198
+
199
+ def create_training_instances(
200
+ input_files,
201
+ tokenizer,
202
+ max_seq_length,
203
+ dupe_factor,
204
+ short_seq_prob,
205
+ masked_lm_prob,
206
+ max_predictions_per_seq,
207
+ rng,
208
+ ):
209
+ """Create `TrainingInstance`s from raw text."""
210
+ all_documents = [[]]
211
+
212
+ # Input file format:
213
+ # (1) One sentence per line. These should ideally be actual sentences, not
214
+ # entire paragraphs or arbitrary spans of text. (Because we use the
215
+ # sentence boundaries for the "next sentence prediction" task).
216
+ # (2) Blank lines between documents. Document boundaries are needed so
217
+ # that the "next sentence prediction" task doesn't span between documents.
218
+ for input_file in input_files:
219
+ with tf.gfile.GFile(input_file, "r") as reader:
220
+ while True:
221
+ line = tokenization.convert_to_unicode(reader.readline())
222
+ if not line:
223
+ break
224
+ line = line.strip()
225
+
226
+ # Empty lines are used as document delimiters
227
+ if not line:
228
+ all_documents.append([])
229
+ tokens = tokenizer.tokenize(line)
230
+ if tokens:
231
+ all_documents[-1].append(tokens)
232
+
233
+ # Remove empty documents
234
+ all_documents = [x for x in all_documents if x]
235
+ rng.shuffle(all_documents)
236
+
237
+ vocab_words = list(tokenizer.vocab.keys())
238
+ instances = []
239
+ for _ in range(dupe_factor):
240
+ for document_index in range(len(all_documents)):
241
+ instances.extend(
242
+ create_instances_from_document(
243
+ all_documents,
244
+ document_index,
245
+ max_seq_length,
246
+ short_seq_prob,
247
+ masked_lm_prob,
248
+ max_predictions_per_seq,
249
+ vocab_words,
250
+ rng,
251
+ )
252
+ )
253
+
254
+ rng.shuffle(instances)
255
+ return instances
256
+
257
+
258
+ def create_instances_from_document(
259
+ all_documents,
260
+ document_index,
261
+ max_seq_length,
262
+ short_seq_prob,
263
+ masked_lm_prob,
264
+ max_predictions_per_seq,
265
+ vocab_words,
266
+ rng,
267
+ ):
268
+ """Creates `TrainingInstance`s for a single document."""
269
+ document = all_documents[document_index]
270
+
271
+ # Account for [CLS], [SEP], [SEP]
272
+ max_num_tokens = max_seq_length - 3
273
+
274
+ # We *usually* want to fill up the entire sequence since we are padding
275
+ # to `max_seq_length` anyways, so short sequences are generally wasted
276
+ # computation. However, we *sometimes*
277
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
278
+ # sequences to minimize the mismatch between pre-training and fine-tuning.
279
+ # The `target_seq_length` is just a rough target however, whereas
280
+ # `max_seq_length` is a hard limit.
281
+ target_seq_length = max_num_tokens
282
+ if rng.random() < short_seq_prob:
283
+ target_seq_length = rng.randint(2, max_num_tokens)
284
+
285
+ # We DON'T just concatenate all of the tokens from a document into a long
286
+ # sequence and choose an arbitrary split point because this would make the
287
+ # next sentence prediction task too easy. Instead, we split the input into
288
+ # segments "A" and "B" based on the actual "sentences" provided by the user
289
+ # input.
290
+ instances = []
291
+ current_chunk = []
292
+ current_length = 0
293
+ i = 0
294
+ while i < len(document):
295
+ segment = document[i]
296
+ current_chunk.append(segment)
297
+ current_length += len(segment)
298
+ if i == len(document) - 1 or current_length >= target_seq_length:
299
+ if current_chunk:
300
+ # `a_end` is how many segments from `current_chunk` go into the `A`
301
+ # (first) sentence.
302
+ a_end = 1
303
+ if len(current_chunk) >= 2:
304
+ a_end = rng.randint(1, len(current_chunk) - 1)
305
+
306
+ tokens_a = []
307
+ for j in range(a_end):
308
+ tokens_a.extend(current_chunk[j])
309
+
310
+ tokens_b = []
311
+ # Random next
312
+ is_random_next = False
313
+ if len(current_chunk) == 1 or rng.random() < 0.5:
314
+ is_random_next = True
315
+ target_b_length = target_seq_length - len(tokens_a)
316
+
317
+ # This should rarely go for more than one iteration for large
318
+ # corpora. However, just to be careful, we try to make sure that
319
+ # the random document is not the same as the document
320
+ # we're processing.
321
+ for _ in range(10):
322
+ random_document_index = rng.randint(0, len(all_documents) - 1)
323
+ if random_document_index != document_index:
324
+ break
325
+
326
+ random_document = all_documents[random_document_index]
327
+ random_start = rng.randint(0, len(random_document) - 1)
328
+ for j in range(random_start, len(random_document)):
329
+ tokens_b.extend(random_document[j])
330
+ if len(tokens_b) >= target_b_length:
331
+ break
332
+ # We didn't actually use these segments so we "put them back" so
333
+ # they don't go to waste.
334
+ num_unused_segments = len(current_chunk) - a_end
335
+ i -= num_unused_segments
336
+ # Actual next
337
+ else:
338
+ is_random_next = False
339
+ for j in range(a_end, len(current_chunk)):
340
+ tokens_b.extend(current_chunk[j])
341
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
342
+
343
+ assert len(tokens_a) >= 1
344
+ assert len(tokens_b) >= 1
345
+
346
+ tokens = []
347
+ segment_ids = []
348
+ tokens.append("[CLS]")
349
+ segment_ids.append(0)
350
+ for token in tokens_a:
351
+ tokens.append(token)
352
+ segment_ids.append(0)
353
+
354
+ tokens.append("[SEP]")
355
+ segment_ids.append(0)
356
+
357
+ for token in tokens_b:
358
+ tokens.append(token)
359
+ segment_ids.append(1)
360
+ tokens.append("[SEP]")
361
+ segment_ids.append(1)
362
+
363
+ (
364
+ tokens,
365
+ masked_lm_positions,
366
+ masked_lm_labels,
367
+ ) = create_masked_lm_predictions(
368
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
369
+ )
370
+ instance = TrainingInstance(
371
+ tokens=tokens,
372
+ segment_ids=segment_ids,
373
+ is_random_next=is_random_next,
374
+ masked_lm_positions=masked_lm_positions,
375
+ masked_lm_labels=masked_lm_labels,
376
+ )
377
+ instances.append(instance)
378
+ current_chunk = []
379
+ current_length = 0
380
+ i += 1
381
+
382
+ return instances
383
+
384
+
385
+ MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
386
+
387
+
388
+ def create_masked_lm_predictions(
389
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
390
+ ):
391
+ """Creates the predictions for the masked LM objective."""
392
+
393
+ cand_indexes = []
394
+ for (i, token) in enumerate(tokens):
395
+ if token == "[CLS]" or token == "[SEP]":
396
+ continue
397
+ # Whole Word Masking means that if we mask all of the wordpieces
398
+ # corresponding to an original word. When a word has been split into
399
+ # WordPieces, the first token does not have any marker and any subsequence
400
+ # tokens are prefixed with ##. So whenever we see the ## token, we
401
+ # append it to the previous set of word indexes.
402
+ #
403
+ # Note that Whole Word Masking does *not* change the training code
404
+ # at all -- we still predict each WordPiece independently, softmaxed
405
+ # over the entire vocabulary.
406
+ if (
407
+ FLAGS.do_whole_word_mask
408
+ and len(cand_indexes) >= 1
409
+ and token.startswith("##")
410
+ ):
411
+ cand_indexes[-1].append(i)
412
+ else:
413
+ cand_indexes.append([i])
414
+
415
+ rng.shuffle(cand_indexes)
416
+
417
+ output_tokens = list(tokens)
418
+
419
+ num_to_predict = min(
420
+ max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
421
+ )
422
+
423
+ masked_lms = []
424
+ covered_indexes = set()
425
+ for index_set in cand_indexes:
426
+ if len(masked_lms) >= num_to_predict:
427
+ break
428
+ # If adding a whole-word mask would exceed the maximum number of
429
+ # predictions, then just skip this candidate.
430
+ if len(masked_lms) + len(index_set) > num_to_predict:
431
+ continue
432
+ is_any_index_covered = False
433
+ for index in index_set:
434
+ if index in covered_indexes:
435
+ is_any_index_covered = True
436
+ break
437
+ if is_any_index_covered:
438
+ continue
439
+ for index in index_set:
440
+ covered_indexes.add(index)
441
+
442
+ masked_token = None
443
+ # 80% of the time, replace with [MASK]
444
+ if rng.random() < 0.8:
445
+ masked_token = "[MASK]"
446
+ else:
447
+ # 10% of the time, keep original
448
+ if rng.random() < 0.5:
449
+ masked_token = tokens[index]
450
+ # 10% of the time, replace with random word
451
+ else:
452
+ masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
453
+
454
+ output_tokens[index] = masked_token
455
+
456
+ masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
457
+ assert len(masked_lms) <= num_to_predict
458
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
459
+
460
+ masked_lm_positions = []
461
+ masked_lm_labels = []
462
+ for p in masked_lms:
463
+ masked_lm_positions.append(p.index)
464
+ masked_lm_labels.append(p.label)
465
+
466
+ return (output_tokens, masked_lm_positions, masked_lm_labels)
467
+
468
+
469
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
470
+ """Truncates a pair of sequences to a maximum sequence length."""
471
+ while True:
472
+ total_length = len(tokens_a) + len(tokens_b)
473
+ if total_length <= max_num_tokens:
474
+ break
475
+
476
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
477
+ assert len(trunc_tokens) >= 1
478
+
479
+ # We want to sometimes truncate from the front and sometimes from the
480
+ # back to add more randomness and avoid biases.
481
+ if rng.random() < 0.5:
482
+ del trunc_tokens[0]
483
+ else:
484
+ trunc_tokens.pop()
485
+
486
+
487
+ def main(_):
488
+ tf.logging.set_verbosity(tf.logging.INFO)
489
+ logger = tf.get_logger()
490
+ logger.propagate = False
491
+
492
+ tokenizer = tokenization.FullTokenizer(
493
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
494
+ )
495
+
496
+ input_files = []
497
+ for input_pattern in FLAGS.input_file.split(","):
498
+ input_files.extend(tf.gfile.Glob(input_pattern))
499
+
500
+ tf.logging.info("*** Reading from input files ***")
501
+ for input_file in input_files:
502
+ tf.logging.info(" %s", input_file)
503
+
504
+ rng = random.Random(FLAGS.random_seed)
505
+ instances = create_training_instances(
506
+ input_files,
507
+ tokenizer,
508
+ FLAGS.max_seq_length,
509
+ FLAGS.dupe_factor,
510
+ FLAGS.short_seq_prob,
511
+ FLAGS.masked_lm_prob,
512
+ FLAGS.max_predictions_per_seq,
513
+ rng,
514
+ )
515
+
516
+ output_files = FLAGS.output_file.split(",")
517
+ tf.logging.info("*** Writing to output files ***")
518
+ for output_file in output_files:
519
+ tf.logging.info(" %s", output_file)
520
+
521
+ write_instance_to_example_files(
522
+ instances,
523
+ tokenizer,
524
+ FLAGS.max_seq_length,
525
+ FLAGS.max_predictions_per_seq,
526
+ output_files,
527
+ )
528
+
529
+
530
+ if __name__ == "__main__":
531
+ flags.mark_flag_as_required("input_file")
532
+ flags.mark_flag_as_required("output_file")
533
+ flags.mark_flag_as_required("vocab_file")
534
+ tf.app.run()
arabert/arabert/extract_features.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Extract pre-computed feature vectors from BERT."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import codecs
22
+ import collections
23
+ import json
24
+ import re
25
+
26
+ import modeling
27
+ import tokenization
28
+ import tensorflow as tf
29
+
30
+ flags = tf.flags
31
+
32
+ FLAGS = flags.FLAGS
33
+
34
+ flags.DEFINE_string("input_file", None, "")
35
+
36
+ flags.DEFINE_string("output_file", None, "")
37
+
38
+ flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
39
+
40
+ flags.DEFINE_string(
41
+ "bert_config_file",
42
+ None,
43
+ "The config json file corresponding to the pre-trained BERT model. "
44
+ "This specifies the model architecture.",
45
+ )
46
+
47
+ flags.DEFINE_integer(
48
+ "max_seq_length",
49
+ 128,
50
+ "The maximum total input sequence length after WordPiece tokenization. "
51
+ "Sequences longer than this will be truncated, and sequences shorter "
52
+ "than this will be padded.",
53
+ )
54
+
55
+ flags.DEFINE_string(
56
+ "init_checkpoint",
57
+ None,
58
+ "Initial checkpoint (usually from a pre-trained BERT model).",
59
+ )
60
+
61
+ flags.DEFINE_string(
62
+ "vocab_file", None, "The vocabulary file that the BERT model was trained on."
63
+ )
64
+
65
+ flags.DEFINE_bool(
66
+ "do_lower_case",
67
+ True,
68
+ "Whether to lower case the input text. Should be True for uncased "
69
+ "models and False for cased models.",
70
+ )
71
+
72
+ flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
73
+
74
+ flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
75
+
76
+ flags.DEFINE_string("master", None, "If using a TPU, the address of the master.")
77
+
78
+ flags.DEFINE_integer(
79
+ "num_tpu_cores",
80
+ 8,
81
+ "Only used if `use_tpu` is True. Total number of TPU cores to use.",
82
+ )
83
+
84
+ flags.DEFINE_bool(
85
+ "use_one_hot_embeddings",
86
+ False,
87
+ "If True, tf.one_hot will be used for embedding lookups, otherwise "
88
+ "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
89
+ "since it is much faster.",
90
+ )
91
+
92
+
93
+ class InputExample(object):
94
+ def __init__(self, unique_id, text_a, text_b):
95
+ self.unique_id = unique_id
96
+ self.text_a = text_a
97
+ self.text_b = text_b
98
+
99
+
100
+ class InputFeatures(object):
101
+ """A single set of features of data."""
102
+
103
+ def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
104
+ self.unique_id = unique_id
105
+ self.tokens = tokens
106
+ self.input_ids = input_ids
107
+ self.input_mask = input_mask
108
+ self.input_type_ids = input_type_ids
109
+
110
+
111
+ def input_fn_builder(features, seq_length):
112
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
113
+
114
+ all_unique_ids = []
115
+ all_input_ids = []
116
+ all_input_mask = []
117
+ all_input_type_ids = []
118
+
119
+ for feature in features:
120
+ all_unique_ids.append(feature.unique_id)
121
+ all_input_ids.append(feature.input_ids)
122
+ all_input_mask.append(feature.input_mask)
123
+ all_input_type_ids.append(feature.input_type_ids)
124
+
125
+ def input_fn(params):
126
+ """The actual input function."""
127
+ batch_size = params["batch_size"]
128
+
129
+ num_examples = len(features)
130
+
131
+ # This is for demo purposes and does NOT scale to large data sets. We do
132
+ # not use Dataset.from_generator() because that uses tf.py_func which is
133
+ # not TPU compatible. The right way to load data is with TFRecordReader.
134
+ d = tf.data.Dataset.from_tensor_slices(
135
+ {
136
+ "unique_ids": tf.constant(
137
+ all_unique_ids, shape=[num_examples], dtype=tf.int32
138
+ ),
139
+ "input_ids": tf.constant(
140
+ all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32
141
+ ),
142
+ "input_mask": tf.constant(
143
+ all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32
144
+ ),
145
+ "input_type_ids": tf.constant(
146
+ all_input_type_ids, shape=[num_examples, seq_length], dtype=tf.int32
147
+ ),
148
+ }
149
+ )
150
+
151
+ d = d.batch(batch_size=batch_size, drop_remainder=False)
152
+ return d
153
+
154
+ return input_fn
155
+
156
+
157
+ def model_fn_builder(
158
+ bert_config, init_checkpoint, layer_indexes, use_tpu, use_one_hot_embeddings
159
+ ):
160
+ """Returns `model_fn` closure for TPUEstimator."""
161
+
162
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
163
+ """The `model_fn` for TPUEstimator."""
164
+
165
+ unique_ids = features["unique_ids"]
166
+ input_ids = features["input_ids"]
167
+ input_mask = features["input_mask"]
168
+ input_type_ids = features["input_type_ids"]
169
+
170
+ model = modeling.BertModel(
171
+ config=bert_config,
172
+ is_training=False,
173
+ input_ids=input_ids,
174
+ input_mask=input_mask,
175
+ token_type_ids=input_type_ids,
176
+ use_one_hot_embeddings=use_one_hot_embeddings,
177
+ )
178
+
179
+ if mode != tf.estimator.ModeKeys.PREDICT:
180
+ raise ValueError("Only PREDICT modes are supported: %s" % (mode))
181
+
182
+ tvars = tf.trainable_variables()
183
+ scaffold_fn = None
184
+ (
185
+ assignment_map,
186
+ initialized_variable_names,
187
+ ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
188
+ if use_tpu:
189
+
190
+ def tpu_scaffold():
191
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
192
+ return tf.train.Scaffold()
193
+
194
+ scaffold_fn = tpu_scaffold
195
+ else:
196
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
197
+
198
+ tf.logging.info("**** Trainable Variables ****")
199
+ for var in tvars:
200
+ init_string = ""
201
+ if var.name in initialized_variable_names:
202
+ init_string = ", *INIT_FROM_CKPT*"
203
+ tf.logging.info(
204
+ " name = %s, shape = %s%s", var.name, var.shape, init_string
205
+ )
206
+
207
+ all_layers = model.get_all_encoder_layers()
208
+
209
+ predictions = {
210
+ "unique_id": unique_ids,
211
+ }
212
+
213
+ for (i, layer_index) in enumerate(layer_indexes):
214
+ predictions["layer_output_%d" % i] = all_layers[layer_index]
215
+
216
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
217
+ mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
218
+ )
219
+ return output_spec
220
+
221
+ return model_fn
222
+
223
+
224
+ def convert_examples_to_features(examples, seq_length, tokenizer):
225
+ """Loads a data file into a list of `InputBatch`s."""
226
+
227
+ features = []
228
+ for (ex_index, example) in enumerate(examples):
229
+ tokens_a = tokenizer.tokenize(example.text_a)
230
+
231
+ tokens_b = None
232
+ if example.text_b:
233
+ tokens_b = tokenizer.tokenize(example.text_b)
234
+
235
+ if tokens_b:
236
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
237
+ # length is less than the specified length.
238
+ # Account for [CLS], [SEP], [SEP] with "- 3"
239
+ _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
240
+ else:
241
+ # Account for [CLS] and [SEP] with "- 2"
242
+ if len(tokens_a) > seq_length - 2:
243
+ tokens_a = tokens_a[0 : (seq_length - 2)]
244
+
245
+ # The convention in BERT is:
246
+ # (a) For sequence pairs:
247
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
248
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
249
+ # (b) For single sequences:
250
+ # tokens: [CLS] the dog is hairy . [SEP]
251
+ # type_ids: 0 0 0 0 0 0 0
252
+ #
253
+ # Where "type_ids" are used to indicate whether this is the first
254
+ # sequence or the second sequence. The embedding vectors for `type=0` and
255
+ # `type=1` were learned during pre-training and are added to the wordpiece
256
+ # embedding vector (and position vector). This is not *strictly* necessary
257
+ # since the [SEP] token unambiguously separates the sequences, but it makes
258
+ # it easier for the model to learn the concept of sequences.
259
+ #
260
+ # For classification tasks, the first vector (corresponding to [CLS]) is
261
+ # used as as the "sentence vector". Note that this only makes sense because
262
+ # the entire model is fine-tuned.
263
+ tokens = []
264
+ input_type_ids = []
265
+ tokens.append("[CLS]")
266
+ input_type_ids.append(0)
267
+ for token in tokens_a:
268
+ tokens.append(token)
269
+ input_type_ids.append(0)
270
+ tokens.append("[SEP]")
271
+ input_type_ids.append(0)
272
+
273
+ if tokens_b:
274
+ for token in tokens_b:
275
+ tokens.append(token)
276
+ input_type_ids.append(1)
277
+ tokens.append("[SEP]")
278
+ input_type_ids.append(1)
279
+
280
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
281
+
282
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
283
+ # tokens are attended to.
284
+ input_mask = [1] * len(input_ids)
285
+
286
+ # Zero-pad up to the sequence length.
287
+ while len(input_ids) < seq_length:
288
+ input_ids.append(0)
289
+ input_mask.append(0)
290
+ input_type_ids.append(0)
291
+
292
+ assert len(input_ids) == seq_length
293
+ assert len(input_mask) == seq_length
294
+ assert len(input_type_ids) == seq_length
295
+
296
+ if ex_index < 5:
297
+ tf.logging.info("*** Example ***")
298
+ tf.logging.info("unique_id: %s" % (example.unique_id))
299
+ tf.logging.info(
300
+ "tokens: %s"
301
+ % " ".join([tokenization.printable_text(x) for x in tokens])
302
+ )
303
+ tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
304
+ tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
305
+ tf.logging.info(
306
+ "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])
307
+ )
308
+
309
+ features.append(
310
+ InputFeatures(
311
+ unique_id=example.unique_id,
312
+ tokens=tokens,
313
+ input_ids=input_ids,
314
+ input_mask=input_mask,
315
+ input_type_ids=input_type_ids,
316
+ )
317
+ )
318
+ return features
319
+
320
+
321
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
322
+ """Truncates a sequence pair in place to the maximum length."""
323
+
324
+ # This is a simple heuristic which will always truncate the longer sequence
325
+ # one token at a time. This makes more sense than truncating an equal percent
326
+ # of tokens from each, since if one sequence is very short then each token
327
+ # that's truncated likely contains more information than a longer sequence.
328
+ while True:
329
+ total_length = len(tokens_a) + len(tokens_b)
330
+ if total_length <= max_length:
331
+ break
332
+ if len(tokens_a) > len(tokens_b):
333
+ tokens_a.pop()
334
+ else:
335
+ tokens_b.pop()
336
+
337
+
338
+ def read_examples(input_file):
339
+ """Read a list of `InputExample`s from an input file."""
340
+ examples = []
341
+ unique_id = 0
342
+ with tf.gfile.GFile(input_file, "r") as reader:
343
+ while True:
344
+ line = tokenization.convert_to_unicode(reader.readline())
345
+ if not line:
346
+ break
347
+ line = line.strip()
348
+ text_a = None
349
+ text_b = None
350
+ m = re.match(r"^(.*) \|\|\| (.*)$", line)
351
+ if m is None:
352
+ text_a = line
353
+ else:
354
+ text_a = m.group(1)
355
+ text_b = m.group(2)
356
+ examples.append(
357
+ InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)
358
+ )
359
+ unique_id += 1
360
+ return examples
361
+
362
+
363
+ def main(_):
364
+ tf.logging.set_verbosity(tf.logging.INFO)
365
+ logger = tf.get_logger()
366
+ logger.propagate = False
367
+
368
+ layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
369
+
370
+ bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
371
+
372
+ tokenizer = tokenization.FullTokenizer(
373
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
374
+ )
375
+
376
+ is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
377
+ run_config = tf.contrib.tpu.RunConfig(
378
+ master=FLAGS.master,
379
+ tpu_config=tf.contrib.tpu.TPUConfig(
380
+ num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host
381
+ ),
382
+ )
383
+
384
+ examples = read_examples(FLAGS.input_file)
385
+
386
+ features = convert_examples_to_features(
387
+ examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer
388
+ )
389
+
390
+ unique_id_to_feature = {}
391
+ for feature in features:
392
+ unique_id_to_feature[feature.unique_id] = feature
393
+
394
+ model_fn = model_fn_builder(
395
+ bert_config=bert_config,
396
+ init_checkpoint=FLAGS.init_checkpoint,
397
+ layer_indexes=layer_indexes,
398
+ use_tpu=FLAGS.use_tpu,
399
+ use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
400
+ )
401
+
402
+ # If TPU is not available, this will fall back to normal Estimator on CPU
403
+ # or GPU.
404
+ estimator = tf.contrib.tpu.TPUEstimator(
405
+ use_tpu=FLAGS.use_tpu,
406
+ model_fn=model_fn,
407
+ config=run_config,
408
+ predict_batch_size=FLAGS.batch_size,
409
+ )
410
+
411
+ input_fn = input_fn_builder(features=features, seq_length=FLAGS.max_seq_length)
412
+
413
+ with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, "w")) as writer:
414
+ for result in estimator.predict(input_fn, yield_single_examples=True):
415
+ unique_id = int(result["unique_id"])
416
+ feature = unique_id_to_feature[unique_id]
417
+ output_json = collections.OrderedDict()
418
+ output_json["linex_index"] = unique_id
419
+ all_features = []
420
+ for (i, token) in enumerate(feature.tokens):
421
+ all_layers = []
422
+ for (j, layer_index) in enumerate(layer_indexes):
423
+ layer_output = result["layer_output_%d" % j]
424
+ layers = collections.OrderedDict()
425
+ layers["index"] = layer_index
426
+ layers["values"] = [
427
+ round(float(x), 6) for x in layer_output[i : (i + 1)].flat
428
+ ]
429
+ all_layers.append(layers)
430
+ features = collections.OrderedDict()
431
+ features["token"] = token
432
+ features["layers"] = all_layers
433
+ all_features.append(features)
434
+ output_json["features"] = all_features
435
+ writer.write(json.dumps(output_json) + "\n")
436
+
437
+
438
+ if __name__ == "__main__":
439
+ flags.mark_flag_as_required("input_file")
440
+ flags.mark_flag_as_required("vocab_file")
441
+ flags.mark_flag_as_required("bert_config_file")
442
+ flags.mark_flag_as_required("init_checkpoint")
443
+ flags.mark_flag_as_required("output_file")
444
+ tf.app.run()
arabert/arabert/lamb_optimizer.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python2, python3
17
+ """Functions and classes related to optimization (weight updates)."""
18
+
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+ from __future__ import print_function
22
+
23
+ import re
24
+ import six
25
+ import tensorflow as tf
26
+
27
+ # pylint: disable=g-direct-tensorflow-import
28
+ from tensorflow.python.ops import array_ops
29
+ from tensorflow.python.ops import linalg_ops
30
+ from tensorflow.python.ops import math_ops
31
+
32
+ # pylint: enable=g-direct-tensorflow-import
33
+
34
+
35
+ class LAMBOptimizer(tf.train.Optimizer):
36
+ """LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
37
+
38
+ # A new optimizer that includes correct L2 weight decay, adaptive
39
+ # element-wise updating, and layer-wise justification. The LAMB optimizer
40
+ # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
41
+ # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
42
+ # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
43
+
44
+ def __init__(
45
+ self,
46
+ learning_rate,
47
+ weight_decay_rate=0.0,
48
+ beta_1=0.9,
49
+ beta_2=0.999,
50
+ epsilon=1e-6,
51
+ exclude_from_weight_decay=None,
52
+ exclude_from_layer_adaptation=None,
53
+ name="LAMBOptimizer",
54
+ ):
55
+ """Constructs a LAMBOptimizer."""
56
+ super(LAMBOptimizer, self).__init__(False, name)
57
+
58
+ self.learning_rate = learning_rate
59
+ self.weight_decay_rate = weight_decay_rate
60
+ self.beta_1 = beta_1
61
+ self.beta_2 = beta_2
62
+ self.epsilon = epsilon
63
+ self.exclude_from_weight_decay = exclude_from_weight_decay
64
+ # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
65
+ # arg is None.
66
+ # TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
67
+ if exclude_from_layer_adaptation:
68
+ self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
69
+ else:
70
+ self.exclude_from_layer_adaptation = exclude_from_weight_decay
71
+
72
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
73
+ """See base class."""
74
+ assignments = []
75
+ for (grad, param) in grads_and_vars:
76
+ if grad is None or param is None:
77
+ continue
78
+
79
+ param_name = self._get_variable_name(param.name)
80
+
81
+ m = tf.get_variable(
82
+ name=six.ensure_str(param_name) + "/adam_m",
83
+ shape=param.shape.as_list(),
84
+ dtype=tf.float32,
85
+ trainable=False,
86
+ initializer=tf.zeros_initializer(),
87
+ )
88
+ v = tf.get_variable(
89
+ name=six.ensure_str(param_name) + "/adam_v",
90
+ shape=param.shape.as_list(),
91
+ dtype=tf.float32,
92
+ trainable=False,
93
+ initializer=tf.zeros_initializer(),
94
+ )
95
+
96
+ # Standard Adam update.
97
+ next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)
98
+ next_v = tf.multiply(self.beta_2, v) + tf.multiply(
99
+ 1.0 - self.beta_2, tf.square(grad)
100
+ )
101
+
102
+ update = next_m / (tf.sqrt(next_v) + self.epsilon)
103
+
104
+ # Just adding the square of the weights to the loss function is *not*
105
+ # the correct way of using L2 regularization/weight decay with Adam,
106
+ # since that will interact with the m and v parameters in strange ways.
107
+ #
108
+ # Instead we want ot decay the weights in a manner that doesn't interact
109
+ # with the m/v parameters. This is equivalent to adding the square
110
+ # of the weights to the loss with plain (non-momentum) SGD.
111
+ if self._do_use_weight_decay(param_name):
112
+ update += self.weight_decay_rate * param
113
+
114
+ ratio = 1.0
115
+ if self._do_layer_adaptation(param_name):
116
+ w_norm = linalg_ops.norm(param, ord=2)
117
+ g_norm = linalg_ops.norm(update, ord=2)
118
+ ratio = array_ops.where(
119
+ math_ops.greater(w_norm, 0),
120
+ array_ops.where(
121
+ math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0
122
+ ),
123
+ 1.0,
124
+ )
125
+
126
+ update_with_lr = ratio * self.learning_rate * update
127
+
128
+ next_param = param - update_with_lr
129
+
130
+ assignments.extend(
131
+ [param.assign(next_param), m.assign(next_m), v.assign(next_v)]
132
+ )
133
+ return tf.group(*assignments, name=name)
134
+
135
+ def _do_use_weight_decay(self, param_name):
136
+ """Whether to use L2 weight decay for `param_name`."""
137
+ if not self.weight_decay_rate:
138
+ return False
139
+ if self.exclude_from_weight_decay:
140
+ for r in self.exclude_from_weight_decay:
141
+ if re.search(r, param_name) is not None:
142
+ return False
143
+ return True
144
+
145
+ def _do_layer_adaptation(self, param_name):
146
+ """Whether to do layer-wise learning rate adaptation for `param_name`."""
147
+ if self.exclude_from_layer_adaptation:
148
+ for r in self.exclude_from_layer_adaptation:
149
+ if re.search(r, param_name) is not None:
150
+ return False
151
+ return True
152
+
153
+ def _get_variable_name(self, param_name):
154
+ """Get the variable name from the tensor name."""
155
+ m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
156
+ if m is not None:
157
+ param_name = m.group(1)
158
+ return param_name
arabert/arabert/modeling.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """The main BERT model and related functions."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import collections
22
+ import copy
23
+ import json
24
+ import math
25
+ import re
26
+ import numpy as np
27
+ import six
28
+ import tensorflow as tf
29
+
30
+
31
+ class BertConfig(object):
32
+ """Configuration for `BertModel`."""
33
+
34
+ def __init__(
35
+ self,
36
+ vocab_size,
37
+ hidden_size=768,
38
+ num_hidden_layers=12,
39
+ num_attention_heads=12,
40
+ intermediate_size=3072,
41
+ hidden_act="gelu",
42
+ hidden_dropout_prob=0.1,
43
+ attention_probs_dropout_prob=0.1,
44
+ max_position_embeddings=512,
45
+ type_vocab_size=16,
46
+ initializer_range=0.02,
47
+ ):
48
+ """Constructs BertConfig.
49
+
50
+ Args:
51
+ vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
52
+ hidden_size: Size of the encoder layers and the pooler layer.
53
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
54
+ num_attention_heads: Number of attention heads for each attention layer in
55
+ the Transformer encoder.
56
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
57
+ layer in the Transformer encoder.
58
+ hidden_act: The non-linear activation function (function or string) in the
59
+ encoder and pooler.
60
+ hidden_dropout_prob: The dropout probability for all fully connected
61
+ layers in the embeddings, encoder, and pooler.
62
+ attention_probs_dropout_prob: The dropout ratio for the attention
63
+ probabilities.
64
+ max_position_embeddings: The maximum sequence length that this model might
65
+ ever be used with. Typically set this to something large just in case
66
+ (e.g., 512 or 1024 or 2048).
67
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
68
+ `BertModel`.
69
+ initializer_range: The stdev of the truncated_normal_initializer for
70
+ initializing all weight matrices.
71
+ """
72
+ self.vocab_size = vocab_size
73
+ self.hidden_size = hidden_size
74
+ self.num_hidden_layers = num_hidden_layers
75
+ self.num_attention_heads = num_attention_heads
76
+ self.hidden_act = hidden_act
77
+ self.intermediate_size = intermediate_size
78
+ self.hidden_dropout_prob = hidden_dropout_prob
79
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
80
+ self.max_position_embeddings = max_position_embeddings
81
+ self.type_vocab_size = type_vocab_size
82
+ self.initializer_range = initializer_range
83
+
84
+ @classmethod
85
+ def from_dict(cls, json_object):
86
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
87
+ config = BertConfig(vocab_size=None)
88
+ for (key, value) in six.iteritems(json_object):
89
+ config.__dict__[key] = value
90
+ return config
91
+
92
+ @classmethod
93
+ def from_json_file(cls, json_file):
94
+ """Constructs a `BertConfig` from a json file of parameters."""
95
+ with tf.gfile.GFile(json_file, "r") as reader:
96
+ text = reader.read()
97
+ return cls.from_dict(json.loads(text))
98
+
99
+ def to_dict(self):
100
+ """Serializes this instance to a Python dictionary."""
101
+ output = copy.deepcopy(self.__dict__)
102
+ return output
103
+
104
+ def to_json_string(self):
105
+ """Serializes this instance to a JSON string."""
106
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
107
+
108
+
109
+ class BertModel(object):
110
+ """BERT model ("Bidirectional Encoder Representations from Transformers").
111
+
112
+ Example usage:
113
+
114
+ ```python
115
+ # Already been converted into WordPiece token ids
116
+ input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
117
+ input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
118
+ token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
119
+
120
+ config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
121
+ num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
122
+
123
+ model = modeling.BertModel(config=config, is_training=True,
124
+ input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
125
+
126
+ label_embeddings = tf.get_variable(...)
127
+ pooled_output = model.get_pooled_output()
128
+ logits = tf.matmul(pooled_output, label_embeddings)
129
+ ...
130
+ ```
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ config,
136
+ is_training,
137
+ input_ids,
138
+ input_mask=None,
139
+ token_type_ids=None,
140
+ use_one_hot_embeddings=False,
141
+ scope=None,
142
+ ):
143
+ """Constructor for BertModel.
144
+
145
+ Args:
146
+ config: `BertConfig` instance.
147
+ is_training: bool. true for training model, false for eval model. Controls
148
+ whether dropout will be applied.
149
+ input_ids: int32 Tensor of shape [batch_size, seq_length].
150
+ input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
151
+ token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
152
+ use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
153
+ embeddings or tf.embedding_lookup() for the word embeddings.
154
+ scope: (optional) variable scope. Defaults to "bert".
155
+
156
+ Raises:
157
+ ValueError: The config is invalid or one of the input tensor shapes
158
+ is invalid.
159
+ """
160
+ config = copy.deepcopy(config)
161
+ if not is_training:
162
+ config.hidden_dropout_prob = 0.0
163
+ config.attention_probs_dropout_prob = 0.0
164
+
165
+ input_shape = get_shape_list(input_ids, expected_rank=2)
166
+ batch_size = input_shape[0]
167
+ seq_length = input_shape[1]
168
+
169
+ if input_mask is None:
170
+ input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
171
+
172
+ if token_type_ids is None:
173
+ token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
174
+
175
+ with tf.variable_scope(scope, default_name="bert"):
176
+ with tf.variable_scope("embeddings"):
177
+ # Perform embedding lookup on the word ids.
178
+ (self.embedding_output, self.embedding_table) = embedding_lookup(
179
+ input_ids=input_ids,
180
+ vocab_size=config.vocab_size,
181
+ embedding_size=config.hidden_size,
182
+ initializer_range=config.initializer_range,
183
+ word_embedding_name="word_embeddings",
184
+ use_one_hot_embeddings=use_one_hot_embeddings,
185
+ )
186
+
187
+ # Add positional embeddings and token type embeddings, then layer
188
+ # normalize and perform dropout.
189
+ self.embedding_output = embedding_postprocessor(
190
+ input_tensor=self.embedding_output,
191
+ use_token_type=True,
192
+ token_type_ids=token_type_ids,
193
+ token_type_vocab_size=config.type_vocab_size,
194
+ token_type_embedding_name="token_type_embeddings",
195
+ use_position_embeddings=True,
196
+ position_embedding_name="position_embeddings",
197
+ initializer_range=config.initializer_range,
198
+ max_position_embeddings=config.max_position_embeddings,
199
+ dropout_prob=config.hidden_dropout_prob,
200
+ )
201
+
202
+ with tf.variable_scope("encoder"):
203
+ # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
204
+ # mask of shape [batch_size, seq_length, seq_length] which is used
205
+ # for the attention scores.
206
+ attention_mask = create_attention_mask_from_input_mask(
207
+ input_ids, input_mask
208
+ )
209
+
210
+ # Run the stacked transformer.
211
+ # `sequence_output` shape = [batch_size, seq_length, hidden_size].
212
+ self.all_encoder_layers = transformer_model(
213
+ input_tensor=self.embedding_output,
214
+ attention_mask=attention_mask,
215
+ hidden_size=config.hidden_size,
216
+ num_hidden_layers=config.num_hidden_layers,
217
+ num_attention_heads=config.num_attention_heads,
218
+ intermediate_size=config.intermediate_size,
219
+ intermediate_act_fn=get_activation(config.hidden_act),
220
+ hidden_dropout_prob=config.hidden_dropout_prob,
221
+ attention_probs_dropout_prob=config.attention_probs_dropout_prob,
222
+ initializer_range=config.initializer_range,
223
+ do_return_all_layers=True,
224
+ )
225
+
226
+ self.sequence_output = self.all_encoder_layers[-1]
227
+ # The "pooler" converts the encoded sequence tensor of shape
228
+ # [batch_size, seq_length, hidden_size] to a tensor of shape
229
+ # [batch_size, hidden_size]. This is necessary for segment-level
230
+ # (or segment-pair-level) classification tasks where we need a fixed
231
+ # dimensional representation of the segment.
232
+ with tf.variable_scope("pooler"):
233
+ # We "pool" the model by simply taking the hidden state corresponding
234
+ # to the first token. We assume that this has been pre-trained
235
+ first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
236
+ self.pooled_output = tf.layers.dense(
237
+ first_token_tensor,
238
+ config.hidden_size,
239
+ activation=tf.tanh,
240
+ kernel_initializer=create_initializer(config.initializer_range),
241
+ )
242
+
243
+ def get_pooled_output(self):
244
+ return self.pooled_output
245
+
246
+ def get_sequence_output(self):
247
+ """Gets final hidden layer of encoder.
248
+
249
+ Returns:
250
+ float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
251
+ to the final hidden of the transformer encoder.
252
+ """
253
+ return self.sequence_output
254
+
255
+ def get_all_encoder_layers(self):
256
+ return self.all_encoder_layers
257
+
258
+ def get_embedding_output(self):
259
+ """Gets output of the embedding lookup (i.e., input to the transformer).
260
+
261
+ Returns:
262
+ float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
263
+ to the output of the embedding layer, after summing the word
264
+ embeddings with the positional embeddings and the token type embeddings,
265
+ then performing layer normalization. This is the input to the transformer.
266
+ """
267
+ return self.embedding_output
268
+
269
+ def get_embedding_table(self):
270
+ return self.embedding_table
271
+
272
+
273
+ def gelu(x):
274
+ """Gaussian Error Linear Unit.
275
+
276
+ This is a smoother version of the RELU.
277
+ Original paper: https://arxiv.org/abs/1606.08415
278
+ Args:
279
+ x: float Tensor to perform activation.
280
+
281
+ Returns:
282
+ `x` with the GELU activation applied.
283
+ """
284
+ cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
285
+ return x * cdf
286
+
287
+
288
+ def get_activation(activation_string):
289
+ """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
290
+
291
+ Args:
292
+ activation_string: String name of the activation function.
293
+
294
+ Returns:
295
+ A Python function corresponding to the activation function. If
296
+ `activation_string` is None, empty, or "linear", this will return None.
297
+ If `activation_string` is not a string, it will return `activation_string`.
298
+
299
+ Raises:
300
+ ValueError: The `activation_string` does not correspond to a known
301
+ activation.
302
+ """
303
+
304
+ # We assume that anything that"s not a string is already an activation
305
+ # function, so we just return it.
306
+ if not isinstance(activation_string, six.string_types):
307
+ return activation_string
308
+
309
+ if not activation_string:
310
+ return None
311
+
312
+ act = activation_string.lower()
313
+ if act == "linear":
314
+ return None
315
+ elif act == "relu":
316
+ return tf.nn.relu
317
+ elif act == "gelu":
318
+ return gelu
319
+ elif act == "tanh":
320
+ return tf.tanh
321
+ else:
322
+ raise ValueError("Unsupported activation: %s" % act)
323
+
324
+
325
+ def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
326
+ """Compute the union of the current variables and checkpoint variables."""
327
+ assignment_map = {}
328
+ initialized_variable_names = {}
329
+
330
+ name_to_variable = collections.OrderedDict()
331
+ for var in tvars:
332
+ name = var.name
333
+ m = re.match("^(.*):\\d+$", name)
334
+ if m is not None:
335
+ name = m.group(1)
336
+ name_to_variable[name] = var
337
+
338
+ init_vars = tf.train.list_variables(init_checkpoint)
339
+
340
+ assignment_map = collections.OrderedDict()
341
+ for x in init_vars:
342
+ (name, var) = (x[0], x[1])
343
+ if name not in name_to_variable:
344
+ continue
345
+ assignment_map[name] = name
346
+ initialized_variable_names[name] = 1
347
+ initialized_variable_names[name + ":0"] = 1
348
+
349
+ return (assignment_map, initialized_variable_names)
350
+
351
+
352
+ def dropout(input_tensor, dropout_prob):
353
+ """Perform dropout.
354
+
355
+ Args:
356
+ input_tensor: float Tensor.
357
+ dropout_prob: Python float. The probability of dropping out a value (NOT of
358
+ *keeping* a dimension as in `tf.nn.dropout`).
359
+
360
+ Returns:
361
+ A version of `input_tensor` with dropout applied.
362
+ """
363
+ if dropout_prob is None or dropout_prob == 0.0:
364
+ return input_tensor
365
+
366
+ output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
367
+ return output
368
+
369
+
370
+ def layer_norm(input_tensor, name=None):
371
+ """Run layer normalization on the last dimension of the tensor."""
372
+ return tf.contrib.layers.layer_norm(
373
+ inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name
374
+ )
375
+
376
+
377
+ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
378
+ """Runs layer normalization followed by dropout."""
379
+ output_tensor = layer_norm(input_tensor, name)
380
+ output_tensor = dropout(output_tensor, dropout_prob)
381
+ return output_tensor
382
+
383
+
384
+ def create_initializer(initializer_range=0.02):
385
+ """Creates a `truncated_normal_initializer` with the given range."""
386
+ return tf.truncated_normal_initializer(stddev=initializer_range)
387
+
388
+
389
+ def embedding_lookup(
390
+ input_ids,
391
+ vocab_size,
392
+ embedding_size=128,
393
+ initializer_range=0.02,
394
+ word_embedding_name="word_embeddings",
395
+ use_one_hot_embeddings=False,
396
+ ):
397
+ """Looks up words embeddings for id tensor.
398
+
399
+ Args:
400
+ input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
401
+ ids.
402
+ vocab_size: int. Size of the embedding vocabulary.
403
+ embedding_size: int. Width of the word embeddings.
404
+ initializer_range: float. Embedding initialization range.
405
+ word_embedding_name: string. Name of the embedding table.
406
+ use_one_hot_embeddings: bool. If True, use one-hot method for word
407
+ embeddings. If False, use `tf.gather()`.
408
+
409
+ Returns:
410
+ float Tensor of shape [batch_size, seq_length, embedding_size].
411
+ """
412
+ # This function assumes that the input is of shape [batch_size, seq_length,
413
+ # num_inputs].
414
+ #
415
+ # If the input is a 2D tensor of shape [batch_size, seq_length], we
416
+ # reshape to [batch_size, seq_length, 1].
417
+ if input_ids.shape.ndims == 2:
418
+ input_ids = tf.expand_dims(input_ids, axis=[-1])
419
+
420
+ embedding_table = tf.get_variable(
421
+ name=word_embedding_name,
422
+ shape=[vocab_size, embedding_size],
423
+ initializer=create_initializer(initializer_range),
424
+ )
425
+
426
+ flat_input_ids = tf.reshape(input_ids, [-1])
427
+ if use_one_hot_embeddings:
428
+ one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
429
+ output = tf.matmul(one_hot_input_ids, embedding_table)
430
+ else:
431
+ output = tf.gather(embedding_table, flat_input_ids)
432
+
433
+ input_shape = get_shape_list(input_ids)
434
+
435
+ output = tf.reshape(output, input_shape[0:-1] + [input_shape[-1] * embedding_size])
436
+ return (output, embedding_table)
437
+
438
+
439
+ def embedding_postprocessor(
440
+ input_tensor,
441
+ use_token_type=False,
442
+ token_type_ids=None,
443
+ token_type_vocab_size=16,
444
+ token_type_embedding_name="token_type_embeddings",
445
+ use_position_embeddings=True,
446
+ position_embedding_name="position_embeddings",
447
+ initializer_range=0.02,
448
+ max_position_embeddings=512,
449
+ dropout_prob=0.1,
450
+ ):
451
+ """Performs various post-processing on a word embedding tensor.
452
+
453
+ Args:
454
+ input_tensor: float Tensor of shape [batch_size, seq_length,
455
+ embedding_size].
456
+ use_token_type: bool. Whether to add embeddings for `token_type_ids`.
457
+ token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
458
+ Must be specified if `use_token_type` is True.
459
+ token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
460
+ token_type_embedding_name: string. The name of the embedding table variable
461
+ for token type ids.
462
+ use_position_embeddings: bool. Whether to add position embeddings for the
463
+ position of each token in the sequence.
464
+ position_embedding_name: string. The name of the embedding table variable
465
+ for positional embeddings.
466
+ initializer_range: float. Range of the weight initialization.
467
+ max_position_embeddings: int. Maximum sequence length that might ever be
468
+ used with this model. This can be longer than the sequence length of
469
+ input_tensor, but cannot be shorter.
470
+ dropout_prob: float. Dropout probability applied to the final output tensor.
471
+
472
+ Returns:
473
+ float tensor with same shape as `input_tensor`.
474
+
475
+ Raises:
476
+ ValueError: One of the tensor shapes or input values is invalid.
477
+ """
478
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
479
+ batch_size = input_shape[0]
480
+ seq_length = input_shape[1]
481
+ width = input_shape[2]
482
+
483
+ output = input_tensor
484
+
485
+ if use_token_type:
486
+ if token_type_ids is None:
487
+ raise ValueError(
488
+ "`token_type_ids` must be specified if" "`use_token_type` is True."
489
+ )
490
+ token_type_table = tf.get_variable(
491
+ name=token_type_embedding_name,
492
+ shape=[token_type_vocab_size, width],
493
+ initializer=create_initializer(initializer_range),
494
+ )
495
+ # This vocab will be small so we always do one-hot here, since it is always
496
+ # faster for a small vocabulary.
497
+ flat_token_type_ids = tf.reshape(token_type_ids, [-1])
498
+ one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
499
+ token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
500
+ token_type_embeddings = tf.reshape(
501
+ token_type_embeddings, [batch_size, seq_length, width]
502
+ )
503
+ output += token_type_embeddings
504
+
505
+ if use_position_embeddings:
506
+ assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
507
+ with tf.control_dependencies([assert_op]):
508
+ full_position_embeddings = tf.get_variable(
509
+ name=position_embedding_name,
510
+ shape=[max_position_embeddings, width],
511
+ initializer=create_initializer(initializer_range),
512
+ )
513
+ # Since the position embedding table is a learned variable, we create it
514
+ # using a (long) sequence length `max_position_embeddings`. The actual
515
+ # sequence length might be shorter than this, for faster training of
516
+ # tasks that do not have long sequences.
517
+ #
518
+ # So `full_position_embeddings` is effectively an embedding table
519
+ # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
520
+ # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
521
+ # perform a slice.
522
+ position_embeddings = tf.slice(
523
+ full_position_embeddings, [0, 0], [seq_length, -1]
524
+ )
525
+ num_dims = len(output.shape.as_list())
526
+
527
+ # Only the last two dimensions are relevant (`seq_length` and `width`), so
528
+ # we broadcast among the first dimensions, which is typically just
529
+ # the batch size.
530
+ position_broadcast_shape = []
531
+ for _ in range(num_dims - 2):
532
+ position_broadcast_shape.append(1)
533
+ position_broadcast_shape.extend([seq_length, width])
534
+ position_embeddings = tf.reshape(
535
+ position_embeddings, position_broadcast_shape
536
+ )
537
+ output += position_embeddings
538
+
539
+ output = layer_norm_and_dropout(output, dropout_prob)
540
+ return output
541
+
542
+
543
+ def create_attention_mask_from_input_mask(from_tensor, to_mask):
544
+ """Create 3D attention mask from a 2D tensor mask.
545
+
546
+ Args:
547
+ from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
548
+ to_mask: int32 Tensor of shape [batch_size, to_seq_length].
549
+
550
+ Returns:
551
+ float Tensor of shape [batch_size, from_seq_length, to_seq_length].
552
+ """
553
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
554
+ batch_size = from_shape[0]
555
+ from_seq_length = from_shape[1]
556
+
557
+ to_shape = get_shape_list(to_mask, expected_rank=2)
558
+ to_seq_length = to_shape[1]
559
+
560
+ to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
561
+
562
+ # We don't assume that `from_tensor` is a mask (although it could be). We
563
+ # don't actually care if we attend *from* padding tokens (only *to* padding)
564
+ # tokens so we create a tensor of all ones.
565
+ #
566
+ # `broadcast_ones` = [batch_size, from_seq_length, 1]
567
+ broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
568
+
569
+ # Here we broadcast along two dimensions to create the mask.
570
+ mask = broadcast_ones * to_mask
571
+
572
+ return mask
573
+
574
+
575
+ def attention_layer(
576
+ from_tensor,
577
+ to_tensor,
578
+ attention_mask=None,
579
+ num_attention_heads=1,
580
+ size_per_head=512,
581
+ query_act=None,
582
+ key_act=None,
583
+ value_act=None,
584
+ attention_probs_dropout_prob=0.0,
585
+ initializer_range=0.02,
586
+ do_return_2d_tensor=False,
587
+ batch_size=None,
588
+ from_seq_length=None,
589
+ to_seq_length=None,
590
+ ):
591
+ """Performs multi-headed attention from `from_tensor` to `to_tensor`.
592
+
593
+ This is an implementation of multi-headed attention based on "Attention
594
+ is all you Need". If `from_tensor` and `to_tensor` are the same, then
595
+ this is self-attention. Each timestep in `from_tensor` attends to the
596
+ corresponding sequence in `to_tensor`, and returns a fixed-with vector.
597
+
598
+ This function first projects `from_tensor` into a "query" tensor and
599
+ `to_tensor` into "key" and "value" tensors. These are (effectively) a list
600
+ of tensors of length `num_attention_heads`, where each tensor is of shape
601
+ [batch_size, seq_length, size_per_head].
602
+
603
+ Then, the query and key tensors are dot-producted and scaled. These are
604
+ softmaxed to obtain attention probabilities. The value tensors are then
605
+ interpolated by these probabilities, then concatenated back to a single
606
+ tensor and returned.
607
+
608
+ In practice, the multi-headed attention are done with transposes and
609
+ reshapes rather than actual separate tensors.
610
+
611
+ Args:
612
+ from_tensor: float Tensor of shape [batch_size, from_seq_length,
613
+ from_width].
614
+ to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
615
+ attention_mask: (optional) int32 Tensor of shape [batch_size,
616
+ from_seq_length, to_seq_length]. The values should be 1 or 0. The
617
+ attention scores will effectively be set to -infinity for any positions in
618
+ the mask that are 0, and will be unchanged for positions that are 1.
619
+ num_attention_heads: int. Number of attention heads.
620
+ size_per_head: int. Size of each attention head.
621
+ query_act: (optional) Activation function for the query transform.
622
+ key_act: (optional) Activation function for the key transform.
623
+ value_act: (optional) Activation function for the value transform.
624
+ attention_probs_dropout_prob: (optional) float. Dropout probability of the
625
+ attention probabilities.
626
+ initializer_range: float. Range of the weight initializer.
627
+ do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
628
+ * from_seq_length, num_attention_heads * size_per_head]. If False, the
629
+ output will be of shape [batch_size, from_seq_length, num_attention_heads
630
+ * size_per_head].
631
+ batch_size: (Optional) int. If the input is 2D, this might be the batch size
632
+ of the 3D version of the `from_tensor` and `to_tensor`.
633
+ from_seq_length: (Optional) If the input is 2D, this might be the seq length
634
+ of the 3D version of the `from_tensor`.
635
+ to_seq_length: (Optional) If the input is 2D, this might be the seq length
636
+ of the 3D version of the `to_tensor`.
637
+
638
+ Returns:
639
+ float Tensor of shape [batch_size, from_seq_length,
640
+ num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
641
+ true, this will be of shape [batch_size * from_seq_length,
642
+ num_attention_heads * size_per_head]).
643
+
644
+ Raises:
645
+ ValueError: Any of the arguments or tensor shapes are invalid.
646
+ """
647
+
648
+ def transpose_for_scores(
649
+ input_tensor, batch_size, num_attention_heads, seq_length, width
650
+ ):
651
+ output_tensor = tf.reshape(
652
+ input_tensor, [batch_size, seq_length, num_attention_heads, width]
653
+ )
654
+
655
+ output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
656
+ return output_tensor
657
+
658
+ from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
659
+ to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
660
+
661
+ if len(from_shape) != len(to_shape):
662
+ raise ValueError(
663
+ "The rank of `from_tensor` must match the rank of `to_tensor`."
664
+ )
665
+
666
+ if len(from_shape) == 3:
667
+ batch_size = from_shape[0]
668
+ from_seq_length = from_shape[1]
669
+ to_seq_length = to_shape[1]
670
+ elif len(from_shape) == 2:
671
+ if batch_size is None or from_seq_length is None or to_seq_length is None:
672
+ raise ValueError(
673
+ "When passing in rank 2 tensors to attention_layer, the values "
674
+ "for `batch_size`, `from_seq_length`, and `to_seq_length` "
675
+ "must all be specified."
676
+ )
677
+
678
+ # Scalar dimensions referenced here:
679
+ # B = batch size (number of sequences)
680
+ # F = `from_tensor` sequence length
681
+ # T = `to_tensor` sequence length
682
+ # N = `num_attention_heads`
683
+ # H = `size_per_head`
684
+
685
+ from_tensor_2d = reshape_to_matrix(from_tensor)
686
+ to_tensor_2d = reshape_to_matrix(to_tensor)
687
+
688
+ # `query_layer` = [B*F, N*H]
689
+ query_layer = tf.layers.dense(
690
+ from_tensor_2d,
691
+ num_attention_heads * size_per_head,
692
+ activation=query_act,
693
+ name="query",
694
+ kernel_initializer=create_initializer(initializer_range),
695
+ )
696
+
697
+ # `key_layer` = [B*T, N*H]
698
+ key_layer = tf.layers.dense(
699
+ to_tensor_2d,
700
+ num_attention_heads * size_per_head,
701
+ activation=key_act,
702
+ name="key",
703
+ kernel_initializer=create_initializer(initializer_range),
704
+ )
705
+
706
+ # `value_layer` = [B*T, N*H]
707
+ value_layer = tf.layers.dense(
708
+ to_tensor_2d,
709
+ num_attention_heads * size_per_head,
710
+ activation=value_act,
711
+ name="value",
712
+ kernel_initializer=create_initializer(initializer_range),
713
+ )
714
+
715
+ # `query_layer` = [B, N, F, H]
716
+ query_layer = transpose_for_scores(
717
+ query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head
718
+ )
719
+
720
+ # `key_layer` = [B, N, T, H]
721
+ key_layer = transpose_for_scores(
722
+ key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head
723
+ )
724
+
725
+ # Take the dot product between "query" and "key" to get the raw
726
+ # attention scores.
727
+ # `attention_scores` = [B, N, F, T]
728
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
729
+ attention_scores = tf.multiply(
730
+ attention_scores, 1.0 / math.sqrt(float(size_per_head))
731
+ )
732
+
733
+ if attention_mask is not None:
734
+ # `attention_mask` = [B, 1, F, T]
735
+ attention_mask = tf.expand_dims(attention_mask, axis=[1])
736
+
737
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
738
+ # masked positions, this operation will create a tensor which is 0.0 for
739
+ # positions we want to attend and -10000.0 for masked positions.
740
+ adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
741
+
742
+ # Since we are adding it to the raw scores before the softmax, this is
743
+ # effectively the same as removing these entirely.
744
+ attention_scores += adder
745
+
746
+ # Normalize the attention scores to probabilities.
747
+ # `attention_probs` = [B, N, F, T]
748
+ attention_probs = tf.nn.softmax(attention_scores)
749
+
750
+ # This is actually dropping out entire tokens to attend to, which might
751
+ # seem a bit unusual, but is taken from the original Transformer paper.
752
+ attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
753
+
754
+ # `value_layer` = [B, T, N, H]
755
+ value_layer = tf.reshape(
756
+ value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]
757
+ )
758
+
759
+ # `value_layer` = [B, N, T, H]
760
+ value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
761
+
762
+ # `context_layer` = [B, N, F, H]
763
+ context_layer = tf.matmul(attention_probs, value_layer)
764
+
765
+ # `context_layer` = [B, F, N, H]
766
+ context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
767
+
768
+ if do_return_2d_tensor:
769
+ # `context_layer` = [B*F, N*H]
770
+ context_layer = tf.reshape(
771
+ context_layer,
772
+ [batch_size * from_seq_length, num_attention_heads * size_per_head],
773
+ )
774
+ else:
775
+ # `context_layer` = [B, F, N*H]
776
+ context_layer = tf.reshape(
777
+ context_layer,
778
+ [batch_size, from_seq_length, num_attention_heads * size_per_head],
779
+ )
780
+
781
+ return context_layer
782
+
783
+
784
+ def transformer_model(
785
+ input_tensor,
786
+ attention_mask=None,
787
+ hidden_size=768,
788
+ num_hidden_layers=12,
789
+ num_attention_heads=12,
790
+ intermediate_size=3072,
791
+ intermediate_act_fn=gelu,
792
+ hidden_dropout_prob=0.1,
793
+ attention_probs_dropout_prob=0.1,
794
+ initializer_range=0.02,
795
+ do_return_all_layers=False,
796
+ ):
797
+ """Multi-headed, multi-layer Transformer from "Attention is All You Need".
798
+
799
+ This is almost an exact implementation of the original Transformer encoder.
800
+
801
+ See the original paper:
802
+ https://arxiv.org/abs/1706.03762
803
+
804
+ Also see:
805
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
806
+
807
+ Args:
808
+ input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
809
+ attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
810
+ seq_length], with 1 for positions that can be attended to and 0 in
811
+ positions that should not be.
812
+ hidden_size: int. Hidden size of the Transformer.
813
+ num_hidden_layers: int. Number of layers (blocks) in the Transformer.
814
+ num_attention_heads: int. Number of attention heads in the Transformer.
815
+ intermediate_size: int. The size of the "intermediate" (a.k.a., feed
816
+ forward) layer.
817
+ intermediate_act_fn: function. The non-linear activation function to apply
818
+ to the output of the intermediate/feed-forward layer.
819
+ hidden_dropout_prob: float. Dropout probability for the hidden layers.
820
+ attention_probs_dropout_prob: float. Dropout probability of the attention
821
+ probabilities.
822
+ initializer_range: float. Range of the initializer (stddev of truncated
823
+ normal).
824
+ do_return_all_layers: Whether to also return all layers or just the final
825
+ layer.
826
+
827
+ Returns:
828
+ float Tensor of shape [batch_size, seq_length, hidden_size], the final
829
+ hidden layer of the Transformer.
830
+
831
+ Raises:
832
+ ValueError: A Tensor shape or parameter is invalid.
833
+ """
834
+ if hidden_size % num_attention_heads != 0:
835
+ raise ValueError(
836
+ "The hidden size (%d) is not a multiple of the number of attention "
837
+ "heads (%d)" % (hidden_size, num_attention_heads)
838
+ )
839
+
840
+ attention_head_size = int(hidden_size / num_attention_heads)
841
+ input_shape = get_shape_list(input_tensor, expected_rank=3)
842
+ batch_size = input_shape[0]
843
+ seq_length = input_shape[1]
844
+ input_width = input_shape[2]
845
+
846
+ # The Transformer performs sum residuals on all layers so the input needs
847
+ # to be the same as the hidden size.
848
+ if input_width != hidden_size:
849
+ raise ValueError(
850
+ "The width of the input tensor (%d) != hidden size (%d)"
851
+ % (input_width, hidden_size)
852
+ )
853
+
854
+ # We keep the representation as a 2D tensor to avoid re-shaping it back and
855
+ # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
856
+ # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
857
+ # help the optimizer.
858
+ prev_output = reshape_to_matrix(input_tensor)
859
+
860
+ all_layer_outputs = []
861
+ for layer_idx in range(num_hidden_layers):
862
+ with tf.variable_scope("layer_%d" % layer_idx):
863
+ layer_input = prev_output
864
+
865
+ with tf.variable_scope("attention"):
866
+ attention_heads = []
867
+ with tf.variable_scope("self"):
868
+ attention_head = attention_layer(
869
+ from_tensor=layer_input,
870
+ to_tensor=layer_input,
871
+ attention_mask=attention_mask,
872
+ num_attention_heads=num_attention_heads,
873
+ size_per_head=attention_head_size,
874
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
875
+ initializer_range=initializer_range,
876
+ do_return_2d_tensor=True,
877
+ batch_size=batch_size,
878
+ from_seq_length=seq_length,
879
+ to_seq_length=seq_length,
880
+ )
881
+ attention_heads.append(attention_head)
882
+
883
+ attention_output = None
884
+ if len(attention_heads) == 1:
885
+ attention_output = attention_heads[0]
886
+ else:
887
+ # In the case where we have other sequences, we just concatenate
888
+ # them to the self-attention head before the projection.
889
+ attention_output = tf.concat(attention_heads, axis=-1)
890
+
891
+ # Run a linear projection of `hidden_size` then add a residual
892
+ # with `layer_input`.
893
+ with tf.variable_scope("output"):
894
+ attention_output = tf.layers.dense(
895
+ attention_output,
896
+ hidden_size,
897
+ kernel_initializer=create_initializer(initializer_range),
898
+ )
899
+ attention_output = dropout(attention_output, hidden_dropout_prob)
900
+ attention_output = layer_norm(attention_output + layer_input)
901
+
902
+ # The activation is only applied to the "intermediate" hidden layer.
903
+ with tf.variable_scope("intermediate"):
904
+ intermediate_output = tf.layers.dense(
905
+ attention_output,
906
+ intermediate_size,
907
+ activation=intermediate_act_fn,
908
+ kernel_initializer=create_initializer(initializer_range),
909
+ )
910
+
911
+ # Down-project back to `hidden_size` then add the residual.
912
+ with tf.variable_scope("output"):
913
+ layer_output = tf.layers.dense(
914
+ intermediate_output,
915
+ hidden_size,
916
+ kernel_initializer=create_initializer(initializer_range),
917
+ )
918
+ layer_output = dropout(layer_output, hidden_dropout_prob)
919
+ layer_output = layer_norm(layer_output + attention_output)
920
+ prev_output = layer_output
921
+ all_layer_outputs.append(layer_output)
922
+
923
+ if do_return_all_layers:
924
+ final_outputs = []
925
+ for layer_output in all_layer_outputs:
926
+ final_output = reshape_from_matrix(layer_output, input_shape)
927
+ final_outputs.append(final_output)
928
+ return final_outputs
929
+ else:
930
+ final_output = reshape_from_matrix(prev_output, input_shape)
931
+ return final_output
932
+
933
+
934
+ def get_shape_list(tensor, expected_rank=None, name=None):
935
+ """Returns a list of the shape of tensor, preferring static dimensions.
936
+
937
+ Args:
938
+ tensor: A tf.Tensor object to find the shape of.
939
+ expected_rank: (optional) int. The expected rank of `tensor`. If this is
940
+ specified and the `tensor` has a different rank, and exception will be
941
+ thrown.
942
+ name: Optional name of the tensor for the error message.
943
+
944
+ Returns:
945
+ A list of dimensions of the shape of tensor. All static dimensions will
946
+ be returned as python integers, and dynamic dimensions will be returned
947
+ as tf.Tensor scalars.
948
+ """
949
+ if name is None:
950
+ name = tensor.name
951
+
952
+ if expected_rank is not None:
953
+ assert_rank(tensor, expected_rank, name)
954
+
955
+ shape = tensor.shape.as_list()
956
+
957
+ non_static_indexes = []
958
+ for (index, dim) in enumerate(shape):
959
+ if dim is None:
960
+ non_static_indexes.append(index)
961
+
962
+ if not non_static_indexes:
963
+ return shape
964
+
965
+ dyn_shape = tf.shape(tensor)
966
+ for index in non_static_indexes:
967
+ shape[index] = dyn_shape[index]
968
+ return shape
969
+
970
+
971
+ def reshape_to_matrix(input_tensor):
972
+ """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
973
+ ndims = input_tensor.shape.ndims
974
+ if ndims < 2:
975
+ raise ValueError(
976
+ "Input tensor must have at least rank 2. Shape = %s" % (input_tensor.shape)
977
+ )
978
+ if ndims == 2:
979
+ return input_tensor
980
+
981
+ width = input_tensor.shape[-1]
982
+ output_tensor = tf.reshape(input_tensor, [-1, width])
983
+ return output_tensor
984
+
985
+
986
+ def reshape_from_matrix(output_tensor, orig_shape_list):
987
+ """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
988
+ if len(orig_shape_list) == 2:
989
+ return output_tensor
990
+
991
+ output_shape = get_shape_list(output_tensor)
992
+
993
+ orig_dims = orig_shape_list[0:-1]
994
+ width = output_shape[-1]
995
+
996
+ return tf.reshape(output_tensor, orig_dims + [width])
997
+
998
+
999
+ def assert_rank(tensor, expected_rank, name=None):
1000
+ """Raises an exception if the tensor rank is not of the expected rank.
1001
+
1002
+ Args:
1003
+ tensor: A tf.Tensor to check the rank of.
1004
+ expected_rank: Python integer or list of integers, expected rank.
1005
+ name: Optional name of the tensor for the error message.
1006
+
1007
+ Raises:
1008
+ ValueError: If the expected shape doesn't match the actual shape.
1009
+ """
1010
+ if name is None:
1011
+ name = tensor.name
1012
+
1013
+ expected_rank_dict = {}
1014
+ if isinstance(expected_rank, six.integer_types):
1015
+ expected_rank_dict[expected_rank] = True
1016
+ else:
1017
+ for x in expected_rank:
1018
+ expected_rank_dict[x] = True
1019
+
1020
+ actual_rank = tensor.shape.ndims
1021
+ if actual_rank not in expected_rank_dict:
1022
+ scope_name = tf.get_variable_scope().name
1023
+ raise ValueError(
1024
+ "For the tensor `%s` in scope `%s`, the actual rank "
1025
+ "`%d` (shape = %s) is not equal to the expected rank `%s`"
1026
+ % (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))
1027
+ )
arabert/arabert/optimization.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Functions and classes related to optimization (weight updates)."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import re
22
+ import tensorflow as tf
23
+ import lamb_optimizer
24
+
25
+ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu,
26
+ optimizer="adamw", poly_power=1.0, start_warmup_step=0,
27
+ colocate_gradients_with_ops=False):
28
+ """Creates an optimizer training op."""
29
+ global_step = tf.train.get_or_create_global_step()
30
+
31
+ learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
32
+
33
+ # Implements linear decay of the learning rate.
34
+ learning_rate = tf.train.polynomial_decay(
35
+ learning_rate,
36
+ global_step,
37
+ num_train_steps,
38
+ end_learning_rate=0.0,
39
+ power=poly_power,
40
+ cycle=False,
41
+ )
42
+
43
+ # Implements linear warmup. I.e., if global_step - start_warmup_step <
44
+ # num_warmup_steps, the learning rate will be
45
+ # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
46
+ if num_warmup_steps:
47
+ tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step)
48
+ + ", for " + str(num_warmup_steps) + " steps ++++++")
49
+ global_steps_int = tf.cast(global_step, tf.int32)
50
+ start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
51
+ global_steps_int = global_steps_int - start_warm_int
52
+ warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
53
+
54
+ global_steps_float = tf.cast(global_steps_int, tf.float32)
55
+ warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
56
+
57
+ warmup_percent_done = global_steps_float / warmup_steps_float
58
+ warmup_learning_rate = init_lr * warmup_percent_done
59
+
60
+ is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
61
+ learning_rate = (
62
+ 1.0 - is_warmup
63
+ ) * learning_rate + is_warmup * warmup_learning_rate
64
+
65
+ # It is OK that you use this optimizer for finetuning, since this
66
+ # is how the model was trained (note that the Adam m/v variables are NOT
67
+ # loaded from init_checkpoint.)
68
+ # It is OK to use AdamW in the finetuning even the model is trained by LAMB.
69
+ # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
70
+ # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
71
+ # batch size of 64 in the finetune.
72
+ if optimizer == "adamw":
73
+ tf.logging.info("using adamw")
74
+ optimizer = AdamWeightDecayOptimizer(
75
+ learning_rate=learning_rate,
76
+ weight_decay_rate=0.01,
77
+ beta_1=0.9,
78
+ beta_2=0.999,
79
+ epsilon=1e-6,
80
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
81
+ elif optimizer == "lamb":
82
+ tf.logging.info("using lamb")
83
+ optimizer = lamb_optimizer.LAMBOptimizer(
84
+ learning_rate=learning_rate,
85
+ weight_decay_rate=0.01,
86
+ beta_1=0.9,
87
+ beta_2=0.999,
88
+ epsilon=1e-6,
89
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
90
+ else:
91
+ raise ValueError("Not supported optimizer: ", optimizer)
92
+
93
+ if use_tpu:
94
+ optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
95
+
96
+ tvars = tf.trainable_variables()
97
+ grads = tf.gradients(loss, tvars)
98
+
99
+ # This is how the model was pre-trained.
100
+ (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
101
+
102
+ train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
103
+
104
+ # Normally the global step update is done inside of `apply_gradients`.
105
+ # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this.
106
+ # But if you use a different optimizer, you should probably take this line
107
+ # out.
108
+ new_global_step = global_step + 1
109
+ train_op = tf.group(train_op, [global_step.assign(new_global_step)])
110
+ return train_op
111
+
112
+
113
+ class AdamWeightDecayOptimizer(tf.train.Optimizer):
114
+ """A basic Adam optimizer that includes "correct" L2 weight decay."""
115
+
116
+ def __init__(
117
+ self,
118
+ learning_rate,
119
+ weight_decay_rate=0.0,
120
+ beta_1=0.9,
121
+ beta_2=0.999,
122
+ epsilon=1e-6,
123
+ exclude_from_weight_decay=None,
124
+ name="AdamWeightDecayOptimizer",
125
+ ):
126
+ """Constructs a AdamWeightDecayOptimizer."""
127
+ super(AdamWeightDecayOptimizer, self).__init__(False, name)
128
+
129
+ self.learning_rate = learning_rate
130
+ self.weight_decay_rate = weight_decay_rate
131
+ self.beta_1 = beta_1
132
+ self.beta_2 = beta_2
133
+ self.epsilon = epsilon
134
+ self.exclude_from_weight_decay = exclude_from_weight_decay
135
+
136
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
137
+ """See base class."""
138
+ assignments = []
139
+ for (grad, param) in grads_and_vars:
140
+ if grad is None or param is None:
141
+ continue
142
+
143
+ param_name = self._get_variable_name(param.name)
144
+
145
+ m = tf.get_variable(
146
+ name=param_name + "/adam_m",
147
+ shape=param.shape.as_list(),
148
+ dtype=tf.float32,
149
+ trainable=False,
150
+ initializer=tf.zeros_initializer(),
151
+ )
152
+ v = tf.get_variable(
153
+ name=param_name + "/adam_v",
154
+ shape=param.shape.as_list(),
155
+ dtype=tf.float32,
156
+ trainable=False,
157
+ initializer=tf.zeros_initializer(),
158
+ )
159
+
160
+ # Standard Adam update.
161
+ next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)
162
+ next_v = tf.multiply(self.beta_2, v) + tf.multiply(
163
+ 1.0 - self.beta_2, tf.square(grad)
164
+ )
165
+
166
+ update = next_m / (tf.sqrt(next_v) + self.epsilon)
167
+
168
+ # Just adding the square of the weights to the loss function is *not*
169
+ # the correct way of using L2 regularization/weight decay with Adam,
170
+ # since that will interact with the m and v parameters in strange ways.
171
+ #
172
+ # Instead we want ot decay the weights in a manner that doesn't interact
173
+ # with the m/v parameters. This is equivalent to adding the square
174
+ # of the weights to the loss with plain (non-momentum) SGD.
175
+ if self._do_use_weight_decay(param_name):
176
+ update += self.weight_decay_rate * param
177
+
178
+ update_with_lr = self.learning_rate * update
179
+
180
+ next_param = param - update_with_lr
181
+
182
+ assignments.extend(
183
+ [param.assign(next_param), m.assign(next_m), v.assign(next_v)]
184
+ )
185
+ return tf.group(*assignments, name=name)
186
+
187
+ def _do_use_weight_decay(self, param_name):
188
+ """Whether to use L2 weight decay for `param_name`."""
189
+ if not self.weight_decay_rate:
190
+ return False
191
+ if self.exclude_from_weight_decay:
192
+ for r in self.exclude_from_weight_decay:
193
+ if re.search(r, param_name) is not None:
194
+ return False
195
+ return True
196
+
197
+ def _get_variable_name(self, param_name):
198
+ """Get the variable name from the tensor name."""
199
+ m = re.match("^(.*):\\d+$", param_name)
200
+ if m is not None:
201
+ param_name = m.group(1)
202
+ return param_name
arabert/arabert/run_classifier.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BERT finetuning runner."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import collections
22
+ import csv
23
+ import os
24
+ import modeling
25
+ import optimization
26
+ import tokenization
27
+ import tensorflow as tf
28
+
29
+ flags = tf.flags
30
+
31
+ FLAGS = flags.FLAGS
32
+
33
+ ## Required parameters
34
+ flags.DEFINE_string(
35
+ "data_dir",
36
+ None,
37
+ "The input data dir. Should contain the .tsv files (or other data files) "
38
+ "for the task.",
39
+ )
40
+
41
+ flags.DEFINE_string(
42
+ "bert_config_file",
43
+ None,
44
+ "The config json file corresponding to the pre-trained BERT model. "
45
+ "This specifies the model architecture.",
46
+ )
47
+
48
+ flags.DEFINE_string("task_name", None, "The name of the task to train.")
49
+
50
+ flags.DEFINE_string(
51
+ "vocab_file", None, "The vocabulary file that the BERT model was trained on."
52
+ )
53
+
54
+ flags.DEFINE_string(
55
+ "output_dir",
56
+ None,
57
+ "The output directory where the model checkpoints will be written.",
58
+ )
59
+
60
+ ## Other parameters
61
+
62
+ flags.DEFINE_string(
63
+ "init_checkpoint",
64
+ None,
65
+ "Initial checkpoint (usually from a pre-trained BERT model).",
66
+ )
67
+
68
+ flags.DEFINE_bool(
69
+ "do_lower_case",
70
+ True,
71
+ "Whether to lower case the input text. Should be True for uncased "
72
+ "models and False for cased models.",
73
+ )
74
+
75
+ flags.DEFINE_integer(
76
+ "max_seq_length",
77
+ 128,
78
+ "The maximum total input sequence length after WordPiece tokenization. "
79
+ "Sequences longer than this will be truncated, and sequences shorter "
80
+ "than this will be padded.",
81
+ )
82
+
83
+ flags.DEFINE_bool("do_train", False, "Whether to run training.")
84
+
85
+ flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
86
+
87
+ flags.DEFINE_bool(
88
+ "do_predict", False, "Whether to run the model in inference mode on the test set."
89
+ )
90
+
91
+ flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
92
+
93
+ flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
94
+
95
+ flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
96
+
97
+ flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
98
+
99
+ flags.DEFINE_float(
100
+ "num_train_epochs", 3.0, "Total number of training epochs to perform."
101
+ )
102
+
103
+ flags.DEFINE_float(
104
+ "warmup_proportion",
105
+ 0.1,
106
+ "Proportion of training to perform linear learning rate warmup for. "
107
+ "E.g., 0.1 = 10% of training.",
108
+ )
109
+
110
+ flags.DEFINE_integer(
111
+ "save_checkpoints_steps", 1000, "How often to save the model checkpoint."
112
+ )
113
+
114
+ flags.DEFINE_integer(
115
+ "iterations_per_loop", 1000, "How many steps to make in each estimator call."
116
+ )
117
+
118
+ flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
119
+
120
+ tf.flags.DEFINE_string(
121
+ "tpu_name",
122
+ None,
123
+ "The Cloud TPU to use for training. This should be either the name "
124
+ "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
125
+ "url.",
126
+ )
127
+
128
+ tf.flags.DEFINE_string(
129
+ "tpu_zone",
130
+ None,
131
+ "[Optional] GCE zone where the Cloud TPU is located in. If not "
132
+ "specified, we will attempt to automatically detect the GCE project from "
133
+ "metadata.",
134
+ )
135
+
136
+ tf.flags.DEFINE_string(
137
+ "gcp_project",
138
+ None,
139
+ "[Optional] Project name for the Cloud TPU-enabled project. If not "
140
+ "specified, we will attempt to automatically detect the GCE project from "
141
+ "metadata.",
142
+ )
143
+
144
+ tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
145
+
146
+ flags.DEFINE_integer(
147
+ "num_tpu_cores",
148
+ 8,
149
+ "Only used if `use_tpu` is True. Total number of TPU cores to use.",
150
+ )
151
+
152
+
153
+ class InputExample(object):
154
+ """A single training/test example for simple sequence classification."""
155
+
156
+ def __init__(self, guid, text_a, text_b=None, label=None):
157
+ """Constructs a InputExample.
158
+
159
+ Args:
160
+ guid: Unique id for the example.
161
+ text_a: string. The untokenized text of the first sequence. For single
162
+ sequence tasks, only this sequence must be specified.
163
+ text_b: (Optional) string. The untokenized text of the second sequence.
164
+ Only must be specified for sequence pair tasks.
165
+ label: (Optional) string. The label of the example. This should be
166
+ specified for train and dev examples, but not for test examples.
167
+ """
168
+ self.guid = guid
169
+ self.text_a = text_a
170
+ self.text_b = text_b
171
+ self.label = label
172
+
173
+
174
+ class PaddingInputExample(object):
175
+ """Fake example so the num input examples is a multiple of the batch size.
176
+
177
+ When running eval/predict on the TPU, we need to pad the number of examples
178
+ to be a multiple of the batch size, because the TPU requires a fixed batch
179
+ size. The alternative is to drop the last batch, which is bad because it means
180
+ the entire output data won't be generated.
181
+
182
+ We use this class instead of `None` because treating `None` as padding
183
+ battches could cause silent errors.
184
+ """
185
+
186
+
187
+ class InputFeatures(object):
188
+ """A single set of features of data."""
189
+
190
+ def __init__(
191
+ self, input_ids, input_mask, segment_ids, label_id, is_real_example=True
192
+ ):
193
+ self.input_ids = input_ids
194
+ self.input_mask = input_mask
195
+ self.segment_ids = segment_ids
196
+ self.label_id = label_id
197
+ self.is_real_example = is_real_example
198
+
199
+
200
+ class DataProcessor(object):
201
+ """Base class for data converters for sequence classification data sets."""
202
+
203
+ def get_train_examples(self, data_dir):
204
+ """Gets a collection of `InputExample`s for the train set."""
205
+ raise NotImplementedError()
206
+
207
+ def get_dev_examples(self, data_dir):
208
+ """Gets a collection of `InputExample`s for the dev set."""
209
+ raise NotImplementedError()
210
+
211
+ def get_test_examples(self, data_dir):
212
+ """Gets a collection of `InputExample`s for prediction."""
213
+ raise NotImplementedError()
214
+
215
+ def get_labels(self):
216
+ """Gets the list of labels for this data set."""
217
+ raise NotImplementedError()
218
+
219
+ @classmethod
220
+ def _read_tsv(cls, input_file, quotechar=None):
221
+ """Reads a tab separated value file."""
222
+ with tf.gfile.Open(input_file, "r") as f:
223
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
224
+ lines = []
225
+ for line in reader:
226
+ lines.append(line)
227
+ return lines
228
+
229
+
230
+ class XnliProcessor(DataProcessor):
231
+ """Processor for the XNLI data set."""
232
+
233
+ def __init__(self):
234
+ self.language = "ar"
235
+
236
+ def get_train_examples(self, data_dir):
237
+ """See base class."""
238
+ lines = self._read_tsv(
239
+ os.path.join(data_dir, "multinli", "multinli.train.%s.tsv" % self.language)
240
+ )
241
+ examples = []
242
+ for (i, line) in enumerate(lines):
243
+ if i == 0:
244
+ continue
245
+ guid = "train-%d" % (i)
246
+ text_a = tokenization.convert_to_unicode(line[0])
247
+ text_b = tokenization.convert_to_unicode(line[1])
248
+ label = tokenization.convert_to_unicode(line[2])
249
+ if label == tokenization.convert_to_unicode("contradictory"):
250
+ label = tokenization.convert_to_unicode("contradiction")
251
+ examples.append(
252
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
253
+ )
254
+ return examples
255
+
256
+ def get_dev_examples(self, data_dir):
257
+ """See base class."""
258
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
259
+ examples = []
260
+ for (i, line) in enumerate(lines):
261
+ if i == 0:
262
+ continue
263
+ guid = "dev-%d" % (i)
264
+ language = tokenization.convert_to_unicode(line[0])
265
+ if language != tokenization.convert_to_unicode(self.language):
266
+ continue
267
+ text_a = tokenization.convert_to_unicode(line[6])
268
+ text_b = tokenization.convert_to_unicode(line[7])
269
+ label = tokenization.convert_to_unicode(line[1])
270
+ examples.append(
271
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
272
+ )
273
+ return examples
274
+
275
+ def get_labels(self):
276
+ """See base class."""
277
+ return ["contradiction", "entailment", "neutral"]
278
+
279
+
280
+ class MnliProcessor(DataProcessor):
281
+ """Processor for the MultiNLI data set (GLUE version)."""
282
+
283
+ def get_train_examples(self, data_dir):
284
+ """See base class."""
285
+ return self._create_examples(
286
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
287
+ )
288
+
289
+ def get_dev_examples(self, data_dir):
290
+ """See base class."""
291
+ return self._create_examples(
292
+ self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched"
293
+ )
294
+
295
+ def get_test_examples(self, data_dir):
296
+ """See base class."""
297
+ return self._create_examples(
298
+ self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test"
299
+ )
300
+
301
+ def get_labels(self):
302
+ """See base class."""
303
+ return ["contradiction", "entailment", "neutral"]
304
+
305
+ def _create_examples(self, lines, set_type):
306
+ """Creates examples for the training and dev sets."""
307
+ examples = []
308
+ for (i, line) in enumerate(lines):
309
+ if i == 0:
310
+ continue
311
+ guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
312
+ text_a = tokenization.convert_to_unicode(line[8])
313
+ text_b = tokenization.convert_to_unicode(line[9])
314
+ if set_type == "test":
315
+ label = "contradiction"
316
+ else:
317
+ label = tokenization.convert_to_unicode(line[-1])
318
+ examples.append(
319
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
320
+ )
321
+ return examples
322
+
323
+
324
+ class MrpcProcessor(DataProcessor):
325
+ """Processor for the MRPC data set (GLUE version)."""
326
+
327
+ def get_train_examples(self, data_dir):
328
+ """See base class."""
329
+ return self._create_examples(
330
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
331
+ )
332
+
333
+ def get_dev_examples(self, data_dir):
334
+ """See base class."""
335
+ return self._create_examples(
336
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev"
337
+ )
338
+
339
+ def get_test_examples(self, data_dir):
340
+ """See base class."""
341
+ return self._create_examples(
342
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test"
343
+ )
344
+
345
+ def get_labels(self):
346
+ """See base class."""
347
+ return ["0", "1"]
348
+
349
+ def _create_examples(self, lines, set_type):
350
+ """Creates examples for the training and dev sets."""
351
+ examples = []
352
+ for (i, line) in enumerate(lines):
353
+ if i == 0:
354
+ continue
355
+ guid = "%s-%s" % (set_type, i)
356
+ text_a = tokenization.convert_to_unicode(line[3])
357
+ text_b = tokenization.convert_to_unicode(line[4])
358
+ if set_type == "test":
359
+ label = "0"
360
+ else:
361
+ label = tokenization.convert_to_unicode(line[0])
362
+ examples.append(
363
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
364
+ )
365
+ return examples
366
+
367
+
368
+ class ColaProcessor(DataProcessor):
369
+ """Processor for the CoLA data set (GLUE version)."""
370
+
371
+ def get_train_examples(self, data_dir):
372
+ """See base class."""
373
+ return self._create_examples(
374
+ self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
375
+ )
376
+
377
+ def get_dev_examples(self, data_dir):
378
+ """See base class."""
379
+ return self._create_examples(
380
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev"
381
+ )
382
+
383
+ def get_test_examples(self, data_dir):
384
+ """See base class."""
385
+ return self._create_examples(
386
+ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test"
387
+ )
388
+
389
+ def get_labels(self):
390
+ """See base class."""
391
+ return ["0", "1"]
392
+
393
+ def _create_examples(self, lines, set_type):
394
+ """Creates examples for the training and dev sets."""
395
+ examples = []
396
+ for (i, line) in enumerate(lines):
397
+ # Only the test set has a header
398
+ if set_type == "test" and i == 0:
399
+ continue
400
+ guid = "%s-%s" % (set_type, i)
401
+ if set_type == "test":
402
+ text_a = tokenization.convert_to_unicode(line[1])
403
+ label = "0"
404
+ else:
405
+ text_a = tokenization.convert_to_unicode(line[3])
406
+ label = tokenization.convert_to_unicode(line[1])
407
+ examples.append(
408
+ InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
409
+ )
410
+ return examples
411
+
412
+
413
+ def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
414
+ """Converts a single `InputExample` into a single `InputFeatures`."""
415
+
416
+ if isinstance(example, PaddingInputExample):
417
+ return InputFeatures(
418
+ input_ids=[0] * max_seq_length,
419
+ input_mask=[0] * max_seq_length,
420
+ segment_ids=[0] * max_seq_length,
421
+ label_id=0,
422
+ is_real_example=False,
423
+ )
424
+
425
+ label_map = {}
426
+ for (i, label) in enumerate(label_list):
427
+ label_map[label] = i
428
+
429
+ tokens_a = tokenizer.tokenize(example.text_a)
430
+ tokens_b = None
431
+ if example.text_b:
432
+ tokens_b = tokenizer.tokenize(example.text_b)
433
+
434
+ if tokens_b:
435
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
436
+ # length is less than the specified length.
437
+ # Account for [CLS], [SEP], [SEP] with "- 3"
438
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
439
+ else:
440
+ # Account for [CLS] and [SEP] with "- 2"
441
+ if len(tokens_a) > max_seq_length - 2:
442
+ tokens_a = tokens_a[0 : (max_seq_length - 2)]
443
+
444
+ # The convention in BERT is:
445
+ # (a) For sequence pairs:
446
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
447
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
448
+ # (b) For single sequences:
449
+ # tokens: [CLS] the dog is hairy . [SEP]
450
+ # type_ids: 0 0 0 0 0 0 0
451
+ #
452
+ # Where "type_ids" are used to indicate whether this is the first
453
+ # sequence or the second sequence. The embedding vectors for `type=0` and
454
+ # `type=1` were learned during pre-training and are added to the wordpiece
455
+ # embedding vector (and position vector). This is not *strictly* necessary
456
+ # since the [SEP] token unambiguously separates the sequences, but it makes
457
+ # it easier for the model to learn the concept of sequences.
458
+ #
459
+ # For classification tasks, the first vector (corresponding to [CLS]) is
460
+ # used as the "sentence vector". Note that this only makes sense because
461
+ # the entire model is fine-tuned.
462
+ tokens = []
463
+ segment_ids = []
464
+ tokens.append("[CLS]")
465
+ segment_ids.append(0)
466
+ for token in tokens_a:
467
+ tokens.append(token)
468
+ segment_ids.append(0)
469
+ tokens.append("[SEP]")
470
+ segment_ids.append(0)
471
+
472
+ if tokens_b:
473
+ for token in tokens_b:
474
+ tokens.append(token)
475
+ segment_ids.append(1)
476
+ tokens.append("[SEP]")
477
+ segment_ids.append(1)
478
+
479
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
480
+
481
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
482
+ # tokens are attended to.
483
+ input_mask = [1] * len(input_ids)
484
+
485
+ # Zero-pad up to the sequence length.
486
+ while len(input_ids) < max_seq_length:
487
+ input_ids.append(0)
488
+ input_mask.append(0)
489
+ segment_ids.append(0)
490
+
491
+ assert len(input_ids) == max_seq_length
492
+ assert len(input_mask) == max_seq_length
493
+ assert len(segment_ids) == max_seq_length
494
+
495
+ label_id = label_map[example.label]
496
+ if ex_index < 5:
497
+ tf.logging.info("*** Example ***")
498
+ tf.logging.info("guid: %s" % (example.guid))
499
+ tf.logging.info(
500
+ "tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])
501
+ )
502
+ tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
503
+ tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
504
+ tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
505
+ tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
506
+
507
+ feature = InputFeatures(
508
+ input_ids=input_ids,
509
+ input_mask=input_mask,
510
+ segment_ids=segment_ids,
511
+ label_id=label_id,
512
+ is_real_example=True,
513
+ )
514
+ return feature
515
+
516
+
517
+ def file_based_convert_examples_to_features(
518
+ examples, label_list, max_seq_length, tokenizer, output_file
519
+ ):
520
+ """Convert a set of `InputExample`s to a TFRecord file."""
521
+
522
+ writer = tf.python_io.TFRecordWriter(output_file)
523
+
524
+ for (ex_index, example) in enumerate(examples):
525
+ if ex_index % 10000 == 0:
526
+ tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
527
+
528
+ feature = convert_single_example(
529
+ ex_index, example, label_list, max_seq_length, tokenizer
530
+ )
531
+
532
+ def create_int_feature(values):
533
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
534
+ return f
535
+
536
+ features = collections.OrderedDict()
537
+ features["input_ids"] = create_int_feature(feature.input_ids)
538
+ features["input_mask"] = create_int_feature(feature.input_mask)
539
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
540
+ features["label_ids"] = create_int_feature([feature.label_id])
541
+ features["is_real_example"] = create_int_feature([int(feature.is_real_example)])
542
+
543
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
544
+ writer.write(tf_example.SerializeToString())
545
+ writer.close()
546
+
547
+
548
+ def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
549
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
550
+
551
+ name_to_features = {
552
+ "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
553
+ "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
554
+ "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
555
+ "label_ids": tf.FixedLenFeature([], tf.int64),
556
+ "is_real_example": tf.FixedLenFeature([], tf.int64),
557
+ }
558
+
559
+ def _decode_record(record, name_to_features):
560
+ """Decodes a record to a TensorFlow example."""
561
+ example = tf.parse_single_example(record, name_to_features)
562
+
563
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
564
+ # So cast all int64 to int32.
565
+ for name in list(example.keys()):
566
+ t = example[name]
567
+ if t.dtype == tf.int64:
568
+ t = tf.to_int32(t)
569
+ example[name] = t
570
+
571
+ return example
572
+
573
+ def input_fn(params):
574
+ """The actual input function."""
575
+ batch_size = params["batch_size"]
576
+
577
+ # For training, we want a lot of parallel reading and shuffling.
578
+ # For eval, we want no shuffling and parallel reading doesn't matter.
579
+ d = tf.data.TFRecordDataset(input_file)
580
+ if is_training:
581
+ d = d.repeat()
582
+ d = d.shuffle(buffer_size=100)
583
+
584
+ d = d.apply(
585
+ tf.contrib.data.map_and_batch(
586
+ lambda record: _decode_record(record, name_to_features),
587
+ batch_size=batch_size,
588
+ drop_remainder=drop_remainder,
589
+ )
590
+ )
591
+
592
+ return d
593
+
594
+ return input_fn
595
+
596
+
597
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
598
+ """Truncates a sequence pair in place to the maximum length."""
599
+
600
+ # This is a simple heuristic which will always truncate the longer sequence
601
+ # one token at a time. This makes more sense than truncating an equal percent
602
+ # of tokens from each, since if one sequence is very short then each token
603
+ # that's truncated likely contains more information than a longer sequence.
604
+ while True:
605
+ total_length = len(tokens_a) + len(tokens_b)
606
+ if total_length <= max_length:
607
+ break
608
+ if len(tokens_a) > len(tokens_b):
609
+ tokens_a.pop()
610
+ else:
611
+ tokens_b.pop()
612
+
613
+
614
+ def create_model(
615
+ bert_config,
616
+ is_training,
617
+ input_ids,
618
+ input_mask,
619
+ segment_ids,
620
+ labels,
621
+ num_labels,
622
+ use_one_hot_embeddings,
623
+ ):
624
+ """Creates a classification model."""
625
+ model = modeling.BertModel(
626
+ config=bert_config,
627
+ is_training=is_training,
628
+ input_ids=input_ids,
629
+ input_mask=input_mask,
630
+ token_type_ids=segment_ids,
631
+ use_one_hot_embeddings=use_one_hot_embeddings,
632
+ )
633
+
634
+ # In the demo, we are doing a simple classification task on the entire
635
+ # segment.
636
+ #
637
+ # If you want to use the token-level output, use model.get_sequence_output()
638
+ # instead.
639
+ output_layer = model.get_pooled_output()
640
+
641
+ hidden_size = output_layer.shape[-1].value
642
+
643
+ output_weights = tf.get_variable(
644
+ "output_weights",
645
+ [num_labels, hidden_size],
646
+ initializer=tf.truncated_normal_initializer(stddev=0.02),
647
+ )
648
+
649
+ output_bias = tf.get_variable(
650
+ "output_bias", [num_labels], initializer=tf.zeros_initializer()
651
+ )
652
+
653
+ with tf.variable_scope("loss"):
654
+ if is_training:
655
+ # I.e., 0.1 dropout
656
+ output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
657
+
658
+ logits = tf.matmul(output_layer, output_weights, transpose_b=True)
659
+ logits = tf.nn.bias_add(logits, output_bias)
660
+ probabilities = tf.nn.softmax(logits, axis=-1)
661
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
662
+
663
+ one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
664
+
665
+ per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
666
+ loss = tf.reduce_mean(per_example_loss)
667
+
668
+ return (loss, per_example_loss, logits, probabilities)
669
+
670
+
671
+ def model_fn_builder(
672
+ bert_config,
673
+ num_labels,
674
+ init_checkpoint,
675
+ learning_rate,
676
+ num_train_steps,
677
+ num_warmup_steps,
678
+ use_tpu,
679
+ use_one_hot_embeddings,
680
+ ):
681
+ """Returns `model_fn` closure for TPUEstimator."""
682
+
683
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
684
+ """The `model_fn` for TPUEstimator."""
685
+
686
+ tf.logging.info("*** Features ***")
687
+ for name in sorted(features.keys()):
688
+ tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
689
+
690
+ input_ids = features["input_ids"]
691
+ input_mask = features["input_mask"]
692
+ segment_ids = features["segment_ids"]
693
+ label_ids = features["label_ids"]
694
+ is_real_example = None
695
+ if "is_real_example" in features:
696
+ is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
697
+ else:
698
+ is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
699
+
700
+ is_training = mode == tf.estimator.ModeKeys.TRAIN
701
+
702
+ (total_loss, per_example_loss, logits, probabilities) = create_model(
703
+ bert_config,
704
+ is_training,
705
+ input_ids,
706
+ input_mask,
707
+ segment_ids,
708
+ label_ids,
709
+ num_labels,
710
+ use_one_hot_embeddings,
711
+ )
712
+
713
+ tvars = tf.trainable_variables()
714
+ initialized_variable_names = {}
715
+ scaffold_fn = None
716
+ if init_checkpoint:
717
+ (
718
+ assignment_map,
719
+ initialized_variable_names,
720
+ ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
721
+ if use_tpu:
722
+
723
+ def tpu_scaffold():
724
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
725
+ return tf.train.Scaffold()
726
+
727
+ scaffold_fn = tpu_scaffold
728
+ else:
729
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
730
+
731
+ tf.logging.info("**** Trainable Variables ****")
732
+ for var in tvars:
733
+ init_string = ""
734
+ if var.name in initialized_variable_names:
735
+ init_string = ", *INIT_FROM_CKPT*"
736
+ tf.logging.info(
737
+ " name = %s, shape = %s%s", var.name, var.shape, init_string
738
+ )
739
+
740
+ output_spec = None
741
+ if mode == tf.estimator.ModeKeys.TRAIN:
742
+
743
+ train_op = optimization.create_optimizer(
744
+ total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
745
+ )
746
+
747
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
748
+ mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
749
+ )
750
+ elif mode == tf.estimator.ModeKeys.EVAL:
751
+
752
+ def metric_fn(per_example_loss, label_ids, logits, is_real_example):
753
+ predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
754
+ accuracy = tf.metrics.accuracy(
755
+ labels=label_ids, predictions=predictions, weights=is_real_example
756
+ )
757
+ loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
758
+ return {
759
+ "eval_accuracy": accuracy,
760
+ "eval_loss": loss,
761
+ }
762
+
763
+ eval_metrics = (
764
+ metric_fn,
765
+ [per_example_loss, label_ids, logits, is_real_example],
766
+ )
767
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
768
+ mode=mode,
769
+ loss=total_loss,
770
+ eval_metrics=eval_metrics,
771
+ scaffold_fn=scaffold_fn,
772
+ )
773
+ else:
774
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
775
+ mode=mode,
776
+ predictions={"probabilities": probabilities},
777
+ scaffold_fn=scaffold_fn,
778
+ )
779
+ return output_spec
780
+
781
+ return model_fn
782
+
783
+
784
+ # This function is not used by this file but is still used by the Colab and
785
+ # people who depend on it.
786
+ def input_fn_builder(features, seq_length, is_training, drop_remainder):
787
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
788
+
789
+ all_input_ids = []
790
+ all_input_mask = []
791
+ all_segment_ids = []
792
+ all_label_ids = []
793
+
794
+ for feature in features:
795
+ all_input_ids.append(feature.input_ids)
796
+ all_input_mask.append(feature.input_mask)
797
+ all_segment_ids.append(feature.segment_ids)
798
+ all_label_ids.append(feature.label_id)
799
+
800
+ def input_fn(params):
801
+ """The actual input function."""
802
+ batch_size = params["batch_size"]
803
+
804
+ num_examples = len(features)
805
+
806
+ # This is for demo purposes and does NOT scale to large data sets. We do
807
+ # not use Dataset.from_generator() because that uses tf.py_func which is
808
+ # not TPU compatible. The right way to load data is with TFRecordReader.
809
+ d = tf.data.Dataset.from_tensor_slices(
810
+ {
811
+ "input_ids": tf.constant(
812
+ all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32
813
+ ),
814
+ "input_mask": tf.constant(
815
+ all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32
816
+ ),
817
+ "segment_ids": tf.constant(
818
+ all_segment_ids, shape=[num_examples, seq_length], dtype=tf.int32
819
+ ),
820
+ "label_ids": tf.constant(
821
+ all_label_ids, shape=[num_examples], dtype=tf.int32
822
+ ),
823
+ }
824
+ )
825
+
826
+ if is_training:
827
+ d = d.repeat()
828
+ d = d.shuffle(buffer_size=100)
829
+
830
+ d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
831
+ return d
832
+
833
+ return input_fn
834
+
835
+
836
+ # This function is not used by this file but is still used by the Colab and
837
+ # people who depend on it.
838
+ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
839
+ """Convert a set of `InputExample`s to a list of `InputFeatures`."""
840
+
841
+ features = []
842
+ for (ex_index, example) in enumerate(examples):
843
+ if ex_index % 10000 == 0:
844
+ tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
845
+
846
+ feature = convert_single_example(
847
+ ex_index, example, label_list, max_seq_length, tokenizer
848
+ )
849
+
850
+ features.append(feature)
851
+ return features
852
+
853
+
854
+ def main(_):
855
+ tf.logging.set_verbosity(tf.logging.INFO)
856
+ logger = tf.get_logger()
857
+ logger.propagate = False
858
+
859
+ processors = {
860
+ "cola": ColaProcessor,
861
+ "mnli": MnliProcessor,
862
+ "mrpc": MrpcProcessor,
863
+ "xnli": XnliProcessor,
864
+ }
865
+
866
+ tokenization.validate_case_matches_checkpoint(
867
+ FLAGS.do_lower_case, FLAGS.init_checkpoint
868
+ )
869
+
870
+ if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
871
+ raise ValueError(
872
+ "At least one of `do_train`, `do_eval` or `do_predict' must be True."
873
+ )
874
+
875
+ bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
876
+
877
+ if FLAGS.max_seq_length > bert_config.max_position_embeddings:
878
+ raise ValueError(
879
+ "Cannot use sequence length %d because the BERT model "
880
+ "was only trained up to sequence length %d"
881
+ % (FLAGS.max_seq_length, bert_config.max_position_embeddings)
882
+ )
883
+
884
+ tf.gfile.MakeDirs(FLAGS.output_dir)
885
+
886
+ task_name = FLAGS.task_name.lower()
887
+
888
+ if task_name not in processors:
889
+ raise ValueError("Task not found: %s" % (task_name))
890
+
891
+ processor = processors[task_name]()
892
+
893
+ label_list = processor.get_labels()
894
+
895
+ tokenizer = tokenization.FullTokenizer(
896
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
897
+ )
898
+
899
+ tpu_cluster_resolver = None
900
+ if FLAGS.use_tpu and FLAGS.tpu_name:
901
+ tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
902
+ FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
903
+ )
904
+
905
+ is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
906
+ run_config = tf.contrib.tpu.RunConfig(
907
+ cluster=tpu_cluster_resolver,
908
+ master=FLAGS.master,
909
+ model_dir=FLAGS.output_dir,
910
+ save_checkpoints_steps=FLAGS.save_checkpoints_steps,
911
+ tpu_config=tf.contrib.tpu.TPUConfig(
912
+ iterations_per_loop=FLAGS.iterations_per_loop,
913
+ num_shards=FLAGS.num_tpu_cores,
914
+ per_host_input_for_training=is_per_host,
915
+ ),
916
+ )
917
+
918
+ train_examples = None
919
+ num_train_steps = None
920
+ num_warmup_steps = None
921
+ if FLAGS.do_train:
922
+ train_examples = processor.get_train_examples(FLAGS.data_dir)
923
+ num_train_steps = int(
924
+ len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
925
+ )
926
+ num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
927
+
928
+ model_fn = model_fn_builder(
929
+ bert_config=bert_config,
930
+ num_labels=len(label_list),
931
+ init_checkpoint=FLAGS.init_checkpoint,
932
+ learning_rate=FLAGS.learning_rate,
933
+ num_train_steps=num_train_steps,
934
+ num_warmup_steps=num_warmup_steps,
935
+ use_tpu=FLAGS.use_tpu,
936
+ use_one_hot_embeddings=FLAGS.use_tpu,
937
+ )
938
+
939
+ # If TPU is not available, this will fall back to normal Estimator on CPU
940
+ # or GPU.
941
+ estimator = tf.contrib.tpu.TPUEstimator(
942
+ use_tpu=FLAGS.use_tpu,
943
+ model_fn=model_fn,
944
+ config=run_config,
945
+ train_batch_size=FLAGS.train_batch_size,
946
+ eval_batch_size=FLAGS.eval_batch_size,
947
+ predict_batch_size=FLAGS.predict_batch_size,
948
+ )
949
+
950
+ if FLAGS.do_train:
951
+ train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
952
+ file_based_convert_examples_to_features(
953
+ train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
954
+ )
955
+ tf.logging.info("***** Running training *****")
956
+ tf.logging.info(" Num examples = %d", len(train_examples))
957
+ tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
958
+ tf.logging.info(" Num steps = %d", num_train_steps)
959
+ train_input_fn = file_based_input_fn_builder(
960
+ input_file=train_file,
961
+ seq_length=FLAGS.max_seq_length,
962
+ is_training=True,
963
+ drop_remainder=True,
964
+ )
965
+ estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
966
+
967
+ if FLAGS.do_eval:
968
+ eval_examples = processor.get_dev_examples(FLAGS.data_dir)
969
+ num_actual_eval_examples = len(eval_examples)
970
+ if FLAGS.use_tpu:
971
+ # TPU requires a fixed batch size for all batches, therefore the number
972
+ # of examples must be a multiple of the batch size, or else examples
973
+ # will get dropped. So we pad with fake examples which are ignored
974
+ # later on. These do NOT count towards the metric (all tf.metrics
975
+ # support a per-instance weight, and these get a weight of 0.0).
976
+ while len(eval_examples) % FLAGS.eval_batch_size != 0:
977
+ eval_examples.append(PaddingInputExample())
978
+
979
+ eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
980
+ file_based_convert_examples_to_features(
981
+ eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file
982
+ )
983
+
984
+ tf.logging.info("***** Running evaluation *****")
985
+ tf.logging.info(
986
+ " Num examples = %d (%d actual, %d padding)",
987
+ len(eval_examples),
988
+ num_actual_eval_examples,
989
+ len(eval_examples) - num_actual_eval_examples,
990
+ )
991
+ tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
992
+
993
+ # This tells the estimator to run through the entire set.
994
+ eval_steps = None
995
+ # However, if running eval on the TPU, you will need to specify the
996
+ # number of steps.
997
+ if FLAGS.use_tpu:
998
+ assert len(eval_examples) % FLAGS.eval_batch_size == 0
999
+ eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
1000
+
1001
+ eval_drop_remainder = True if FLAGS.use_tpu else False
1002
+ eval_input_fn = file_based_input_fn_builder(
1003
+ input_file=eval_file,
1004
+ seq_length=FLAGS.max_seq_length,
1005
+ is_training=False,
1006
+ drop_remainder=eval_drop_remainder,
1007
+ )
1008
+
1009
+ result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
1010
+
1011
+ output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
1012
+ with tf.gfile.GFile(output_eval_file, "w") as writer:
1013
+ tf.logging.info("***** Eval results *****")
1014
+ for key in sorted(result.keys()):
1015
+ tf.logging.info(" %s = %s", key, str(result[key]))
1016
+ writer.write("%s = %s\n" % (key, str(result[key])))
1017
+
1018
+ if FLAGS.do_predict:
1019
+ predict_examples = processor.get_test_examples(FLAGS.data_dir)
1020
+ num_actual_predict_examples = len(predict_examples)
1021
+ if FLAGS.use_tpu:
1022
+ # TPU requires a fixed batch size for all batches, therefore the number
1023
+ # of examples must be a multiple of the batch size, or else examples
1024
+ # will get dropped. So we pad with fake examples which are ignored
1025
+ # later on.
1026
+ while len(predict_examples) % FLAGS.predict_batch_size != 0:
1027
+ predict_examples.append(PaddingInputExample())
1028
+
1029
+ predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
1030
+ file_based_convert_examples_to_features(
1031
+ predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file
1032
+ )
1033
+
1034
+ tf.logging.info("***** Running prediction*****")
1035
+ tf.logging.info(
1036
+ " Num examples = %d (%d actual, %d padding)",
1037
+ len(predict_examples),
1038
+ num_actual_predict_examples,
1039
+ len(predict_examples) - num_actual_predict_examples,
1040
+ )
1041
+ tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
1042
+
1043
+ predict_drop_remainder = True if FLAGS.use_tpu else False
1044
+ predict_input_fn = file_based_input_fn_builder(
1045
+ input_file=predict_file,
1046
+ seq_length=FLAGS.max_seq_length,
1047
+ is_training=False,
1048
+ drop_remainder=predict_drop_remainder,
1049
+ )
1050
+
1051
+ result = estimator.predict(input_fn=predict_input_fn)
1052
+
1053
+ output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
1054
+ with tf.gfile.GFile(output_predict_file, "w") as writer:
1055
+ num_written_lines = 0
1056
+ tf.logging.info("***** Predict results *****")
1057
+ for (i, prediction) in enumerate(result):
1058
+ probabilities = prediction["probabilities"]
1059
+ if i >= num_actual_predict_examples:
1060
+ break
1061
+ output_line = (
1062
+ "\t".join(
1063
+ str(class_probability) for class_probability in probabilities
1064
+ )
1065
+ + "\n"
1066
+ )
1067
+ writer.write(output_line)
1068
+ num_written_lines += 1
1069
+ assert num_written_lines == num_actual_predict_examples
1070
+
1071
+
1072
+ if __name__ == "__main__":
1073
+ flags.mark_flag_as_required("data_dir")
1074
+ flags.mark_flag_as_required("task_name")
1075
+ flags.mark_flag_as_required("vocab_file")
1076
+ flags.mark_flag_as_required("bert_config_file")
1077
+ flags.mark_flag_as_required("output_dir")
1078
+ tf.app.run()
arabert/arabert/run_pretraining.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Run masked LM/next sentence masked_lm pre-training for BERT."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import os
22
+ import modeling
23
+ import optimization
24
+ import tensorflow as tf
25
+
26
+ flags = tf.flags
27
+
28
+ FLAGS = flags.FLAGS
29
+
30
+ ## Required parameters
31
+ flags.DEFINE_string(
32
+ "bert_config_file",
33
+ None,
34
+ "The config json file corresponding to the pre-trained BERT model. "
35
+ "This specifies the model architecture.",
36
+ )
37
+
38
+ flags.DEFINE_string(
39
+ "input_file", None, "Input TF example files (can be a glob or comma separated)."
40
+ )
41
+
42
+ flags.DEFINE_string(
43
+ "output_dir",
44
+ None,
45
+ "The output directory where the model checkpoints will be written.",
46
+ )
47
+
48
+ ## Other parameters
49
+ flags.DEFINE_string(
50
+ "init_checkpoint",
51
+ None,
52
+ "Initial checkpoint (usually from a pre-trained BERT model).",
53
+ )
54
+
55
+ flags.DEFINE_integer(
56
+ "max_seq_length",
57
+ 128,
58
+ "The maximum total input sequence length after WordPiece tokenization. "
59
+ "Sequences longer than this will be truncated, and sequences shorter "
60
+ "than this will be padded. Must match data generation.",
61
+ )
62
+
63
+ flags.DEFINE_integer(
64
+ "max_predictions_per_seq",
65
+ 20,
66
+ "Maximum number of masked LM predictions per sequence. "
67
+ "Must match data generation.",
68
+ )
69
+
70
+ flags.DEFINE_bool("do_train", False, "Whether to run training.")
71
+
72
+ flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
73
+
74
+ flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
75
+
76
+ flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
77
+
78
+ flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
79
+
80
+ flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
81
+ "The optimizer for training.")
82
+
83
+ flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
84
+
85
+ flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
86
+
87
+ flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
88
+
89
+ flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
90
+
91
+ flags.DEFINE_integer(
92
+ "save_checkpoints_steps", 1000, "How often to save the model checkpoint."
93
+ )
94
+
95
+ flags.DEFINE_integer(
96
+ "iterations_per_loop", 1000, "How many steps to make in each estimator call."
97
+ )
98
+
99
+ flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
100
+
101
+ flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
102
+
103
+ tf.flags.DEFINE_string(
104
+ "tpu_name",
105
+ None,
106
+ "The Cloud TPU to use for training. This should be either the name "
107
+ "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
108
+ "url.",
109
+ )
110
+
111
+ tf.flags.DEFINE_string(
112
+ "tpu_zone",
113
+ None,
114
+ "[Optional] GCE zone where the Cloud TPU is located in. If not "
115
+ "specified, we will attempt to automatically detect the GCE project from "
116
+ "metadata.",
117
+ )
118
+
119
+ tf.flags.DEFINE_string(
120
+ "gcp_project",
121
+ None,
122
+ "[Optional] Project name for the Cloud TPU-enabled project. If not "
123
+ "specified, we will attempt to automatically detect the GCE project from "
124
+ "metadata.",
125
+ )
126
+
127
+ tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
128
+
129
+ flags.DEFINE_integer(
130
+ "num_tpu_cores",
131
+ 8,
132
+ "Only used if `use_tpu` is True. Total number of TPU cores to use.",
133
+ )
134
+
135
+ flags.DEFINE_integer("keep_checkpoint_max", 10,
136
+ "How many checkpoints to keep.")
137
+
138
+
139
+ def model_fn_builder(
140
+ bert_config,
141
+ init_checkpoint,
142
+ learning_rate,
143
+ num_train_steps,
144
+ num_warmup_steps,
145
+ use_tpu,
146
+ use_one_hot_embeddings,
147
+ optimizer,
148
+ poly_power,
149
+ start_warmup_step,
150
+ ):
151
+ """Returns `model_fn` closure for TPUEstimator."""
152
+
153
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
154
+ """The `model_fn` for TPUEstimator."""
155
+
156
+ tf.logging.info("*** Features ***")
157
+ for name in sorted(features.keys()):
158
+ tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
159
+
160
+ input_ids = features["input_ids"]
161
+ input_mask = features["input_mask"]
162
+ segment_ids = features["segment_ids"]
163
+ masked_lm_positions = features["masked_lm_positions"]
164
+ masked_lm_ids = features["masked_lm_ids"]
165
+ masked_lm_weights = features["masked_lm_weights"]
166
+ next_sentence_labels = features["next_sentence_labels"]
167
+
168
+ is_training = mode == tf.estimator.ModeKeys.TRAIN
169
+
170
+ model = modeling.BertModel(
171
+ config=bert_config,
172
+ is_training=is_training,
173
+ input_ids=input_ids,
174
+ input_mask=input_mask,
175
+ token_type_ids=segment_ids,
176
+ use_one_hot_embeddings=use_one_hot_embeddings,
177
+ )
178
+
179
+ (
180
+ masked_lm_loss,
181
+ masked_lm_example_loss,
182
+ masked_lm_log_probs,
183
+ ) = get_masked_lm_output(
184
+ bert_config,
185
+ model.get_sequence_output(),
186
+ model.get_embedding_table(),
187
+ masked_lm_positions,
188
+ masked_lm_ids,
189
+ masked_lm_weights,
190
+ )
191
+
192
+ (
193
+ next_sentence_loss,
194
+ next_sentence_example_loss,
195
+ next_sentence_log_probs,
196
+ ) = get_next_sentence_output(
197
+ bert_config, model.get_pooled_output(), next_sentence_labels
198
+ )
199
+
200
+ total_loss = masked_lm_loss + next_sentence_loss
201
+
202
+ tvars = tf.trainable_variables()
203
+
204
+ initialized_variable_names = {}
205
+ scaffold_fn = None
206
+ if init_checkpoint:
207
+ (
208
+ assignment_map,
209
+ initialized_variable_names,
210
+ ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
211
+ if use_tpu:
212
+
213
+ def tpu_scaffold():
214
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
215
+ return tf.train.Scaffold()
216
+
217
+ scaffold_fn = tpu_scaffold
218
+ else:
219
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
220
+
221
+ tf.logging.info("**** Trainable Variables ****")
222
+ for var in tvars:
223
+ init_string = ""
224
+ if var.name in initialized_variable_names:
225
+ init_string = ", *INIT_FROM_CKPT*"
226
+ tf.logging.info(
227
+ " name = %s, shape = %s%s", var.name, var.shape, init_string
228
+ )
229
+
230
+ output_spec = None
231
+ if mode == tf.estimator.ModeKeys.TRAIN:
232
+ train_op = optimization.create_optimizer(
233
+ total_loss,
234
+ learning_rate,
235
+ num_train_steps,
236
+ num_warmup_steps,
237
+ use_tpu,
238
+ optimizer,
239
+ poly_power,
240
+ start_warmup_step,
241
+ )
242
+
243
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
244
+ mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
245
+ )
246
+ elif mode == tf.estimator.ModeKeys.EVAL:
247
+
248
+ def metric_fn(
249
+ masked_lm_example_loss,
250
+ masked_lm_log_probs,
251
+ masked_lm_ids,
252
+ masked_lm_weights,
253
+ next_sentence_example_loss,
254
+ next_sentence_log_probs,
255
+ next_sentence_labels,
256
+ ):
257
+ """Computes the loss and accuracy of the model."""
258
+ masked_lm_log_probs = tf.reshape(
259
+ masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]
260
+ )
261
+ masked_lm_predictions = tf.argmax(
262
+ masked_lm_log_probs, axis=-1, output_type=tf.int32
263
+ )
264
+ masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
265
+ masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
266
+ masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
267
+ masked_lm_accuracy = tf.metrics.accuracy(
268
+ labels=masked_lm_ids,
269
+ predictions=masked_lm_predictions,
270
+ weights=masked_lm_weights,
271
+ )
272
+ masked_lm_mean_loss = tf.metrics.mean(
273
+ values=masked_lm_example_loss, weights=masked_lm_weights
274
+ )
275
+
276
+ next_sentence_log_probs = tf.reshape(
277
+ next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]
278
+ )
279
+ next_sentence_predictions = tf.argmax(
280
+ next_sentence_log_probs, axis=-1, output_type=tf.int32
281
+ )
282
+ next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
283
+ next_sentence_accuracy = tf.metrics.accuracy(
284
+ labels=next_sentence_labels, predictions=next_sentence_predictions
285
+ )
286
+ next_sentence_mean_loss = tf.metrics.mean(
287
+ values=next_sentence_example_loss
288
+ )
289
+
290
+ return {
291
+ "masked_lm_accuracy": masked_lm_accuracy,
292
+ "masked_lm_loss": masked_lm_mean_loss,
293
+ "next_sentence_accuracy": next_sentence_accuracy,
294
+ "next_sentence_loss": next_sentence_mean_loss,
295
+ }
296
+
297
+ eval_metrics = (
298
+ metric_fn,
299
+ [
300
+ masked_lm_example_loss,
301
+ masked_lm_log_probs,
302
+ masked_lm_ids,
303
+ masked_lm_weights,
304
+ next_sentence_example_loss,
305
+ next_sentence_log_probs,
306
+ next_sentence_labels,
307
+ ],
308
+ )
309
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
310
+ mode=mode,
311
+ loss=total_loss,
312
+ eval_metrics=eval_metrics,
313
+ scaffold_fn=scaffold_fn,
314
+ )
315
+ else:
316
+ raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
317
+
318
+ return output_spec
319
+
320
+ return model_fn
321
+
322
+
323
+ def get_masked_lm_output(
324
+ bert_config, input_tensor, output_weights, positions, label_ids, label_weights
325
+ ):
326
+ """Get loss and log probs for the masked LM."""
327
+ input_tensor = gather_indexes(input_tensor, positions)
328
+
329
+ with tf.variable_scope("cls/predictions"):
330
+ # We apply one more non-linear transformation before the output layer.
331
+ # This matrix is not used after pre-training.
332
+ with tf.variable_scope("transform"):
333
+ input_tensor = tf.layers.dense(
334
+ input_tensor,
335
+ units=bert_config.hidden_size,
336
+ activation=modeling.get_activation(bert_config.hidden_act),
337
+ kernel_initializer=modeling.create_initializer(
338
+ bert_config.initializer_range
339
+ ),
340
+ )
341
+ input_tensor = modeling.layer_norm(input_tensor)
342
+
343
+ # The output weights are the same as the input embeddings, but there is
344
+ # an output-only bias for each token.
345
+ output_bias = tf.get_variable(
346
+ "output_bias",
347
+ shape=[bert_config.vocab_size],
348
+ initializer=tf.zeros_initializer(),
349
+ )
350
+ logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
351
+ logits = tf.nn.bias_add(logits, output_bias)
352
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
353
+
354
+ label_ids = tf.reshape(label_ids, [-1])
355
+ label_weights = tf.reshape(label_weights, [-1])
356
+
357
+ one_hot_labels = tf.one_hot(
358
+ label_ids, depth=bert_config.vocab_size, dtype=tf.float32
359
+ )
360
+
361
+ # The `positions` tensor might be zero-padded (if the sequence is too
362
+ # short to have the maximum number of predictions). The `label_weights`
363
+ # tensor has a value of 1.0 for every real prediction and 0.0 for the
364
+ # padding predictions.
365
+ per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
366
+ numerator = tf.reduce_sum(label_weights * per_example_loss)
367
+ denominator = tf.reduce_sum(label_weights) + 1e-5
368
+ loss = numerator / denominator
369
+
370
+ return (loss, per_example_loss, log_probs)
371
+
372
+
373
+ def get_next_sentence_output(bert_config, input_tensor, labels):
374
+ """Get loss and log probs for the next sentence prediction."""
375
+
376
+ # Simple binary classification. Note that 0 is "next sentence" and 1 is
377
+ # "random sentence". This weight matrix is not used after pre-training.
378
+ with tf.variable_scope("cls/seq_relationship"):
379
+ output_weights = tf.get_variable(
380
+ "output_weights",
381
+ shape=[2, bert_config.hidden_size],
382
+ initializer=modeling.create_initializer(bert_config.initializer_range),
383
+ )
384
+ output_bias = tf.get_variable(
385
+ "output_bias", shape=[2], initializer=tf.zeros_initializer()
386
+ )
387
+
388
+ logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
389
+ logits = tf.nn.bias_add(logits, output_bias)
390
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
391
+ labels = tf.reshape(labels, [-1])
392
+ one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
393
+ per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
394
+ loss = tf.reduce_mean(per_example_loss)
395
+ return (loss, per_example_loss, log_probs)
396
+
397
+
398
+ def gather_indexes(sequence_tensor, positions):
399
+ """Gathers the vectors at the specific positions over a minibatch."""
400
+ sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
401
+ batch_size = sequence_shape[0]
402
+ seq_length = sequence_shape[1]
403
+ width = sequence_shape[2]
404
+
405
+ flat_offsets = tf.reshape(
406
+ tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]
407
+ )
408
+ flat_positions = tf.reshape(positions + flat_offsets, [-1])
409
+ flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width])
410
+ output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
411
+ return output_tensor
412
+
413
+
414
+ def input_fn_builder(
415
+ input_files, max_seq_length, max_predictions_per_seq, is_training, num_cpu_threads=4
416
+ ):
417
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
418
+
419
+ def input_fn(params):
420
+ """The actual input function."""
421
+ batch_size = params["batch_size"]
422
+
423
+ name_to_features = {
424
+ "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
425
+ "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
426
+ "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
427
+ "masked_lm_positions": tf.FixedLenFeature(
428
+ [max_predictions_per_seq], tf.int64
429
+ ),
430
+ "masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
431
+ "masked_lm_weights": tf.FixedLenFeature(
432
+ [max_predictions_per_seq], tf.float32
433
+ ),
434
+ "next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
435
+ }
436
+
437
+ # For training, we want a lot of parallel reading and shuffling.
438
+ # For eval, we want no shuffling and parallel reading doesn't matter.
439
+ if is_training:
440
+ d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
441
+ d = d.repeat()
442
+ d = d.shuffle(buffer_size=len(input_files))
443
+
444
+ # `cycle_length` is the number of parallel files that get read.
445
+ cycle_length = min(num_cpu_threads, len(input_files))
446
+
447
+ # `sloppy` mode means that the interleaving is not exact. This adds
448
+ # even more randomness to the training pipeline.
449
+ d = d.apply(
450
+ tf.contrib.data.parallel_interleave(
451
+ tf.data.TFRecordDataset,
452
+ sloppy=is_training,
453
+ cycle_length=cycle_length,
454
+ )
455
+ )
456
+ d = d.shuffle(buffer_size=100)
457
+ else:
458
+ d = tf.data.TFRecordDataset(input_files)
459
+ # Since we evaluate for a fixed number of steps we don't want to encounter
460
+ # out-of-range exceptions.
461
+ d = d.repeat()
462
+
463
+ # We must `drop_remainder` on training because the TPU requires fixed
464
+ # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
465
+ # and we *don't* want to drop the remainder, otherwise we wont cover
466
+ # every sample.
467
+ d = d.apply(
468
+ tf.contrib.data.map_and_batch(
469
+ lambda record: _decode_record(record, name_to_features),
470
+ batch_size=batch_size,
471
+ num_parallel_batches=num_cpu_threads,
472
+ drop_remainder=True,
473
+ )
474
+ )
475
+ return d
476
+
477
+ return input_fn
478
+
479
+
480
+ def _decode_record(record, name_to_features):
481
+ """Decodes a record to a TensorFlow example."""
482
+ example = tf.parse_single_example(record, name_to_features)
483
+
484
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
485
+ # So cast all int64 to int32.
486
+ for name in list(example.keys()):
487
+ t = example[name]
488
+ if t.dtype == tf.int64:
489
+ t = tf.to_int32(t)
490
+ example[name] = t
491
+
492
+ return example
493
+
494
+
495
+ def main(_):
496
+ tf.logging.set_verbosity(tf.logging.INFO)
497
+ logger = tf.get_logger()
498
+ logger.propagate = False
499
+
500
+ if not FLAGS.do_train and not FLAGS.do_eval:
501
+ raise ValueError("At least one of `do_train` or `do_eval` must be True.")
502
+
503
+ bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
504
+
505
+ tf.gfile.MakeDirs(FLAGS.output_dir)
506
+
507
+ input_files = []
508
+ for input_pattern in FLAGS.input_file.split(","):
509
+ input_files.extend(tf.gfile.Glob(input_pattern))
510
+
511
+ # tf.logging.info("*** Input Files ***")
512
+ # for input_file in input_files:
513
+ # tf.logging.info(" %s" % input_file)
514
+
515
+ tpu_cluster_resolver = None
516
+ if FLAGS.use_tpu and FLAGS.tpu_name:
517
+ tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
518
+ FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
519
+ )
520
+
521
+ is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
522
+ run_config = tf.contrib.tpu.RunConfig(
523
+ cluster=tpu_cluster_resolver,
524
+ master=FLAGS.master,
525
+ model_dir=FLAGS.output_dir,
526
+ save_checkpoints_steps=FLAGS.save_checkpoints_steps,
527
+ keep_checkpoint_max=FLAGS.keep_checkpoint_max,
528
+ tpu_config=tf.contrib.tpu.TPUConfig(
529
+ iterations_per_loop=FLAGS.iterations_per_loop,
530
+ num_shards=FLAGS.num_tpu_cores,
531
+ per_host_input_for_training=is_per_host,
532
+ ),
533
+ )
534
+
535
+ model_fn = model_fn_builder(
536
+ bert_config=bert_config,
537
+ init_checkpoint=FLAGS.init_checkpoint,
538
+ learning_rate=FLAGS.learning_rate,
539
+ num_train_steps=FLAGS.num_train_steps,
540
+ num_warmup_steps=FLAGS.num_warmup_steps,
541
+ use_tpu=FLAGS.use_tpu,
542
+ use_one_hot_embeddings=FLAGS.use_tpu,
543
+ optimizer=FLAGS.optimizer,
544
+ poly_power=FLAGS.poly_power,
545
+ start_warmup_step=FLAGS.start_warmup_step
546
+ )
547
+
548
+ # If TPU is not available, this will fall back to normal Estimator on CPU
549
+ # or GPU.
550
+ estimator = tf.contrib.tpu.TPUEstimator(
551
+ use_tpu=FLAGS.use_tpu,
552
+ model_fn=model_fn,
553
+ config=run_config,
554
+ train_batch_size=FLAGS.train_batch_size,
555
+ eval_batch_size=FLAGS.eval_batch_size,
556
+ )
557
+
558
+ if FLAGS.do_train:
559
+ tf.logging.info("***** Running training *****")
560
+ tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
561
+ train_input_fn = input_fn_builder(
562
+ input_files=input_files,
563
+ max_seq_length=FLAGS.max_seq_length,
564
+ max_predictions_per_seq=FLAGS.max_predictions_per_seq,
565
+ is_training=True,
566
+ )
567
+ estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
568
+
569
+ if FLAGS.do_eval:
570
+ tf.logging.info("***** Running evaluation *****")
571
+ tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
572
+
573
+ eval_input_fn = input_fn_builder(
574
+ input_files=input_files,
575
+ max_seq_length=FLAGS.max_seq_length,
576
+ max_predictions_per_seq=FLAGS.max_predictions_per_seq,
577
+ is_training=False,
578
+ )
579
+
580
+ result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
581
+
582
+ output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
583
+ with tf.gfile.GFile(output_eval_file, "w") as writer:
584
+ tf.logging.info("***** Eval results *****")
585
+ for key in sorted(result.keys()):
586
+ tf.logging.info(" %s = %s", key, str(result[key]))
587
+ writer.write("%s = %s\n" % (key, str(result[key])))
588
+
589
+ if __name__ == "__main__":
590
+ flags.mark_flag_as_required("input_file")
591
+ flags.mark_flag_as_required("bert_config_file")
592
+ flags.mark_flag_as_required("output_dir")
593
+ tf.app.run()
arabert/arabert/run_squad.py ADDED
@@ -0,0 +1,1440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Run BERT on SQuAD 1.1 and SQuAD 2.0."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import collections
22
+ import json
23
+ import math
24
+ import os
25
+ import random
26
+ import modeling
27
+ import optimization
28
+ import tokenization
29
+ import six
30
+ import tensorflow as tf
31
+
32
+ flags = tf.flags
33
+
34
+ FLAGS = flags.FLAGS
35
+
36
+ ## Required parameters
37
+ flags.DEFINE_string(
38
+ "bert_config_file",
39
+ None,
40
+ "The config json file corresponding to the pre-trained BERT model. "
41
+ "This specifies the model architecture.",
42
+ )
43
+
44
+ flags.DEFINE_string(
45
+ "vocab_file", None, "The vocabulary file that the BERT model was trained on."
46
+ )
47
+
48
+ flags.DEFINE_string(
49
+ "output_dir",
50
+ None,
51
+ "The output directory where the model checkpoints will be written.",
52
+ )
53
+
54
+ ## Other parameters
55
+ flags.DEFINE_string(
56
+ "train_file", None, "SQuAD json for training. E.g., train-v1.1.json"
57
+ )
58
+
59
+ flags.DEFINE_string(
60
+ "predict_file",
61
+ None,
62
+ "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
63
+ )
64
+
65
+ flags.DEFINE_string(
66
+ "init_checkpoint",
67
+ None,
68
+ "Initial checkpoint (usually from a pre-trained BERT model).",
69
+ )
70
+
71
+ flags.DEFINE_bool(
72
+ "do_lower_case",
73
+ True,
74
+ "Whether to lower case the input text. Should be True for uncased "
75
+ "models and False for cased models.",
76
+ )
77
+
78
+ flags.DEFINE_integer(
79
+ "max_seq_length",
80
+ 384,
81
+ "The maximum total input sequence length after WordPiece tokenization. "
82
+ "Sequences longer than this will be truncated, and sequences shorter "
83
+ "than this will be padded.",
84
+ )
85
+
86
+ flags.DEFINE_integer(
87
+ "doc_stride",
88
+ 128,
89
+ "When splitting up a long document into chunks, how much stride to "
90
+ "take between chunks.",
91
+ )
92
+
93
+ flags.DEFINE_integer(
94
+ "max_query_length",
95
+ 64,
96
+ "The maximum number of tokens for the question. Questions longer than "
97
+ "this will be truncated to this length.",
98
+ )
99
+
100
+ flags.DEFINE_bool("do_train", False, "Whether to run training.")
101
+
102
+ flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
103
+
104
+ flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
105
+
106
+ flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predictions.")
107
+
108
+ flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
109
+
110
+ flags.DEFINE_float(
111
+ "num_train_epochs", 3.0, "Total number of training epochs to perform."
112
+ )
113
+
114
+ flags.DEFINE_float(
115
+ "warmup_proportion",
116
+ 0.1,
117
+ "Proportion of training to perform linear learning rate warmup for. "
118
+ "E.g., 0.1 = 10% of training.",
119
+ )
120
+
121
+ flags.DEFINE_integer(
122
+ "save_checkpoints_steps", 1000, "How often to save the model checkpoint."
123
+ )
124
+
125
+ flags.DEFINE_integer(
126
+ "iterations_per_loop", 1000, "How many steps to make in each estimator call."
127
+ )
128
+
129
+ flags.DEFINE_integer(
130
+ "n_best_size",
131
+ 20,
132
+ "The total number of n-best predictions to generate in the "
133
+ "nbest_predictions.json output file.",
134
+ )
135
+
136
+ flags.DEFINE_integer(
137
+ "max_answer_length",
138
+ 30,
139
+ "The maximum length of an answer that can be generated. This is needed "
140
+ "because the start and end predictions are not conditioned on one another.",
141
+ )
142
+
143
+ flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
144
+
145
+ tf.flags.DEFINE_string(
146
+ "tpu_name",
147
+ None,
148
+ "The Cloud TPU to use for training. This should be either the name "
149
+ "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
150
+ "url.",
151
+ )
152
+
153
+ tf.flags.DEFINE_string(
154
+ "tpu_zone",
155
+ None,
156
+ "[Optional] GCE zone where the Cloud TPU is located in. If not "
157
+ "specified, we will attempt to automatically detect the GCE project from "
158
+ "metadata.",
159
+ )
160
+
161
+ tf.flags.DEFINE_string(
162
+ "gcp_project",
163
+ None,
164
+ "[Optional] Project name for the Cloud TPU-enabled project. If not "
165
+ "specified, we will attempt to automatically detect the GCE project from "
166
+ "metadata.",
167
+ )
168
+
169
+ tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
170
+
171
+ flags.DEFINE_integer(
172
+ "num_tpu_cores",
173
+ 8,
174
+ "Only used if `use_tpu` is True. Total number of TPU cores to use.",
175
+ )
176
+
177
+ flags.DEFINE_bool(
178
+ "verbose_logging",
179
+ False,
180
+ "If true, all of the warnings related to data processing will be printed. "
181
+ "A number of warnings are expected for a normal SQuAD evaluation.",
182
+ )
183
+
184
+ flags.DEFINE_bool(
185
+ "version_2_with_negative",
186
+ False,
187
+ "If true, the SQuAD examples contain some that do not have an answer.",
188
+ )
189
+
190
+ flags.DEFINE_float(
191
+ "null_score_diff_threshold",
192
+ 0.0,
193
+ "If null_score - best_non_null is greater than the threshold predict null.",
194
+ )
195
+
196
+
197
+ class SquadExample(object):
198
+ """A single training/test example for simple sequence classification.
199
+
200
+ For examples without an answer, the start and end position are -1.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ qas_id,
206
+ question_text,
207
+ doc_tokens,
208
+ orig_answer_text=None,
209
+ start_position=None,
210
+ end_position=None,
211
+ is_impossible=False,
212
+ ):
213
+ self.qas_id = qas_id
214
+ self.question_text = question_text
215
+ self.doc_tokens = doc_tokens
216
+ self.orig_answer_text = orig_answer_text
217
+ self.start_position = start_position
218
+ self.end_position = end_position
219
+ self.is_impossible = is_impossible
220
+
221
+ def __str__(self):
222
+ return self.__repr__()
223
+
224
+ def __repr__(self):
225
+ s = ""
226
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
227
+ s += ", question_text: %s" % (tokenization.printable_text(self.question_text))
228
+ s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
229
+ if self.start_position:
230
+ s += ", start_position: %d" % (self.start_position)
231
+ if self.start_position:
232
+ s += ", end_position: %d" % (self.end_position)
233
+ if self.start_position:
234
+ s += ", is_impossible: %r" % (self.is_impossible)
235
+ return s
236
+
237
+
238
+ class InputFeatures(object):
239
+ """A single set of features of data."""
240
+
241
+ def __init__(
242
+ self,
243
+ unique_id,
244
+ example_index,
245
+ doc_span_index,
246
+ tokens,
247
+ token_to_orig_map,
248
+ token_is_max_context,
249
+ input_ids,
250
+ input_mask,
251
+ segment_ids,
252
+ start_position=None,
253
+ end_position=None,
254
+ is_impossible=None,
255
+ ):
256
+ self.unique_id = unique_id
257
+ self.example_index = example_index
258
+ self.doc_span_index = doc_span_index
259
+ self.tokens = tokens
260
+ self.token_to_orig_map = token_to_orig_map
261
+ self.token_is_max_context = token_is_max_context
262
+ self.input_ids = input_ids
263
+ self.input_mask = input_mask
264
+ self.segment_ids = segment_ids
265
+ self.start_position = start_position
266
+ self.end_position = end_position
267
+ self.is_impossible = is_impossible
268
+
269
+
270
+ def read_squad_examples(input_file, is_training):
271
+ """Read a SQuAD json file into a list of SquadExample."""
272
+ with tf.gfile.Open(input_file, "r") as reader:
273
+ input_data = json.load(reader)["data"]
274
+
275
+ def is_whitespace(c):
276
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
277
+ return True
278
+ return False
279
+
280
+ examples = []
281
+ for entry in input_data:
282
+ for paragraph in entry["paragraphs"]:
283
+ paragraph_text = paragraph["context"]
284
+ doc_tokens = []
285
+ char_to_word_offset = []
286
+ prev_is_whitespace = True
287
+ for c in paragraph_text:
288
+ if is_whitespace(c):
289
+ prev_is_whitespace = True
290
+ else:
291
+ if prev_is_whitespace:
292
+ doc_tokens.append(c)
293
+ else:
294
+ doc_tokens[-1] += c
295
+ prev_is_whitespace = False
296
+ char_to_word_offset.append(len(doc_tokens) - 1)
297
+
298
+ for qa in paragraph["qas"]:
299
+ qas_id = qa["id"]
300
+ question_text = qa["question"]
301
+ start_position = None
302
+ end_position = None
303
+ orig_answer_text = None
304
+ is_impossible = False
305
+ if is_training:
306
+
307
+ if FLAGS.version_2_with_negative:
308
+ is_impossible = qa["is_impossible"]
309
+ if (len(qa["answers"]) != 1) and (not is_impossible):
310
+ raise ValueError(
311
+ "For training, each question should have exactly 1 answer."
312
+ )
313
+ if not is_impossible:
314
+ answer = qa["answers"][0]
315
+ orig_answer_text = answer["text"]
316
+ answer_offset = answer["answer_start"]
317
+ answer_length = len(orig_answer_text)
318
+ start_position = char_to_word_offset[answer_offset]
319
+ end_position = char_to_word_offset[
320
+ answer_offset + answer_length - 1
321
+ ]
322
+ # Only add answers where the text can be exactly recovered from the
323
+ # document. If this CAN'T happen it's likely due to weird Unicode
324
+ # stuff so we will just skip the example.
325
+ #
326
+ # Note that this means for training mode, every example is NOT
327
+ # guaranteed to be preserved.
328
+ actual_text = " ".join(
329
+ doc_tokens[start_position : (end_position + 1)]
330
+ )
331
+ cleaned_answer_text = " ".join(
332
+ tokenization.whitespace_tokenize(orig_answer_text)
333
+ )
334
+ if actual_text.find(cleaned_answer_text) == -1:
335
+ tf.logging.warning(
336
+ "Could not find answer: '%s' vs. '%s'",
337
+ actual_text,
338
+ cleaned_answer_text,
339
+ )
340
+ continue
341
+ else:
342
+ start_position = -1
343
+ end_position = -1
344
+ orig_answer_text = ""
345
+
346
+ example = SquadExample(
347
+ qas_id=qas_id,
348
+ question_text=question_text,
349
+ doc_tokens=doc_tokens,
350
+ orig_answer_text=orig_answer_text,
351
+ start_position=start_position,
352
+ end_position=end_position,
353
+ is_impossible=is_impossible,
354
+ )
355
+ examples.append(example)
356
+
357
+ return examples
358
+
359
+
360
+ def convert_examples_to_features(
361
+ examples,
362
+ tokenizer,
363
+ max_seq_length,
364
+ doc_stride,
365
+ max_query_length,
366
+ is_training,
367
+ output_fn,
368
+ ):
369
+ """Loads a data file into a list of `InputBatch`s."""
370
+
371
+ unique_id = 1000000000
372
+
373
+ for (example_index, example) in enumerate(examples):
374
+ query_tokens = tokenizer.tokenize(example.question_text)
375
+
376
+ if len(query_tokens) > max_query_length:
377
+ query_tokens = query_tokens[0:max_query_length]
378
+
379
+ tok_to_orig_index = []
380
+ orig_to_tok_index = []
381
+ all_doc_tokens = []
382
+ for (i, token) in enumerate(example.doc_tokens):
383
+ orig_to_tok_index.append(len(all_doc_tokens))
384
+ sub_tokens = tokenizer.tokenize(token)
385
+ for sub_token in sub_tokens:
386
+ tok_to_orig_index.append(i)
387
+ all_doc_tokens.append(sub_token)
388
+
389
+ tok_start_position = None
390
+ tok_end_position = None
391
+ if is_training and example.is_impossible:
392
+ tok_start_position = -1
393
+ tok_end_position = -1
394
+ if is_training and not example.is_impossible:
395
+ tok_start_position = orig_to_tok_index[example.start_position]
396
+ if example.end_position < len(example.doc_tokens) - 1:
397
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
398
+ else:
399
+ tok_end_position = len(all_doc_tokens) - 1
400
+ (tok_start_position, tok_end_position) = _improve_answer_span(
401
+ all_doc_tokens,
402
+ tok_start_position,
403
+ tok_end_position,
404
+ tokenizer,
405
+ example.orig_answer_text,
406
+ )
407
+
408
+ # The -3 accounts for [CLS], [SEP] and [SEP]
409
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
410
+
411
+ # We can have documents that are longer than the maximum sequence length.
412
+ # To deal with this we do a sliding window approach, where we take chunks
413
+ # of the up to our max length with a stride of `doc_stride`.
414
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
415
+ "DocSpan", ["start", "length"]
416
+ )
417
+ doc_spans = []
418
+ start_offset = 0
419
+ while start_offset < len(all_doc_tokens):
420
+ length = len(all_doc_tokens) - start_offset
421
+ if length > max_tokens_for_doc:
422
+ length = max_tokens_for_doc
423
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
424
+ if start_offset + length == len(all_doc_tokens):
425
+ break
426
+ start_offset += min(length, doc_stride)
427
+
428
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
429
+ tokens = []
430
+ token_to_orig_map = {}
431
+ token_is_max_context = {}
432
+ segment_ids = []
433
+ tokens.append("[CLS]")
434
+ segment_ids.append(0)
435
+ for token in query_tokens:
436
+ tokens.append(token)
437
+ segment_ids.append(0)
438
+ tokens.append("[SEP]")
439
+ segment_ids.append(0)
440
+
441
+ for i in range(doc_span.length):
442
+ split_token_index = doc_span.start + i
443
+ token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
444
+
445
+ is_max_context = _check_is_max_context(
446
+ doc_spans, doc_span_index, split_token_index
447
+ )
448
+ token_is_max_context[len(tokens)] = is_max_context
449
+ tokens.append(all_doc_tokens[split_token_index])
450
+ segment_ids.append(1)
451
+ tokens.append("[SEP]")
452
+ segment_ids.append(1)
453
+
454
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
455
+
456
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
457
+ # tokens are attended to.
458
+ input_mask = [1] * len(input_ids)
459
+
460
+ # Zero-pad up to the sequence length.
461
+ while len(input_ids) < max_seq_length:
462
+ input_ids.append(0)
463
+ input_mask.append(0)
464
+ segment_ids.append(0)
465
+
466
+ assert len(input_ids) == max_seq_length
467
+ assert len(input_mask) == max_seq_length
468
+ assert len(segment_ids) == max_seq_length
469
+
470
+ start_position = None
471
+ end_position = None
472
+ if is_training and not example.is_impossible:
473
+ # For training, if our document chunk does not contain an annotation
474
+ # we throw it out, since there is nothing to predict.
475
+ doc_start = doc_span.start
476
+ doc_end = doc_span.start + doc_span.length - 1
477
+ out_of_span = False
478
+ if not (
479
+ tok_start_position >= doc_start and tok_end_position <= doc_end
480
+ ):
481
+ out_of_span = True
482
+ if out_of_span:
483
+ start_position = 0
484
+ end_position = 0
485
+ else:
486
+ doc_offset = len(query_tokens) + 2
487
+ start_position = tok_start_position - doc_start + doc_offset
488
+ end_position = tok_end_position - doc_start + doc_offset
489
+
490
+ if is_training and example.is_impossible:
491
+ start_position = 0
492
+ end_position = 0
493
+
494
+ if example_index < 20:
495
+ tf.logging.info("*** Example ***")
496
+ tf.logging.info("unique_id: %s" % (unique_id))
497
+ tf.logging.info("example_index: %s" % (example_index))
498
+ tf.logging.info("doc_span_index: %s" % (doc_span_index))
499
+ tf.logging.info(
500
+ "tokens: %s"
501
+ % " ".join([tokenization.printable_text(x) for x in tokens])
502
+ )
503
+ tf.logging.info(
504
+ "token_to_orig_map: %s"
505
+ % " ".join(
506
+ [
507
+ "%d:%d" % (x, y)
508
+ for (x, y) in six.iteritems(token_to_orig_map)
509
+ ]
510
+ )
511
+ )
512
+ tf.logging.info(
513
+ "token_is_max_context: %s"
514
+ % " ".join(
515
+ [
516
+ "%d:%s" % (x, y)
517
+ for (x, y) in six.iteritems(token_is_max_context)
518
+ ]
519
+ )
520
+ )
521
+ tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
522
+ tf.logging.info(
523
+ "input_mask: %s" % " ".join([str(x) for x in input_mask])
524
+ )
525
+ tf.logging.info(
526
+ "segment_ids: %s" % " ".join([str(x) for x in segment_ids])
527
+ )
528
+ if is_training and example.is_impossible:
529
+ tf.logging.info("impossible example")
530
+ if is_training and not example.is_impossible:
531
+ answer_text = " ".join(tokens[start_position : (end_position + 1)])
532
+ tf.logging.info("start_position: %d" % (start_position))
533
+ tf.logging.info("end_position: %d" % (end_position))
534
+ tf.logging.info(
535
+ "answer: %s" % (tokenization.printable_text(answer_text))
536
+ )
537
+
538
+ feature = InputFeatures(
539
+ unique_id=unique_id,
540
+ example_index=example_index,
541
+ doc_span_index=doc_span_index,
542
+ tokens=tokens,
543
+ token_to_orig_map=token_to_orig_map,
544
+ token_is_max_context=token_is_max_context,
545
+ input_ids=input_ids,
546
+ input_mask=input_mask,
547
+ segment_ids=segment_ids,
548
+ start_position=start_position,
549
+ end_position=end_position,
550
+ is_impossible=example.is_impossible,
551
+ )
552
+
553
+ # Run callback
554
+ output_fn(feature)
555
+
556
+ unique_id += 1
557
+
558
+
559
+ def _improve_answer_span(
560
+ doc_tokens, input_start, input_end, tokenizer, orig_answer_text
561
+ ):
562
+ """Returns tokenized answer spans that better match the annotated answer."""
563
+
564
+ # The SQuAD annotations are character based. We first project them to
565
+ # whitespace-tokenized words. But then after WordPiece tokenization, we can
566
+ # often find a "better match". For example:
567
+ #
568
+ # Question: What year was John Smith born?
569
+ # Context: The leader was John Smith (1895-1943).
570
+ # Answer: 1895
571
+ #
572
+ # The original whitespace-tokenized answer will be "(1895-1943).". However
573
+ # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
574
+ # the exact answer, 1895.
575
+ #
576
+ # However, this is not always possible. Consider the following:
577
+ #
578
+ # Question: What country is the top exporter of electornics?
579
+ # Context: The Japanese electronics industry is the lagest in the world.
580
+ # Answer: Japan
581
+ #
582
+ # In this case, the annotator chose "Japan" as a character sub-span of
583
+ # the word "Japanese". Since our WordPiece tokenizer does not split
584
+ # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
585
+ # in SQuAD, but does happen.
586
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
587
+
588
+ for new_start in range(input_start, input_end + 1):
589
+ for new_end in range(input_end, new_start - 1, -1):
590
+ text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
591
+ if text_span == tok_answer_text:
592
+ return (new_start, new_end)
593
+
594
+ return (input_start, input_end)
595
+
596
+
597
+ def _check_is_max_context(doc_spans, cur_span_index, position):
598
+ """Check if this is the 'max context' doc span for the token."""
599
+
600
+ # Because of the sliding window approach taken to scoring documents, a single
601
+ # token can appear in multiple documents. E.g.
602
+ # Doc: the man went to the store and bought a gallon of milk
603
+ # Span A: the man went to the
604
+ # Span B: to the store and bought
605
+ # Span C: and bought a gallon of
606
+ # ...
607
+ #
608
+ # Now the word 'bought' will have two scores from spans B and C. We only
609
+ # want to consider the score with "maximum context", which we define as
610
+ # the *minimum* of its left and right context (the *sum* of left and
611
+ # right context will always be the same, of course).
612
+ #
613
+ # In the example the maximum context for 'bought' would be span C since
614
+ # it has 1 left context and 3 right context, while span B has 4 left context
615
+ # and 0 right context.
616
+ best_score = None
617
+ best_span_index = None
618
+ for (span_index, doc_span) in enumerate(doc_spans):
619
+ end = doc_span.start + doc_span.length - 1
620
+ if position < doc_span.start:
621
+ continue
622
+ if position > end:
623
+ continue
624
+ num_left_context = position - doc_span.start
625
+ num_right_context = end - position
626
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
627
+ if best_score is None or score > best_score:
628
+ best_score = score
629
+ best_span_index = span_index
630
+
631
+ return cur_span_index == best_span_index
632
+
633
+
634
+ def create_model(
635
+ bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings
636
+ ):
637
+ """Creates a classification model."""
638
+ model = modeling.BertModel(
639
+ config=bert_config,
640
+ is_training=is_training,
641
+ input_ids=input_ids,
642
+ input_mask=input_mask,
643
+ token_type_ids=segment_ids,
644
+ use_one_hot_embeddings=use_one_hot_embeddings,
645
+ )
646
+
647
+ final_hidden = model.get_sequence_output()
648
+
649
+ final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
650
+ batch_size = final_hidden_shape[0]
651
+ seq_length = final_hidden_shape[1]
652
+ hidden_size = final_hidden_shape[2]
653
+
654
+ output_weights = tf.get_variable(
655
+ "cls/squad/output_weights",
656
+ [2, hidden_size],
657
+ initializer=tf.truncated_normal_initializer(stddev=0.02),
658
+ )
659
+
660
+ output_bias = tf.get_variable(
661
+ "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()
662
+ )
663
+
664
+ final_hidden_matrix = tf.reshape(
665
+ final_hidden, [batch_size * seq_length, hidden_size]
666
+ )
667
+ logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
668
+ logits = tf.nn.bias_add(logits, output_bias)
669
+
670
+ logits = tf.reshape(logits, [batch_size, seq_length, 2])
671
+ logits = tf.transpose(logits, [2, 0, 1])
672
+
673
+ unstacked_logits = tf.unstack(logits, axis=0)
674
+
675
+ (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
676
+
677
+ return (start_logits, end_logits)
678
+
679
+
680
+ def model_fn_builder(
681
+ bert_config,
682
+ init_checkpoint,
683
+ learning_rate,
684
+ num_train_steps,
685
+ num_warmup_steps,
686
+ use_tpu,
687
+ use_one_hot_embeddings,
688
+ ):
689
+ """Returns `model_fn` closure for TPUEstimator."""
690
+
691
+ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
692
+ """The `model_fn` for TPUEstimator."""
693
+
694
+ tf.logging.info("*** Features ***")
695
+ for name in sorted(features.keys()):
696
+ tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
697
+
698
+ unique_ids = features["unique_ids"]
699
+ input_ids = features["input_ids"]
700
+ input_mask = features["input_mask"]
701
+ segment_ids = features["segment_ids"]
702
+
703
+ is_training = mode == tf.estimator.ModeKeys.TRAIN
704
+
705
+ (start_logits, end_logits) = create_model(
706
+ bert_config=bert_config,
707
+ is_training=is_training,
708
+ input_ids=input_ids,
709
+ input_mask=input_mask,
710
+ segment_ids=segment_ids,
711
+ use_one_hot_embeddings=use_one_hot_embeddings,
712
+ )
713
+
714
+ tvars = tf.trainable_variables()
715
+
716
+ initialized_variable_names = {}
717
+ scaffold_fn = None
718
+ if init_checkpoint:
719
+ (
720
+ assignment_map,
721
+ initialized_variable_names,
722
+ ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
723
+ if use_tpu:
724
+
725
+ def tpu_scaffold():
726
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
727
+ return tf.train.Scaffold()
728
+
729
+ scaffold_fn = tpu_scaffold
730
+ else:
731
+ tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
732
+
733
+ tf.logging.info("**** Trainable Variables ****")
734
+ for var in tvars:
735
+ init_string = ""
736
+ if var.name in initialized_variable_names:
737
+ init_string = ", *INIT_FROM_CKPT*"
738
+ tf.logging.info(
739
+ " name = %s, shape = %s%s", var.name, var.shape, init_string
740
+ )
741
+
742
+ output_spec = None
743
+ if mode == tf.estimator.ModeKeys.TRAIN:
744
+ seq_length = modeling.get_shape_list(input_ids)[1]
745
+
746
+ def compute_loss(logits, positions):
747
+ one_hot_positions = tf.one_hot(
748
+ positions, depth=seq_length, dtype=tf.float32
749
+ )
750
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
751
+ loss = -tf.reduce_mean(
752
+ tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
753
+ )
754
+ return loss
755
+
756
+ start_positions = features["start_positions"]
757
+ end_positions = features["end_positions"]
758
+
759
+ start_loss = compute_loss(start_logits, start_positions)
760
+ end_loss = compute_loss(end_logits, end_positions)
761
+
762
+ total_loss = (start_loss + end_loss) / 2.0
763
+
764
+ train_op = optimization.create_optimizer(
765
+ total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
766
+ )
767
+
768
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
769
+ mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
770
+ )
771
+ elif mode == tf.estimator.ModeKeys.PREDICT:
772
+ predictions = {
773
+ "unique_ids": unique_ids,
774
+ "start_logits": start_logits,
775
+ "end_logits": end_logits,
776
+ }
777
+ output_spec = tf.contrib.tpu.TPUEstimatorSpec(
778
+ mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
779
+ )
780
+ else:
781
+ raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
782
+
783
+ return output_spec
784
+
785
+ return model_fn
786
+
787
+
788
+ def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
789
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
790
+
791
+ name_to_features = {
792
+ "unique_ids": tf.FixedLenFeature([], tf.int64),
793
+ "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
794
+ "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
795
+ "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
796
+ }
797
+
798
+ if is_training:
799
+ name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
800
+ name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
801
+
802
+ def _decode_record(record, name_to_features):
803
+ """Decodes a record to a TensorFlow example."""
804
+ example = tf.parse_single_example(record, name_to_features)
805
+
806
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
807
+ # So cast all int64 to int32.
808
+ for name in list(example.keys()):
809
+ t = example[name]
810
+ if t.dtype == tf.int64:
811
+ t = tf.to_int32(t)
812
+ example[name] = t
813
+
814
+ return example
815
+
816
+ def input_fn(params):
817
+ """The actual input function."""
818
+ batch_size = params["batch_size"]
819
+
820
+ # For training, we want a lot of parallel reading and shuffling.
821
+ # For eval, we want no shuffling and parallel reading doesn't matter.
822
+ d = tf.data.TFRecordDataset(input_file)
823
+ if is_training:
824
+ d = d.repeat()
825
+ d = d.shuffle(buffer_size=100)
826
+
827
+ d = d.apply(
828
+ tf.contrib.data.map_and_batch(
829
+ lambda record: _decode_record(record, name_to_features),
830
+ batch_size=batch_size,
831
+ drop_remainder=drop_remainder,
832
+ )
833
+ )
834
+
835
+ return d
836
+
837
+ return input_fn
838
+
839
+
840
+ RawResult = collections.namedtuple(
841
+ "RawResult", ["unique_id", "start_logits", "end_logits"]
842
+ )
843
+
844
+
845
+ def write_predictions(
846
+ all_examples,
847
+ all_features,
848
+ all_results,
849
+ n_best_size,
850
+ max_answer_length,
851
+ do_lower_case,
852
+ output_prediction_file,
853
+ output_nbest_file,
854
+ output_null_log_odds_file,
855
+ ):
856
+ """Write final predictions to the json file and log-odds of null if needed."""
857
+ tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
858
+ tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
859
+
860
+ example_index_to_features = collections.defaultdict(list)
861
+ for feature in all_features:
862
+ example_index_to_features[feature.example_index].append(feature)
863
+
864
+ unique_id_to_result = {}
865
+ for result in all_results:
866
+ unique_id_to_result[result.unique_id] = result
867
+
868
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
869
+ "PrelimPrediction",
870
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
871
+ )
872
+
873
+ all_predictions = collections.OrderedDict()
874
+ all_nbest_json = collections.OrderedDict()
875
+ scores_diff_json = collections.OrderedDict()
876
+
877
+ for (example_index, example) in enumerate(all_examples):
878
+ features = example_index_to_features[example_index]
879
+
880
+ prelim_predictions = []
881
+ # keep track of the minimum score of null start+end of position 0
882
+ score_null = 1000000 # large and positive
883
+ min_null_feature_index = 0 # the paragraph slice with min mull score
884
+ null_start_logit = 0 # the start logit at the slice with min null score
885
+ null_end_logit = 0 # the end logit at the slice with min null score
886
+ for (feature_index, feature) in enumerate(features):
887
+ result = unique_id_to_result[feature.unique_id]
888
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
889
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
890
+ # if we could have irrelevant answers, get the min score of irrelevant
891
+ if FLAGS.version_2_with_negative:
892
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
893
+ if feature_null_score < score_null:
894
+ score_null = feature_null_score
895
+ min_null_feature_index = feature_index
896
+ null_start_logit = result.start_logits[0]
897
+ null_end_logit = result.end_logits[0]
898
+ for start_index in start_indexes:
899
+ for end_index in end_indexes:
900
+ # We could hypothetically create invalid predictions, e.g., predict
901
+ # that the start of the span is in the question. We throw out all
902
+ # invalid predictions.
903
+ if start_index >= len(feature.tokens):
904
+ continue
905
+ if end_index >= len(feature.tokens):
906
+ continue
907
+ if start_index not in feature.token_to_orig_map:
908
+ continue
909
+ if end_index not in feature.token_to_orig_map:
910
+ continue
911
+ if not feature.token_is_max_context.get(start_index, False):
912
+ continue
913
+ if end_index < start_index:
914
+ continue
915
+ length = end_index - start_index + 1
916
+ if length > max_answer_length:
917
+ continue
918
+ prelim_predictions.append(
919
+ _PrelimPrediction(
920
+ feature_index=feature_index,
921
+ start_index=start_index,
922
+ end_index=end_index,
923
+ start_logit=result.start_logits[start_index],
924
+ end_logit=result.end_logits[end_index],
925
+ )
926
+ )
927
+
928
+ if FLAGS.version_2_with_negative:
929
+ prelim_predictions.append(
930
+ _PrelimPrediction(
931
+ feature_index=min_null_feature_index,
932
+ start_index=0,
933
+ end_index=0,
934
+ start_logit=null_start_logit,
935
+ end_logit=null_end_logit,
936
+ )
937
+ )
938
+ prelim_predictions = sorted(
939
+ prelim_predictions,
940
+ key=lambda x: (x.start_logit + x.end_logit),
941
+ reverse=True,
942
+ )
943
+
944
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
945
+ "NbestPrediction", ["text", "start_logit", "end_logit"]
946
+ )
947
+
948
+ seen_predictions = {}
949
+ nbest = []
950
+ for pred in prelim_predictions:
951
+ if len(nbest) >= n_best_size:
952
+ break
953
+ feature = features[pred.feature_index]
954
+ if pred.start_index > 0: # this is a non-null prediction
955
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
956
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
957
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
958
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
959
+ tok_text = " ".join(tok_tokens)
960
+
961
+ # De-tokenize WordPieces that have been split off.
962
+ tok_text = tok_text.replace(" ##", "")
963
+ tok_text = tok_text.replace("##", "")
964
+
965
+ # Clean whitespace
966
+ tok_text = tok_text.strip()
967
+ tok_text = " ".join(tok_text.split())
968
+ orig_text = " ".join(orig_tokens)
969
+
970
+ final_text = get_final_text(tok_text, orig_text, do_lower_case)
971
+ if final_text in seen_predictions:
972
+ continue
973
+
974
+ seen_predictions[final_text] = True
975
+ else:
976
+ final_text = ""
977
+ seen_predictions[final_text] = True
978
+
979
+ nbest.append(
980
+ _NbestPrediction(
981
+ text=final_text,
982
+ start_logit=pred.start_logit,
983
+ end_logit=pred.end_logit,
984
+ )
985
+ )
986
+
987
+ # if we didn't inlude the empty option in the n-best, inlcude it
988
+ if FLAGS.version_2_with_negative:
989
+ if "" not in seen_predictions:
990
+ nbest.append(
991
+ _NbestPrediction(
992
+ text="", start_logit=null_start_logit, end_logit=null_end_logit
993
+ )
994
+ )
995
+ # In very rare edge cases we could have no valid predictions. So we
996
+ # just create a nonce prediction in this case to avoid failure.
997
+ if not nbest:
998
+ nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
999
+
1000
+ assert len(nbest) >= 1
1001
+
1002
+ total_scores = []
1003
+ best_non_null_entry = None
1004
+ for entry in nbest:
1005
+ total_scores.append(entry.start_logit + entry.end_logit)
1006
+ if not best_non_null_entry:
1007
+ if entry.text:
1008
+ best_non_null_entry = entry
1009
+
1010
+ probs = _compute_softmax(total_scores)
1011
+
1012
+ nbest_json = []
1013
+ for (i, entry) in enumerate(nbest):
1014
+ output = collections.OrderedDict()
1015
+ output["text"] = entry.text
1016
+ output["probability"] = probs[i]
1017
+ output["start_logit"] = entry.start_logit
1018
+ output["end_logit"] = entry.end_logit
1019
+ nbest_json.append(output)
1020
+
1021
+ assert len(nbest_json) >= 1
1022
+
1023
+ if not FLAGS.version_2_with_negative:
1024
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
1025
+ else:
1026
+ # predict "" iff the null score - the score of best non-null > threshold
1027
+ score_diff = (
1028
+ score_null
1029
+ - best_non_null_entry.start_logit
1030
+ - (best_non_null_entry.end_logit)
1031
+ )
1032
+ scores_diff_json[example.qas_id] = score_diff
1033
+ if score_diff > FLAGS.null_score_diff_threshold:
1034
+ all_predictions[example.qas_id] = ""
1035
+ else:
1036
+ all_predictions[example.qas_id] = best_non_null_entry.text
1037
+
1038
+ all_nbest_json[example.qas_id] = nbest_json
1039
+
1040
+ with tf.gfile.GFile(output_prediction_file, "w") as writer:
1041
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
1042
+
1043
+ with tf.gfile.GFile(output_nbest_file, "w") as writer:
1044
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
1045
+
1046
+ if FLAGS.version_2_with_negative:
1047
+ with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
1048
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
1049
+
1050
+
1051
+ def get_final_text(pred_text, orig_text, do_lower_case):
1052
+ """Project the tokenized prediction back to the original text."""
1053
+
1054
+ # When we created the data, we kept track of the alignment between original
1055
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
1056
+ # now `orig_text` contains the span of our original text corresponding to the
1057
+ # span that we predicted.
1058
+ #
1059
+ # However, `orig_text` may contain extra characters that we don't want in
1060
+ # our prediction.
1061
+ #
1062
+ # For example, let's say:
1063
+ # pred_text = steve smith
1064
+ # orig_text = Steve Smith's
1065
+ #
1066
+ # We don't want to return `orig_text` because it contains the extra "'s".
1067
+ #
1068
+ # We don't want to return `pred_text` because it's already been normalized
1069
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
1070
+ # our tokenizer does additional normalization like stripping accent
1071
+ # characters).
1072
+ #
1073
+ # What we really want to return is "Steve Smith".
1074
+ #
1075
+ # Therefore, we have to apply a semi-complicated alignment heruistic between
1076
+ # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
1077
+ # can fail in certain cases in which case we just return `orig_text`.
1078
+
1079
+ def _strip_spaces(text):
1080
+ ns_chars = []
1081
+ ns_to_s_map = collections.OrderedDict()
1082
+ for (i, c) in enumerate(text):
1083
+ if c == " ":
1084
+ continue
1085
+ ns_to_s_map[len(ns_chars)] = i
1086
+ ns_chars.append(c)
1087
+ ns_text = "".join(ns_chars)
1088
+ return (ns_text, ns_to_s_map)
1089
+
1090
+ # We first tokenize `orig_text`, strip whitespace from the result
1091
+ # and `pred_text`, and check if they are the same length. If they are
1092
+ # NOT the same length, the heuristic has failed. If they are the same
1093
+ # length, we assume the characters are one-to-one aligned.
1094
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
1095
+
1096
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
1097
+
1098
+ start_position = tok_text.find(pred_text)
1099
+ if start_position == -1:
1100
+ if FLAGS.verbose_logging:
1101
+ tf.logging.info(
1102
+ "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)
1103
+ )
1104
+ return orig_text
1105
+ end_position = start_position + len(pred_text) - 1
1106
+
1107
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
1108
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
1109
+
1110
+ if len(orig_ns_text) != len(tok_ns_text):
1111
+ if FLAGS.verbose_logging:
1112
+ tf.logging.info(
1113
+ "Length not equal after stripping spaces: '%s' vs '%s'",
1114
+ orig_ns_text,
1115
+ tok_ns_text,
1116
+ )
1117
+ return orig_text
1118
+
1119
+ # We then project the characters in `pred_text` back to `orig_text` using
1120
+ # the character-to-character alignment.
1121
+ tok_s_to_ns_map = {}
1122
+ for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
1123
+ tok_s_to_ns_map[tok_index] = i
1124
+
1125
+ orig_start_position = None
1126
+ if start_position in tok_s_to_ns_map:
1127
+ ns_start_position = tok_s_to_ns_map[start_position]
1128
+ if ns_start_position in orig_ns_to_s_map:
1129
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
1130
+
1131
+ if orig_start_position is None:
1132
+ if FLAGS.verbose_logging:
1133
+ tf.logging.info("Couldn't map start position")
1134
+ return orig_text
1135
+
1136
+ orig_end_position = None
1137
+ if end_position in tok_s_to_ns_map:
1138
+ ns_end_position = tok_s_to_ns_map[end_position]
1139
+ if ns_end_position in orig_ns_to_s_map:
1140
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
1141
+
1142
+ if orig_end_position is None:
1143
+ if FLAGS.verbose_logging:
1144
+ tf.logging.info("Couldn't map end position")
1145
+ return orig_text
1146
+
1147
+ output_text = orig_text[orig_start_position : (orig_end_position + 1)]
1148
+ return output_text
1149
+
1150
+
1151
+ def _get_best_indexes(logits, n_best_size):
1152
+ """Get the n-best logits from a list."""
1153
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
1154
+
1155
+ best_indexes = []
1156
+ for i in range(len(index_and_score)):
1157
+ if i >= n_best_size:
1158
+ break
1159
+ best_indexes.append(index_and_score[i][0])
1160
+ return best_indexes
1161
+
1162
+
1163
+ def _compute_softmax(scores):
1164
+ """Compute softmax probability over raw logits."""
1165
+ if not scores:
1166
+ return []
1167
+
1168
+ max_score = None
1169
+ for score in scores:
1170
+ if max_score is None or score > max_score:
1171
+ max_score = score
1172
+
1173
+ exp_scores = []
1174
+ total_sum = 0.0
1175
+ for score in scores:
1176
+ x = math.exp(score - max_score)
1177
+ exp_scores.append(x)
1178
+ total_sum += x
1179
+
1180
+ probs = []
1181
+ for score in exp_scores:
1182
+ probs.append(score / total_sum)
1183
+ return probs
1184
+
1185
+
1186
+ class FeatureWriter(object):
1187
+ """Writes InputFeature to TF example file."""
1188
+
1189
+ def __init__(self, filename, is_training):
1190
+ self.filename = filename
1191
+ self.is_training = is_training
1192
+ self.num_features = 0
1193
+ self._writer = tf.python_io.TFRecordWriter(filename)
1194
+
1195
+ def process_feature(self, feature):
1196
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
1197
+ self.num_features += 1
1198
+
1199
+ def create_int_feature(values):
1200
+ feature = tf.train.Feature(
1201
+ int64_list=tf.train.Int64List(value=list(values))
1202
+ )
1203
+ return feature
1204
+
1205
+ features = collections.OrderedDict()
1206
+ features["unique_ids"] = create_int_feature([feature.unique_id])
1207
+ features["input_ids"] = create_int_feature(feature.input_ids)
1208
+ features["input_mask"] = create_int_feature(feature.input_mask)
1209
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
1210
+
1211
+ if self.is_training:
1212
+ features["start_positions"] = create_int_feature([feature.start_position])
1213
+ features["end_positions"] = create_int_feature([feature.end_position])
1214
+ impossible = 0
1215
+ if feature.is_impossible:
1216
+ impossible = 1
1217
+ features["is_impossible"] = create_int_feature([impossible])
1218
+
1219
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
1220
+ self._writer.write(tf_example.SerializeToString())
1221
+
1222
+ def close(self):
1223
+ self._writer.close()
1224
+
1225
+
1226
+ def validate_flags_or_throw(bert_config):
1227
+ """Validate the input FLAGS or throw an exception."""
1228
+ tokenization.validate_case_matches_checkpoint(
1229
+ FLAGS.do_lower_case, FLAGS.init_checkpoint
1230
+ )
1231
+
1232
+ if not FLAGS.do_train and not FLAGS.do_predict:
1233
+ raise ValueError("At least one of `do_train` or `do_predict` must be True.")
1234
+
1235
+ if FLAGS.do_train:
1236
+ if not FLAGS.train_file:
1237
+ raise ValueError(
1238
+ "If `do_train` is True, then `train_file` must be specified."
1239
+ )
1240
+ if FLAGS.do_predict:
1241
+ if not FLAGS.predict_file:
1242
+ raise ValueError(
1243
+ "If `do_predict` is True, then `predict_file` must be specified."
1244
+ )
1245
+
1246
+ if FLAGS.max_seq_length > bert_config.max_position_embeddings:
1247
+ raise ValueError(
1248
+ "Cannot use sequence length %d because the BERT model "
1249
+ "was only trained up to sequence length %d"
1250
+ % (FLAGS.max_seq_length, bert_config.max_position_embeddings)
1251
+ )
1252
+
1253
+ if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
1254
+ raise ValueError(
1255
+ "The max_seq_length (%d) must be greater than max_query_length "
1256
+ "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)
1257
+ )
1258
+
1259
+
1260
+ def main(_):
1261
+ tf.logging.set_verbosity(tf.logging.INFO)
1262
+ logger = tf.get_logger()
1263
+ logger.propagate = False
1264
+
1265
+ bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
1266
+
1267
+ validate_flags_or_throw(bert_config)
1268
+
1269
+ tf.gfile.MakeDirs(FLAGS.output_dir)
1270
+
1271
+ tokenizer = tokenization.FullTokenizer(
1272
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
1273
+ )
1274
+
1275
+ tpu_cluster_resolver = None
1276
+ if FLAGS.use_tpu and FLAGS.tpu_name:
1277
+ tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
1278
+ FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
1279
+ )
1280
+
1281
+ is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
1282
+ run_config = tf.contrib.tpu.RunConfig(
1283
+ cluster=tpu_cluster_resolver,
1284
+ master=FLAGS.master,
1285
+ model_dir=FLAGS.output_dir,
1286
+ save_checkpoints_steps=FLAGS.save_checkpoints_steps,
1287
+ tpu_config=tf.contrib.tpu.TPUConfig(
1288
+ iterations_per_loop=FLAGS.iterations_per_loop,
1289
+ num_shards=FLAGS.num_tpu_cores,
1290
+ per_host_input_for_training=is_per_host,
1291
+ ),
1292
+ )
1293
+
1294
+ train_examples = None
1295
+ num_train_steps = None
1296
+ num_warmup_steps = None
1297
+ if FLAGS.do_train:
1298
+ train_examples = read_squad_examples(
1299
+ input_file=FLAGS.train_file, is_training=True
1300
+ )
1301
+ num_train_steps = int(
1302
+ len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
1303
+ )
1304
+ num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
1305
+
1306
+ # Pre-shuffle the input to avoid having to make a very large shuffle
1307
+ # buffer in in the `input_fn`.
1308
+ rng = random.Random(12345)
1309
+ rng.shuffle(train_examples)
1310
+
1311
+ model_fn = model_fn_builder(
1312
+ bert_config=bert_config,
1313
+ init_checkpoint=FLAGS.init_checkpoint,
1314
+ learning_rate=FLAGS.learning_rate,
1315
+ num_train_steps=num_train_steps,
1316
+ num_warmup_steps=num_warmup_steps,
1317
+ use_tpu=FLAGS.use_tpu,
1318
+ use_one_hot_embeddings=FLAGS.use_tpu,
1319
+ )
1320
+
1321
+ # If TPU is not available, this will fall back to normal Estimator on CPU
1322
+ # or GPU.
1323
+ estimator = tf.contrib.tpu.TPUEstimator(
1324
+ use_tpu=FLAGS.use_tpu,
1325
+ model_fn=model_fn,
1326
+ config=run_config,
1327
+ train_batch_size=FLAGS.train_batch_size,
1328
+ predict_batch_size=FLAGS.predict_batch_size,
1329
+ )
1330
+
1331
+ if FLAGS.do_train:
1332
+ # We write to a temporary file to avoid storing very large constant tensors
1333
+ # in memory.
1334
+ train_writer = FeatureWriter(
1335
+ filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True
1336
+ )
1337
+ convert_examples_to_features(
1338
+ examples=train_examples,
1339
+ tokenizer=tokenizer,
1340
+ max_seq_length=FLAGS.max_seq_length,
1341
+ doc_stride=FLAGS.doc_stride,
1342
+ max_query_length=FLAGS.max_query_length,
1343
+ is_training=True,
1344
+ output_fn=train_writer.process_feature,
1345
+ )
1346
+ train_writer.close()
1347
+
1348
+ tf.logging.info("***** Running training *****")
1349
+ tf.logging.info(" Num orig examples = %d", len(train_examples))
1350
+ tf.logging.info(" Num split examples = %d", train_writer.num_features)
1351
+ tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
1352
+ tf.logging.info(" Num steps = %d", num_train_steps)
1353
+ del train_examples
1354
+
1355
+ train_input_fn = input_fn_builder(
1356
+ input_file=train_writer.filename,
1357
+ seq_length=FLAGS.max_seq_length,
1358
+ is_training=True,
1359
+ drop_remainder=True,
1360
+ )
1361
+ estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
1362
+
1363
+ if FLAGS.do_predict:
1364
+ eval_examples = read_squad_examples(
1365
+ input_file=FLAGS.predict_file, is_training=False
1366
+ )
1367
+
1368
+ eval_writer = FeatureWriter(
1369
+ filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False
1370
+ )
1371
+ eval_features = []
1372
+
1373
+ def append_feature(feature):
1374
+ eval_features.append(feature)
1375
+ eval_writer.process_feature(feature)
1376
+
1377
+ convert_examples_to_features(
1378
+ examples=eval_examples,
1379
+ tokenizer=tokenizer,
1380
+ max_seq_length=FLAGS.max_seq_length,
1381
+ doc_stride=FLAGS.doc_stride,
1382
+ max_query_length=FLAGS.max_query_length,
1383
+ is_training=False,
1384
+ output_fn=append_feature,
1385
+ )
1386
+ eval_writer.close()
1387
+
1388
+ tf.logging.info("***** Running predictions *****")
1389
+ tf.logging.info(" Num orig examples = %d", len(eval_examples))
1390
+ tf.logging.info(" Num split examples = %d", len(eval_features))
1391
+ tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
1392
+
1393
+ all_results = []
1394
+
1395
+ predict_input_fn = input_fn_builder(
1396
+ input_file=eval_writer.filename,
1397
+ seq_length=FLAGS.max_seq_length,
1398
+ is_training=False,
1399
+ drop_remainder=False,
1400
+ )
1401
+
1402
+ # If running eval on the TPU, you will need to specify the number of
1403
+ # steps.
1404
+ all_results = []
1405
+ for result in estimator.predict(predict_input_fn, yield_single_examples=True):
1406
+ if len(all_results) % 1000 == 0:
1407
+ tf.logging.info("Processing example: %d" % (len(all_results)))
1408
+ unique_id = int(result["unique_ids"])
1409
+ start_logits = [float(x) for x in result["start_logits"].flat]
1410
+ end_logits = [float(x) for x in result["end_logits"].flat]
1411
+ all_results.append(
1412
+ RawResult(
1413
+ unique_id=unique_id,
1414
+ start_logits=start_logits,
1415
+ end_logits=end_logits,
1416
+ )
1417
+ )
1418
+
1419
+ output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
1420
+ output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
1421
+ output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
1422
+
1423
+ write_predictions(
1424
+ eval_examples,
1425
+ eval_features,
1426
+ all_results,
1427
+ FLAGS.n_best_size,
1428
+ FLAGS.max_answer_length,
1429
+ FLAGS.do_lower_case,
1430
+ output_prediction_file,
1431
+ output_nbest_file,
1432
+ output_null_log_odds_file,
1433
+ )
1434
+
1435
+
1436
+ if __name__ == "__main__":
1437
+ flags.mark_flag_as_required("vocab_file")
1438
+ flags.mark_flag_as_required("bert_config_file")
1439
+ flags.mark_flag_as_required("output_dir")
1440
+ tf.app.run()
arabert/arabert/sample_text.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Text should be one-sentence-per-line, with empty lines between documents.
2
+ This sample text is was randomly selected from the pretriaining corpus for araBERT with Farasa Tokenization.
3
+
4
+ • « أدعو ال+ جه +ات ال+ مختص +ة في ال+ دول +ة إلى إجراء دراس +ة مقارن +ة بين مستوي +ات ال+ طلب +ة في زمني و+ في هذا ال+ وقت » .
5
+ أريد طرح جانب واحد بسيط في هذه ال+ مساح +ة بناء على حوار بين +ي و+ بين أحد أقربائي عن كيفي +ة تأثير هذه ال+ حقب +ة ال+ رقمي +ة على ال+ حيا +ة ال+ جامعي +ة ، و+ مما لا شك في +ه أن ال+ تقني +ة اقتحم +ت حقل ال+ تعليم +كما فعل +ت مع غير +ه و+ غير +ت +ه ب+ شكل جذري .
6
+ ال+ يوم : ال+ هواتف ال+ ذكي +ة منتشر +ة و+ ال+ معلم +ون متوافر +ون على مدار ال+ ساع +ة ، و+ لن نبالغ إذا قل +نا إن هاتف أو كمبيوتر ال+ مكتب في زمن +نا هذا قد لا يكون ضروري +ا .
7
+ ال+ يوم : ال+ واتساب وسيل +ة ال+ تنسيق و+ ال+ تواصل بين ال+ طلاب ، و+ من خلال +ه تنشأ مجموع +ات تناقش ال+ مشروع +ات ال+ جامعي +ة ، بل إن ال+ واتساب ألغى ال+ حاج +ة إلى ال+ ماسح ال+ ضوئي في حال +ات كثير +ة ، و+ أصبح تصوير ال+ أوراق و+ ال+ مستند +ات يتم عن طريق +ه ب+ استخدام كاميرا ال+ هاتف ال+ محمول .
8
+ ال+ يوم : كل شيء يتم تصوير +ه و+ رصد +ه و+ بث +ه على وسائل ال+ تواصل كاف +ة ، و+ إن انتشر تفتح ال+ سلط +ات تحقيق +ا و+ يتحول أي موضوع ، سواء كان تافه +ا أم كبير +ا ، إلى و+ صم +ة عار في ال+ حال +ة ال+ أولى ، و+ قضي +ة في ال+ حال +ة ال+ ثاني +ة .
9
+ ال+ يوم : ينشغل ال+ طلاب ب+ وسائل ال+ تواصل ال+ اجتماعي داخل ال+ فصل بين مدرس +ين متساهل +ين و+ متشدد +ين ، و+ هناك طلاب يستخدم +ون ال+ يوتيوب وسيل +ة ترفيهي +ة داخل ال+ فصل على حساب ال+ درس و+ ب+ وجود ال+ معلم .
10
+ +كن +نا نمزح مع بعض +نا كثير +ا مزاح +ا خفيف +ا و+ ثقيل +ا ، و+ كان +ت تحدث معارك أحيان +ا ، لكن كان +ت فضيل +ة ال+ ستر منتشر +ة بين +نا ، و+ لم يكن أحد يشي ب+ ال+ آخر ، إلا نادر +ا ، و+ إن حدث ذلك ف+ لا دليل علي +ه ، إلا من صاحب ال+ وشاي +ة و+ كلم +ت +ه قد تصدق ، أو يتم تجاهل +ها .
11
+ لم يكن استخدام ال+ هاتف ال+ محمول مسموح +ا ب+ +ه داخل ال+ فصول ال+ دراسي +ة إلا في حال +ات ال+ طوارئ ، ال+ معلم +ون ال+ أجانب كان +وا يتساهلون مع ذلك بينما ال+ عرب متشدد +ون ، كان مسموح +ا ل+ +نا ب+ فتح كتاب تعليمي عن ال+ ماد +ة نفس +ها و+ ال+ مطالع +ة في +ه ب+ حضور ال+ مدرس .
12
+
13
+ أزم +ات ال+ أندي +ة - ال+ إمار +ات ال+ يومكثر ال+ حديث في ال+ آون +ة ال+ أخير +ة عن ال+ مشكل +ات التي تعاني +ها أندي +ت +نا ، و+ ما تواجه +ه من معوق +ات و+ تحدي +ات فرض +ت علي +ها ب+ سبب عوامل و+ تراكم +ات أسلوب ال+ عمل ال+ إداري ، الذي تنتهج +ه ال+ أغلبي +ة من +ها .
14
+ ال+ أزم +ات التي تظهر بين فتر +ة و+ أخرى مرد +ها غياب ال+ منهجي +ة و+ سوء ال+ تخطيط و+ ال+ صرف ال+ عشوائي .
15
+ أما ال+ أهلي ف+ رغم ما مر ب+ +ه خلال هذا ال+ موسم و+ ما واجه +ه من تحدي +ات ، إلا أن +ه أنعش آمال +ه من جديد و+ بقي +ت ل+ +ه خطو +ة من أجل مرافق +ة ال+ فرق ال+ متأهل +ة ل+ ال+ دور ال+ تالي .
16
+ ف+ هو واقع تعيش +ه هذه ال+ أندي +ة و+ ال+ جميع يدرك تداعي +ات +ه ال+ سلبي +ة علي +ها ، التي تنعكس مباشر +ة على أدائ +ها ال+ مؤسسي و+ مخرج +ات +ه ، و+ هو أمر يخالف ال+ طموح .
17
+ لكن ال+ سؤال الذي يتردد دائم +ا من ال+ متسبب في هذا ؟ و+ ل+ أجل ذلك عاد ال+ عين ب+ مكاسب عد +ة لعل أهم +ها أن +ه استطاع أن يغسل أحزان +ه ال+ محلي +ة في ال+ بطول +ة ال+ آسيوي +ة ، و+ يحيى أمل +ه في ال+ خروج ب+ مكسب آسيوي ينسي +ه خسار +ة ال+ دوري و+ ال+ كأس ، بل قد يعطي +ه مساح +ة أكبر من ال+ تركيز ل+ ال+ منافس +ة على ال+ أبطال و+ إعاد +ة ذكري +ات 2003 .
18
+ و+ ��عل ال+ أزم +ات التي تظهر بين فتر +ة و+ أخرى مرد +ها غياب ال+ منهجي +ة و+ سوء ال+ تخطيط ، ب+ ال+ إضاف +ة إلى ال+ صرف ال+ عشوائي الذي كبد ميزاني +ات +ها ال+ كثير ، و+ وضع +ها في خان +ة حرج +ة دفع +ها أحيان +ا ل+ إطلاق صرخ +ات ال+ استغاث +ة ل+ نجد +ت +ها و+ إخراج +ها من تلك ال+ دوام +ات التي تقع في +ها .
19
+ و+ لماذا يستمر هذا ال+ وضع في أغلب ال+ أندي +ة دون حراك نحو ال+ تغيير و+ ال+ تطوير و+ خلع
20
+
21
+ This sample text is was randomly selected from the pretriaining corpus for araBERT WITHOUT Farasa Tokenization.
22
+
23
+ • " أدعو الجهات المختصة في الدولة إلى إجراء دراسة مقارنة بين مستويات الطلبة في زمني وفي هذا الوقت ".
24
+ أريد طرح جانب واحد بسيط في هذه المساحة بناء على حوار بيني وبين أحد أقربائي عن كيفية تأثير هذه الحقبة الرقمية على الحياة الجامعية ، ومما لا شك فيه أن التقنية اقتحمت حقل التعليم كما فعلت مع غيره وغيرته بشكل جذري.
25
+ اليوم : الهواتف الذكية منتشرة والمعلمون متوافرون على مدار الساعة ، ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا قد لا يكون ضروريا.
26
+ اليوم : الواتساب وسيلة التنسيق والتواصل بين الطلاب ، ومن خلاله تنشأ مجموعات تناقش المشروعات الجامعية ، بل إن الواتساب ألغى الحاجة إلى الماسح الضوئي في حالات كثيرة ، وأصبح تصوير الأوراق والمستندات يتم عن طريقه باستخدام كاميرا الهاتف المحمول.
27
+ اليوم : كل شيء يتم تصويره ورصده وبثه على وسائل التواصل كافة ، وإن انتشر تفتح السلطات تحقيقا ويتحول أي موضوع ، سواء كان تافها أم كبيرا ، إلى وصمة عار في الحالة الأولى ، وقضية في الحالة الثانية.
28
+ اليوم : ينشغل الطلاب بوسائل التواصل الاجتماعي داخل الفصل بين مدرسين متساهلين ومتشددين ، وهناك طلاب يستخدمون اليوتيوب وسيلة ترفيهية داخل الفصل على حساب الدرس وبوجود المعلم.
29
+ كنا نمزح مع بعضنا كثيرا مزاحا خفيفا وثقيلا ، وكانت تحدث معارك أحيانا ، لكن كانت فضيلة الستر منتشرة بيننا ، ولم يكن أحد يشي بالآخر ، إلا نادرا ، وإن حدث ذلك فلا دليل عليه ، إلا من صاحب الوشاية وكلمته قد تصدق ، أو يتم تجاهلها.
30
+ لم يكن استخدام الهاتف المحمول مسموحا به داخل الفصول الدراسية إلا في حالات الطوارئ ، المعلمون الأجانب كانوا يتساهلون مع ذلك بينما العرب متشددون ، كان مسموحا لنا بفتح كتاب تعليمي عن المادة نفسها والمطالعة فيه بحضور المدرس.
31
+
32
+ أزمات الأندية - الإمارات اليومكثر الحديث في الآونة الأخيرة عن المشكلات التي تعانيها أنديتنا ، وما تواجهه من معوقات وتحديات فرضت عليها بسبب عوامل وتراكمات أسلوب العمل الإداري ، الذي تنتهجه الأغلبية منها.
33
+ الأزمات التي تظهر بين فترة وأخرى مردها غياب المنهجية وسوء التخطيط والصرف العشوائي.
34
+ أما الأهلي فرغم ما مر به خلال هذا الموسم وما واجهه من تحديات ، إلا أنه أنعش آماله من جديد وبقيت له خطوة من أجل مرافقة الفرق المتأهلة للدور التالي.
35
+ فهو واقع تعيشه هذه الأندية والجميع يدرك تداعياته السلبية عليها ، التي تنعكس مباشرة على أدائها المؤسسي ومخرجاته ، وهو أمر يخالف الطموح.
36
+ لكن السؤال الذي يتردد دائما من المتسبب في هذا ؟ ولأجل ذلك عاد العين بمكاسب عدة لعل أهمها أنه استطاع أن يغسل أحزانه المحلية في البطولة الآسيوية ، ويحيى أمله في الخروج بمكسب آسيوي ينسيه خسارة الدوري والكأس ، بل قد يعطيه مساحة أكبر من التركيز للمنافسة على الأبطال وإعادة ذكريات 2003.
37
+ ولعل الأزمات التي تظهر بين فترة وأخرى مرده�� غياب المنهجية وسوء التخطيط ، بالإضافة إلى الصرف العشوائي الذي كبد ميزانياتها الكثير ، ووضعها في خانة حرجة دفعها أحيانا لإطلاق صرخات الاستغاثة لنجدتها وإخراجها من تلك الدوامات التي تقع فيها.
38
+ ولماذا يستمر هذا الوضع في أغلب الأندية دون حراك نحو التغيير والتطوير وخلع الجلباب الإداري القديم؟
arabert/arabert/tokenization.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ from __future__ import absolute_import
18
+ from __future__ import division
19
+ from __future__ import print_function
20
+
21
+ import collections
22
+ import re
23
+ import unicodedata
24
+ import six
25
+ import tensorflow as tf
26
+
27
+
28
+ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29
+ """Checks whether the casing config is consistent with the checkpoint name."""
30
+
31
+ # The casing has to be passed in by the user and there is no explicit check
32
+ # as to whether it matches the checkpoint. The casing information probably
33
+ # should have been stored in the bert_config.json file, but it's not, so
34
+ # we have to heuristically detect it to validate.
35
+
36
+ if not init_checkpoint:
37
+ return
38
+
39
+ m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40
+ if m is None:
41
+ return
42
+
43
+ model_name = m.group(1)
44
+
45
+ lower_models = [
46
+ "uncased_L-24_H-1024_A-16",
47
+ "uncased_L-12_H-768_A-12",
48
+ "multilingual_L-12_H-768_A-12",
49
+ "chinese_L-12_H-768_A-12",
50
+ ]
51
+
52
+ cased_models = [
53
+ "cased_L-12_H-768_A-12",
54
+ "cased_L-24_H-1024_A-16",
55
+ "multi_cased_L-12_H-768_A-12",
56
+ ]
57
+
58
+ is_bad_config = False
59
+ if model_name in lower_models and not do_lower_case:
60
+ is_bad_config = True
61
+ actual_flag = "False"
62
+ case_name = "lowercased"
63
+ opposite_flag = "True"
64
+
65
+ if model_name in cased_models and do_lower_case:
66
+ is_bad_config = True
67
+ actual_flag = "True"
68
+ case_name = "cased"
69
+ opposite_flag = "False"
70
+
71
+ if is_bad_config:
72
+ raise ValueError(
73
+ "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
74
+ "However, `%s` seems to be a %s model, so you "
75
+ "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
76
+ "how the model was pre-training. If this error is wrong, please "
77
+ "just comment out this check."
78
+ % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)
79
+ )
80
+
81
+
82
+ def convert_to_unicode(text):
83
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
84
+ if six.PY3:
85
+ if isinstance(text, str):
86
+ return text
87
+ elif isinstance(text, bytes):
88
+ return text.decode("utf-8", "ignore")
89
+ else:
90
+ raise ValueError("Unsupported string type: %s" % (type(text)))
91
+ elif six.PY2:
92
+ if isinstance(text, str):
93
+ return text.decode("utf-8", "ignore")
94
+ elif isinstance(text, unicode):
95
+ return text
96
+ else:
97
+ raise ValueError("Unsupported string type: %s" % (type(text)))
98
+ else:
99
+ raise ValueError("Not running on Python2 or Python 3?")
100
+
101
+
102
+ def printable_text(text):
103
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
104
+
105
+ # These functions want `str` for both Python2 and Python3, but in one case
106
+ # it's a Unicode string and in the other it's a byte string.
107
+ if six.PY3:
108
+ if isinstance(text, str):
109
+ return text
110
+ elif isinstance(text, bytes):
111
+ return text.decode("utf-8", "ignore")
112
+ else:
113
+ raise ValueError("Unsupported string type: %s" % (type(text)))
114
+ elif six.PY2:
115
+ if isinstance(text, str):
116
+ return text
117
+ elif isinstance(text, unicode):
118
+ return text.encode("utf-8")
119
+ else:
120
+ raise ValueError("Unsupported string type: %s" % (type(text)))
121
+ else:
122
+ raise ValueError("Not running on Python2 or Python 3?")
123
+
124
+
125
+ def load_vocab(vocab_file):
126
+ """Loads a vocabulary file into a dictionary."""
127
+ vocab = collections.OrderedDict()
128
+ index = 0
129
+ with tf.gfile.GFile(vocab_file, "r") as reader:
130
+ while True:
131
+ token = convert_to_unicode(reader.readline())
132
+ if not token:
133
+ break
134
+ token = token.strip()
135
+ vocab[token] = index
136
+ index += 1
137
+ return vocab
138
+
139
+
140
+ def convert_by_vocab(vocab, items):
141
+ """Converts a sequence of [tokens|ids] using the vocab."""
142
+ output = []
143
+ for item in items:
144
+ output.append(vocab[item])
145
+ return output
146
+
147
+
148
+ def convert_tokens_to_ids(vocab, tokens):
149
+ return convert_by_vocab(vocab, tokens)
150
+
151
+
152
+ def convert_ids_to_tokens(inv_vocab, ids):
153
+ return convert_by_vocab(inv_vocab, ids)
154
+
155
+
156
+ def whitespace_tokenize(text):
157
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
158
+ text = text.strip()
159
+ if not text:
160
+ return []
161
+ tokens = text.split()
162
+ return tokens
163
+
164
+
165
+ class FullTokenizer(object):
166
+ """Runs end-to-end tokenziation."""
167
+
168
+ def __init__(self, vocab_file, do_lower_case=True):
169
+ self.vocab = load_vocab(vocab_file)
170
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
171
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
172
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
173
+
174
+ def tokenize(self, text):
175
+ split_tokens = []
176
+ for token in self.basic_tokenizer.tokenize(text):
177
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
178
+ split_tokens.append(sub_token)
179
+
180
+ return split_tokens
181
+
182
+ def convert_tokens_to_ids(self, tokens):
183
+ return convert_by_vocab(self.vocab, tokens)
184
+
185
+ def convert_ids_to_tokens(self, ids):
186
+ return convert_by_vocab(self.inv_vocab, ids)
187
+
188
+
189
+ class BasicTokenizer(object):
190
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
191
+
192
+ def __init__(self, do_lower_case=True):
193
+ """Constructs a BasicTokenizer.
194
+
195
+ Args:
196
+ do_lower_case: Whether to lower case the input.
197
+ """
198
+ self.do_lower_case = do_lower_case
199
+
200
+ def tokenize(self, text):
201
+ """Tokenizes a piece of text."""
202
+ text = convert_to_unicode(text)
203
+ text = self._clean_text(text)
204
+
205
+ # This was added on November 1st, 2018 for the multilingual and Chinese
206
+ # models. This is also applied to the English models now, but it doesn't
207
+ # matter since the English models were not trained on any Chinese data
208
+ # and generally don't have any Chinese data in them (there are Chinese
209
+ # characters in the vocabulary because Wikipedia does have some Chinese
210
+ # words in the English Wikipedia.).
211
+ text = self._tokenize_chinese_chars(text)
212
+
213
+ orig_tokens = whitespace_tokenize(text)
214
+ split_tokens = []
215
+ for token in orig_tokens:
216
+ if self.do_lower_case:
217
+ token = token.lower()
218
+ token = self._run_strip_accents(token)
219
+ split_tokens.extend(self._run_split_on_punc(token))
220
+
221
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
222
+ return output_tokens
223
+
224
+ def _run_strip_accents(self, text):
225
+ """Strips accents from a piece of text."""
226
+ text = unicodedata.normalize("NFD", text)
227
+ output = []
228
+ for char in text:
229
+ cat = unicodedata.category(char)
230
+ if cat == "Mn":
231
+ continue
232
+ output.append(char)
233
+ return "".join(output)
234
+
235
+ def _run_split_on_punc(self, text):
236
+ """Splits punctuation on a piece of text."""
237
+ chars = list(text)
238
+ i = 0
239
+ start_new_word = True
240
+ output = []
241
+ while i < len(chars):
242
+ char = chars[i]
243
+ if _is_punctuation(char):
244
+ output.append([char])
245
+ start_new_word = True
246
+ else:
247
+ if start_new_word:
248
+ output.append([])
249
+ start_new_word = False
250
+ output[-1].append(char)
251
+ i += 1
252
+
253
+ return ["".join(x) for x in output]
254
+
255
+ def _tokenize_chinese_chars(self, text):
256
+ """Adds whitespace around any CJK character."""
257
+ output = []
258
+ for char in text:
259
+ cp = ord(char)
260
+ if self._is_chinese_char(cp):
261
+ output.append(" ")
262
+ output.append(char)
263
+ output.append(" ")
264
+ else:
265
+ output.append(char)
266
+ return "".join(output)
267
+
268
+ def _is_chinese_char(self, cp):
269
+ """Checks whether CP is the codepoint of a CJK character."""
270
+ # This defines a "chinese character" as anything in the CJK Unicode block:
271
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
272
+ #
273
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
274
+ # despite its name. The modern Korean Hangul alphabet is a different block,
275
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
276
+ # space-separated words, so they are not treated specially and handled
277
+ # like the all of the other languages.
278
+ if (
279
+ (cp >= 0x4E00 and cp <= 0x9FFF)
280
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
281
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
282
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
283
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
284
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
285
+ or (cp >= 0xF900 and cp <= 0xFAFF)
286
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
287
+ ): #
288
+ return True
289
+
290
+ return False
291
+
292
+ def _clean_text(self, text):
293
+ """Performs invalid character removal and whitespace cleanup on text."""
294
+ output = []
295
+ for char in text:
296
+ cp = ord(char)
297
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
298
+ continue
299
+ if _is_whitespace(char):
300
+ output.append(" ")
301
+ else:
302
+ output.append(char)
303
+ return "".join(output)
304
+
305
+
306
+ class WordpieceTokenizer(object):
307
+ """Runs WordPiece tokenziation."""
308
+
309
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
310
+ self.vocab = vocab
311
+ self.unk_token = unk_token
312
+ self.max_input_chars_per_word = max_input_chars_per_word
313
+
314
+ def tokenize(self, text):
315
+ """Tokenizes a piece of text into its word pieces.
316
+
317
+ This uses a greedy longest-match-first algorithm to perform tokenization
318
+ using the given vocabulary.
319
+
320
+ For example:
321
+ input = "unaffable"
322
+ output = ["un", "##aff", "##able"]
323
+
324
+ Args:
325
+ text: A single token or whitespace separated tokens. This should have
326
+ already been passed through `BasicTokenizer.
327
+
328
+ Returns:
329
+ A list of wordpiece tokens.
330
+ """
331
+
332
+ text = convert_to_unicode(text)
333
+
334
+ output_tokens = []
335
+ for token in whitespace_tokenize(text):
336
+ chars = list(token)
337
+ if len(chars) > self.max_input_chars_per_word:
338
+ output_tokens.append(self.unk_token)
339
+ continue
340
+
341
+ is_bad = False
342
+ start = 0
343
+ sub_tokens = []
344
+ while start < len(chars):
345
+ end = len(chars)
346
+ cur_substr = None
347
+ while start < end:
348
+ substr = "".join(chars[start:end])
349
+ if start > 0:
350
+ substr = "##" + substr
351
+ if substr in self.vocab:
352
+ cur_substr = substr
353
+ break
354
+ end -= 1
355
+ if cur_substr is None:
356
+ is_bad = True
357
+ break
358
+ sub_tokens.append(cur_substr)
359
+ start = end
360
+
361
+ if is_bad:
362
+ output_tokens.append(self.unk_token)
363
+ else:
364
+ output_tokens.extend(sub_tokens)
365
+ return output_tokens
366
+
367
+
368
+ def _is_whitespace(char):
369
+ """Checks whether `chars` is a whitespace character."""
370
+ # \t, \n, and \r are technically contorl characters but we treat them
371
+ # as whitespace since they are generally considered as such.
372
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
373
+ return True
374
+ cat = unicodedata.category(char)
375
+ if cat == "Zs":
376
+ return True
377
+ return False
378
+
379
+
380
+ def _is_control(char):
381
+ """Checks whether `chars` is a control character."""
382
+ # These are technically control characters but we count them as whitespace
383
+ # characters.
384
+ if char == "\t" or char == "\n" or char == "\r":
385
+ return False
386
+ cat = unicodedata.category(char)
387
+ if cat in ("Cc", "Cf"):
388
+ return True
389
+ return False
390
+
391
+
392
+ def _is_punctuation(char):
393
+ """Checks whether `chars` is a punctuation character."""
394
+ cp = ord(char)
395
+ # We treat all non-letter/number ASCII as punctuation.
396
+ # Characters such as "^", "$", and "`" are not in the Unicode
397
+ # Punctuation class but we treat them as punctuation anyways, for
398
+ # consistency.
399
+ if (
400
+ cp == 91 or cp == 93 or cp == 43
401
+ ): # [ and ] are not punctuation since they are used in [xx] and the +
402
+ return False
403
+
404
+ if (
405
+ (cp >= 33 and cp <= 47)
406
+ or (cp >= 58 and cp <= 64)
407
+ or (cp >= 91 and cp <= 96)
408
+ or (cp >= 123 and cp <= 126)
409
+ ):
410
+ return True
411
+ cat = unicodedata.category(char)
412
+ if cat.startswith("P"):
413
+ return True
414
+ return False
arabert/arabert_logo.png ADDED
arabert/araelectra/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ .vscode/
3
+ data/
4
+ *.bat
arabert/araelectra/LICENSE ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==========================================
2
+ SOFTWARE LICENSE AGREEMENT - AraELECTRA
3
+ ==========================================
4
+
5
+ * NAME: AraELECTRA: Pre-Training Text Discriminatorsfor Arabic Language Understanding
6
+
7
+ * ACKNOWLEDGMENTS
8
+
9
+ This [software] was generated by [American
10
+ University of Beirut] (“Owners”). The statements
11
+ made herein are solely the responsibility of the author[s].
12
+
13
+ The following software programs and programs have been used in the
14
+ generation of [AraELECTRA]:
15
+
16
+ + ELECTRA
17
+ - Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning.
18
+ "ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
19
+ https://openreview.net/pdf?id=r1xMH1BtvB, 2020.
20
+ - License and link : https://github.com/google-research/electra
21
+
22
+ + PyArabic
23
+ - T. Zerrouki, Pyarabic, An Arabic language library for Python,
24
+ https://pypi.python.org/pypi/pyarabic/, 2010
25
+ - License and link: https://github.com/linuxscout/pyarabic/
26
+
27
+ * LICENSE
28
+
29
+ This software and database is being provided to you, the LICENSEE,
30
+ by the Owners under the following license. By obtaining, using and/or
31
+ copying this software and database, you agree that you have read,
32
+ understood, and will comply with these terms and conditions. You
33
+ further agree that you have read and you will abide by the license
34
+ agreements provided in the above links under “acknowledgements”:
35
+ Permission to use, copy, modify and distribute this software and
36
+ database and its documentation for any purpose and without fee or
37
+ royalty is hereby granted, provided that you agree to comply with the
38
+ following copyright notice and statements, including the disclaimer,
39
+ and that the same appear on ALL copies of the software, database and
40
+ documentation, including modifications that you make for internal use
41
+ or for distribution. [AraELECTRA] Copyright 2020 by [American University
42
+ of Beirut]. All rights reserved. If you remix, transform, or build
43
+ upon the material, you must distribute your contributions under the
44
+ same license as this one. You may not apply legal terms or technological
45
+ measures that legally restrict others from doing anything this license
46
+ permits. THIS SOFTWARE IS PROVIDED "AS IS" AND THE OWNERS MAKE NO
47
+ REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE,
48
+ BUT NOT LIMITATION, THE OWNERS MAKE NO REPRESENTATIONS OR WARRANTIES OF
49
+ MERCHANT-ABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
50
+ THE LICENSED SOFTWARE, DATABASE OR DOCUMENTATION WILL NOT INFRINGE ANY THIRD
51
+ PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS. The name of the
52
+ Owners may not be used in advertising or publicity pertaining to
53
+ distribution of the software and/or database. Title to copyright in
54
+ this software, database and any associated documentation shall at all
55
+ times remain with the Owners and LICENSEE agrees to preserve same.
56
+
57
+ The use of AraELECTRA should be cited as follows:
58
+
59
+ @inproceedings{antoun-etal-2021-araelectra,
60
+ title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
61
+ author = "Antoun, Wissam and
62
+ Baly, Fady and
63
+ Hajj, Hazem",
64
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
65
+ month = apr,
66
+ year = "2021",
67
+ address = "Kyiv, Ukraine (Virtual)",
68
+ publisher = "Association for Computational Linguistics",
69
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
70
+ pages = "191--195",
71
+ }
72
+
73
+ [AraELECTRA] Copyright 2020 by [American University of Beirut].
74
+ All rights reserved.
75
+ ==========================================
76
+
arabert/araelectra/README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ELECTRA
2
+
3
+ ## Introduction
4
+
5
+ **ELECTRA** is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). AraELECTRA achieves state-of-the-art results on Arabic QA dataset.
6
+
7
+ For a detailed description, please refer to the AraELECTRA paper [AraELECTRA: Pre-Training Text Discriminatorsfor Arabic Language Understanding](https://arxiv.org/abs/2012.15516).
8
+
9
+ This repository contains code to pre-train ELECTRA. It also supports fine-tuning ELECTRA on downstream tasks including classification tasks (e.g,. [GLUE](https://gluebenchmark.com/)), QA tasks (e.g., [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)), and sequence tagging tasks (e.g., [text chunking](https://www.clips.uantwerpen.be/conll2000/chunking/)).
10
+
11
+
12
+ ## Released Models
13
+
14
+ We are releasing two pre-trained models:
15
+
16
+ | Model | Layers | Hidden Size | Attention Heads | Params | HuggingFace Model Name |
17
+ | --- | --- | --- | --- | --- | --- |
18
+ | AraELECTRA-base-discriminator | 12 | 12 | 768 | 136M | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) |
19
+ | AraELECTRA-base-generator | 12 |4 | 256 | 60M | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator)
20
+
21
+ ## Results
22
+
23
+ Model | TyDiQA (EM - F1 ) | ARCD (EM - F1 ) |
24
+ |:----|:----:|:----:|
25
+ AraBERTv0.1| 68.51 - 82.86 | 31.62 - 67.45 |
26
+ AraBERTv1| 61.11 - 79.36 | 31.7 - 67.8 |
27
+ AraBERTv0.2-base| 73.07 - 85.41| 32.76 - 66.53|
28
+ AraBERTv2-base| 61.67 - 81.66| 31.34 - 67.23 |
29
+ AraBERTv0.2-large| 73.72 - 86.03| 36.89 - **71.32** |
30
+ AraBERTv2-large| 64.49 - 82.51 | 34.19 - 68.12 |
31
+ ArabicBERT-base| 67.42 - 81.24| 30.48 - 62.24 |
32
+ ArabicBERT-large| 70.03 - 84.12| 33.33 - 67.27 |
33
+ Arabic-ALBERT-base| 67.10 - 80.98| 30.91 - 61.33 |
34
+ Arabic-ALBERT-large| 68.07 - 81.59| 34.19 - 65.41 |
35
+ Arabic-ALBERT-xlarge| 71.12 - 84.59| **37.75** - 68.03 |
36
+ AraELECTRA| **74.91 - 86.68**| 37.03 - 71.22 |
37
+
38
+ ## Requirements
39
+ * Python 3
40
+ * [TensorFlow](https://www.tensorflow.org/) 1.15 (although we hope to support TensorFlow 2.0 at a future date)
41
+ * [NumPy](https://numpy.org/)
42
+ * [scikit-learn](https://scikit-learn.org/stable/) and [SciPy](https://www.scipy.org/) (for computing some evaluation metrics).
43
+
44
+ ## Pre-training
45
+ Use `build_pretraining_dataset.py` or `build_arabert_pretraining_data.py` to create a pre-training dataset from a dump of raw text. It has the following arguments:
46
+
47
+ * `--corpus-dir`: A directory containing raw text files to turn into ELECTRA examples. A text file can contain multiple documents with empty lines separating them.
48
+ * `--vocab-file`: File defining the wordpiece vocabulary.
49
+ * `--output-dir`: Where to write out ELECTRA examples.
50
+ * `--max-seq-length`: The number of tokens per example (128 by default).
51
+ * `--num-processes`: If >1 parallelize across multiple processes (1 by default).
52
+ * `--blanks-separate-docs`: Whether blank lines indicate document boundaries (True by default).
53
+ * `--do-lower-case/--no-lower-case`: Whether to lower case the input text (True by default).
54
+
55
+ Use `run_pretraining.py` to pre-train an ELECTRA model. It has the following arguments:
56
+
57
+ * `--data-dir`: a directory where pre-training data, model weights, etc. are stored. By default, the training loads examples from `<data-dir>/pretrain_tfrecords` and a vocabulary from `<data-dir>/vocab.txt`.
58
+ * `--model-name`: a name for the model being trained. Model weights will be saved in `<data-dir>/models/<model-name>` by default.
59
+ * `--hparams` (optional): a JSON dict or path to a JSON file containing model hyperparameters, data paths, etc. See `configure_pretraining.py` for the supported hyperparameters.
60
+
61
+ If training is halted, re-running the `run_pretraining.py` with the same arguments will continue the training where it left off.
62
+
63
+ You can continue pre-training from the released ELECTRA checkpoints by
64
+ 1. Setting the model-name to point to a downloaded model (e.g., `--model-name electra_small` if you downloaded weights to `$DATA_DIR/electra_small`).
65
+ 2. Setting `num_train_steps` by (for example) adding `"num_train_steps": 4010000` to the `--hparams`. This will continue training the small model for 10000 more steps (it has already been trained for 4e6 steps).
66
+ 3. Increase the learning rate to account for the linear learning rate decay. For example, to start with a learning rate of 2e-4 you should set the `learning_rate` hparam to 2e-4 * (4e6 + 10000) / 10000.
67
+ 4. For ELECTRA-Small, you also need to specifiy `"generator_hidden_size": 1.0` in the `hparams` because we did not use a small generator for that model.
68
+
69
+ #### Evaluating the pre-trained model.
70
+
71
+ To evaluate the model on a downstream task, see the below finetuning instructions. To evaluate the generator/discriminator on the openwebtext data run `python3 run_pretraining.py --data-dir $DATA_DIR --model-name electra_small_owt --hparams '{"do_train": false, "do_eval": true}'`. This will print out eval metrics such as the accuracy of the generator and discriminator, and also writing the metrics out to `data-dir/model-name/results`.
72
+
73
+ ## Fine-tuning
74
+
75
+ Use `run_finetuning.py` to fine-tune and evaluate an ELECTRA model on a downstream NLP task. It expects three arguments:
76
+
77
+ * `--data-dir`: a directory where data, model weights, etc. are stored. By default, the script loads finetuning data from `<data-dir>/finetuning_data/<task-name>` and a vocabulary from `<data-dir>/vocab.txt`.
78
+ * `--model-name`: a name of the pre-trained model: the pre-trained weights should exist in `data-dir/models/model-name`.
79
+ * `--hparams`: a JSON dict containing model hyperparameters, data paths, etc. (e.g., `--hparams '{"task_names": ["rte"], "model_size": "base", "learning_rate": 1e-4, ...}'`). See `configure_pretraining.py` for the supported hyperparameters. Instead of a dict, this can also be a path to a `.json` file containing the hyperparameters. You must specify the `"task_names"` and `"model_size"` (see examples below).
80
+
81
+ Eval metrics will be saved in `data-dir/model-name/results` and model weights will be saved in `data-dir/model-name/finetuning_models` by default. Evaluation is done on the dev set by default. To customize the training, add `--hparams '{"hparam1": value1, "hparam2": value2, ...}'` to the run command. Some particularly useful options:
82
+
83
+ * `"debug": true` fine-tunes a tiny ELECTRA model for a few steps.
84
+ * `"task_names": ["task_name"]`: specifies the tasks to train on. A list because the codebase nominally supports multi-task learning, (although be warned this has not been thoroughly tested).
85
+ * `"model_size": one of "small", "base", or "large"`: determines the size of the model; you must set this to the same size as the pre-trained model.
86
+ * `"do_train" and "do_eval"`: train and/or evaluate a model (both are set to true by default). For using `"do_eval": true` with `"do_train": false`, you need to specify the `init_checkpoint`, e.g., `python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["mnli"], "do_train": false, "do_eval": true, "init_checkpoint": "<data-dir>/models/electra_base/finetuning_models/mnli_model_1"}'`
87
+ * `"num_trials": n`: If >1, does multiple fine-tuning/evaluation runs with different random seeds.
88
+ * `"learning_rate": lr, "train_batch_size": n`, etc. can be used to change training hyperparameters.
89
+ * `"model_hparam_overrides": {"hidden_size": n, "num_hidden_layers": m}`, etc. can be used to changed the hyperparameters for the underlying transformer (the `"model_size"` flag sets the default values).
90
+
91
+ ### Setup
92
+ Get a pre-trained ELECTRA model either by training your own (see pre-training instructions above), or downloading the release ELECTRA weights and unziping them under `$DATA_DIR/models` (e.g., you should have a directory`$DATA_DIR/models/electra_large` if you are using the large model).
93
+
94
+
95
+ ### Finetune ELECTRA on question answering
96
+
97
+ The code supports [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) 1.1 and 2.0, as well as datasets in [the 2019 MRQA shared task](https://github.com/mrqa/MRQA-Shared-Task-2019)
98
+
99
+ * **ARCD**: Download the train/dev datasets from `https://github.com/husseinmozannar/SOQAL` move them under `$DATA_DIR/finetuning_data/squadv1/(train|dev).json`
100
+
101
+ Then run (for example)
102
+ ```
103
+ python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["squad"]}'
104
+ ```
105
+
106
+ This repository uses the official evaluation code released by the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) authors
107
+
108
+ or you can use the `transformers` library as shown in the notebook `ARCD_pytorch.ipynb` of `Tydiqa_ar_pytorch.ipynb` from the examples folder
109
+
110
+ ### Finetune ELECTRA on sequence tagging
111
+
112
+ Download the CoNLL-2000 text chunking dataset from [here](https://www.clips.uantwerpen.be/conll2000/chunking/) and put it under `$DATA_DIR/finetuning_data/chunk/(train|dev).txt`. Then run
113
+ ```
114
+ python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["chunk"]}'
115
+ ```
116
+
117
+ ### Adding a new task
118
+ The easiest way to run on a new task is to implement a new `finetune.task.Task`, add it to `finetune.task_builder.py`, and then use `run_finetuning.py` as normal. For classification/qa/sequence tagging, you can inherit from a `finetune.classification.classification_tasks.ClassificationTask`, `finetune.qa.qa_tasks.QATask`, or `finetune.tagging.tagging_tasks.TaggingTask`.
119
+ For preprocessing data, we use the same tokenizer as [BERT](https://github.com/google-research/bert).
120
+
121
+
122
+
123
+
124
+ ## Citation
125
+
126
+ ## If you used this model please cite us as:
127
+ ```
128
+ @inproceedings{antoun-etal-2021-araelectra,
129
+ title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
130
+ author = "Antoun, Wissam and
131
+ Baly, Fady and
132
+ Hajj, Hazem",
133
+ booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
134
+ month = apr,
135
+ year = "2021",
136
+ address = "Kyiv, Ukraine (Virtual)",
137
+ publisher = "Association for Computational Linguistics",
138
+ url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
139
+ pages = "191--195",
140
+ }
141
+ ```
142
+
143
+
144
+
arabert/araelectra/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # coding=utf-8
arabert/araelectra/build_openwebtext_pretraining_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Preprocessess the Open WebText corpus for ELECTRA pre-training."""
17
+
18
+ import argparse
19
+ import multiprocessing
20
+ import os
21
+ import random
22
+ import tarfile
23
+ import time
24
+ import tensorflow as tf
25
+
26
+ import build_pretraining_dataset
27
+ from util import utils
28
+
29
+
30
+ def write_examples(job_id, args):
31
+ """A single process creating and writing out pre-processed examples."""
32
+ job_tmp_dir = os.path.join(args.data_dir, "tmp", "job_" + str(job_id))
33
+ owt_dir = os.path.join(args.data_dir, "openwebtext")
34
+
35
+ def log(*args):
36
+ msg = " ".join(map(str, args))
37
+ print("Job {}:".format(job_id), msg)
38
+
39
+ log("Creating example writer")
40
+ example_writer = build_pretraining_dataset.ExampleWriter(
41
+ job_id=job_id,
42
+ vocab_file=os.path.join(args.data_dir, "vocab.txt"),
43
+ output_dir=os.path.join(args.data_dir, "pretrain_tfrecords"),
44
+ max_seq_length=args.max_seq_length,
45
+ num_jobs=args.num_processes,
46
+ blanks_separate_docs=False,
47
+ do_lower_case=args.do_lower_case
48
+ )
49
+ log("Writing tf examples")
50
+ fnames = sorted(tf.io.gfile.listdir(owt_dir))
51
+ fnames = [f for (i, f) in enumerate(fnames)
52
+ if i % args.num_processes == job_id]
53
+ random.shuffle(fnames)
54
+ start_time = time.time()
55
+ for file_no, fname in enumerate(fnames):
56
+ if file_no > 0 and file_no % 10 == 0:
57
+ elapsed = time.time() - start_time
58
+ log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, "
59
+ "{:} examples written".format(
60
+ file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed),
61
+ int((len(fnames) - file_no) / (file_no / elapsed)),
62
+ example_writer.n_written))
63
+ utils.rmkdir(job_tmp_dir)
64
+ with tarfile.open(os.path.join(owt_dir, fname)) as f:
65
+ f.extractall(job_tmp_dir)
66
+ extracted_files = tf.io.gfile.listdir(job_tmp_dir)
67
+ random.shuffle(extracted_files)
68
+ for txt_fname in extracted_files:
69
+ example_writer.write_examples(os.path.join(job_tmp_dir, txt_fname))
70
+ example_writer.finish()
71
+ log("Done!")
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser(description=__doc__)
76
+ parser.add_argument("--data-dir", required=True,
77
+ help="Location of data (vocab file, corpus, etc).")
78
+ parser.add_argument("--max-seq-length", default=128, type=int,
79
+ help="Number of tokens per example.")
80
+ parser.add_argument("--num-processes", default=1, type=int,
81
+ help="Parallelize across multiple processes.")
82
+ parser.add_argument("--do-lower-case", dest='do_lower_case',
83
+ action='store_true', help="Lower case input text.")
84
+ parser.add_argument("--no-lower-case", dest='do_lower_case',
85
+ action='store_false', help="Don't lower case input text.")
86
+ parser.set_defaults(do_lower_case=True)
87
+ args = parser.parse_args()
88
+
89
+ utils.rmkdir(os.path.join(args.data_dir, "pretrain_tfrecords"))
90
+ if args.num_processes == 1:
91
+ write_examples(0, args)
92
+ else:
93
+ jobs = []
94
+ for i in range(args.num_processes):
95
+ job = multiprocessing.Process(target=write_examples, args=(i, args))
96
+ jobs.append(job)
97
+ job.start()
98
+ for job in jobs:
99
+ job.join()
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
arabert/araelectra/build_pretraining_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Writes out text data as tfrecords that ELECTRA can be pre-trained on."""
17
+
18
+ import argparse
19
+ import multiprocessing
20
+ import os
21
+ import random
22
+ import time
23
+ import tensorflow as tf
24
+
25
+ from model import tokenization
26
+ from util import utils
27
+
28
+
29
+ def create_int_feature(values):
30
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
31
+ return feature
32
+
33
+
34
+ class ExampleBuilder(object):
35
+ """Given a stream of input text, creates pretraining examples."""
36
+
37
+ def __init__(self, tokenizer, max_length):
38
+ self._tokenizer = tokenizer
39
+ self._current_sentences = []
40
+ self._current_length = 0
41
+ self._max_length = max_length
42
+ self._target_length = max_length
43
+
44
+ def add_line(self, line):
45
+ """Adds a line of text to the current example being built."""
46
+ line = line.strip().replace("\n", " ")
47
+ if (not line) and self._current_length != 0: # empty lines separate docs
48
+ return self._create_example()
49
+ bert_tokens = self._tokenizer.tokenize(line)
50
+ bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens)
51
+ self._current_sentences.append(bert_tokids)
52
+ self._current_length += len(bert_tokids)
53
+ if self._current_length >= self._target_length:
54
+ return self._create_example()
55
+ return None
56
+
57
+ def _create_example(self):
58
+ """Creates a pre-training example from the current list of sentences."""
59
+ # small chance to only have one segment as in classification tasks
60
+ if random.random() < 0.1:
61
+ first_segment_target_length = 100000
62
+ else:
63
+ # -3 due to not yet having [CLS]/[SEP] tokens in the input text
64
+ first_segment_target_length = (self._target_length - 3) // 2
65
+
66
+ first_segment = []
67
+ second_segment = []
68
+ for sentence in self._current_sentences:
69
+ # the sentence goes to the first segment if (1) the first segment is
70
+ # empty, (2) the sentence doesn't put the first segment over length or
71
+ # (3) 50% of the time when it does put the first segment over length
72
+ if (len(first_segment) == 0 or
73
+ len(first_segment) + len(sentence) < first_segment_target_length or
74
+ (len(second_segment) == 0 and
75
+ len(first_segment) < first_segment_target_length and
76
+ random.random() < 0.5)):
77
+ first_segment += sentence
78
+ else:
79
+ second_segment += sentence
80
+
81
+ # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
82
+ first_segment = first_segment[:self._max_length - 2]
83
+ second_segment = second_segment[:max(0, self._max_length -
84
+ len(first_segment) - 3)]
85
+
86
+ # prepare to start building the next example
87
+ self._current_sentences = []
88
+ self._current_length = 0
89
+ # small chance for random-length instead of max_length-length example
90
+ if random.random() < 0.05:
91
+ self._target_length = random.randint(5, self._max_length)
92
+ else:
93
+ self._target_length = self._max_length
94
+
95
+ return self._make_tf_example(first_segment, second_segment)
96
+
97
+ def _make_tf_example(self, first_segment, second_segment):
98
+ """Converts two "segments" of text into a tf.train.Example."""
99
+ vocab = self._tokenizer.vocab
100
+ input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]]
101
+ segment_ids = [0] * len(input_ids)
102
+ if second_segment:
103
+ input_ids += second_segment + [vocab["[SEP]"]]
104
+ segment_ids += [1] * (len(second_segment) + 1)
105
+ input_mask = [1] * len(input_ids)
106
+ input_ids += [0] * (self._max_length - len(input_ids))
107
+ input_mask += [0] * (self._max_length - len(input_mask))
108
+ segment_ids += [0] * (self._max_length - len(segment_ids))
109
+ tf_example = tf.train.Example(features=tf.train.Features(feature={
110
+ "input_ids": create_int_feature(input_ids),
111
+ "input_mask": create_int_feature(input_mask),
112
+ "segment_ids": create_int_feature(segment_ids)
113
+ }))
114
+ return tf_example
115
+
116
+
117
+ class ExampleWriter(object):
118
+ """Writes pre-training examples to disk."""
119
+
120
+ def __init__(self, job_id, vocab_file, output_dir, max_seq_length,
121
+ num_jobs, blanks_separate_docs, do_lower_case,
122
+ num_out_files=1000):
123
+ self._blanks_separate_docs = blanks_separate_docs
124
+ tokenizer = tokenization.FullTokenizer(
125
+ vocab_file=vocab_file,
126
+ do_lower_case=do_lower_case)
127
+ self._example_builder = ExampleBuilder(tokenizer, max_seq_length)
128
+ self._writers = []
129
+ for i in range(num_out_files):
130
+ if i % num_jobs == job_id:
131
+ output_fname = os.path.join(
132
+ output_dir, "pretrain_data.tfrecord-{:}-of-{:}".format(
133
+ i, num_out_files))
134
+ self._writers.append(tf.io.TFRecordWriter(output_fname))
135
+ self.n_written = 0
136
+
137
+ def write_examples(self, input_file):
138
+ """Writes out examples from the provided input file."""
139
+ with tf.io.gfile.GFile(input_file) as f:
140
+ for line in f:
141
+ line = line.strip()
142
+ if line or self._blanks_separate_docs:
143
+ example = self._example_builder.add_line(line)
144
+ if example:
145
+ self._writers[self.n_written % len(self._writers)].write(
146
+ example.SerializeToString())
147
+ self.n_written += 1
148
+ example = self._example_builder.add_line("")
149
+ if example:
150
+ self._writers[self.n_written % len(self._writers)].write(
151
+ example.SerializeToString())
152
+ self.n_written += 1
153
+
154
+ def finish(self):
155
+ for writer in self._writers:
156
+ writer.close()
157
+
158
+
159
+ def write_examples(job_id, args):
160
+ """A single process creating and writing out pre-processed examples."""
161
+
162
+ def log(*args):
163
+ msg = " ".join(map(str, args))
164
+ print("Job {}:".format(job_id), msg)
165
+
166
+ log("Creating example writer")
167
+ example_writer = ExampleWriter(
168
+ job_id=job_id,
169
+ vocab_file=args.vocab_file,
170
+ output_dir=args.output_dir,
171
+ max_seq_length=args.max_seq_length,
172
+ num_jobs=args.num_processes,
173
+ blanks_separate_docs=args.blanks_separate_docs,
174
+ do_lower_case=args.do_lower_case
175
+ )
176
+ log("Writing tf examples")
177
+ fnames = sorted(tf.io.gfile.listdir(args.corpus_dir))
178
+ fnames = [f for (i, f) in enumerate(fnames)
179
+ if i % args.num_processes == job_id]
180
+ random.shuffle(fnames)
181
+ start_time = time.time()
182
+ for file_no, fname in enumerate(fnames):
183
+ if file_no > 0:
184
+ elapsed = time.time() - start_time
185
+ log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, "
186
+ "{:} examples written".format(
187
+ file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed),
188
+ int((len(fnames) - file_no) / (file_no / elapsed)),
189
+ example_writer.n_written))
190
+ example_writer.write_examples(os.path.join(args.corpus_dir, fname))
191
+ example_writer.finish()
192
+ log("Done!")
193
+
194
+
195
+ def main():
196
+ parser = argparse.ArgumentParser(description=__doc__)
197
+ parser.add_argument("--corpus-dir", required=True,
198
+ help="Location of pre-training text files.")
199
+ parser.add_argument("--vocab-file", required=True,
200
+ help="Location of vocabulary file.")
201
+ parser.add_argument("--output-dir", required=True,
202
+ help="Where to write out the tfrecords.")
203
+ parser.add_argument("--max-seq-length", default=128, type=int,
204
+ help="Number of tokens per example.")
205
+ parser.add_argument("--num-processes", default=1, type=int,
206
+ help="Parallelize across multiple processes.")
207
+ parser.add_argument("--blanks-separate-docs", default=True, type=bool,
208
+ help="Whether blank lines indicate document boundaries.")
209
+ parser.add_argument("--do-lower-case", dest='do_lower_case',
210
+ action='store_true', help="Lower case input text.")
211
+ parser.add_argument("--no-lower-case", dest='do_lower_case',
212
+ action='store_false', help="Don't lower case input text.")
213
+ parser.set_defaults(do_lower_case=True)
214
+ args = parser.parse_args()
215
+
216
+ utils.rmkdir(args.output_dir)
217
+ if args.num_processes == 1:
218
+ write_examples(0, args)
219
+ else:
220
+ jobs = []
221
+ for i in range(args.num_processes):
222
+ job = multiprocessing.Process(target=write_examples, args=(i, args))
223
+ jobs.append(job)
224
+ job.start()
225
+ for job in jobs:
226
+ job.join()
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
arabert/araelectra/build_pretraining_dataset_single_file.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ import argparse
4
+ import os
5
+ import tensorflow as tf
6
+
7
+ import build_pretraining_dataset
8
+ from model import tokenization
9
+
10
+ class ExampleWriter(object):
11
+ """Writes pre-training examples to disk."""
12
+
13
+ def __init__(self, input_fname, vocab_file, output_dir, max_seq_length,
14
+ blanks_separate_docs, do_lower_case):
15
+ self._blanks_separate_docs = blanks_separate_docs
16
+ tokenizer = tokenization.FullTokenizer(
17
+ vocab_file=vocab_file,
18
+ do_lower_case=do_lower_case)
19
+ self._example_builder = build_pretraining_dataset.ExampleBuilder(tokenizer, max_seq_length)
20
+ output_fname = os.path.join(output_dir, "{}.tfrecord".format(input_fname.split("/")[-1]))
21
+ self._writer = tf.io.TFRecordWriter(output_fname)
22
+ self.n_written = 0
23
+
24
+ def write_examples(self, input_file):
25
+ """Writes out examples from the provided input file."""
26
+ with tf.io.gfile.GFile(input_file) as f:
27
+ for line in f:
28
+ line = line.strip()
29
+ if line or self._blanks_separate_docs:
30
+ example = self._example_builder.add_line(line)
31
+ if example:
32
+ self._writer.write(example.SerializeToString())
33
+ self.n_written += 1
34
+ example = self._example_builder.add_line("")
35
+ if example:
36
+ self._writer.write(example.SerializeToString())
37
+ self.n_written += 1
38
+
39
+ def finish(self):
40
+ self._writer.close()
41
+
42
+ def write_examples(args):
43
+ """A single process creating and writing out pre-processed examples."""
44
+
45
+ def log(*args):
46
+ msg = " ".join(map(str, args))
47
+ print(msg)
48
+
49
+ log("Creating example writer")
50
+ example_writer = ExampleWriter(
51
+ input_fname=args.input_file,
52
+ vocab_file=args.vocab_file,
53
+ output_dir=args.output_dir,
54
+ max_seq_length=args.max_seq_length,
55
+ blanks_separate_docs=args.blanks_separate_docs,
56
+ do_lower_case=args.do_lower_case
57
+ )
58
+ log("Writing tf example")
59
+
60
+ example_writer.write_examples(args.input_file)
61
+ example_writer.finish()
62
+ log("Done!")
63
+ return
64
+
65
+
66
+ def main():
67
+ parser = argparse.ArgumentParser(description=__doc__)
68
+ parser.add_argument("--input-file", required=True,
69
+ help="Location of pre-training text files.")
70
+ parser.add_argument("--vocab-file", required=True,
71
+ help="Location of vocabulary file.")
72
+ parser.add_argument("--output-dir", required=True,
73
+ help="Where to write out the tfrecords.")
74
+ parser.add_argument("--max-seq-length", default=128, type=int,
75
+ help="Number of tokens per example.")
76
+ parser.add_argument("--blanks-separate-docs", default=True, type=bool,
77
+ help="Whether blank lines indicate document boundaries.")
78
+ parser.add_argument("--do-lower-case", dest='do_lower_case',
79
+ action='store_true', help="Lower case input text.")
80
+ parser.add_argument("--no-lower-case", dest='do_lower_case',
81
+ action='store_false', help="Don't lower case input text.")
82
+ parser.set_defaults(do_lower_case=True)
83
+ args = parser.parse_args()
84
+
85
+ write_examples(args)
86
+
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
arabert/araelectra/configure_finetuning.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Config controlling hyperparameters for fine-tuning ELECTRA."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+
24
+ import tensorflow as tf
25
+
26
+
27
+ class FinetuningConfig(object):
28
+ """Fine-tuning hyperparameters."""
29
+
30
+ def __init__(self, model_name, data_dir, **kwargs):
31
+ # general
32
+ self.model_name = model_name
33
+ self.debug = False # debug mode for quickly running things
34
+ self.log_examples = False # print out some train examples for debugging
35
+ self.num_trials = 1 # how many train+eval runs to perform
36
+ self.do_train = True # train a model
37
+ self.do_eval = True # evaluate the model
38
+ self.keep_all_models = True # if False, only keep the last trial's ckpt
39
+
40
+ # model
41
+ self.model_size = "base" # one of "small", "base", or "large"
42
+ self.task_names = ["chunk"] # which tasks to learn
43
+ # override the default transformer hparams for the provided model size; see
44
+ # modeling.BertConfig for the possible hparams and util.training_utils for
45
+ # the defaults
46
+ self.model_hparam_overrides = (
47
+ kwargs["model_hparam_overrides"]
48
+ if "model_hparam_overrides" in kwargs else {})
49
+ self.embedding_size = None # bert hidden size by default
50
+ self.vocab_size = 64000 # number of tokens in the vocabulary
51
+ self.do_lower_case = True
52
+
53
+ # training
54
+ self.learning_rate = 1e-4
55
+ self.weight_decay_rate = 0.01
56
+ self.layerwise_lr_decay = 0.8 # if > 0, the learning rate for a layer is
57
+ # lr * lr_decay^(depth - max_depth) i.e.,
58
+ # shallower layers have lower learning rates
59
+ self.num_train_epochs = 3.0 # passes over the dataset during training
60
+ self.warmup_proportion = 0.1 # how much of training to warm up the LR for
61
+ self.save_checkpoints_steps = 1000000
62
+ self.iterations_per_loop = 1000
63
+ self.use_tfrecords_if_existing = True # don't make tfrecords and write them
64
+ # to disc if existing ones are found
65
+
66
+ # writing model outputs to disc
67
+ self.write_test_outputs = False # whether to write test set outputs,
68
+ # currently supported for GLUE + SQuAD 2.0
69
+ self.n_writes_test = 5 # write test set predictions for the first n trials
70
+
71
+ # sizing
72
+ self.max_seq_length = 128
73
+ self.train_batch_size = 32
74
+ self.eval_batch_size = 32
75
+ self.predict_batch_size = 32
76
+ self.double_unordered = True # for tasks like paraphrase where sentence
77
+ # order doesn't matter, train the model on
78
+ # on both sentence orderings for each example
79
+ # for qa tasks
80
+ self.max_query_length = 64 # max tokens in q as opposed to context
81
+ self.doc_stride = 128 # stride when splitting doc into multiple examples
82
+ self.n_best_size = 20 # number of predictions per example to save
83
+ self.max_answer_length = 30 # filter out answers longer than this length
84
+ self.answerable_classifier = True # answerable classifier for SQuAD 2.0
85
+ self.answerable_uses_start_logits = True # more advanced answerable
86
+ # classifier using predicted start
87
+ self.answerable_weight = 0.5 # weight for answerability loss
88
+ self.joint_prediction = True # jointly predict the start and end positions
89
+ # of the answer span
90
+ self.beam_size = 20 # beam size when doing joint predictions
91
+ self.qa_na_threshold = -2.75 # threshold for "no answer" when writing SQuAD
92
+ # 2.0 test outputs
93
+
94
+ # TPU settings
95
+ self.use_tpu = False
96
+ self.num_tpu_cores = 1
97
+ self.tpu_job_name = None
98
+ self.tpu_name = None # cloud TPU to use for training
99
+ self.tpu_zone = None # GCE zone where the Cloud TPU is located in
100
+ self.gcp_project = None # project name for the Cloud TPU-enabled project
101
+
102
+ # default locations of data files
103
+ self.data_dir = data_dir
104
+ pretrained_model_dir = os.path.join(data_dir, "models", model_name)
105
+ self.raw_data_dir = os.path.join(data_dir, "finetuning_data", "{:}").format
106
+ self.vocab_file = os.path.join(pretrained_model_dir, "vocab.txt")
107
+ if not tf.io.gfile.exists(self.vocab_file):
108
+ self.vocab_file = os.path.join(self.data_dir, "vocab.txt")
109
+ task_names_str = ",".join(
110
+ kwargs["task_names"] if "task_names" in kwargs else self.task_names)
111
+ self.init_checkpoint = None if self.debug else pretrained_model_dir
112
+ self.model_dir = os.path.join(pretrained_model_dir, "finetuning_models",
113
+ task_names_str + "_model")
114
+ results_dir = os.path.join(pretrained_model_dir, "results")
115
+ self.results_txt = os.path.join(results_dir,
116
+ task_names_str + "_results.txt")
117
+ self.results_pkl = os.path.join(results_dir,
118
+ task_names_str + "_results.pkl")
119
+ qa_topdir = os.path.join(results_dir, task_names_str + "_qa")
120
+ self.qa_eval_file = os.path.join(qa_topdir, "{:}_eval.json").format
121
+ self.qa_preds_file = os.path.join(qa_topdir, "{:}_preds.json").format
122
+ self.qa_na_file = os.path.join(qa_topdir, "{:}_null_odds.json").format
123
+ self.preprocessed_data_dir = os.path.join(
124
+ pretrained_model_dir, "finetuning_tfrecords",
125
+ task_names_str + "_tfrecords" + ("-debug" if self.debug else ""))
126
+ self.test_predictions = os.path.join(
127
+ pretrained_model_dir, "test_predictions",
128
+ "{:}_{:}_{:}_predictions.pkl").format
129
+
130
+ # update defaults with passed-in hyperparameters
131
+ self.update(kwargs)
132
+
133
+ # default hyperparameters for single-task models
134
+ if len(self.task_names) == 1:
135
+ task_name = self.task_names[0]
136
+ if task_name == "rte" or task_name == "sts":
137
+ self.num_train_epochs = 10.0
138
+ elif "squad" in task_name or "qa" in task_name:
139
+ self.max_seq_length = 512
140
+ self.num_train_epochs = 2.0
141
+ self.write_distill_outputs = False
142
+ self.write_test_outputs = False
143
+ elif task_name == "chunk":
144
+ self.max_seq_length = 256
145
+ else:
146
+ self.num_train_epochs = 3.0
147
+
148
+ # default hyperparameters for different model sizes
149
+ if self.model_size == "large":
150
+ self.learning_rate = 5e-5
151
+ self.layerwise_lr_decay = 0.9
152
+ elif self.model_size == "small":
153
+ self.embedding_size = 128
154
+
155
+ # debug-mode settings
156
+ if self.debug:
157
+ self.save_checkpoints_steps = 1000000
158
+ self.use_tfrecords_if_existing = False
159
+ self.num_trials = 1
160
+ self.iterations_per_loop = 1
161
+ self.train_batch_size = 32
162
+ self.num_train_epochs = 3.0
163
+ self.log_examples = True
164
+
165
+ # passed-in-arguments override (for example) debug-mode defaults
166
+ self.update(kwargs)
167
+
168
+ def update(self, kwargs):
169
+ for k, v in kwargs.items():
170
+ if k not in self.__dict__:
171
+ raise ValueError("Unknown hparam " + k)
172
+ self.__dict__[k] = v
arabert/araelectra/configure_pretraining.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Config controlling hyperparameters for pre-training ELECTRA."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+
24
+
25
+ class PretrainingConfig(object):
26
+ """Defines pre-training hyperparameters."""
27
+
28
+ def __init__(self, model_name, data_dir, **kwargs):
29
+ self.model_name = model_name
30
+ self.debug = False # debug mode for quickly running things
31
+ self.do_train = True # pre-train ELECTRA
32
+ self.do_eval = False # evaluate generator/discriminator on unlabeled data
33
+
34
+ # loss functions
35
+ # train ELECTRA or Electric? if both are false, trains a masked LM like BERT
36
+ self.electra_objective = True
37
+ self.electric_objective = False
38
+ self.gen_weight = 1.0 # masked language modeling / generator loss
39
+ self.disc_weight = 50.0 # discriminator loss
40
+ self.mask_prob = 0.15 # percent of input tokens to mask out / replace
41
+
42
+ # optimization
43
+ self.learning_rate = 2e-4
44
+ self.lr_decay_power = 1.0 # linear weight decay by default
45
+ self.weight_decay_rate = 0.01
46
+ self.num_warmup_steps = 10000
47
+
48
+ # training settings
49
+ self.iterations_per_loop = 5000
50
+ self.save_checkpoints_steps = 25000
51
+ self.num_train_steps = 2000000
52
+ self.num_eval_steps = 10000
53
+ self.keep_checkpoint_max = 0 # maximum number of recent checkpoint files to keep;
54
+ # change to 0 or None to keep all checkpoints
55
+
56
+ # model settings
57
+ self.model_size = "base" # one of "small", "base", or "large"
58
+ # override the default transformer hparams for the provided model size; see
59
+ # modeling.BertConfig for the possible hparams and util.training_utils for
60
+ # the defaults
61
+ self.model_hparam_overrides = (
62
+ kwargs["model_hparam_overrides"]
63
+ if "model_hparam_overrides" in kwargs else {})
64
+ self.embedding_size = None # bert hidden size by default
65
+ self.vocab_size = 64000 # number of tokens in the vocabulary
66
+ self.do_lower_case = False # lowercase the input?
67
+
68
+ # generator settings
69
+ self.uniform_generator = False # generator is uniform at random
70
+ self.two_tower_generator = False # generator is a two-tower cloze model
71
+ self.untied_generator_embeddings = False # tie generator/discriminator
72
+ # token embeddings?
73
+ self.untied_generator = True # tie all generator/discriminator weights?
74
+ self.generator_layers = 1.0 # frac of discriminator layers for generator
75
+ self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
76
+ self.disallow_correct = False # force the generator to sample incorrect
77
+ # tokens (so 15% of tokens are always
78
+ # fake)
79
+ self.temperature = 1.0 # temperature for sampling from generator
80
+
81
+ # batch sizes
82
+ self.max_seq_length = 512
83
+ self.train_batch_size = 256
84
+ self.eval_batch_size = 256
85
+
86
+ # TPU settings
87
+ self.use_tpu = True
88
+ self.num_tpu_cores = 8
89
+ self.tpu_job_name = None
90
+ self.tpu_name = "" # cloud TPU to use for training
91
+ self.tpu_zone = "" # GCE zone where the Cloud TPU is located in
92
+ self.gcp_project = "" # project name for the Cloud TPU-enabled project
93
+
94
+ # default locations of data files
95
+ self.pretrain_tfrecords = os.path.join(
96
+ data_dir, "pretraining_data/512/*")
97
+ self.vocab_file = os.path.join(data_dir, "bertvocab_final.txt")
98
+ self.model_dir = os.path.join(data_dir, "models", model_name)
99
+ results_dir = os.path.join(self.model_dir, "results")
100
+ self.results_txt = os.path.join(results_dir, "unsup_results.txt")
101
+ self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
102
+
103
+ # update defaults with passed-in hyperparameters
104
+ self.update(kwargs)
105
+
106
+ self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
107
+ self.max_seq_length)
108
+
109
+ # debug-mode settings
110
+ if self.debug:
111
+ self.train_batch_size = 8
112
+ self.num_train_steps = 20
113
+ self.eval_batch_size = 4
114
+ self.iterations_per_loop = 1
115
+ self.num_eval_steps = 2
116
+
117
+ # defaults for different-sized model
118
+ if self.model_size == "small":
119
+ self.embedding_size = 128
120
+ # Here are the hyperparameters we used for larger models; see Table 6 in the
121
+ # paper for the full hyperparameters
122
+ else:
123
+ self.max_seq_length = 512
124
+ self.learning_rate = 2e-4
125
+ if self.model_size == "base":
126
+ self.embedding_size = 768
127
+ self.generator_hidden_size = 0.33333
128
+ self.train_batch_size = 256
129
+ else:
130
+ self.embedding_size = 1024
131
+ self.mask_prob = 0.25
132
+ self.train_batch_size = 2048
133
+ if self.electric_objective:
134
+ self.two_tower_generator = True # electric requires a two-tower generator
135
+
136
+ # passed-in-arguments override (for example) debug-mode defaults
137
+ self.update(kwargs)
138
+
139
+ def update(self, kwargs):
140
+ for k, v in kwargs.items():
141
+ if k not in self.__dict__:
142
+ raise ValueError("Unknown hparam " + k)
143
+ self.__dict__[k] = v
arabert/araelectra/finetune/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
arabert/araelectra/finetune/classification/classification_metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluation metrics for classification tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ import numpy as np
24
+ import scipy
25
+ import sklearn
26
+
27
+ from finetune import scorer
28
+
29
+
30
+ class SentenceLevelScorer(scorer.Scorer):
31
+ """Abstract scorer for classification/regression tasks."""
32
+
33
+ __metaclass__ = abc.ABCMeta
34
+
35
+ def __init__(self):
36
+ super(SentenceLevelScorer, self).__init__()
37
+ self._total_loss = 0
38
+ self._true_labels = []
39
+ self._preds = []
40
+
41
+ def update(self, results):
42
+ super(SentenceLevelScorer, self).update(results)
43
+ self._total_loss += results['loss']
44
+ self._true_labels.append(results['label_ids'] if 'label_ids' in results
45
+ else results['targets'])
46
+ self._preds.append(results['predictions'])
47
+
48
+ def get_loss(self):
49
+ return self._total_loss / len(self._true_labels)
50
+
51
+
52
+ class AccuracyScorer(SentenceLevelScorer):
53
+
54
+ def _get_results(self):
55
+ correct, count = 0, 0
56
+ for y_true, pred in zip(self._true_labels, self._preds):
57
+ count += 1
58
+ correct += (1 if y_true == pred else 0)
59
+ return [
60
+ ('accuracy', 100.0 * correct / count),
61
+ ('loss', self.get_loss()),
62
+ ]
63
+
64
+
65
+ class F1Scorer(SentenceLevelScorer):
66
+ """Computes F1 for classification tasks."""
67
+
68
+ def __init__(self):
69
+ super(F1Scorer, self).__init__()
70
+ self._positive_label = 1
71
+
72
+ def _get_results(self):
73
+ n_correct, n_predicted, n_gold = 0, 0, 0
74
+ for y_true, pred in zip(self._true_labels, self._preds):
75
+ if pred == self._positive_label:
76
+ n_gold += 1
77
+ if pred == self._positive_label:
78
+ n_predicted += 1
79
+ if pred == y_true:
80
+ n_correct += 1
81
+ if n_correct == 0:
82
+ p, r, f1 = 0, 0, 0
83
+ else:
84
+ p = 100.0 * n_correct / n_predicted
85
+ r = 100.0 * n_correct / n_gold
86
+ f1 = 2 * p * r / (p + r)
87
+ return [
88
+ ('precision', p),
89
+ ('recall', r),
90
+ ('f1', f1),
91
+ ('loss', self.get_loss()),
92
+ ]
93
+
94
+
95
+ class MCCScorer(SentenceLevelScorer):
96
+
97
+ def _get_results(self):
98
+ return [
99
+ ('mcc', 100 * sklearn.metrics.matthews_corrcoef(
100
+ self._true_labels, self._preds)),
101
+ ('loss', self.get_loss()),
102
+ ]
103
+
104
+
105
+ class RegressionScorer(SentenceLevelScorer):
106
+
107
+ def _get_results(self):
108
+ preds = np.array(self._preds).flatten()
109
+ return [
110
+ ('pearson', 100.0 * scipy.stats.pearsonr(
111
+ self._true_labels, preds)[0]),
112
+ ('spearman', 100.0 * scipy.stats.spearmanr(
113
+ self._true_labels, preds)[0]),
114
+ ('mse', np.mean(np.square(np.array(self._true_labels) - self._preds))),
115
+ ('loss', self.get_loss()),
116
+ ]
arabert/araelectra/finetune/classification/classification_tasks.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Text classification and regression tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ import csv
24
+ import os
25
+ import tensorflow as tf
26
+
27
+ import configure_finetuning
28
+ from finetune import feature_spec
29
+ from finetune import task
30
+ from finetune.classification import classification_metrics
31
+ from model import tokenization
32
+ from util import utils
33
+
34
+
35
+ class InputExample(task.Example):
36
+ """A single training/test example for simple sequence classification."""
37
+
38
+ def __init__(self, eid, task_name, text_a, text_b=None, label=None):
39
+ super(InputExample, self).__init__(task_name)
40
+ self.eid = eid
41
+ self.text_a = text_a
42
+ self.text_b = text_b
43
+ self.label = label
44
+
45
+
46
+ class SingleOutputTask(task.Task):
47
+ """Task with a single prediction per example (e.g., text classification)."""
48
+
49
+ __metaclass__ = abc.ABCMeta
50
+
51
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
52
+ tokenizer):
53
+ super(SingleOutputTask, self).__init__(config, name)
54
+ self._tokenizer = tokenizer
55
+
56
+ def get_examples(self, split):
57
+ return self._create_examples(read_tsv(
58
+ os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"),
59
+ max_lines=100 if self.config.debug else None), split)
60
+
61
+ @abc.abstractmethod
62
+ def _create_examples(self, lines, split):
63
+ pass
64
+
65
+ def featurize(self, example: InputExample, is_training, log=False):
66
+ """Turn an InputExample into a dict of features."""
67
+ tokens_a = self._tokenizer.tokenize(example.text_a)
68
+ tokens_b = None
69
+ if example.text_b:
70
+ tokens_b = self._tokenizer.tokenize(example.text_b)
71
+
72
+ if tokens_b:
73
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
74
+ # length is less than the specified length.
75
+ # Account for [CLS], [SEP], [SEP] with "- 3"
76
+ _truncate_seq_pair(tokens_a, tokens_b, self.config.max_seq_length - 3)
77
+ else:
78
+ # Account for [CLS] and [SEP] with "- 2"
79
+ if len(tokens_a) > self.config.max_seq_length - 2:
80
+ tokens_a = tokens_a[0:(self.config.max_seq_length - 2)]
81
+
82
+ # The convention in BERT is:
83
+ # (a) For sequence pairs:
84
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
85
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
86
+ # (b) For single sequences:
87
+ # tokens: [CLS] the dog is hairy . [SEP]
88
+ # type_ids: 0 0 0 0 0 0 0
89
+ #
90
+ # Where "type_ids" are used to indicate whether this is the first
91
+ # sequence or the second sequence. The embedding vectors for `type=0` and
92
+ # `type=1` were learned during pre-training and are added to the wordpiece
93
+ # embedding vector (and position vector). This is not *strictly* necessary
94
+ # since the [SEP] token unambiguously separates the sequences, but it
95
+ # makes it easier for the model to learn the concept of sequences.
96
+ #
97
+ # For classification tasks, the first vector (corresponding to [CLS]) is
98
+ # used as the "sentence vector". Note that this only makes sense because
99
+ # the entire model is fine-tuned.
100
+ tokens = []
101
+ segment_ids = []
102
+ tokens.append("[CLS]")
103
+ segment_ids.append(0)
104
+ for token in tokens_a:
105
+ tokens.append(token)
106
+ segment_ids.append(0)
107
+ tokens.append("[SEP]")
108
+ segment_ids.append(0)
109
+
110
+ if tokens_b:
111
+ for token in tokens_b:
112
+ tokens.append(token)
113
+ segment_ids.append(1)
114
+ tokens.append("[SEP]")
115
+ segment_ids.append(1)
116
+
117
+ input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
118
+
119
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
120
+ # tokens are attended to.
121
+ input_mask = [1] * len(input_ids)
122
+
123
+ # Zero-pad up to the sequence length.
124
+ while len(input_ids) < self.config.max_seq_length:
125
+ input_ids.append(0)
126
+ input_mask.append(0)
127
+ segment_ids.append(0)
128
+
129
+ assert len(input_ids) == self.config.max_seq_length
130
+ assert len(input_mask) == self.config.max_seq_length
131
+ assert len(segment_ids) == self.config.max_seq_length
132
+
133
+ if log:
134
+ utils.log(" Example {:}".format(example.eid))
135
+ utils.log(" tokens: {:}".format(" ".join(
136
+ [tokenization.printable_text(x) for x in tokens])))
137
+ utils.log(" input_ids: {:}".format(" ".join(map(str, input_ids))))
138
+ utils.log(" input_mask: {:}".format(" ".join(map(str, input_mask))))
139
+ utils.log(" segment_ids: {:}".format(" ".join(map(str, segment_ids))))
140
+
141
+ eid = example.eid
142
+ features = {
143
+ "input_ids": input_ids,
144
+ "input_mask": input_mask,
145
+ "segment_ids": segment_ids,
146
+ "task_id": self.config.task_names.index(self.name),
147
+ self.name + "_eid": eid,
148
+ }
149
+ self._add_features(features, example, log)
150
+ return features
151
+
152
+ def _load_glue(self, lines, split, text_a_loc, text_b_loc, label_loc,
153
+ skip_first_line=False, eid_offset=0, swap=False):
154
+ examples = []
155
+ for (i, line) in enumerate(lines):
156
+ try:
157
+ if i == 0 and skip_first_line:
158
+ continue
159
+ eid = i - (1 if skip_first_line else 0) + eid_offset
160
+ text_a = tokenization.convert_to_unicode(line[text_a_loc])
161
+ if text_b_loc is None:
162
+ text_b = None
163
+ else:
164
+ text_b = tokenization.convert_to_unicode(line[text_b_loc])
165
+ if "test" in split or "diagnostic" in split:
166
+ label = self._get_dummy_label()
167
+ else:
168
+ label = tokenization.convert_to_unicode(line[label_loc])
169
+ if swap:
170
+ text_a, text_b = text_b, text_a
171
+ examples.append(InputExample(eid=eid, task_name=self.name,
172
+ text_a=text_a, text_b=text_b, label=label))
173
+ except Exception as ex:
174
+ utils.log("Error constructing example from line", i,
175
+ "for task", self.name + ":", ex)
176
+ utils.log("Input causing the error:", line)
177
+ return examples
178
+
179
+ @abc.abstractmethod
180
+ def _get_dummy_label(self):
181
+ pass
182
+
183
+ @abc.abstractmethod
184
+ def _add_features(self, features, example, log):
185
+ pass
186
+
187
+
188
+ class RegressionTask(SingleOutputTask):
189
+ """Task where the output is a real-valued score for the input text."""
190
+
191
+ __metaclass__ = abc.ABCMeta
192
+
193
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
194
+ tokenizer, min_value, max_value):
195
+ super(RegressionTask, self).__init__(config, name, tokenizer)
196
+ self._tokenizer = tokenizer
197
+ self._min_value = min_value
198
+ self._max_value = max_value
199
+
200
+ def _get_dummy_label(self):
201
+ return 0.0
202
+
203
+ def get_feature_specs(self):
204
+ feature_specs = [feature_spec.FeatureSpec(self.name + "_eid", []),
205
+ feature_spec.FeatureSpec(self.name + "_targets", [],
206
+ is_int_feature=False)]
207
+ return feature_specs
208
+
209
+ def _add_features(self, features, example, log):
210
+ label = float(example.label)
211
+ assert self._min_value <= label <= self._max_value
212
+ # simple normalization of the label
213
+ label = (label - self._min_value) / self._max_value
214
+ if log:
215
+ utils.log(" label: {:}".format(label))
216
+ features[example.task_name + "_targets"] = label
217
+
218
+ def get_prediction_module(self, bert_model, features, is_training,
219
+ percent_done):
220
+ reprs = bert_model.get_pooled_output()
221
+ if is_training:
222
+ reprs = tf.nn.dropout(reprs, keep_prob=0.9)
223
+
224
+ predictions = tf.layers.dense(reprs, 1)
225
+ predictions = tf.squeeze(predictions, -1)
226
+
227
+ targets = features[self.name + "_targets"]
228
+ losses = tf.square(predictions - targets)
229
+ outputs = dict(
230
+ loss=losses,
231
+ predictions=predictions,
232
+ targets=features[self.name + "_targets"],
233
+ eid=features[self.name + "_eid"]
234
+ )
235
+ return losses, outputs
236
+
237
+ def get_scorer(self):
238
+ return classification_metrics.RegressionScorer()
239
+
240
+
241
+ class ClassificationTask(SingleOutputTask):
242
+ """Task where the output is a single categorical label for the input text."""
243
+ __metaclass__ = abc.ABCMeta
244
+
245
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
246
+ tokenizer, label_list):
247
+ super(ClassificationTask, self).__init__(config, name, tokenizer)
248
+ self._tokenizer = tokenizer
249
+ self._label_list = label_list
250
+
251
+ def _get_dummy_label(self):
252
+ return self._label_list[0]
253
+
254
+ def get_feature_specs(self):
255
+ return [feature_spec.FeatureSpec(self.name + "_eid", []),
256
+ feature_spec.FeatureSpec(self.name + "_label_ids", [])]
257
+
258
+ def _add_features(self, features, example, log):
259
+ label_map = {}
260
+ for (i, label) in enumerate(self._label_list):
261
+ label_map[label] = i
262
+ label_id = label_map[example.label]
263
+ if log:
264
+ utils.log(" label: {:} (id = {:})".format(example.label, label_id))
265
+ features[example.task_name + "_label_ids"] = label_id
266
+
267
+ def get_prediction_module(self, bert_model, features, is_training,
268
+ percent_done):
269
+ num_labels = len(self._label_list)
270
+ reprs = bert_model.get_pooled_output()
271
+
272
+ if is_training:
273
+ reprs = tf.nn.dropout(reprs, keep_prob=0.9)
274
+
275
+ logits = tf.layers.dense(reprs, num_labels)
276
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
277
+
278
+ label_ids = features[self.name + "_label_ids"]
279
+ labels = tf.one_hot(label_ids, depth=num_labels, dtype=tf.float32)
280
+
281
+ losses = -tf.reduce_sum(labels * log_probs, axis=-1)
282
+
283
+ outputs = dict(
284
+ loss=losses,
285
+ logits=logits,
286
+ predictions=tf.argmax(logits, axis=-1),
287
+ label_ids=label_ids,
288
+ eid=features[self.name + "_eid"],
289
+ )
290
+ return losses, outputs
291
+
292
+ def get_scorer(self):
293
+ return classification_metrics.AccuracyScorer()
294
+
295
+
296
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
297
+ """Truncates a sequence pair in place to the maximum length."""
298
+
299
+ # This is a simple heuristic which will always truncate the longer sequence
300
+ # one token at a time. This makes more sense than truncating an equal percent
301
+ # of tokens from each, since if one sequence is very short then each token
302
+ # that's truncated likely contains more information than a longer sequence.
303
+ while True:
304
+ total_length = len(tokens_a) + len(tokens_b)
305
+ if total_length <= max_length:
306
+ break
307
+ if len(tokens_a) > len(tokens_b):
308
+ tokens_a.pop()
309
+ else:
310
+ tokens_b.pop()
311
+
312
+
313
+ def read_tsv(input_file, quotechar=None, max_lines=None):
314
+ """Reads a tab separated value file."""
315
+ with tf.io.gfile.GFile(input_file, "r") as f:
316
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
317
+ lines = []
318
+ for i, line in enumerate(reader):
319
+ if max_lines and i >= max_lines:
320
+ break
321
+ lines.append(line)
322
+ return lines
323
+
324
+
325
+ class MNLI(ClassificationTask):
326
+ """Multi-NLI."""
327
+
328
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
329
+ super(MNLI, self).__init__(config, "mnli", tokenizer,
330
+ ["contradiction", "entailment", "neutral"])
331
+
332
+ def get_examples(self, split):
333
+ if split == "dev":
334
+ split += "_matched"
335
+ return self._create_examples(read_tsv(
336
+ os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"),
337
+ max_lines=100 if self.config.debug else None), split)
338
+
339
+ def _create_examples(self, lines, split):
340
+ if split == "diagnostic":
341
+ return self._load_glue(lines, split, 1, 2, None, True)
342
+ else:
343
+ return self._load_glue(lines, split, 8, 9, -1, True)
344
+
345
+ def get_test_splits(self):
346
+ return ["test_matched", "test_mismatched", "diagnostic"]
347
+
348
+
349
+ class MRPC(ClassificationTask):
350
+ """Microsoft Research Paraphrase Corpus."""
351
+
352
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
353
+ super(MRPC, self).__init__(config, "mrpc", tokenizer, ["0", "1"])
354
+
355
+ def _create_examples(self, lines, split):
356
+ examples = []
357
+ examples += self._load_glue(lines, split, 3, 4, 0, True)
358
+ if self.config.double_unordered and split == "train":
359
+ examples += self._load_glue(
360
+ lines, split, 3, 4, 0, True, len(examples), True)
361
+ return examples
362
+
363
+
364
+ class CoLA(ClassificationTask):
365
+ """Corpus of Linguistic Acceptability."""
366
+
367
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
368
+ super(CoLA, self).__init__(config, "cola", tokenizer, ["0", "1"])
369
+
370
+ def _create_examples(self, lines, split):
371
+ return self._load_glue(lines, split, 1 if split == "test" else 3,
372
+ None, 1, split == "test")
373
+
374
+ def get_scorer(self):
375
+ return classification_metrics.MCCScorer()
376
+
377
+
378
+ class SST(ClassificationTask):
379
+ """Stanford Sentiment Treebank."""
380
+
381
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
382
+ super(SST, self).__init__(config, "sst", tokenizer, ["0", "1"])
383
+
384
+ def _create_examples(self, lines, split):
385
+ if "test" in split:
386
+ return self._load_glue(lines, split, 1, None, None, True)
387
+ else:
388
+ return self._load_glue(lines, split, 0, None, 1, True)
389
+
390
+
391
+ class QQP(ClassificationTask):
392
+ """Quora Question Pair."""
393
+
394
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
395
+ super(QQP, self).__init__(config, "qqp", tokenizer, ["0", "1"])
396
+
397
+ def _create_examples(self, lines, split):
398
+ return self._load_glue(lines, split, 1 if split == "test" else 3,
399
+ 2 if split == "test" else 4, 5, True)
400
+
401
+
402
+ class RTE(ClassificationTask):
403
+ """Recognizing Textual Entailment."""
404
+
405
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
406
+ super(RTE, self).__init__(config, "rte", tokenizer,
407
+ ["entailment", "not_entailment"])
408
+
409
+ def _create_examples(self, lines, split):
410
+ return self._load_glue(lines, split, 1, 2, 3, True)
411
+
412
+
413
+ class QNLI(ClassificationTask):
414
+ """Question NLI."""
415
+
416
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
417
+ super(QNLI, self).__init__(config, "qnli", tokenizer,
418
+ ["entailment", "not_entailment"])
419
+
420
+ def _create_examples(self, lines, split):
421
+ return self._load_glue(lines, split, 1, 2, 3, True)
422
+
423
+
424
+ class STS(RegressionTask):
425
+ """Semantic Textual Similarity."""
426
+
427
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
428
+ super(STS, self).__init__(config, "sts", tokenizer, 0.0, 5.0)
429
+
430
+ def _create_examples(self, lines, split):
431
+ examples = []
432
+ if split == "test":
433
+ examples += self._load_glue(lines, split, -2, -1, None, True)
434
+ else:
435
+ examples += self._load_glue(lines, split, -3, -2, -1, True)
436
+ if self.config.double_unordered and split == "train":
437
+ examples += self._load_glue(
438
+ lines, split, -3, -2, -1, True, len(examples), True)
439
+ return examples
arabert/araelectra/finetune/feature_spec.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Defines the inputs used when fine-tuning a model."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import numpy as np
23
+ import tensorflow as tf
24
+
25
+ import configure_finetuning
26
+
27
+
28
+ def get_shared_feature_specs(config: configure_finetuning.FinetuningConfig):
29
+ """Non-task-specific model inputs."""
30
+ return [
31
+ FeatureSpec("input_ids", [config.max_seq_length]),
32
+ FeatureSpec("input_mask", [config.max_seq_length]),
33
+ FeatureSpec("segment_ids", [config.max_seq_length]),
34
+ FeatureSpec("task_id", []),
35
+ ]
36
+
37
+
38
+ class FeatureSpec(object):
39
+ """Defines a feature passed as input to the model."""
40
+
41
+ def __init__(self, name, shape, default_value_fn=None, is_int_feature=True):
42
+ self.name = name
43
+ self.shape = shape
44
+ self.default_value_fn = default_value_fn
45
+ self.is_int_feature = is_int_feature
46
+
47
+ def get_parsing_spec(self):
48
+ return tf.io.FixedLenFeature(
49
+ self.shape, tf.int64 if self.is_int_feature else tf.float32)
50
+
51
+ def get_default_values(self):
52
+ if self.default_value_fn:
53
+ return self.default_value_fn(self.shape)
54
+ else:
55
+ return np.zeros(
56
+ self.shape, np.int64 if self.is_int_feature else np.float32)
arabert/araelectra/finetune/preprocessing.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Code for serializing raw fine-tuning data into tfrecords"""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import collections
23
+ import os
24
+ import random
25
+ import numpy as np
26
+ import tensorflow as tf
27
+
28
+ import configure_finetuning
29
+ from finetune import feature_spec
30
+ from util import utils
31
+
32
+
33
+ class Preprocessor(object):
34
+ """Class for loading, preprocessing, and serializing fine-tuning datasets."""
35
+
36
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tasks):
37
+ self._config = config
38
+ self._tasks = tasks
39
+ self._name_to_task = {task.name: task for task in tasks}
40
+
41
+ self._feature_specs = feature_spec.get_shared_feature_specs(config)
42
+ for task in tasks:
43
+ self._feature_specs += task.get_feature_specs()
44
+ self._name_to_feature_config = {
45
+ spec.name: spec.get_parsing_spec()
46
+ for spec in self._feature_specs
47
+ }
48
+ assert len(self._name_to_feature_config) == len(self._feature_specs)
49
+
50
+ def prepare_train(self):
51
+ return self._serialize_dataset(self._tasks, True, "train")
52
+
53
+ def prepare_predict(self, tasks, split):
54
+ return self._serialize_dataset(tasks, False, split)
55
+
56
+ def _serialize_dataset(self, tasks, is_training, split):
57
+ """Write out the dataset as tfrecords."""
58
+ dataset_name = "_".join(sorted([task.name for task in tasks]))
59
+ dataset_name += "_" + split
60
+ dataset_prefix = os.path.join(
61
+ self._config.preprocessed_data_dir, dataset_name)
62
+ tfrecords_path = dataset_prefix + ".tfrecord"
63
+ metadata_path = dataset_prefix + ".metadata"
64
+ batch_size = (self._config.train_batch_size if is_training else
65
+ self._config.eval_batch_size)
66
+
67
+ utils.log("Loading dataset", dataset_name)
68
+ n_examples = None
69
+ if (self._config.use_tfrecords_if_existing and
70
+ tf.io.gfile.exists(metadata_path)):
71
+ n_examples = utils.load_json(metadata_path)["n_examples"]
72
+
73
+ if n_examples is None:
74
+ utils.log("Existing tfrecords not found so creating")
75
+ examples = []
76
+ for task in tasks:
77
+ task_examples = task.get_examples(split)
78
+ examples += task_examples
79
+ if is_training:
80
+ random.shuffle(examples)
81
+ utils.mkdir(tfrecords_path.rsplit("/", 1)[0])
82
+ n_examples = self.serialize_examples(
83
+ examples, is_training, tfrecords_path, batch_size)
84
+ utils.write_json({"n_examples": n_examples}, metadata_path)
85
+
86
+ input_fn = self._input_fn_builder(tfrecords_path, is_training)
87
+ if is_training:
88
+ steps = int(n_examples // batch_size * self._config.num_train_epochs)
89
+ else:
90
+ steps = n_examples // batch_size
91
+
92
+ return input_fn, steps
93
+
94
+ def serialize_examples(self, examples, is_training, output_file, batch_size):
95
+ """Convert a set of `InputExample`s to a TFRecord file."""
96
+ n_examples = 0
97
+ with tf.io.TFRecordWriter(output_file) as writer:
98
+ for (ex_index, example) in enumerate(examples):
99
+ if ex_index % 2000 == 0:
100
+ utils.log("Writing example {:} of {:}".format(
101
+ ex_index, len(examples)))
102
+ for tf_example in self._example_to_tf_example(
103
+ example, is_training,
104
+ log=self._config.log_examples and ex_index < 1):
105
+ writer.write(tf_example.SerializeToString())
106
+ n_examples += 1
107
+ # add padding so the dataset is a multiple of batch_size
108
+ while n_examples % batch_size != 0:
109
+ writer.write(self._make_tf_example(task_id=len(self._config.task_names))
110
+ .SerializeToString())
111
+ n_examples += 1
112
+ return n_examples
113
+
114
+ def _example_to_tf_example(self, example, is_training, log=False):
115
+ examples = self._name_to_task[example.task_name].featurize(
116
+ example, is_training, log)
117
+ if not isinstance(examples, list):
118
+ examples = [examples]
119
+ for example in examples:
120
+ yield self._make_tf_example(**example)
121
+
122
+ def _make_tf_example(self, **kwargs):
123
+ """Make a tf.train.Example from the provided features."""
124
+ for k in kwargs:
125
+ if k not in self._name_to_feature_config:
126
+ raise ValueError("Unknown feature", k)
127
+ features = collections.OrderedDict()
128
+ for spec in self._feature_specs:
129
+ if spec.name in kwargs:
130
+ values = kwargs[spec.name]
131
+ else:
132
+ values = spec.get_default_values()
133
+ if (isinstance(values, int) or isinstance(values, bool) or
134
+ isinstance(values, float) or isinstance(values, np.float32) or
135
+ (isinstance(values, np.ndarray) and values.size == 1)):
136
+ values = [values]
137
+ if spec.is_int_feature:
138
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(
139
+ value=list(values)))
140
+ else:
141
+ feature = tf.train.Feature(float_list=tf.train.FloatList(
142
+ value=list(values)))
143
+ features[spec.name] = feature
144
+ return tf.train.Example(features=tf.train.Features(feature=features))
145
+
146
+ def _input_fn_builder(self, input_file, is_training):
147
+ """Creates an `input_fn` closure to be passed to TPUEstimator."""
148
+
149
+ def input_fn(params):
150
+ """The actual input function."""
151
+ d = tf.data.TFRecordDataset(input_file)
152
+ if is_training:
153
+ d = d.repeat()
154
+ d = d.shuffle(buffer_size=100)
155
+ return d.apply(
156
+ tf.data.experimental.map_and_batch(
157
+ self._decode_tfrecord,
158
+ batch_size=params["batch_size"],
159
+ drop_remainder=True))
160
+
161
+ return input_fn
162
+
163
+ def _decode_tfrecord(self, record):
164
+ """Decodes a record to a TensorFlow example."""
165
+ example = tf.io.parse_single_example(record, self._name_to_feature_config)
166
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
167
+ # So cast all int64 to int32.
168
+ for name, tensor in example.items():
169
+ if tensor.dtype == tf.int64:
170
+ example[name] = tf.cast(tensor, tf.int32)
171
+ else:
172
+ example[name] = tensor
173
+ return example
arabert/araelectra/finetune/qa/mrqa_official_eval.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Official evaluation script for the MRQA Workshop Shared Task.
17
+ Adapted fromt the SQuAD v1.1 official evaluation script.
18
+ Modified slightly for the ELECTRA codebase.
19
+ """
20
+ from __future__ import absolute_import
21
+ from __future__ import division
22
+ from __future__ import print_function
23
+
24
+ import os
25
+ import string
26
+ import re
27
+ import json
28
+ import tensorflow as tf
29
+ from collections import Counter
30
+
31
+ import configure_finetuning
32
+
33
+
34
+ def normalize_answer(s):
35
+ """Lower text and remove punctuation, articles and extra whitespace."""
36
+ def remove_articles(text):
37
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
38
+
39
+ def white_space_fix(text):
40
+ return ' '.join(text.split())
41
+
42
+ def remove_punc(text):
43
+ exclude = set(string.punctuation)
44
+ return ''.join(ch for ch in text if ch not in exclude)
45
+
46
+ def lower(text):
47
+ return text.lower()
48
+
49
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
50
+
51
+
52
+ def f1_score(prediction, ground_truth):
53
+ prediction_tokens = normalize_answer(prediction).split()
54
+ ground_truth_tokens = normalize_answer(ground_truth).split()
55
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
56
+ num_same = sum(common.values())
57
+ if num_same == 0:
58
+ return 0
59
+ precision = 1.0 * num_same / len(prediction_tokens)
60
+ recall = 1.0 * num_same / len(ground_truth_tokens)
61
+ f1 = (2 * precision * recall) / (precision + recall)
62
+ return f1
63
+
64
+
65
+ def exact_match_score(prediction, ground_truth):
66
+ return (normalize_answer(prediction) == normalize_answer(ground_truth))
67
+
68
+
69
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
70
+ scores_for_ground_truths = []
71
+ for ground_truth in ground_truths:
72
+ score = metric_fn(prediction, ground_truth)
73
+ scores_for_ground_truths.append(score)
74
+ return max(scores_for_ground_truths)
75
+
76
+
77
+ def read_predictions(prediction_file):
78
+ with tf.io.gfile.GFile(prediction_file) as f:
79
+ predictions = json.load(f)
80
+ return predictions
81
+
82
+
83
+ def read_answers(gold_file):
84
+ answers = {}
85
+ with tf.io.gfile.GFile(gold_file, 'r') as f:
86
+ for i, line in enumerate(f):
87
+ example = json.loads(line)
88
+ if i == 0 and 'header' in example:
89
+ continue
90
+ for qa in example['qas']:
91
+ answers[qa['qid']] = qa['answers']
92
+ return answers
93
+
94
+
95
+ def evaluate(answers, predictions, skip_no_answer=False):
96
+ f1 = exact_match = total = 0
97
+ for qid, ground_truths in answers.items():
98
+ if qid not in predictions:
99
+ if not skip_no_answer:
100
+ message = 'Unanswered question %s will receive score 0.' % qid
101
+ print(message)
102
+ total += 1
103
+ continue
104
+ total += 1
105
+ prediction = predictions[qid]
106
+ exact_match += metric_max_over_ground_truths(
107
+ exact_match_score, prediction, ground_truths)
108
+ f1 += metric_max_over_ground_truths(
109
+ f1_score, prediction, ground_truths)
110
+
111
+ exact_match = 100.0 * exact_match / total
112
+ f1 = 100.0 * f1 / total
113
+
114
+ return {'exact_match': exact_match, 'f1': f1}
115
+
116
+
117
+ def main(config: configure_finetuning.FinetuningConfig, split, task_name):
118
+ answers = read_answers(os.path.join(config.raw_data_dir(task_name), split + ".jsonl"))
119
+ predictions = read_predictions(config.qa_preds_file(task_name))
120
+ return evaluate(answers, predictions, True)
arabert/araelectra/finetune/qa/qa_metrics.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Evaluation metrics for question-answering tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import collections
23
+ import numpy as np
24
+ import six
25
+
26
+ import configure_finetuning
27
+ from finetune import scorer
28
+ from finetune.qa import mrqa_official_eval
29
+ from finetune.qa import squad_official_eval
30
+ from finetune.qa import squad_official_eval_v1
31
+ from model import tokenization
32
+ from util import utils
33
+
34
+
35
+ RawResult = collections.namedtuple("RawResult", [
36
+ "unique_id", "start_logits", "end_logits", "answerable_logit",
37
+ "start_top_log_probs", "start_top_index", "end_top_log_probs",
38
+ "end_top_index"
39
+ ])
40
+
41
+
42
+ class SpanBasedQAScorer(scorer.Scorer):
43
+ """Runs evaluation for SQuAD 1.1, SQuAD 2.0, and MRQA tasks."""
44
+
45
+ def __init__(self, config: configure_finetuning.FinetuningConfig, task, split,
46
+ v2):
47
+ super(SpanBasedQAScorer, self).__init__()
48
+ self._config = config
49
+ self._task = task
50
+ self._name = task.name
51
+ self._split = split
52
+ self._v2 = v2
53
+ self._all_results = []
54
+ self._total_loss = 0
55
+ self._split = split
56
+ self._eval_examples = task.get_examples(split)
57
+
58
+ def update(self, results):
59
+ super(SpanBasedQAScorer, self).update(results)
60
+ self._all_results.append(
61
+ RawResult(
62
+ unique_id=results["eid"],
63
+ start_logits=results["start_logits"],
64
+ end_logits=results["end_logits"],
65
+ answerable_logit=results["answerable_logit"],
66
+ start_top_log_probs=results["start_top_log_probs"],
67
+ start_top_index=results["start_top_index"],
68
+ end_top_log_probs=results["end_top_log_probs"],
69
+ end_top_index=results["end_top_index"],
70
+ ))
71
+ self._total_loss += results["loss"]
72
+
73
+ def get_loss(self):
74
+ return self._total_loss / len(self._all_results)
75
+
76
+ def _get_results(self):
77
+ self.write_predictions()
78
+ if self._name == "squad":
79
+ squad_official_eval.set_opts(self._config, self._split)
80
+ squad_official_eval.main()
81
+ return sorted(utils.load_json(
82
+ self._config.qa_eval_file(self._name)).items())
83
+ elif self._name == "squadv1":
84
+ return sorted(squad_official_eval_v1.main(
85
+ self._config, self._split).items())
86
+ else:
87
+ return sorted(mrqa_official_eval.main(
88
+ self._config, self._split, self._name).items())
89
+
90
+ def write_predictions(self):
91
+ """Write final predictions to the json file."""
92
+ unique_id_to_result = {}
93
+ for result in self._all_results:
94
+ unique_id_to_result[result.unique_id] = result
95
+
96
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
97
+ "PrelimPrediction",
98
+ ["feature_index", "start_index", "end_index", "start_logit",
99
+ "end_logit"])
100
+
101
+ all_predictions = collections.OrderedDict()
102
+ all_nbest_json = collections.OrderedDict()
103
+ scores_diff_json = collections.OrderedDict()
104
+
105
+ for example in self._eval_examples:
106
+ example_id = example.qas_id if "squad" in self._name else example.qid
107
+ features = self._task.featurize(example, False, for_eval=True)
108
+
109
+ prelim_predictions = []
110
+ # keep track of the minimum score of null start+end of position 0
111
+ score_null = 1000000 # large and positive
112
+ for (feature_index, feature) in enumerate(features):
113
+ result = unique_id_to_result[feature[self._name + "_eid"]]
114
+ if self._config.joint_prediction:
115
+ start_indexes = result.start_top_index
116
+ end_indexes = result.end_top_index
117
+ else:
118
+ start_indexes = _get_best_indexes(result.start_logits,
119
+ self._config.n_best_size)
120
+ end_indexes = _get_best_indexes(result.end_logits,
121
+ self._config.n_best_size)
122
+ # if we could have irrelevant answers, get the min score of irrelevant
123
+ if self._v2:
124
+ if self._config.answerable_classifier:
125
+ feature_null_score = result.answerable_logit
126
+ else:
127
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
128
+ if feature_null_score < score_null:
129
+ score_null = feature_null_score
130
+ for i, start_index in enumerate(start_indexes):
131
+ for j, end_index in enumerate(
132
+ end_indexes[i] if self._config.joint_prediction else end_indexes):
133
+ # We could hypothetically create invalid predictions, e.g., predict
134
+ # that the start of the span is in the question. We throw out all
135
+ # invalid predictions.
136
+ if start_index >= len(feature[self._name + "_tokens"]):
137
+ continue
138
+ if end_index >= len(feature[self._name + "_tokens"]):
139
+ continue
140
+ if start_index == 0:
141
+ continue
142
+ if start_index not in feature[self._name + "_token_to_orig_map"]:
143
+ continue
144
+ if end_index not in feature[self._name + "_token_to_orig_map"]:
145
+ continue
146
+ if not feature[self._name + "_token_is_max_context"].get(
147
+ start_index, False):
148
+ continue
149
+ if end_index < start_index:
150
+ continue
151
+ length = end_index - start_index + 1
152
+ if length > self._config.max_answer_length:
153
+ continue
154
+ start_logit = (result.start_top_log_probs[i] if
155
+ self._config.joint_prediction else
156
+ result.start_logits[start_index])
157
+ end_logit = (result.end_top_log_probs[i, j] if
158
+ self._config.joint_prediction else
159
+ result.end_logits[end_index])
160
+ prelim_predictions.append(
161
+ _PrelimPrediction(
162
+ feature_index=feature_index,
163
+ start_index=start_index,
164
+ end_index=end_index,
165
+ start_logit=start_logit,
166
+ end_logit=end_logit))
167
+
168
+ if self._v2:
169
+ if len(prelim_predictions) == 0 and self._config.debug:
170
+ tokid = sorted(feature[self._name + "_token_to_orig_map"].keys())[0]
171
+ prelim_predictions.append(_PrelimPrediction(
172
+ feature_index=0,
173
+ start_index=tokid,
174
+ end_index=tokid + 1,
175
+ start_logit=1.0,
176
+ end_logit=1.0))
177
+ prelim_predictions = sorted(
178
+ prelim_predictions,
179
+ key=lambda x: (x.start_logit + x.end_logit),
180
+ reverse=True)
181
+
182
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
183
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
184
+
185
+ seen_predictions = {}
186
+ nbest = []
187
+ for pred in prelim_predictions:
188
+ if len(nbest) >= self._config.n_best_size:
189
+ break
190
+ feature = features[pred.feature_index]
191
+ tok_tokens = feature[self._name + "_tokens"][
192
+ pred.start_index:(pred.end_index + 1)]
193
+ orig_doc_start = feature[
194
+ self._name + "_token_to_orig_map"][pred.start_index]
195
+ orig_doc_end = feature[
196
+ self._name + "_token_to_orig_map"][pred.end_index]
197
+ orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
198
+ tok_text = " ".join(tok_tokens)
199
+
200
+ # De-tokenize WordPieces that have been split off.
201
+ tok_text = tok_text.replace(" ##", "")
202
+ tok_text = tok_text.replace("##", "")
203
+
204
+ # Clean whitespace
205
+ tok_text = tok_text.strip()
206
+ tok_text = " ".join(tok_text.split())
207
+ orig_text = " ".join(orig_tokens)
208
+
209
+ final_text = get_final_text(self._config, tok_text, orig_text)
210
+ if final_text in seen_predictions:
211
+ continue
212
+
213
+ seen_predictions[final_text] = True
214
+
215
+ nbest.append(
216
+ _NbestPrediction(
217
+ text=final_text,
218
+ start_logit=pred.start_logit,
219
+ end_logit=pred.end_logit))
220
+
221
+ # In very rare edge cases we could have no valid predictions. So we
222
+ # just create a nonce prediction in this case to avoid failure.
223
+ if not nbest:
224
+ nbest.append(
225
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
226
+
227
+ assert len(nbest) >= 1
228
+
229
+ total_scores = []
230
+ best_non_null_entry = None
231
+ for entry in nbest:
232
+ total_scores.append(entry.start_logit + entry.end_logit)
233
+ if not best_non_null_entry:
234
+ if entry.text:
235
+ best_non_null_entry = entry
236
+
237
+ probs = _compute_softmax(total_scores)
238
+
239
+ nbest_json = []
240
+ for (i, entry) in enumerate(nbest):
241
+ output = collections.OrderedDict()
242
+ output["text"] = entry.text
243
+ output["probability"] = probs[i]
244
+ output["start_logit"] = entry.start_logit
245
+ output["end_logit"] = entry.end_logit
246
+ nbest_json.append(dict(output))
247
+
248
+ assert len(nbest_json) >= 1
249
+
250
+ if not self._v2:
251
+ all_predictions[example_id] = nbest_json[0]["text"]
252
+ else:
253
+ # predict "" iff the null score - the score of best non-null > threshold
254
+ if self._config.answerable_classifier:
255
+ score_diff = score_null
256
+ else:
257
+ score_diff = score_null - best_non_null_entry.start_logit - (
258
+ best_non_null_entry.end_logit)
259
+ scores_diff_json[example_id] = score_diff
260
+ all_predictions[example_id] = best_non_null_entry.text
261
+
262
+ all_nbest_json[example_id] = nbest_json
263
+
264
+ utils.write_json(dict(all_predictions),
265
+ self._config.qa_preds_file(self._name))
266
+ if self._v2:
267
+ utils.write_json({
268
+ k: float(v) for k, v in six.iteritems(scores_diff_json)},
269
+ self._config.qa_na_file(self._name))
270
+
271
+
272
+ def _get_best_indexes(logits, n_best_size):
273
+ """Get the n-best logits from a list."""
274
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
275
+
276
+ best_indexes = []
277
+ for i in range(len(index_and_score)):
278
+ if i >= n_best_size:
279
+ break
280
+ best_indexes.append(index_and_score[i][0])
281
+ return best_indexes
282
+
283
+
284
+ def _compute_softmax(scores):
285
+ """Compute softmax probability over raw logits."""
286
+ if not scores:
287
+ return []
288
+
289
+ max_score = None
290
+ for score in scores:
291
+ if max_score is None or score > max_score:
292
+ max_score = score
293
+
294
+ exp_scores = []
295
+ total_sum = 0.0
296
+ for score in scores:
297
+ x = np.exp(score - max_score)
298
+ exp_scores.append(x)
299
+ total_sum += x
300
+
301
+ probs = []
302
+ for score in exp_scores:
303
+ probs.append(score / total_sum)
304
+ return probs
305
+
306
+
307
+ def get_final_text(config: configure_finetuning.FinetuningConfig, pred_text,
308
+ orig_text):
309
+ """Project the tokenized prediction back to the original text."""
310
+
311
+ # When we created the data, we kept track of the alignment between original
312
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
313
+ # now `orig_text` contains the span of our original text corresponding to the
314
+ # span that we predicted.
315
+ #
316
+ # However, `orig_text` may contain extra characters that we don't want in
317
+ # our prediction.
318
+ #
319
+ # For example, let's say:
320
+ # pred_text = steve smith
321
+ # orig_text = Steve Smith's
322
+ #
323
+ # We don't want to return `orig_text` because it contains the extra "'s".
324
+ #
325
+ # We don't want to return `pred_text` because it's already been normalized
326
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
327
+ # our tokenizer does additional normalization like stripping accent
328
+ # characters).
329
+ #
330
+ # What we really want to return is "Steve Smith".
331
+ #
332
+ # Therefore, we have to apply a semi-complicated alignment heruistic between
333
+ # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
334
+ # can fail in certain cases in which case we just return `orig_text`.
335
+
336
+ def _strip_spaces(text):
337
+ ns_chars = []
338
+ ns_to_s_map = collections.OrderedDict()
339
+ for i, c in enumerate(text):
340
+ if c == " ":
341
+ continue
342
+ ns_to_s_map[len(ns_chars)] = i
343
+ ns_chars.append(c)
344
+ ns_text = "".join(ns_chars)
345
+ return ns_text, dict(ns_to_s_map)
346
+
347
+ # We first tokenize `orig_text`, strip whitespace from the result
348
+ # and `pred_text`, and check if they are the same length. If they are
349
+ # NOT the same length, the heuristic has failed. If they are the same
350
+ # length, we assume the characters are one-to-one aligned.
351
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=config.do_lower_case)
352
+
353
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
354
+
355
+ start_position = tok_text.find(pred_text)
356
+ if start_position == -1:
357
+ if config.debug:
358
+ utils.log(
359
+ "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
360
+ return orig_text
361
+ end_position = start_position + len(pred_text) - 1
362
+
363
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
364
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
365
+
366
+ if len(orig_ns_text) != len(tok_ns_text):
367
+ if config.debug:
368
+ utils.log("Length not equal after stripping spaces: '%s' vs '%s'",
369
+ orig_ns_text, tok_ns_text)
370
+ return orig_text
371
+
372
+ # We then project the characters in `pred_text` back to `orig_text` using
373
+ # the character-to-character alignment.
374
+ tok_s_to_ns_map = {}
375
+ for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
376
+ tok_s_to_ns_map[tok_index] = i
377
+
378
+ orig_start_position = None
379
+ if start_position in tok_s_to_ns_map:
380
+ ns_start_position = tok_s_to_ns_map[start_position]
381
+ if ns_start_position in orig_ns_to_s_map:
382
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
383
+
384
+ if orig_start_position is None:
385
+ if config.debug:
386
+ utils.log("Couldn't map start position")
387
+ return orig_text
388
+
389
+ orig_end_position = None
390
+ if end_position in tok_s_to_ns_map:
391
+ ns_end_position = tok_s_to_ns_map[end_position]
392
+ if ns_end_position in orig_ns_to_s_map:
393
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
394
+
395
+ if orig_end_position is None:
396
+ if config.debug:
397
+ utils.log("Couldn't map end position")
398
+ return orig_text
399
+
400
+ output_text = orig_text[orig_start_position:(orig_end_position + 1)]
401
+ return output_text
arabert/araelectra/finetune/qa/qa_tasks.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Question answering tasks. SQuAD 1.1/2.0 and 2019 MRQA tasks are supported."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ import collections
24
+ import json
25
+ import os
26
+ import six
27
+ import tensorflow as tf
28
+
29
+ import configure_finetuning
30
+ from finetune import feature_spec
31
+ from finetune import task
32
+ from finetune.qa import qa_metrics
33
+ from model import modeling
34
+ from model import tokenization
35
+ from util import utils
36
+
37
+
38
+ class QAExample(task.Example):
39
+ """Question-answering example."""
40
+
41
+ def __init__(self,
42
+ task_name,
43
+ eid,
44
+ qas_id,
45
+ qid,
46
+ question_text,
47
+ doc_tokens,
48
+ orig_answer_text=None,
49
+ start_position=None,
50
+ end_position=None,
51
+ is_impossible=False):
52
+ super(QAExample, self).__init__(task_name)
53
+ self.eid = eid
54
+ self.qas_id = qas_id
55
+ self.qid = qid
56
+ self.question_text = question_text
57
+ self.doc_tokens = doc_tokens
58
+ self.orig_answer_text = orig_answer_text
59
+ self.start_position = start_position
60
+ self.end_position = end_position
61
+ self.is_impossible = is_impossible
62
+
63
+ def __str__(self):
64
+ return self.__repr__()
65
+
66
+ def __repr__(self):
67
+ s = ""
68
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
69
+ s += ", question_text: %s" % (
70
+ tokenization.printable_text(self.question_text))
71
+ s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
72
+ if self.start_position:
73
+ s += ", start_position: %d" % self.start_position
74
+ if self.start_position:
75
+ s += ", end_position: %d" % self.end_position
76
+ if self.start_position:
77
+ s += ", is_impossible: %r" % self.is_impossible
78
+ return s
79
+
80
+
81
+ def _check_is_max_context(doc_spans, cur_span_index, position):
82
+ """Check if this is the 'max context' doc span for the token."""
83
+
84
+ # Because of the sliding window approach taken to scoring documents, a single
85
+ # token can appear in multiple documents. E.g.
86
+ # Doc: the man went to the store and bought a gallon of milk
87
+ # Span A: the man went to the
88
+ # Span B: to the store and bought
89
+ # Span C: and bought a gallon of
90
+ # ...
91
+ #
92
+ # Now the word 'bought' will have two scores from spans B and C. We only
93
+ # want to consider the score with "maximum context", which we define as
94
+ # the *minimum* of its left and right context (the *sum* of left and
95
+ # right context will always be the same, of course).
96
+ #
97
+ # In the example the maximum context for 'bought' would be span C since
98
+ # it has 1 left context and 3 right context, while span B has 4 left context
99
+ # and 0 right context.
100
+ best_score = None
101
+ best_span_index = None
102
+ for (span_index, doc_span) in enumerate(doc_spans):
103
+ end = doc_span.start + doc_span.length - 1
104
+ if position < doc_span.start:
105
+ continue
106
+ if position > end:
107
+ continue
108
+ num_left_context = position - doc_span.start
109
+ num_right_context = end - position
110
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
111
+ if best_score is None or score > best_score:
112
+ best_score = score
113
+ best_span_index = span_index
114
+
115
+ return cur_span_index == best_span_index
116
+
117
+
118
+ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
119
+ orig_answer_text):
120
+ """Returns tokenized answer spans that better match the annotated answer."""
121
+
122
+ # The SQuAD annotations are character based. We first project them to
123
+ # whitespace-tokenized words. But then after WordPiece tokenization, we can
124
+ # often find a "better match". For example:
125
+ #
126
+ # Question: What year was John Smith born?
127
+ # Context: The leader was John Smith (1895-1943).
128
+ # Answer: 1895
129
+ #
130
+ # The original whitespace-tokenized answer will be "(1895-1943).". However
131
+ # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
132
+ # the exact answer, 1895.
133
+ #
134
+ # However, this is not always possible. Consider the following:
135
+ #
136
+ # Question: What country is the top exporter of electornics?
137
+ # Context: The Japanese electronics industry is the lagest in the world.
138
+ # Answer: Japan
139
+ #
140
+ # In this case, the annotator chose "Japan" as a character sub-span of
141
+ # the word "Japanese". Since our WordPiece tokenizer does not split
142
+ # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
143
+ # in SQuAD, but does happen.
144
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
145
+
146
+ for new_start in range(input_start, input_end + 1):
147
+ for new_end in range(input_end, new_start - 1, -1):
148
+ text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
149
+ if text_span == tok_answer_text:
150
+ return new_start, new_end
151
+
152
+ return input_start, input_end
153
+
154
+
155
+ def is_whitespace(c):
156
+ return c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F
157
+
158
+
159
+ class QATask(task.Task):
160
+ """A span-based question answering tasks (e.g., SQuAD)."""
161
+
162
+ __metaclass__ = abc.ABCMeta
163
+
164
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
165
+ tokenizer, v2=False):
166
+ super(QATask, self).__init__(config, name)
167
+ self._tokenizer = tokenizer
168
+ self._examples = {}
169
+ self.v2 = v2
170
+
171
+ def _add_examples(self, examples, example_failures, paragraph, split):
172
+ paragraph_text = paragraph["context"]
173
+ doc_tokens = []
174
+ char_to_word_offset = []
175
+ prev_is_whitespace = True
176
+ for c in paragraph_text:
177
+ if is_whitespace(c):
178
+ prev_is_whitespace = True
179
+ else:
180
+ if prev_is_whitespace:
181
+ doc_tokens.append(c)
182
+ else:
183
+ doc_tokens[-1] += c
184
+ prev_is_whitespace = False
185
+ char_to_word_offset.append(len(doc_tokens) - 1)
186
+
187
+ for qa in paragraph["qas"]:
188
+ qas_id = qa["id"] if "id" in qa else None
189
+ qid = qa["qid"] if "qid" in qa else None
190
+ question_text = qa["question"]
191
+ start_position = None
192
+ end_position = None
193
+ orig_answer_text = None
194
+ is_impossible = False
195
+ if split == "train":
196
+ if self.v2:
197
+ is_impossible = qa["is_impossible"]
198
+ if not is_impossible:
199
+ if "detected_answers" in qa: # MRQA format
200
+ answer = qa["detected_answers"][0]
201
+ answer_offset = answer["char_spans"][0][0]
202
+ else: # SQuAD format
203
+ answer = qa["answers"][0]
204
+ answer_offset = answer["answer_start"]
205
+ orig_answer_text = answer["text"]
206
+ answer_length = len(orig_answer_text)
207
+ start_position = char_to_word_offset[answer_offset]
208
+ if answer_offset + answer_length - 1 >= len(char_to_word_offset):
209
+ utils.log("End position is out of document!")
210
+ example_failures[0] += 1
211
+ continue
212
+ end_position = char_to_word_offset[answer_offset + answer_length - 1]
213
+
214
+ # Only add answers where the text can be exactly recovered from the
215
+ # document. If this CAN'T happen it's likely due to weird Unicode
216
+ # stuff so we will just skip the example.
217
+ #
218
+ # Note that this means for training mode, every example is NOT
219
+ # guaranteed to be preserved.
220
+ actual_text = " ".join(
221
+ doc_tokens[start_position:(end_position + 1)])
222
+ cleaned_answer_text = " ".join(
223
+ tokenization.whitespace_tokenize(orig_answer_text))
224
+ actual_text = actual_text.lower()
225
+ cleaned_answer_text = cleaned_answer_text.lower()
226
+ if actual_text.find(cleaned_answer_text) == -1:
227
+ utils.log("Could not find answer: '{:}' in doc vs. "
228
+ "'{:}' in provided answer".format(
229
+ tokenization.printable_text(actual_text),
230
+ tokenization.printable_text(cleaned_answer_text)))
231
+ example_failures[0] += 1
232
+ continue
233
+ else:
234
+ start_position = -1
235
+ end_position = -1
236
+ orig_answer_text = ""
237
+
238
+ example = QAExample(
239
+ task_name=self.name,
240
+ eid=len(examples),
241
+ qas_id=qas_id,
242
+ qid=qid,
243
+ question_text=question_text,
244
+ doc_tokens=doc_tokens,
245
+ orig_answer_text=orig_answer_text,
246
+ start_position=start_position,
247
+ end_position=end_position,
248
+ is_impossible=is_impossible)
249
+ examples.append(example)
250
+
251
+ def get_feature_specs(self):
252
+ return [
253
+ feature_spec.FeatureSpec(self.name + "_eid", []),
254
+ feature_spec.FeatureSpec(self.name + "_start_positions", []),
255
+ feature_spec.FeatureSpec(self.name + "_end_positions", []),
256
+ feature_spec.FeatureSpec(self.name + "_is_impossible", []),
257
+ ]
258
+
259
+ def featurize(self, example: QAExample, is_training, log=False,
260
+ for_eval=False):
261
+ all_features = []
262
+ query_tokens = self._tokenizer.tokenize(example.question_text)
263
+
264
+ if len(query_tokens) > self.config.max_query_length:
265
+ query_tokens = query_tokens[0:self.config.max_query_length]
266
+
267
+ tok_to_orig_index = []
268
+ orig_to_tok_index = []
269
+ all_doc_tokens = []
270
+ for (i, token) in enumerate(example.doc_tokens):
271
+ orig_to_tok_index.append(len(all_doc_tokens))
272
+ sub_tokens = self._tokenizer.tokenize(token)
273
+ for sub_token in sub_tokens:
274
+ tok_to_orig_index.append(i)
275
+ all_doc_tokens.append(sub_token)
276
+
277
+ tok_start_position = None
278
+ tok_end_position = None
279
+ if is_training and example.is_impossible:
280
+ tok_start_position = -1
281
+ tok_end_position = -1
282
+ if is_training and not example.is_impossible:
283
+ tok_start_position = orig_to_tok_index[example.start_position]
284
+ if example.end_position < len(example.doc_tokens) - 1:
285
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
286
+ else:
287
+ tok_end_position = len(all_doc_tokens) - 1
288
+ (tok_start_position, tok_end_position) = _improve_answer_span(
289
+ all_doc_tokens, tok_start_position, tok_end_position, self._tokenizer,
290
+ example.orig_answer_text)
291
+
292
+ # The -3 accounts for [CLS], [SEP] and [SEP]
293
+ max_tokens_for_doc = self.config.max_seq_length - len(query_tokens) - 3
294
+
295
+ # We can have documents that are longer than the maximum sequence length.
296
+ # To deal with this we do a sliding window approach, where we take chunks
297
+ # of the up to our max length with a stride of `doc_stride`.
298
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
299
+ "DocSpan", ["start", "length"])
300
+ doc_spans = []
301
+ start_offset = 0
302
+ while start_offset < len(all_doc_tokens):
303
+ length = len(all_doc_tokens) - start_offset
304
+ if length > max_tokens_for_doc:
305
+ length = max_tokens_for_doc
306
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
307
+ if start_offset + length == len(all_doc_tokens):
308
+ break
309
+ start_offset += min(length, self.config.doc_stride)
310
+
311
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
312
+ tokens = []
313
+ token_to_orig_map = {}
314
+ token_is_max_context = {}
315
+ segment_ids = []
316
+ tokens.append("[CLS]")
317
+ segment_ids.append(0)
318
+ for token in query_tokens:
319
+ tokens.append(token)
320
+ segment_ids.append(0)
321
+ tokens.append("[SEP]")
322
+ segment_ids.append(0)
323
+
324
+ for i in range(doc_span.length):
325
+ split_token_index = doc_span.start + i
326
+ token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
327
+
328
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
329
+ split_token_index)
330
+ token_is_max_context[len(tokens)] = is_max_context
331
+ tokens.append(all_doc_tokens[split_token_index])
332
+ segment_ids.append(1)
333
+ tokens.append("[SEP]")
334
+ segment_ids.append(1)
335
+
336
+ input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
337
+
338
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
339
+ # tokens are attended to.
340
+ input_mask = [1] * len(input_ids)
341
+
342
+ # Zero-pad up to the sequence length.
343
+ while len(input_ids) < self.config.max_seq_length:
344
+ input_ids.append(0)
345
+ input_mask.append(0)
346
+ segment_ids.append(0)
347
+
348
+ assert len(input_ids) == self.config.max_seq_length
349
+ assert len(input_mask) == self.config.max_seq_length
350
+ assert len(segment_ids) == self.config.max_seq_length
351
+
352
+ start_position = None
353
+ end_position = None
354
+ if is_training and not example.is_impossible:
355
+ # For training, if our document chunk does not contain an annotation
356
+ # we throw it out, since there is nothing to predict.
357
+ doc_start = doc_span.start
358
+ doc_end = doc_span.start + doc_span.length - 1
359
+ out_of_span = False
360
+ if not (tok_start_position >= doc_start and
361
+ tok_end_position <= doc_end):
362
+ out_of_span = True
363
+ if out_of_span:
364
+ start_position = 0
365
+ end_position = 0
366
+ else:
367
+ doc_offset = len(query_tokens) + 2
368
+ start_position = tok_start_position - doc_start + doc_offset
369
+ end_position = tok_end_position - doc_start + doc_offset
370
+
371
+ if is_training and example.is_impossible:
372
+ start_position = 0
373
+ end_position = 0
374
+
375
+ if log:
376
+ utils.log("*** Example ***")
377
+ utils.log("doc_span_index: %s" % doc_span_index)
378
+ utils.log("tokens: %s" % " ".join(
379
+ [tokenization.printable_text(x) for x in tokens]))
380
+ utils.log("token_to_orig_map: %s" % " ".join(
381
+ ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
382
+ utils.log("token_is_max_context: %s" % " ".join([
383
+ "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
384
+ ]))
385
+ utils.log("input_ids: %s" % " ".join([str(x) for x in input_ids]))
386
+ utils.log("input_mask: %s" % " ".join([str(x) for x in input_mask]))
387
+ utils.log("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
388
+ if is_training and example.is_impossible:
389
+ utils.log("impossible example")
390
+ if is_training and not example.is_impossible:
391
+ answer_text = " ".join(tokens[start_position:(end_position + 1)])
392
+ utils.log("start_position: %d" % start_position)
393
+ utils.log("end_position: %d" % end_position)
394
+ utils.log("answer: %s" % (tokenization.printable_text(answer_text)))
395
+
396
+ features = {
397
+ "task_id": self.config.task_names.index(self.name),
398
+ self.name + "_eid": (1000 * example.eid) + doc_span_index,
399
+ "input_ids": input_ids,
400
+ "input_mask": input_mask,
401
+ "segment_ids": segment_ids,
402
+ }
403
+ if for_eval:
404
+ features.update({
405
+ self.name + "_doc_span_index": doc_span_index,
406
+ self.name + "_tokens": tokens,
407
+ self.name + "_token_to_orig_map": token_to_orig_map,
408
+ self.name + "_token_is_max_context": token_is_max_context,
409
+ })
410
+ if is_training:
411
+ features.update({
412
+ self.name + "_start_positions": start_position,
413
+ self.name + "_end_positions": end_position,
414
+ self.name + "_is_impossible": example.is_impossible
415
+ })
416
+ all_features.append(features)
417
+ return all_features
418
+
419
+ def get_prediction_module(self, bert_model, features, is_training,
420
+ percent_done):
421
+ final_hidden = bert_model.get_sequence_output()
422
+
423
+ final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
424
+ batch_size = final_hidden_shape[0]
425
+ seq_length = final_hidden_shape[1]
426
+
427
+ answer_mask = tf.cast(features["input_mask"], tf.float32)
428
+ answer_mask *= tf.cast(features["segment_ids"], tf.float32)
429
+ answer_mask += tf.one_hot(0, seq_length)
430
+
431
+ start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
432
+
433
+ start_top_log_probs = tf.zeros([batch_size, self.config.beam_size])
434
+ start_top_index = tf.zeros([batch_size, self.config.beam_size], tf.int32)
435
+ end_top_log_probs = tf.zeros([batch_size, self.config.beam_size,
436
+ self.config.beam_size])
437
+ end_top_index = tf.zeros([batch_size, self.config.beam_size,
438
+ self.config.beam_size], tf.int32)
439
+ if self.config.joint_prediction:
440
+ start_logits += 1000.0 * (answer_mask - 1)
441
+ start_log_probs = tf.nn.log_softmax(start_logits)
442
+ start_top_log_probs, start_top_index = tf.nn.top_k(
443
+ start_log_probs, k=self.config.beam_size)
444
+
445
+ if not is_training:
446
+ # batch, beam, length, hidden
447
+ end_features = tf.tile(tf.expand_dims(final_hidden, 1),
448
+ [1, self.config.beam_size, 1, 1])
449
+ # batch, beam, length
450
+ start_index = tf.one_hot(start_top_index,
451
+ depth=seq_length, axis=-1, dtype=tf.float32)
452
+ # batch, beam, hidden
453
+ start_features = tf.reduce_sum(
454
+ tf.expand_dims(final_hidden, 1) *
455
+ tf.expand_dims(start_index, -1), axis=-2)
456
+ # batch, beam, length, hidden
457
+ start_features = tf.tile(tf.expand_dims(start_features, 2),
458
+ [1, 1, seq_length, 1])
459
+ else:
460
+ start_index = tf.one_hot(
461
+ features[self.name + "_start_positions"], depth=seq_length,
462
+ axis=-1, dtype=tf.float32)
463
+ start_features = tf.reduce_sum(tf.expand_dims(start_index, -1) *
464
+ final_hidden, axis=1)
465
+ start_features = tf.tile(tf.expand_dims(start_features, 1),
466
+ [1, seq_length, 1])
467
+ end_features = final_hidden
468
+
469
+ final_repr = tf.concat([start_features, end_features], -1)
470
+ final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu,
471
+ name="qa_hidden")
472
+ # batch, beam, length (batch, length when training)
473
+ end_logits = tf.squeeze(tf.layers.dense(final_repr, 1), -1,
474
+ name="qa_logits")
475
+ if is_training:
476
+ end_logits += 1000.0 * (answer_mask - 1)
477
+ else:
478
+ end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1)
479
+
480
+ if not is_training:
481
+ end_log_probs = tf.nn.log_softmax(end_logits)
482
+ end_top_log_probs, end_top_index = tf.nn.top_k(
483
+ end_log_probs, k=self.config.beam_size)
484
+ end_logits = tf.zeros([batch_size, seq_length])
485
+ else:
486
+ end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
487
+ start_logits += 1000.0 * (answer_mask - 1)
488
+ end_logits += 1000.0 * (answer_mask - 1)
489
+
490
+ def compute_loss(logits, positions):
491
+ one_hot_positions = tf.one_hot(
492
+ positions, depth=seq_length, dtype=tf.float32)
493
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
494
+ loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
495
+ return loss
496
+
497
+ start_positions = features[self.name + "_start_positions"]
498
+ end_positions = features[self.name + "_end_positions"]
499
+
500
+ start_loss = compute_loss(start_logits, start_positions)
501
+ end_loss = compute_loss(end_logits, end_positions)
502
+
503
+ losses = (start_loss + end_loss) / 2.0
504
+
505
+ answerable_logit = tf.zeros([batch_size])
506
+ if self.config.answerable_classifier:
507
+ final_repr = final_hidden[:, 0]
508
+ if self.config.answerable_uses_start_logits:
509
+ start_p = tf.nn.softmax(start_logits)
510
+ start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) *
511
+ final_hidden, axis=1)
512
+ final_repr = tf.concat([final_repr, start_feature], -1)
513
+ final_repr = tf.layers.dense(final_repr, 512,
514
+ activation=modeling.gelu)
515
+ answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1)
516
+ answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits(
517
+ labels=tf.cast(features[self.name + "_is_impossible"], tf.float32),
518
+ logits=answerable_logit)
519
+ losses += answerable_loss * self.config.answerable_weight
520
+
521
+ return losses, dict(
522
+ loss=losses,
523
+ start_logits=start_logits,
524
+ end_logits=end_logits,
525
+ answerable_logit=answerable_logit,
526
+ start_positions=features[self.name + "_start_positions"],
527
+ end_positions=features[self.name + "_end_positions"],
528
+ start_top_log_probs=start_top_log_probs,
529
+ start_top_index=start_top_index,
530
+ end_top_log_probs=end_top_log_probs,
531
+ end_top_index=end_top_index,
532
+ eid=features[self.name + "_eid"],
533
+ )
534
+
535
+ def get_scorer(self, split="dev"):
536
+ return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
537
+
538
+
539
+ class MRQATask(QATask):
540
+ """Class for finetuning tasks from the 2019 MRQA shared task."""
541
+
542
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
543
+ tokenizer):
544
+ super(MRQATask, self).__init__(config, name, tokenizer)
545
+
546
+ def get_examples(self, split):
547
+ if split in self._examples:
548
+ utils.log("N EXAMPLES", split, len(self._examples[split]))
549
+ return self._examples[split]
550
+
551
+ examples = []
552
+ example_failures = [0]
553
+ with tf.io.gfile.GFile(os.path.join(
554
+ self.config.raw_data_dir(self.name), split + ".jsonl"), "r") as f:
555
+ for i, line in enumerate(f):
556
+ if self.config.debug and i > 10:
557
+ break
558
+ paragraph = json.loads(line.strip())
559
+ if "header" in paragraph:
560
+ continue
561
+ self._add_examples(examples, example_failures, paragraph, split)
562
+ self._examples[split] = examples
563
+ utils.log("{:} examples created, {:} failures".format(
564
+ len(examples), example_failures[0]))
565
+ return examples
566
+
567
+ def get_scorer(self, split="dev"):
568
+ return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
569
+
570
+
571
+ class SQuADTask(QATask):
572
+ """Class for finetuning on SQuAD 2.0 or 1.1."""
573
+
574
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
575
+ tokenizer, v2=False):
576
+ super(SQuADTask, self).__init__(config, name, tokenizer, v2=v2)
577
+
578
+ def get_examples(self, split):
579
+ if split in self._examples:
580
+ return self._examples[split]
581
+
582
+ with tf.io.gfile.GFile(os.path.join(
583
+ self.config.raw_data_dir(self.name),
584
+ split + ("-debug" if self.config.debug else "") + ".json"), "r") as f:
585
+ input_data = json.load(f)["data"]
586
+
587
+ examples = []
588
+ example_failures = [0]
589
+ for entry in input_data:
590
+ for paragraph in entry["paragraphs"]:
591
+ self._add_examples(examples, example_failures, paragraph, split)
592
+ self._examples[split] = examples
593
+ utils.log("{:} examples created, {:} failures".format(
594
+ len(examples), example_failures[0]))
595
+ return examples
596
+
597
+ def get_scorer(self, split="dev"):
598
+ return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
599
+
600
+
601
+ class SQuAD(SQuADTask):
602
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
603
+ super(SQuAD, self).__init__(config, "squad", tokenizer, v2=True)
604
+
605
+
606
+ class SQuADv1(SQuADTask):
607
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
608
+ super(SQuADv1, self).__init__(config, "squadv1", tokenizer)
609
+
610
+
611
+ class NewsQA(MRQATask):
612
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
613
+ super(NewsQA, self).__init__(config, "newsqa", tokenizer)
614
+
615
+
616
+ class NaturalQuestions(MRQATask):
617
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
618
+ super(NaturalQuestions, self).__init__(config, "naturalqs", tokenizer)
619
+
620
+
621
+ class SearchQA(MRQATask):
622
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
623
+ super(SearchQA, self).__init__(config, "searchqa", tokenizer)
624
+
625
+
626
+ class TriviaQA(MRQATask):
627
+ def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
628
+ super(TriviaQA, self).__init__(config, "triviaqa", tokenizer)
arabert/araelectra/finetune/qa/squad_official_eval.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Official evaluation script for SQuAD version 2.0.
17
+
18
+ In addition to basic functionality, we also compute additional statistics and
19
+ plot precision-recall curves if an additional na_prob.json file is provided.
20
+ This file is expected to map question ID's to the model's predicted probability
21
+ that a question is unanswerable.
22
+
23
+ Modified slightly for the ELECTRA codebase.
24
+ """
25
+ from __future__ import absolute_import
26
+ from __future__ import division
27
+ from __future__ import print_function
28
+
29
+ import argparse
30
+ import collections
31
+ import json
32
+ import numpy as np
33
+ import os
34
+ import re
35
+ import string
36
+ import sys
37
+ import tensorflow as tf
38
+
39
+ import configure_finetuning
40
+
41
+ OPTS = None
42
+
43
+ def parse_args():
44
+ parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
45
+ parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
46
+ parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
47
+ parser.add_argument('--out-file', '-o', metavar='eval.json',
48
+ help='Write accuracy metrics to file (default is stdout).')
49
+ parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
50
+ help='Model estimates of probability of no answer.')
51
+ parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
52
+ help='Predict "" if no-answer probability exceeds this (default = 1.0).')
53
+ parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
54
+ help='Save precision-recall curves to directory.')
55
+ parser.add_argument('--verbose', '-v', action='store_true')
56
+ if len(sys.argv) == 1:
57
+ parser.print_help()
58
+ sys.exit(1)
59
+ return parser.parse_args()
60
+
61
+ def set_opts(config: configure_finetuning.FinetuningConfig, split):
62
+ global OPTS
63
+ Options = collections.namedtuple("Options", [
64
+ "data_file", "pred_file", "out_file", "na_prob_file", "na_prob_thresh",
65
+ "out_image_dir", "verbose"])
66
+ OPTS = Options(
67
+ data_file=os.path.join(
68
+ config.raw_data_dir("squad"),
69
+ split + ("-debug" if config.debug else "") + ".json"),
70
+ pred_file=config.qa_preds_file("squad"),
71
+ out_file=config.qa_eval_file("squad"),
72
+ na_prob_file=config.qa_na_file("squad"),
73
+ na_prob_thresh=config.qa_na_threshold,
74
+ out_image_dir=None,
75
+ verbose=False
76
+ )
77
+
78
+ def make_qid_to_has_ans(dataset):
79
+ qid_to_has_ans = {}
80
+ for article in dataset:
81
+ for p in article['paragraphs']:
82
+ for qa in p['qas']:
83
+ qid_to_has_ans[qa['id']] = bool(qa['answers'])
84
+ return qid_to_has_ans
85
+
86
+ def normalize_answer(s):
87
+ """Lower text and remove punctuation, articles and extra whitespace."""
88
+ def remove_articles(text):
89
+ regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
90
+ return re.sub(regex, ' ', text)
91
+ def white_space_fix(text):
92
+ return ' '.join(text.split())
93
+ def remove_punc(text):
94
+ exclude = set(string.punctuation)
95
+ return ''.join(ch for ch in text if ch not in exclude)
96
+ def lower(text):
97
+ return text.lower()
98
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
99
+
100
+ def get_tokens(s):
101
+ if not s: return []
102
+ return normalize_answer(s).split()
103
+
104
+ def compute_exact(a_gold, a_pred):
105
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
106
+
107
+ def compute_f1(a_gold, a_pred):
108
+ gold_toks = get_tokens(a_gold)
109
+ pred_toks = get_tokens(a_pred)
110
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
111
+ num_same = sum(common.values())
112
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
113
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
114
+ return int(gold_toks == pred_toks)
115
+ if num_same == 0:
116
+ return 0
117
+ precision = 1.0 * num_same / len(pred_toks)
118
+ recall = 1.0 * num_same / len(gold_toks)
119
+ f1 = (2 * precision * recall) / (precision + recall)
120
+ return f1
121
+
122
+ def get_raw_scores(dataset, preds):
123
+ exact_scores = {}
124
+ f1_scores = {}
125
+ for article in dataset:
126
+ for p in article['paragraphs']:
127
+ for qa in p['qas']:
128
+ qid = qa['id']
129
+ gold_answers = [a['text'] for a in qa['answers']
130
+ if normalize_answer(a['text'])]
131
+ if not gold_answers:
132
+ # For unanswerable questions, only correct answer is empty string
133
+ gold_answers = ['']
134
+ if qid not in preds:
135
+ print('Missing prediction for %s' % qid)
136
+ continue
137
+ a_pred = preds[qid]
138
+ # Take max over all gold answers
139
+ exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
140
+ f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
141
+ return exact_scores, f1_scores
142
+
143
+ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
144
+ new_scores = {}
145
+ for qid, s in scores.items():
146
+ pred_na = na_probs[qid] > na_prob_thresh
147
+ if pred_na:
148
+ new_scores[qid] = float(not qid_to_has_ans[qid])
149
+ else:
150
+ new_scores[qid] = s
151
+ return new_scores
152
+
153
+ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
154
+ if not qid_list:
155
+ total = len(exact_scores)
156
+ return collections.OrderedDict([
157
+ ('exact', 100.0 * sum(exact_scores.values()) / total),
158
+ ('f1', 100.0 * sum(f1_scores.values()) / total),
159
+ ('total', total),
160
+ ])
161
+ else:
162
+ total = len(qid_list)
163
+ return collections.OrderedDict([
164
+ ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
165
+ ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
166
+ ('total', total),
167
+ ])
168
+
169
+ def merge_eval(main_eval, new_eval, prefix):
170
+ for k in new_eval:
171
+ main_eval['%s_%s' % (prefix, k)] = new_eval[k]
172
+
173
+ def plot_pr_curve(precisions, recalls, out_image, title):
174
+ plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
175
+ plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
176
+ plt.xlabel('Recall')
177
+ plt.ylabel('Precision')
178
+ plt.xlim([0.0, 1.05])
179
+ plt.ylim([0.0, 1.05])
180
+ plt.title(title)
181
+ plt.savefig(out_image)
182
+ plt.clf()
183
+
184
+ def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
185
+ out_image=None, title=None):
186
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
187
+ true_pos = 0.0
188
+ cur_p = 1.0
189
+ cur_r = 0.0
190
+ precisions = [1.0]
191
+ recalls = [0.0]
192
+ avg_prec = 0.0
193
+ for i, qid in enumerate(qid_list):
194
+ if qid_to_has_ans[qid]:
195
+ true_pos += scores[qid]
196
+ cur_p = true_pos / float(i+1)
197
+ cur_r = true_pos / float(num_true_pos)
198
+ if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
199
+ # i.e., if we can put a threshold after this point
200
+ avg_prec += cur_p * (cur_r - recalls[-1])
201
+ precisions.append(cur_p)
202
+ recalls.append(cur_r)
203
+ if out_image:
204
+ plot_pr_curve(precisions, recalls, out_image, title)
205
+ return {'ap': 100.0 * avg_prec}
206
+
207
+ def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
208
+ qid_to_has_ans, out_image_dir):
209
+ if out_image_dir and not os.path.exists(out_image_dir):
210
+ os.makedirs(out_image_dir)
211
+ num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
212
+ if num_true_pos == 0:
213
+ return
214
+ pr_exact = make_precision_recall_eval(
215
+ exact_raw, na_probs, num_true_pos, qid_to_has_ans,
216
+ out_image=os.path.join(out_image_dir, 'pr_exact.png'),
217
+ title='Precision-Recall curve for Exact Match score')
218
+ pr_f1 = make_precision_recall_eval(
219
+ f1_raw, na_probs, num_true_pos, qid_to_has_ans,
220
+ out_image=os.path.join(out_image_dir, 'pr_f1.png'),
221
+ title='Precision-Recall curve for F1 score')
222
+ oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
223
+ pr_oracle = make_precision_recall_eval(
224
+ oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
225
+ out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
226
+ title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
227
+ merge_eval(main_eval, pr_exact, 'pr_exact')
228
+ merge_eval(main_eval, pr_f1, 'pr_f1')
229
+ merge_eval(main_eval, pr_oracle, 'pr_oracle')
230
+
231
+ def histogram_na_prob(na_probs, qid_list, image_dir, name):
232
+ if not qid_list:
233
+ return
234
+ x = [na_probs[k] for k in qid_list]
235
+ weights = np.ones_like(x) / float(len(x))
236
+ plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
237
+ plt.xlabel('Model probability of no-answer')
238
+ plt.ylabel('Proportion of dataset')
239
+ plt.title('Histogram of no-answer probability: %s' % name)
240
+ plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
241
+ plt.clf()
242
+
243
+ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
244
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
245
+ cur_score = num_no_ans
246
+ best_score = cur_score
247
+ best_thresh = 0.0
248
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
249
+ for i, qid in enumerate(qid_list):
250
+ if qid not in scores: continue
251
+ if qid_to_has_ans[qid]:
252
+ diff = scores[qid]
253
+ else:
254
+ if preds[qid]:
255
+ diff = -1
256
+ else:
257
+ diff = 0
258
+ cur_score += diff
259
+ if cur_score > best_score:
260
+ best_score = cur_score
261
+ best_thresh = na_probs[qid]
262
+ return 100.0 * best_score / len(scores), best_thresh
263
+
264
+ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
265
+ best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
266
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
267
+ main_eval['best_exact'] = best_exact
268
+ main_eval['best_exact_thresh'] = exact_thresh
269
+ main_eval['best_f1'] = best_f1
270
+ main_eval['best_f1_thresh'] = f1_thresh
271
+
272
+ def main():
273
+ with tf.io.gfile.GFile(OPTS.data_file) as f:
274
+ dataset_json = json.load(f)
275
+ dataset = dataset_json['data']
276
+ with tf.io.gfile.GFile(OPTS.pred_file) as f:
277
+ preds = json.load(f)
278
+ if OPTS.na_prob_file:
279
+ with tf.io.gfile.GFile(OPTS.na_prob_file) as f:
280
+ na_probs = json.load(f)
281
+ else:
282
+ na_probs = {k: 0.0 for k in preds}
283
+ qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
284
+ has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
285
+ no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
286
+ exact_raw, f1_raw = get_raw_scores(dataset, preds)
287
+ exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
288
+ OPTS.na_prob_thresh)
289
+ f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
290
+ OPTS.na_prob_thresh)
291
+ out_eval = make_eval_dict(exact_thresh, f1_thresh)
292
+ if has_ans_qids:
293
+ has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
294
+ merge_eval(out_eval, has_ans_eval, 'HasAns')
295
+ if no_ans_qids:
296
+ no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
297
+ merge_eval(out_eval, no_ans_eval, 'NoAns')
298
+ if OPTS.na_prob_file:
299
+ find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
300
+ if OPTS.na_prob_file and OPTS.out_image_dir:
301
+ run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
302
+ qid_to_has_ans, OPTS.out_image_dir)
303
+ histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
304
+ histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
305
+ if OPTS.out_file:
306
+ with tf.io.gfile.GFile(OPTS.out_file, 'w') as f:
307
+ json.dump(out_eval, f)
308
+ else:
309
+ print(json.dumps(out_eval, indent=2))
310
+
311
+ if __name__ == '__main__':
312
+ OPTS = parse_args()
313
+ if OPTS.out_image_dir:
314
+ import matplotlib
315
+ matplotlib.use('Agg')
316
+ import matplotlib.pyplot as plt
317
+ main()
arabert/araelectra/finetune/qa/squad_official_eval_v1.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Official evaluation script for v1.1 of the SQuAD dataset.
18
+ Modified slightly for the ELECTRA codebase.
19
+ """
20
+ from __future__ import absolute_import
21
+ from __future__ import division
22
+ from __future__ import print_function
23
+ from collections import Counter
24
+ import string
25
+ import re
26
+ import json
27
+ import sys
28
+ import os
29
+ import collections
30
+ import tensorflow as tf
31
+
32
+ import configure_finetuning
33
+
34
+
35
+ def normalize_answer(s):
36
+ """Lower text and remove punctuation, articles and extra whitespace."""
37
+ def remove_articles(text):
38
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
39
+
40
+ def white_space_fix(text):
41
+ return ' '.join(text.split())
42
+
43
+ def remove_punc(text):
44
+ exclude = set(string.punctuation)
45
+ return ''.join(ch for ch in text if ch not in exclude)
46
+
47
+ def lower(text):
48
+ return text.lower()
49
+
50
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
51
+
52
+
53
+ def f1_score(prediction, ground_truth):
54
+ prediction_tokens = normalize_answer(prediction).split()
55
+ ground_truth_tokens = normalize_answer(ground_truth).split()
56
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
57
+ num_same = sum(common.values())
58
+ if num_same == 0:
59
+ return 0
60
+ precision = 1.0 * num_same / len(prediction_tokens)
61
+ recall = 1.0 * num_same / len(ground_truth_tokens)
62
+ f1 = (2 * precision * recall) / (precision + recall)
63
+ return f1
64
+
65
+
66
+ def exact_match_score(prediction, ground_truth):
67
+ return (normalize_answer(prediction) == normalize_answer(ground_truth))
68
+
69
+
70
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
71
+ scores_for_ground_truths = []
72
+ for ground_truth in ground_truths:
73
+ score = metric_fn(prediction, ground_truth)
74
+ scores_for_ground_truths.append(score)
75
+ return max(scores_for_ground_truths)
76
+
77
+
78
+ def evaluate(dataset, predictions):
79
+ f1 = exact_match = total = 0
80
+ for article in dataset:
81
+ for paragraph in article['paragraphs']:
82
+ for qa in paragraph['qas']:
83
+ total += 1
84
+ if qa['id'] not in predictions:
85
+ message = 'Unanswered question ' + qa['id'] + \
86
+ ' will receive score 0.'
87
+ print(message, file=sys.stderr)
88
+ continue
89
+ ground_truths = list(map(lambda x: x['text'], qa['answers']))
90
+ prediction = predictions[qa['id']]
91
+ exact_match += metric_max_over_ground_truths(
92
+ exact_match_score, prediction, ground_truths)
93
+ f1 += metric_max_over_ground_truths(
94
+ f1_score, prediction, ground_truths)
95
+
96
+ exact_match = 100.0 * exact_match / total
97
+ f1 = 100.0 * f1 / total
98
+
99
+ return {'exact_match': exact_match, 'f1': f1}
100
+
101
+
102
+ def main(config: configure_finetuning.FinetuningConfig, split):
103
+ expected_version = '1.1'
104
+ # parser = argparse.ArgumentParser(
105
+ # description='Evaluation for SQuAD ' + expected_version)
106
+ # parser.add_argument('dataset_file', help='Dataset file')
107
+ # parser.add_argument('prediction_file', help='Prediction File')
108
+ # args = parser.parse_args()
109
+ Args = collections.namedtuple("Args", [
110
+ "dataset_file", "prediction_file"
111
+ ])
112
+ args = Args(dataset_file=os.path.join(
113
+ config.raw_data_dir("squadv1"),
114
+ split + ("-debug" if config.debug else "") + ".json"),
115
+ prediction_file=config.qa_preds_file("squadv1"))
116
+ with tf.io.gfile.GFile(args.dataset_file) as dataset_file:
117
+ dataset_json = json.load(dataset_file)
118
+ if dataset_json['version'] != expected_version:
119
+ print('Evaluation expects v-' + expected_version +
120
+ ', but got dataset with v-' + dataset_json['version'],
121
+ file=sys.stderr)
122
+ dataset = dataset_json['data']
123
+ with tf.io.gfile.GFile(args.prediction_file) as prediction_file:
124
+ predictions = json.load(prediction_file)
125
+ return evaluate(dataset, predictions)
126
+
arabert/araelectra/finetune/scorer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Base class for evaluation metrics."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+
24
+
25
+ class Scorer(object):
26
+ """Abstract base class for computing evaluation metrics."""
27
+
28
+ __metaclass__ = abc.ABCMeta
29
+
30
+ def __init__(self):
31
+ self._updated = False
32
+ self._cached_results = {}
33
+
34
+ @abc.abstractmethod
35
+ def update(self, results):
36
+ self._updated = True
37
+
38
+ @abc.abstractmethod
39
+ def get_loss(self):
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def _get_results(self):
44
+ return []
45
+
46
+ def get_results(self, prefix=""):
47
+ results = self._get_results() if self._updated else self._cached_results
48
+ self._cached_results = results
49
+ self._updated = False
50
+ return [(prefix + k, v) for k, v in results]
51
+
52
+ def results_str(self):
53
+ return " - ".join(["{:}: {:.2f}".format(k, v)
54
+ for k, v in self.get_results()])
arabert/araelectra/finetune/tagging/tagging_metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Metrics for sequence tagging tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ import six
24
+
25
+ import numpy as np
26
+
27
+ from finetune import scorer
28
+ from finetune.tagging import tagging_utils
29
+
30
+
31
+ class WordLevelScorer(scorer.Scorer):
32
+ """Base class for tagging scorers."""
33
+ __metaclass__ = abc.ABCMeta
34
+
35
+ def __init__(self):
36
+ super(WordLevelScorer, self).__init__()
37
+ self._total_loss = 0
38
+ self._total_words = 0
39
+ self._labels = []
40
+ self._preds = []
41
+
42
+ def update(self, results):
43
+ super(WordLevelScorer, self).update(results)
44
+ self._total_loss += results['loss']
45
+ n_words = int(round(np.sum(results['labels_mask'])))
46
+ self._labels.append(results['labels'][:n_words])
47
+ self._preds.append(results['predictions'][:n_words])
48
+ self._total_loss += np.sum(results['loss'])
49
+ self._total_words += n_words
50
+
51
+ def get_loss(self):
52
+ return self._total_loss / max(1, self._total_words)
53
+
54
+
55
+ class AccuracyScorer(WordLevelScorer):
56
+ """Computes accuracy scores."""
57
+
58
+ def __init__(self, auto_fail_label=None):
59
+ super(AccuracyScorer, self).__init__()
60
+ self._auto_fail_label = auto_fail_label
61
+
62
+ def _get_results(self):
63
+ correct, count = 0, 0
64
+ for labels, preds in zip(self._labels, self._preds):
65
+ for y_true, y_pred in zip(labels, preds):
66
+ count += 1
67
+ correct += (1 if y_pred == y_true and y_true != self._auto_fail_label
68
+ else 0)
69
+ return [
70
+ ('accuracy', 100.0 * correct / count),
71
+ ('loss', self.get_loss())
72
+ ]
73
+
74
+
75
+ class F1Scorer(WordLevelScorer):
76
+ """Computes F1 scores."""
77
+
78
+ __metaclass__ = abc.ABCMeta
79
+
80
+ def __init__(self):
81
+ super(F1Scorer, self).__init__()
82
+ self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
83
+
84
+ def _get_results(self):
85
+ if self._n_correct == 0:
86
+ p, r, f1 = 0, 0, 0
87
+ else:
88
+ p = 100.0 * self._n_correct / self._n_predicted
89
+ r = 100.0 * self._n_correct / self._n_gold
90
+ f1 = 2 * p * r / (p + r)
91
+ return [
92
+ ('precision', p),
93
+ ('recall', r),
94
+ ('f1', f1),
95
+ ('loss', self.get_loss()),
96
+ ]
97
+
98
+
99
+ class EntityLevelF1Scorer(F1Scorer):
100
+ """Computes F1 score for entity-level tasks such as NER."""
101
+
102
+ def __init__(self, label_mapping):
103
+ super(EntityLevelF1Scorer, self).__init__()
104
+ self._inv_label_mapping = {v: k for k, v in six.iteritems(label_mapping)}
105
+
106
+ def _get_results(self):
107
+ self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
108
+ for labels, preds in zip(self._labels, self._preds):
109
+ sent_spans = set(tagging_utils.get_span_labels(
110
+ labels, self._inv_label_mapping))
111
+ span_preds = set(tagging_utils.get_span_labels(
112
+ preds, self._inv_label_mapping))
113
+ self._n_correct += len(sent_spans & span_preds)
114
+ self._n_gold += len(sent_spans)
115
+ self._n_predicted += len(span_preds)
116
+ return super(EntityLevelF1Scorer, self)._get_results()
arabert/araelectra/finetune/tagging/tagging_tasks.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Sequence tagging tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ import collections
24
+ import os
25
+ import tensorflow as tf
26
+
27
+ import configure_finetuning
28
+ from finetune import feature_spec
29
+ from finetune import task
30
+ from finetune.tagging import tagging_metrics
31
+ from finetune.tagging import tagging_utils
32
+ from model import tokenization
33
+ from pretrain import pretrain_helpers
34
+ from util import utils
35
+
36
+
37
+ LABEL_ENCODING = "BIOES"
38
+
39
+
40
+ class TaggingExample(task.Example):
41
+ """A single tagged input sequence."""
42
+
43
+ def __init__(self, eid, task_name, words, tags, is_token_level,
44
+ label_mapping):
45
+ super(TaggingExample, self).__init__(task_name)
46
+ self.eid = eid
47
+ self.words = words
48
+ if is_token_level:
49
+ labels = tags
50
+ else:
51
+ span_labels = tagging_utils.get_span_labels(tags)
52
+ labels = tagging_utils.get_tags(
53
+ span_labels, len(words), LABEL_ENCODING)
54
+ self.labels = [label_mapping[l] for l in labels]
55
+
56
+
57
+ class TaggingTask(task.Task):
58
+ """Defines a sequence tagging task (e.g., part-of-speech tagging)."""
59
+
60
+ __metaclass__ = abc.ABCMeta
61
+
62
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name,
63
+ tokenizer, is_token_level):
64
+ super(TaggingTask, self).__init__(config, name)
65
+ self._tokenizer = tokenizer
66
+ self._label_mapping_path = os.path.join(
67
+ self.config.preprocessed_data_dir,
68
+ ("debug_" if self.config.debug else "") + self.name +
69
+ "_label_mapping.pkl")
70
+ self._is_token_level = is_token_level
71
+ self._label_mapping = None
72
+
73
+ def get_examples(self, split):
74
+ sentences = self._get_labeled_sentences(split)
75
+ examples = []
76
+ label_mapping = self._get_label_mapping(split, sentences)
77
+ for i, (words, tags) in enumerate(sentences):
78
+ examples.append(TaggingExample(
79
+ i, self.name, words, tags, self._is_token_level, label_mapping
80
+ ))
81
+ return examples
82
+
83
+ def _get_label_mapping(self, provided_split=None, provided_sentences=None):
84
+ if self._label_mapping is not None:
85
+ return self._label_mapping
86
+ if tf.io.gfile.exists(self._label_mapping_path):
87
+ self._label_mapping = utils.load_pickle(self._label_mapping_path)
88
+ return self._label_mapping
89
+ utils.log("Writing label mapping for task", self.name)
90
+ tag_counts = collections.Counter()
91
+ train_tags = set()
92
+ for split in ["train", "dev", "test"]:
93
+ if not tf.io.gfile.exists(os.path.join(
94
+ self.config.raw_data_dir(self.name), split + ".txt")):
95
+ continue
96
+ if split == provided_split:
97
+ split_sentences = provided_sentences
98
+ else:
99
+ split_sentences = self._get_labeled_sentences(split)
100
+ for _, tags in split_sentences:
101
+ if not self._is_token_level:
102
+ span_labels = tagging_utils.get_span_labels(tags)
103
+ tags = tagging_utils.get_tags(span_labels, len(tags), LABEL_ENCODING)
104
+ for tag in tags:
105
+ tag_counts[tag] += 1
106
+ if provided_split == "train":
107
+ train_tags.add(tag)
108
+ if self.name == "ccg":
109
+ infrequent_tags = []
110
+ for tag in tag_counts:
111
+ if tag not in train_tags:
112
+ infrequent_tags.append(tag)
113
+ label_mapping = {
114
+ label: i for i, label in enumerate(sorted(filter(
115
+ lambda t: t not in infrequent_tags, tag_counts.keys())))
116
+ }
117
+ n = len(label_mapping)
118
+ for tag in infrequent_tags:
119
+ label_mapping[tag] = n
120
+ else:
121
+ labels = sorted(tag_counts.keys())
122
+ label_mapping = {label: i for i, label in enumerate(labels)}
123
+ utils.write_pickle(label_mapping, self._label_mapping_path)
124
+ self._label_mapping = label_mapping
125
+ return label_mapping
126
+
127
+ def featurize(self, example: TaggingExample, is_training, log=False):
128
+ words_to_tokens = tokenize_and_align(self._tokenizer, example.words)
129
+ input_ids = []
130
+ tagged_positions = []
131
+ for word_tokens in words_to_tokens:
132
+ if len(words_to_tokens) + len(input_ids) + 1 > self.config.max_seq_length:
133
+ input_ids.append(self._tokenizer.vocab["[SEP]"])
134
+ break
135
+ if "[CLS]" not in word_tokens and "[SEP]" not in word_tokens:
136
+ tagged_positions.append(len(input_ids))
137
+ for token in word_tokens:
138
+ input_ids.append(self._tokenizer.vocab[token])
139
+
140
+ pad = lambda x: x + [0] * (self.config.max_seq_length - len(x))
141
+ labels = pad(example.labels[:self.config.max_seq_length])
142
+ labeled_positions = pad(tagged_positions)
143
+ labels_mask = pad([1.0] * len(tagged_positions))
144
+ segment_ids = pad([1] * len(input_ids))
145
+ input_mask = pad([1] * len(input_ids))
146
+ input_ids = pad(input_ids)
147
+ assert len(input_ids) == self.config.max_seq_length
148
+ assert len(input_mask) == self.config.max_seq_length
149
+ assert len(segment_ids) == self.config.max_seq_length
150
+ assert len(labels) == self.config.max_seq_length
151
+ assert len(labels_mask) == self.config.max_seq_length
152
+
153
+ return {
154
+ "input_ids": input_ids,
155
+ "input_mask": input_mask,
156
+ "segment_ids": segment_ids,
157
+ "task_id": self.config.task_names.index(self.name),
158
+ self.name + "_eid": example.eid,
159
+ self.name + "_labels": labels,
160
+ self.name + "_labels_mask": labels_mask,
161
+ self.name + "_labeled_positions": labeled_positions
162
+ }
163
+
164
+ def _get_labeled_sentences(self, split):
165
+ sentences = []
166
+ with tf.io.gfile.GFile(os.path.join(self.config.raw_data_dir(self.name),
167
+ split + ".txt"), "r") as f:
168
+ sentence = []
169
+ for line in f:
170
+ line = line.strip().split()
171
+ if not line:
172
+ if sentence:
173
+ words, tags = zip(*sentence)
174
+ sentences.append((words, tags))
175
+ sentence = []
176
+ if self.config.debug and len(sentences) > 100:
177
+ return sentences
178
+ continue
179
+ if line[0] == "-DOCSTART-":
180
+ continue
181
+ word, tag = line[0], line[-1]
182
+ sentence.append((word, tag))
183
+ return sentences
184
+
185
+ def get_scorer(self):
186
+ return tagging_metrics.AccuracyScorer() if self._is_token_level else \
187
+ tagging_metrics.EntityLevelF1Scorer(self._get_label_mapping())
188
+
189
+ def get_feature_specs(self):
190
+ return [
191
+ feature_spec.FeatureSpec(self.name + "_eid", []),
192
+ feature_spec.FeatureSpec(self.name + "_labels",
193
+ [self.config.max_seq_length]),
194
+ feature_spec.FeatureSpec(self.name + "_labels_mask",
195
+ [self.config.max_seq_length],
196
+ is_int_feature=False),
197
+ feature_spec.FeatureSpec(self.name + "_labeled_positions",
198
+ [self.config.max_seq_length]),
199
+ ]
200
+
201
+ def get_prediction_module(
202
+ self, bert_model, features, is_training, percent_done):
203
+ n_classes = len(self._get_label_mapping())
204
+ reprs = bert_model.get_sequence_output()
205
+ reprs = pretrain_helpers.gather_positions(
206
+ reprs, features[self.name + "_labeled_positions"])
207
+ logits = tf.layers.dense(reprs, n_classes)
208
+ losses = tf.nn.softmax_cross_entropy_with_logits(
209
+ labels=tf.one_hot(features[self.name + "_labels"], n_classes),
210
+ logits=logits)
211
+ losses *= features[self.name + "_labels_mask"]
212
+ losses = tf.reduce_sum(losses, axis=-1)
213
+ return losses, dict(
214
+ loss=losses,
215
+ logits=logits,
216
+ predictions=tf.argmax(logits, axis=-1),
217
+ labels=features[self.name + "_labels"],
218
+ labels_mask=features[self.name + "_labels_mask"],
219
+ eid=features[self.name + "_eid"],
220
+ )
221
+
222
+ def _create_examples(self, lines, split):
223
+ pass
224
+
225
+
226
+ def tokenize_and_align(tokenizer, words, cased=False):
227
+ """Splits up words into subword-level tokens."""
228
+ words = ["[CLS]"] + list(words) + ["[SEP]"]
229
+ basic_tokenizer = tokenizer.basic_tokenizer
230
+ tokenized_words = []
231
+ for word in words:
232
+ word = tokenization.convert_to_unicode(word)
233
+ word = basic_tokenizer._clean_text(word)
234
+ if word == "[CLS]" or word == "[SEP]":
235
+ word_toks = [word]
236
+ else:
237
+ if not cased:
238
+ word = word.lower()
239
+ word = basic_tokenizer._run_strip_accents(word)
240
+ word_toks = basic_tokenizer._run_split_on_punc(word)
241
+ tokenized_word = []
242
+ for word_tok in word_toks:
243
+ tokenized_word += tokenizer.wordpiece_tokenizer.tokenize(word_tok)
244
+ tokenized_words.append(tokenized_word)
245
+ assert len(tokenized_words) == len(words)
246
+ return tokenized_words
247
+
248
+
249
+ class Chunking(TaggingTask):
250
+ """Text chunking."""
251
+
252
+ def __init__(self, config, tokenizer):
253
+ super(Chunking, self).__init__(config, "chunk", tokenizer, False)
arabert/araelectra/finetune/tagging/tagging_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Utilities for sequence tagging tasks."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+
23
+ def get_span_labels(sentence_tags, inv_label_mapping=None):
24
+ """Go from token-level labels to list of entities (start, end, class)."""
25
+ if inv_label_mapping:
26
+ sentence_tags = [inv_label_mapping[i] for i in sentence_tags]
27
+ span_labels = []
28
+ last = 'O'
29
+ start = -1
30
+ for i, tag in enumerate(sentence_tags):
31
+ pos, _ = (None, 'O') if tag == 'O' else tag.split('-')
32
+ if (pos == 'S' or pos == 'B' or tag == 'O') and last != 'O':
33
+ span_labels.append((start, i - 1, last.split('-')[-1]))
34
+ if pos == 'B' or pos == 'S' or last == 'O':
35
+ start = i
36
+ last = tag
37
+ if sentence_tags[-1] != 'O':
38
+ span_labels.append((start, len(sentence_tags) - 1,
39
+ sentence_tags[-1].split('-')[-1]))
40
+ return span_labels
41
+
42
+
43
+ def get_tags(span_labels, length, encoding):
44
+ """Converts a list of entities to token-label labels based on the provided
45
+ encoding (e.g., BIOES).
46
+ """
47
+
48
+ tags = ['O' for _ in range(length)]
49
+ for s, e, t in span_labels:
50
+ for i in range(s, e + 1):
51
+ tags[i] = 'I-' + t
52
+ if 'E' in encoding:
53
+ tags[e] = 'E-' + t
54
+ if 'B' in encoding:
55
+ tags[s] = 'B-' + t
56
+ if 'S' in encoding and s - e == 0:
57
+ tags[s] = 'S-' + t
58
+ return tags
arabert/araelectra/finetune/task.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Defines a supervised NLP task."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import abc
23
+ from typing import List, Tuple
24
+
25
+ import configure_finetuning
26
+ from finetune import feature_spec
27
+ from finetune import scorer
28
+ from model import modeling
29
+
30
+
31
+ class Example(object):
32
+ __metaclass__ = abc.ABCMeta
33
+
34
+ def __init__(self, task_name):
35
+ self.task_name = task_name
36
+
37
+
38
+ class Task(object):
39
+ """Override this class to add a new fine-tuning task."""
40
+
41
+ __metaclass__ = abc.ABCMeta
42
+
43
+ def __init__(self, config: configure_finetuning.FinetuningConfig, name):
44
+ self.config = config
45
+ self.name = name
46
+
47
+ def get_test_splits(self):
48
+ return ["test"]
49
+
50
+ @abc.abstractmethod
51
+ def get_examples(self, split):
52
+ pass
53
+
54
+ @abc.abstractmethod
55
+ def get_scorer(self) -> scorer.Scorer:
56
+ pass
57
+
58
+ @abc.abstractmethod
59
+ def get_feature_specs(self) -> List[feature_spec.FeatureSpec]:
60
+ pass
61
+
62
+ @abc.abstractmethod
63
+ def featurize(self, example: Example, is_training: bool,
64
+ log: bool=False):
65
+ pass
66
+
67
+ @abc.abstractmethod
68
+ def get_prediction_module(
69
+ self, bert_model: modeling.BertModel, features: dict, is_training: bool,
70
+ percent_done: float) -> Tuple:
71
+ pass
72
+
73
+ def __repr__(self):
74
+ return "Task(" + self.name + ")"
arabert/araelectra/finetune/task_builder.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Returns task instances given the task name."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import configure_finetuning
23
+ from finetune.classification import classification_tasks
24
+ from finetune.qa import qa_tasks
25
+ from finetune.tagging import tagging_tasks
26
+ from model import tokenization
27
+
28
+
29
+ def get_tasks(config: configure_finetuning.FinetuningConfig):
30
+ tokenizer = tokenization.FullTokenizer(vocab_file=config.vocab_file,
31
+ do_lower_case=config.do_lower_case)
32
+ return [get_task(config, task_name, tokenizer)
33
+ for task_name in config.task_names]
34
+
35
+
36
+ def get_task(config: configure_finetuning.FinetuningConfig, task_name,
37
+ tokenizer):
38
+ """Get an instance of a task based on its name."""
39
+ if task_name == "cola":
40
+ return classification_tasks.CoLA(config, tokenizer)
41
+ elif task_name == "mrpc":
42
+ return classification_tasks.MRPC(config, tokenizer)
43
+ elif task_name == "mnli":
44
+ return classification_tasks.MNLI(config, tokenizer)
45
+ elif task_name == "sst":
46
+ return classification_tasks.SST(config, tokenizer)
47
+ elif task_name == "rte":
48
+ return classification_tasks.RTE(config, tokenizer)
49
+ elif task_name == "qnli":
50
+ return classification_tasks.QNLI(config, tokenizer)
51
+ elif task_name == "qqp":
52
+ return classification_tasks.QQP(config, tokenizer)
53
+ elif task_name == "sts":
54
+ return classification_tasks.STS(config, tokenizer)
55
+ elif task_name == "squad":
56
+ return qa_tasks.SQuAD(config, tokenizer)
57
+ elif task_name == "squadv1":
58
+ return qa_tasks.SQuADv1(config, tokenizer)
59
+ elif task_name == "newsqa":
60
+ return qa_tasks.NewsQA(config, tokenizer)
61
+ elif task_name == "naturalqs":
62
+ return qa_tasks.NaturalQuestions(config, tokenizer)
63
+ elif task_name == "triviaqa":
64
+ return qa_tasks.TriviaQA(config, tokenizer)
65
+ elif task_name == "searchqa":
66
+ return qa_tasks.SearchQA(config, tokenizer)
67
+ elif task_name == "chunk":
68
+ return tagging_tasks.Chunking(config, tokenizer)
69
+ else:
70
+ raise ValueError("Unknown task " + task_name)