Upload 106 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- arabert/.gitignore +142 -0
- arabert/AJGT.xlsx +0 -0
- arabert/README.md +227 -0
- arabert/__init__.py +1 -0
- arabert/__pycache__/__init__.cpython-310.pyc +0 -0
- arabert/__pycache__/__init__.cpython-38.pyc +0 -0
- arabert/__pycache__/__init__.cpython-39.pyc +0 -0
- arabert/__pycache__/preprocess.cpython-310.pyc +0 -0
- arabert/__pycache__/preprocess.cpython-38.pyc +0 -0
- arabert/__pycache__/preprocess.cpython-39.pyc +0 -0
- arabert/arabert/LICENSE +75 -0
- arabert/arabert/Readme.md +75 -0
- arabert/arabert/__init__.py +14 -0
- arabert/arabert/create_classification_data.py +260 -0
- arabert/arabert/create_pretraining_data.py +534 -0
- arabert/arabert/extract_features.py +444 -0
- arabert/arabert/lamb_optimizer.py +158 -0
- arabert/arabert/modeling.py +1027 -0
- arabert/arabert/optimization.py +202 -0
- arabert/arabert/run_classifier.py +1078 -0
- arabert/arabert/run_pretraining.py +593 -0
- arabert/arabert/run_squad.py +1440 -0
- arabert/arabert/sample_text.txt +38 -0
- arabert/arabert/tokenization.py +414 -0
- arabert/arabert_logo.png +0 -0
- arabert/araelectra/.gitignore +4 -0
- arabert/araelectra/LICENSE +76 -0
- arabert/araelectra/README.md +144 -0
- arabert/araelectra/__init__.py +1 -0
- arabert/araelectra/build_openwebtext_pretraining_dataset.py +103 -0
- arabert/araelectra/build_pretraining_dataset.py +230 -0
- arabert/araelectra/build_pretraining_dataset_single_file.py +90 -0
- arabert/araelectra/configure_finetuning.py +172 -0
- arabert/araelectra/configure_pretraining.py +143 -0
- arabert/araelectra/finetune/__init__.py +14 -0
- arabert/araelectra/finetune/classification/classification_metrics.py +116 -0
- arabert/araelectra/finetune/classification/classification_tasks.py +439 -0
- arabert/araelectra/finetune/feature_spec.py +56 -0
- arabert/araelectra/finetune/preprocessing.py +173 -0
- arabert/araelectra/finetune/qa/mrqa_official_eval.py +120 -0
- arabert/araelectra/finetune/qa/qa_metrics.py +401 -0
- arabert/araelectra/finetune/qa/qa_tasks.py +628 -0
- arabert/araelectra/finetune/qa/squad_official_eval.py +317 -0
- arabert/araelectra/finetune/qa/squad_official_eval_v1.py +126 -0
- arabert/araelectra/finetune/scorer.py +54 -0
- arabert/araelectra/finetune/tagging/tagging_metrics.py +116 -0
- arabert/araelectra/finetune/tagging/tagging_tasks.py +253 -0
- arabert/araelectra/finetune/tagging/tagging_utils.py +58 -0
- arabert/araelectra/finetune/task.py +74 -0
- arabert/araelectra/finetune/task_builder.py +70 -0
arabert/.gitignore
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
#vscode stuff
|
129 |
+
.vscode/
|
130 |
+
|
131 |
+
# Local History for Visual Studio Code
|
132 |
+
.history/
|
133 |
+
# Pyre type checker
|
134 |
+
.pyre/
|
135 |
+
testing_squad.py
|
136 |
+
FarasaSegmenterJar.jar
|
137 |
+
data/
|
138 |
+
testing/
|
139 |
+
*.tsv
|
140 |
+
*.zip
|
141 |
+
model_cards/
|
142 |
+
optuna/
|
arabert/AJGT.xlsx
ADDED
Binary file (107 kB). View file
|
|
arabert/README.md
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AraBERTv2 / AraGPT2 / AraELECTRA
|
2 |
+
|
3 |
+
<img src="https://github.com/aub-mind/arabert/blob/master/arabert_logo.png" width="100" align="right"/>
|
4 |
+
|
5 |
+
This repository now contains code and implementation for:
|
6 |
+
- **AraBERT v0.1/v1**: Original
|
7 |
+
- **AraBERT v0.2/v2**: Base and large versions with better vocabulary, more data, more training [Read More...](#AraBERT)
|
8 |
+
- **AraGPT2**: base, medium, large and MEGA. Trained from scratch on Arabic [Read More...](#AraGPT2)
|
9 |
+
- **AraELECTRA**: Trained from scratch on Arabic [Read More...](#AraELECTRA)
|
10 |
+
|
11 |
+
If you want to clone the old repository:
|
12 |
+
```bash
|
13 |
+
git clone https://github.com/aub-mind/arabert/
|
14 |
+
cd arabert && git checkout 6a58ca118911ef311cbe8cdcdcc1d03601123291
|
15 |
+
```
|
16 |
+
# Update
|
17 |
+
|
18 |
+
- **02-Apr-2021:** AraELECTRA powered Arabic Wikipedia QA system [](https://share.streamlit.io/wissamantoun/arabic-wikipedia-qa-streamlit/main)
|
19 |
+
|
20 |
+
# AraBERTv2
|
21 |
+
|
22 |
+
## What's New!
|
23 |
+
|
24 |
+
AraBERT now comes in 4 new variants to replace the old v1 versions:
|
25 |
+
|
26 |
+
More Detail in the AraBERT folder and in the [README](https://github.com/aub-mind/arabert/tree/master/arabert) and in the [AraBERT Paper](https://arxiv.org/abs/2003.00104)
|
27 |
+
|
28 |
+
Model | HuggingFace Model Name | Size (MB/Params)| Pre-Segmentation | DataSet (Sentences/Size/nWords) |
|
29 |
+
---|:---:|:---:|:---:|:---:
|
30 |
+
AraBERTv0.2-base | [bert-base-arabertv02](https://huggingface.co/aubmindlab/bert-base-arabertv02) | 543MB / 136M | No | 200M / 77GB / 8.6B |
|
31 |
+
AraBERTv0.2-large| [bert-large-arabertv02](https://huggingface.co/aubmindlab/bert-large-arabertv02) | 1.38G / 371M | No | 200M / 77GB / 8.6B |
|
32 |
+
AraBERTv2-base| [bert-base-arabertv2](https://huggingface.co/aubmindlab/bert-base-arabertv2) | 543MB / 136M | Yes | 200M / 77GB / 8.6B |
|
33 |
+
AraBERTv2-large| [bert-large-arabertv2](https://huggingface.co/aubmindlab/bert-large-arabertv2) | 1.38G / 371M | Yes | 200M / 77GB / 8.6B |
|
34 |
+
AraBERTv0.1-base| [bert-base-arabertv01](https://huggingface.co/aubmindlab/bert-base-arabertv01) | 543MB / 136M | No | 77M / 23GB / 2.7B |
|
35 |
+
AraBERTv1-base| [bert-base-arabert](https://huggingface.co/aubmindlab/bert-base-arabert) | 543MB / 136M | Yes | 77M / 23GB / 2.7B |
|
36 |
+
|
37 |
+
All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
|
38 |
+
|
39 |
+
## Better Pre-Processing and New Vocab
|
40 |
+
|
41 |
+
We identified an issue with AraBERTv1's wordpiece vocabulary. The issue came from punctuations and numbers that were still attached to words when learned the wordpiece vocab. We now insert a space between numbers and characters and around punctuation characters.
|
42 |
+
|
43 |
+
The new vocabulary was learnt using the `BertWordpieceTokenizer` from the `tokenizers` library, and should now support the Fast tokenizer implementation from the `transformers` library.
|
44 |
+
|
45 |
+
**P.S.**: All the old BERT codes should work with the new BERT, just change the model name and check the new preprocessing function
|
46 |
+
|
47 |
+
**Please read the section on how to use the [preprocessing function](#Preprocessing)**
|
48 |
+
|
49 |
+
## Bigger Dataset and More Compute
|
50 |
+
|
51 |
+
We used ~3.5 times more data, and trained for longer.
|
52 |
+
For Dataset Sources see the [Dataset Section](#Dataset)
|
53 |
+
|
54 |
+
Model | Hardware | num of examples with seq len (128 / 512) |128 (Batch Size/ Num of Steps) | 512 (Batch Size/ Num of Steps) | Total Steps | Total Time (in Days) |
|
55 |
+
---|:---:|:---:|:---:|:---:|:---:|:---:
|
56 |
+
AraBERTv0.2-base | TPUv3-8 | 420M / 207M | 2560 / 1M | 384/ 2M | 3M | 36
|
57 |
+
AraBERTv0.2-large | TPUv3-128 | 420M / 207M | 13440 / 250K | 2056 / 300K | 550K | 7
|
58 |
+
AraBERTv2-base | TPUv3-8 | 420M / 207M | 2560 / 1M | 384/ 2M | 3M | 36
|
59 |
+
AraBERTv2-large | TPUv3-128 | 520M / 245M | 13440 / 250K | 2056 / 300K | 550K | 7
|
60 |
+
AraBERT-base (v1/v0.1) | TPUv2-8 | - |512 / 900K | 128 / 300K| 1.2M | 4
|
61 |
+
|
62 |
+
# AraGPT2
|
63 |
+
|
64 |
+
More details and code are available in the AraGPT2 folder and [README](https://github.com/aub-mind/arabert/blob/master/aragpt2/README.md)
|
65 |
+
|
66 |
+
## Model
|
67 |
+
|
68 |
+
Model | HuggingFace Model Name | Size / Params|
|
69 |
+
---|:---:|:---:
|
70 |
+
AraGPT2-base | [aragpt2-base](https://huggingface.co/aubmindlab/aragpt2-base) | 527MB/135M |
|
71 |
+
AraGPT2-medium | [aragpt2-medium](https://huggingface.co/aubmindlab/aragpt2-medium) | 1.38G/370M |
|
72 |
+
AraGPT2-large | [aragpt2-large](https://huggingface.co/aubmindlab/aragpt2-large) | 2.98GB/792M |
|
73 |
+
AraGPT2-mega | [aragpt2-mega](https://huggingface.co/aubmindlab/aragpt2-mega) | 5.5GB/1.46B |
|
74 |
+
AraGPT2-mega-detector-long | [aragpt2-mega-detector-long](https://huggingface.co/aubmindlab/aragpt2-mega-detector-long) | 516MB/135M |
|
75 |
+
|
76 |
+
All models are available in the `HuggingFace` model page under the [aubmindlab](https://huggingface.co/aubmindlab/) name. Checkpoints are available in PyTorch, TF2 and TF1 formats.
|
77 |
+
|
78 |
+
## Dataset and Compute
|
79 |
+
|
80 |
+
For Dataset Source see the [Dataset Section](#Dataset)
|
81 |
+
|
82 |
+
Model | Hardware | num of examples (seq len = 1024) | Batch Size | Num of Steps | Time (in days)
|
83 |
+
---|:---:|:---:|:---:|:---:|:---:
|
84 |
+
AraGPT2-base | TPUv3-128 | 9.7M | 1792 | 125K | 1.5
|
85 |
+
AraGPT2-medium | TPUv3-128 | 9.7M | 1152 | 85K | 1.5
|
86 |
+
AraGPT2-large | TPUv3-128 | 9.7M | 256 | 220k | 3
|
87 |
+
AraGPT2-mega | TPUv3-128 | 9.7M | 256 | 800K | 9
|
88 |
+
|
89 |
+
# AraELECTRA
|
90 |
+
|
91 |
+
More details and code are available in the AraELECTRA folder and [README](https://github.com/aub-mind/arabert/blob/master/araelectra/README.md)
|
92 |
+
|
93 |
+
## Model
|
94 |
+
|
95 |
+
Model | HuggingFace Model Name | Size (MB/Params)|
|
96 |
+
---|:---:|:---:
|
97 |
+
AraELECTRA-base-generator | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator) | 227MB/60M |
|
98 |
+
AraELECTRA-base-discriminator | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) | 516MB/135M |
|
99 |
+
|
100 |
+
## Dataset and Compute
|
101 |
+
Model | Hardware | num of examples (seq len = 512) | Batch Size | Num of Steps | Time (in days)
|
102 |
+
---|:---:|:---:|:---:|:---:|:---:
|
103 |
+
ELECTRA-base | TPUv3-8 | - | 256 | 2M | 24
|
104 |
+
|
105 |
+
# Dataset
|
106 |
+
|
107 |
+
The pretraining data used for the new AraBERT model is also used for **AraGPT2 and AraELECTRA**.
|
108 |
+
|
109 |
+
The dataset consists of 77GB or 200,095,961 lines or 8,655,948,860 words or 82,232,988,358 chars (before applying Farasa Segmentation)
|
110 |
+
|
111 |
+
For the new dataset we added the unshuffled OSCAR corpus, after we thoroughly filter it, to the previous dataset used in AraBERTv1 but with out the websites that we previously crawled:
|
112 |
+
- OSCAR unshuffled and filtered.
|
113 |
+
- [Arabic Wikipedia dump](https://archive.org/details/arwiki-20190201) from 2020/09/01
|
114 |
+
- [The 1.5B words Arabic Corpus](https://www.semanticscholar.org/paper/1.5-billion-words-Arabic-Corpus-El-Khair/f3eeef4afb81223df96575adadf808fe7fe440b4)
|
115 |
+
- [The OSIAN Corpus](https://www.aclweb.org/anthology/W19-4619)
|
116 |
+
- Assafir news articles. Huge thank you for Assafir for the data
|
117 |
+
|
118 |
+
# Preprocessing
|
119 |
+
|
120 |
+
It is recommended to apply our preprocessing function before training/testing on any dataset.
|
121 |
+
**Install farasapy to segment text for AraBERT v1 & v2 `pip install farasapy`**
|
122 |
+
|
123 |
+
```python
|
124 |
+
from arabert.preprocess import ArabertPreprocessor
|
125 |
+
|
126 |
+
model_name = "aubmindlab/bert-base-arabertv2"
|
127 |
+
arabert_prep = ArabertPreprocessor(model_name=model_name)
|
128 |
+
|
129 |
+
text = "ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
|
130 |
+
arabert_prep.preprocess(text)
|
131 |
+
>>>"و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
|
132 |
+
```
|
133 |
+
|
134 |
+
You can also use the `unpreprocess()` function to reverse the preprocessing changes, by fixing the spacing around non alphabetical characters, and also de-segmenting if the model selected need pre-segmentation. We highly recommend unprocessing generated content of `AraGPT2` model, to make it look more natural.
|
135 |
+
```python
|
136 |
+
output_text = "و+ لن نبالغ إذا قل +نا : إن ' هاتف ' أو ' كمبيوتر ال+ مكتب ' في زمن +نا هذا ضروري"
|
137 |
+
arabert_prep.unpreprocess(output_text)
|
138 |
+
>>>"ولن نبالغ إذا قلنا: إن 'هاتف' أو 'كمبيوتر المكتب' في زمننا هذا ضروري"
|
139 |
+
```
|
140 |
+
|
141 |
+
### Accepted Model Names:
|
142 |
+
The `ArabertPreprocessor` class expects one of the following model names:
|
143 |
+
|
144 |
+
Note: You can also use the same model name from the `HuggingFace` model repository without removing `aubmindlab/`. Defaults to `bert-base-arabertv02` with no pre-segmentation
|
145 |
+
|
146 |
+
```
|
147 |
+
bert-base-arabertv01
|
148 |
+
bert-base-arabert
|
149 |
+
bert-base-arabertv02
|
150 |
+
bert-base-arabertv2
|
151 |
+
bert-large-arabertv02
|
152 |
+
bert-large-arabertv2
|
153 |
+
araelectra-base-discriminator
|
154 |
+
araelectra-base-generator
|
155 |
+
aragpt2-base
|
156 |
+
aragpt2-medium
|
157 |
+
aragpt2-large
|
158 |
+
aragpt2-mega
|
159 |
+
```
|
160 |
+
# Examples Notebooks
|
161 |
+
|
162 |
+
- You can find the old examples that work with AraBERTv1 in the `examples/old` folder
|
163 |
+
- Check the [Readme.md](https://github.com/aub-mind/arabert/tree/master/examples) file in the examples folder for new links to colab notebooks
|
164 |
+
|
165 |
+
# TensorFlow 1.x models
|
166 |
+
|
167 |
+
**You can find the PyTorch, TF2 and TF1 models in HuggingFace's Transformer Library under the ```aubmindlab``` username**
|
168 |
+
|
169 |
+
- `wget https://huggingface.co/aubmindlab/MODEL_NAME/resolve/main/tf1_model.tar.gz` where `MODEL_NAME` is any model under the `aubmindlab` name
|
170 |
+
|
171 |
+
|
172 |
+
# If you used this model please cite us as :
|
173 |
+
## AraBERT
|
174 |
+
Google Scholar has our Bibtex wrong (missing name), use this instead
|
175 |
+
```
|
176 |
+
@inproceedings{antoun2020arabert,
|
177 |
+
title={AraBERT: Transformer-based Model for Arabic Language Understanding},
|
178 |
+
author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
|
179 |
+
booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
|
180 |
+
pages={9}
|
181 |
+
}
|
182 |
+
```
|
183 |
+
## AraGPT2
|
184 |
+
```
|
185 |
+
@inproceedings{antoun-etal-2021-aragpt2,
|
186 |
+
title = "{A}ra{GPT}2: Pre-Trained Transformer for {A}rabic Language Generation",
|
187 |
+
author = "Antoun, Wissam and
|
188 |
+
Baly, Fady and
|
189 |
+
Hajj, Hazem",
|
190 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
191 |
+
month = apr,
|
192 |
+
year = "2021",
|
193 |
+
address = "Kyiv, Ukraine (Virtual)",
|
194 |
+
publisher = "Association for Computational Linguistics",
|
195 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.21",
|
196 |
+
pages = "196--207",
|
197 |
+
}
|
198 |
+
```
|
199 |
+
|
200 |
+
## AraELECTRA
|
201 |
+
```
|
202 |
+
@inproceedings{antoun-etal-2021-araelectra,
|
203 |
+
title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
|
204 |
+
author = "Antoun, Wissam and
|
205 |
+
Baly, Fady and
|
206 |
+
Hajj, Hazem",
|
207 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
208 |
+
month = apr,
|
209 |
+
year = "2021",
|
210 |
+
address = "Kyiv, Ukraine (Virtual)",
|
211 |
+
publisher = "Association for Computational Linguistics",
|
212 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
|
213 |
+
pages = "191--195",
|
214 |
+
}
|
215 |
+
```
|
216 |
+
|
217 |
+
|
218 |
+
# Acknowledgments
|
219 |
+
Thanks to TensorFlow Research Cloud (TFRC) for the free access to Cloud TPUs, couldn't have done it without this program, and to the [AUB MIND Lab](https://sites.aub.edu.lb/mindlab/) Members for the continous support. Also thanks to [Yakshof](https://www.yakshof.com/#/) and Assafir for data and storage access. Another thanks for Habib Rahal (https://www.behance.net/rahalhabib), for putting a face to AraBERT.
|
220 |
+
|
221 |
+
# Contacts
|
222 |
+
**Wissam Antoun**: [Linkedin](https://www.linkedin.com/in/wissam-antoun-622142b4/) | [Twitter](https://twitter.com/wissam_antoun) | [Github](https://github.com/WissamAntoun) | wfa07 (AT) mail (DOT) aub (DOT) edu | wissam.antoun (AT) gmail (DOT) com
|
223 |
+
|
224 |
+
**Fady Baly**: [Linkedin](https://www.linkedin.com/in/fadybaly/) | [Twitter](https://twitter.com/fadybaly) | [Github](https://github.com/fadybaly) | fgb06 (AT) mail (DOT) aub (DOT) edu | baly.fady (AT) gmail (DOT) com
|
225 |
+
|
226 |
+
|
227 |
+
|
arabert/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
arabert/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (117 Bytes). View file
|
|
arabert/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (143 Bytes). View file
|
|
arabert/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (139 Bytes). View file
|
|
arabert/__pycache__/preprocess.cpython-310.pyc
ADDED
Binary file (21.8 kB). View file
|
|
arabert/__pycache__/preprocess.cpython-38.pyc
ADDED
Binary file (21.9 kB). View file
|
|
arabert/__pycache__/preprocess.cpython-39.pyc
ADDED
Binary file (21.8 kB). View file
|
|
arabert/arabert/LICENSE
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
==========================================
|
2 |
+
SOFTWARE LICENSE AGREEMENT - AraBERT
|
3 |
+
==========================================
|
4 |
+
|
5 |
+
* NAME: AraBERT : Arabic Bidirectional Encoder Representations from Transformers
|
6 |
+
|
7 |
+
* ACKNOWLEDGMENTS
|
8 |
+
|
9 |
+
This [software] was generated by [American
|
10 |
+
University of Beirut] (“Owners”). The statements
|
11 |
+
made herein are solely the responsibility of the author[s].
|
12 |
+
|
13 |
+
The following software programs and programs have been used in the
|
14 |
+
generation of [AraBERT]:
|
15 |
+
|
16 |
+
+ Farasa Segmenter
|
17 |
+
- Abdelali, Ahmed, Kareem Darwish, Nadir Durrani, and Hamdy Mubarak.
|
18 |
+
"Farasa: A fast and furious segmenter for arabic." In Proceedings of
|
19 |
+
the 2016 Conference of the North American Chapter of the Association
|
20 |
+
for Computational Linguistics: Demonstrations, pp. 11-16. 2016.
|
21 |
+
- License and link : http://alt.qcri.org/farasa/segmenter.html
|
22 |
+
|
23 |
+
+ BERT
|
24 |
+
- Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.
|
25 |
+
"Bert: Pre-training of deep bidirectional transformers for language
|
26 |
+
understanding." arXiv preprint arXiv:1810.04805 (2018).
|
27 |
+
- License and link : https://github.com/google-research/bert
|
28 |
+
|
29 |
+
+ PyArabic
|
30 |
+
- T. Zerrouki, Pyarabic, An Arabic language library for Python,
|
31 |
+
https://pypi.python.org/pypi/pyarabic/, 2010
|
32 |
+
- License and link: https://github.com/linuxscout/pyarabic/
|
33 |
+
|
34 |
+
* LICENSE
|
35 |
+
|
36 |
+
This software and database is being provided to you, the LICENSEE,
|
37 |
+
by the Owners under the following license. By obtaining, using and/or
|
38 |
+
copying this software and database, you agree that you have read,
|
39 |
+
understood, and will comply with these terms and conditions. You
|
40 |
+
further agree that you have read and you will abide by the license
|
41 |
+
agreements provided in the above links under “acknowledgements”:
|
42 |
+
Permission to use, copy, modify and distribute this software and
|
43 |
+
database and its documentation for any purpose and without fee or
|
44 |
+
royalty is hereby granted, provided that you agree to comply with the
|
45 |
+
following copyright notice and statements, including the disclaimer,
|
46 |
+
and that the same appear on ALL copies of the software, database and
|
47 |
+
documentation, including modifications that you make for internal use
|
48 |
+
or for distribution. [AraBERT] Copyright 2020 by [American University
|
49 |
+
of Beirut]. All rights reserved. If you remix, transform, or build
|
50 |
+
upon the material, you must distribute your contributions under the
|
51 |
+
same license as this one. You may not apply legal terms or technological
|
52 |
+
measures that legally restrict others from doing anything this license
|
53 |
+
permits. THIS SOFTWARE IS PROVIDED "AS IS" AND THE OWNERS MAKE NO
|
54 |
+
REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE,
|
55 |
+
BUT NOT LIMITATION, THE OWNERS MAKE NO REPRESENTATIONS OR WARRANTIES OF
|
56 |
+
MERCHANT-ABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
|
57 |
+
THE LICENSED SOFTWARE, DATABASE OR DOCUMENTATION WILL NOT INFRINGE ANY THIRD
|
58 |
+
PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS. The name of the
|
59 |
+
Owners may not be used in advertising or publicity pertaining to
|
60 |
+
distribution of the software and/or database. Title to copyright in
|
61 |
+
this software, database and any associated documentation shall at all
|
62 |
+
times remain with the Owners and LICENSEE agrees to preserve same.
|
63 |
+
|
64 |
+
The use of AraBERT should be cited as follows:
|
65 |
+
@inproceedings{antoun2020arabert,
|
66 |
+
title={AraBERT: Transformer-based Model for Arabic Language Understanding},
|
67 |
+
author={Antoun, Wissam and Baly, Fady and Hajj, Hazem},
|
68 |
+
booktitle={LREC 2020 Workshop Language Resources and Evaluation Conference 11--16 May 2020},
|
69 |
+
pages={9}
|
70 |
+
}
|
71 |
+
|
72 |
+
[AraBERT] Copyright 2020 by [American University of Beirut].
|
73 |
+
All rights reserved.
|
74 |
+
==========================================
|
75 |
+
|
arabert/arabert/Readme.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AraBERT v1 & v2 : Pre-training BERT for Arabic Language Understanding
|
2 |
+
<img src="https://github.com/aub-mind/arabert/blob/master/arabert_logo.png" width="100" align="left"/>
|
3 |
+
|
4 |
+
**AraBERT** is an Arabic pretrained lanaguage model based on [Google's BERT architechture](https://github.com/google-research/bert). AraBERT uses the same BERT-Base config. More details are available in the [AraBERT Paper](https://arxiv.org/abs/2003.00104v2) and in the [AraBERT Meetup](https://github.com/WissamAntoun/pydata_khobar_meetup)
|
5 |
+
|
6 |
+
There are two versions of the model, AraBERTv0.1 and AraBERTv1, with the difference being that AraBERTv1 uses pre-segmented text where prefixes and suffixes were splitted using the [Farasa Segmenter](http://alt.qcri.org/farasa/segmenter.html).
|
7 |
+
|
8 |
+
|
9 |
+
We evalaute AraBERT models on different downstream tasks and compare them to [mBERT]((https://github.com/google-research/bert/blob/master/multilingual.md)), and other state of the art models (*To the extent of our knowledge*). The Tasks were Sentiment Analysis on 6 different datasets ([HARD](https://github.com/elnagara/HARD-Arabic-Dataset), [ASTD-Balanced](https://www.aclweb.org/anthology/D15-1299), [ArsenTD-Lev](https://staff.aub.edu.lb/~we07/Publications/ArSentD-LEV_Sentiment_Corpus.pdf), [LABR](https://github.com/mohamedadaly/LABR)), Named Entity Recognition with the [ANERcorp](http://curtis.ml.cmu.edu/w/courses/index.php/ANERcorp), and Arabic Question Answering on [Arabic-SQuAD and ARCD](https://github.com/husseinmozannar/SOQAL)
|
10 |
+
|
11 |
+
|
12 |
+
## Results
|
13 |
+
Task | Metric | AraBERTv0.1 | AraBERTv1 | AraBERTv0.2-base | AraBERTv2-Base | AraBERTv0.2-large | AraBERTv2-large| AraELECTRA-Base
|
14 |
+
:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
15 |
+
HARD |Acc.|**96.2**|96.1|-|-|-|-|-
|
16 |
+
ASTD |Acc.|92.2|**92.6**|-|-|-|-|-
|
17 |
+
ArsenTD-Lev|macro-f1|53.56|-|55.71|-|56.94|-|**57.20**
|
18 |
+
AJGT|Acc.|93.1|**93.8**|-|-|-|-|-
|
19 |
+
LABR|Acc.|85.9|**86.7**|-|-|-|-|-
|
20 |
+
ANERcorp|macro-F1|83.1|82.4|83.70|-|83.08|-|**83.95**
|
21 |
+
ARCD|EM - F1|31.62 - 67.45|31.7 - 67.8|32.76 - 66.53|31.34 - 67.23|36.89 - **71.32**|34.19 - 68.12|**37.03** - 71.22
|
22 |
+
TyDiQA-ar|EM - F1|68.51 - 82.86|- |73.07 - 85.41|-|73.72 - 86.03|-|**74.91 - 86.68**
|
23 |
+
|
24 |
+
|
25 |
+
## How to use
|
26 |
+
|
27 |
+
You can easily use AraBERT since it is almost fully compatible with existing codebases (Use this repo instead of the official BERT one, the only difference is in the ```tokenization.py``` file where we modify the _is_punctuation function to make it compatible with the "+" symbol and the "[" and "]" characters)
|
28 |
+
|
29 |
+
|
30 |
+
**AraBERTv1 an v2 always needs pre-segmentation**
|
31 |
+
```python
|
32 |
+
from transformers import AutoTokenizer, AutoModel
|
33 |
+
from arabert.preprocess import ArabertPreprocessor
|
34 |
+
|
35 |
+
model_name = "aubmindlab/bert-base-arabertv2"
|
36 |
+
arabert_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
37 |
+
arabert_model = AutoModel.from_pretrained(model_name)
|
38 |
+
|
39 |
+
arabert_prep = ArabertPreprocessor(model_name=model_name)
|
40 |
+
|
41 |
+
text = "ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا ضروري"
|
42 |
+
arabert_prep.preprocess(text)
|
43 |
+
>>>"و+ لن نبالغ إذا قل +نا إن هاتف أو كمبيوتر ال+ مكتب في زمن +نا هذا ضروري"
|
44 |
+
|
45 |
+
arabert_tokenizer.tokenize(text_preprocessed)
|
46 |
+
|
47 |
+
>>> ['و+', 'لن', 'نبال', '##غ', 'إذا', 'قل', '+نا', 'إن', 'هاتف', 'أو', 'كمبيوتر', 'ال+', 'مكتب', 'في', 'زمن', '+نا', 'هذا', 'ضروري']
|
48 |
+
```
|
49 |
+
|
50 |
+
**AraBERTv0.1 and v0.2 needs no pre-segmentation.**
|
51 |
+
```python
|
52 |
+
from transformers import AutoTokenizer, AutoModel
|
53 |
+
from arabert.preprocess import ArabertPreprocessor
|
54 |
+
|
55 |
+
arabert_tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv01",do_lower_case=False)
|
56 |
+
arabert_model = AutoModel.from_pretrained("aubmindlab/bert-base-arabertv01")
|
57 |
+
|
58 |
+
text = "ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا ضروري"
|
59 |
+
|
60 |
+
model_name = "aubmindlab/bert-base-arabertv01"
|
61 |
+
arabert_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
62 |
+
arabert_model = AutoModel.from_pretrained(model_name)
|
63 |
+
|
64 |
+
arabert_prep = ArabertPreprocessor(model_name=model_name)
|
65 |
+
|
66 |
+
arabert_tokenizer.tokenize(text_preprocessed)
|
67 |
+
|
68 |
+
>>> ['ولن', 'ن', '##بالغ', 'إذا', 'قلنا', 'إن', 'هاتف', 'أو', 'كمبيوتر', 'المكتب', 'في', 'زمن', '##ن', '##ا', 'هذا', 'ضروري']
|
69 |
+
```
|
70 |
+
|
71 |
+
## Model Weights and Vocab Download
|
72 |
+
|
73 |
+
**You can find the PyTorch, TF2 and TF1 models in HuggingFace's Transformer Library under the ```aubmindlab``` username**
|
74 |
+
|
75 |
+
- `wget https://huggingface.co/aubmindlab/MODEL_NAME/resolve/main/tf1_model.tar.gz` where `MODEL_NAME` is any model under the `aubmindlab` name
|
arabert/arabert/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
arabert/arabert/create_classification_data.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Scripts used to pre_process and create the data for classifier evaluation
|
2 |
+
#%%
|
3 |
+
import pandas as pd
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
|
6 |
+
import sys
|
7 |
+
sys.path.append("..")
|
8 |
+
|
9 |
+
from arabert.preprocess import ArabertPreprocessor
|
10 |
+
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
tqdm.pandas()
|
15 |
+
|
16 |
+
from tokenization import FullTokenizer
|
17 |
+
from run_classifier import input_fn_builder, model_fn_builder
|
18 |
+
|
19 |
+
|
20 |
+
model_name = "bert-base-arabert"
|
21 |
+
arabert_prep = ArabertPreprocessor(model_name=model_name, keep_emojis=False)
|
22 |
+
|
23 |
+
|
24 |
+
class Dataset:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
name,
|
28 |
+
train,
|
29 |
+
test,
|
30 |
+
label_list,
|
31 |
+
train_InputExamples=None,
|
32 |
+
test_InputExamples=None,
|
33 |
+
train_features=None,
|
34 |
+
test_features=None,
|
35 |
+
):
|
36 |
+
self.name = name
|
37 |
+
self.train = train
|
38 |
+
self.test = test
|
39 |
+
self.label_list = label_list
|
40 |
+
self.train_InputExamples = train_InputExamples
|
41 |
+
self.test_InputExamples = test_InputExamples
|
42 |
+
self.train_features = train_features
|
43 |
+
self.test_features = test_features
|
44 |
+
|
45 |
+
|
46 |
+
all_datasets = []
|
47 |
+
#%%
|
48 |
+
# *************HARD************
|
49 |
+
df_HARD = pd.read_csv("Datasets\\HARD\\balanced-reviews-utf8.tsv", sep="\t", header=0)
|
50 |
+
|
51 |
+
df_HARD = df_HARD[["rating", "review"]] # we are interested in rating and review only
|
52 |
+
# code rating as +ve if > 3, -ve if less, no 3s in dataset
|
53 |
+
df_HARD["rating"] = df_HARD["rating"].apply(lambda x: 0 if x < 3 else 1)
|
54 |
+
# rename columns to fit default constructor in fastai
|
55 |
+
df_HARD.columns = ["label", "text"]
|
56 |
+
df_HARD["text"] = df_HARD["text"].progress_apply(
|
57 |
+
lambda x: arabert_prep.preprocess(
|
58 |
+
x
|
59 |
+
)
|
60 |
+
)
|
61 |
+
train_HARD, test_HARD = train_test_split(df_HARD, test_size=0.2, random_state=42)
|
62 |
+
label_list_HARD = [0, 1]
|
63 |
+
|
64 |
+
data_Hard = Dataset("HARD", train_HARD, test_HARD, label_list_HARD)
|
65 |
+
all_datasets.append(data_Hard)
|
66 |
+
|
67 |
+
#%%
|
68 |
+
# *************ASTD-Unbalanced************
|
69 |
+
df_ASTD_UN = pd.read_csv(
|
70 |
+
"Datasets\\ASTD-master\\data\\Tweets.txt", sep="\t", header=None
|
71 |
+
)
|
72 |
+
|
73 |
+
DATA_COLUMN = "text"
|
74 |
+
LABEL_COLUMN = "label"
|
75 |
+
df_ASTD_UN.columns = [DATA_COLUMN, LABEL_COLUMN]
|
76 |
+
|
77 |
+
df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
|
78 |
+
lambda x: 0 if (x == "NEG") else x
|
79 |
+
)
|
80 |
+
df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
|
81 |
+
lambda x: 1 if (x == "POS") else x
|
82 |
+
)
|
83 |
+
df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
|
84 |
+
lambda x: 2 if (x == "NEUTRAL") else x
|
85 |
+
)
|
86 |
+
df_ASTD_UN[LABEL_COLUMN] = df_ASTD_UN[LABEL_COLUMN].apply(
|
87 |
+
lambda x: 3 if (x == "OBJ") else x
|
88 |
+
)
|
89 |
+
df_ASTD_UN["text"] = df_ASTD_UN["text"].progress_apply(
|
90 |
+
lambda x: arabert_prep.preprocess(
|
91 |
+
x
|
92 |
+
)
|
93 |
+
)
|
94 |
+
train_ASTD_UN, test_ASTD_UN = train_test_split(
|
95 |
+
df_ASTD_UN, test_size=0.2, random_state=42
|
96 |
+
)
|
97 |
+
label_list_ASTD_UN = [0, 1, 2, 3]
|
98 |
+
|
99 |
+
data_ASTD_UN = Dataset(
|
100 |
+
"ASTD-Unbalanced", train_ASTD_UN, test_ASTD_UN, label_list_ASTD_UN
|
101 |
+
)
|
102 |
+
all_datasets.append(data_ASTD_UN)
|
103 |
+
#%%
|
104 |
+
# *************ASTD-Dahou-Balanced************
|
105 |
+
|
106 |
+
df_ASTD_B = pd.read_csv(
|
107 |
+
"Datasets\\Dahou\\data_csv_balanced\\ASTD-balanced-not-linked.csv",
|
108 |
+
sep=",",
|
109 |
+
header=0,
|
110 |
+
)
|
111 |
+
|
112 |
+
df_ASTD_B.columns = [DATA_COLUMN, LABEL_COLUMN]
|
113 |
+
|
114 |
+
df_ASTD_B[LABEL_COLUMN] = df_ASTD_B[LABEL_COLUMN].apply(lambda x: 0 if (x == -1) else x)
|
115 |
+
df_ASTD_B["text"] = df_ASTD_B["text"].progress_apply(
|
116 |
+
lambda x: arabert_prep.preprocess(
|
117 |
+
x
|
118 |
+
)
|
119 |
+
)
|
120 |
+
train_ASTD_B, test_ASTD_B = train_test_split(df_ASTD_B, test_size=0.2, random_state=42)
|
121 |
+
label_list_ASTD_B = [0, 1]
|
122 |
+
|
123 |
+
data_ASTD_B = Dataset(
|
124 |
+
"ASTD-Dahou-Balanced", train_ASTD_B, test_ASTD_B, label_list_ASTD_B
|
125 |
+
)
|
126 |
+
all_datasets.append(data_ASTD_B)
|
127 |
+
|
128 |
+
#%%
|
129 |
+
# *************ArSenTD-LEV************
|
130 |
+
df_ArSenTD = pd.read_csv(
|
131 |
+
"Datasets\\ArSenTD-LEV\\ArSenTD-LEV-processed-no-emojis2.csv", sep=",", header=0
|
132 |
+
)
|
133 |
+
|
134 |
+
df_ArSenTD.columns = [DATA_COLUMN, LABEL_COLUMN]
|
135 |
+
|
136 |
+
df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
|
137 |
+
lambda x: 0 if (x == "very_negative") else x
|
138 |
+
)
|
139 |
+
df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
|
140 |
+
lambda x: 1 if (x == "negative") else x
|
141 |
+
)
|
142 |
+
df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
|
143 |
+
lambda x: 2 if (x == "neutral") else x
|
144 |
+
)
|
145 |
+
df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
|
146 |
+
lambda x: 3 if (x == "positive") else x
|
147 |
+
)
|
148 |
+
df_ArSenTD[LABEL_COLUMN] = df_ArSenTD[LABEL_COLUMN].apply(
|
149 |
+
lambda x: 4 if (x == "very_positive") else x
|
150 |
+
)
|
151 |
+
df_ArSenTD["text"] = df_ArSenTD["text"].progress_apply(
|
152 |
+
lambda x: arabert_prep.preprocess(
|
153 |
+
x
|
154 |
+
)
|
155 |
+
)
|
156 |
+
label_list_ArSenTD = [0, 1, 2, 3, 4]
|
157 |
+
|
158 |
+
train_ArSenTD, test_ArSenTD = train_test_split(
|
159 |
+
df_ArSenTD, test_size=0.2, random_state=42
|
160 |
+
)
|
161 |
+
|
162 |
+
data_ArSenTD = Dataset("ArSenTD-LEV", train_ArSenTD, test_ArSenTD, label_list_ArSenTD)
|
163 |
+
all_datasets.append(data_ArSenTD)
|
164 |
+
|
165 |
+
#%%
|
166 |
+
# *************AJGT************
|
167 |
+
df_AJGT = pd.read_excel("Datasets\\Ajgt\\AJGT.xlsx", header=0)
|
168 |
+
|
169 |
+
df_AJGT = df_AJGT[["Feed", "Sentiment"]]
|
170 |
+
df_AJGT.columns = [DATA_COLUMN, LABEL_COLUMN]
|
171 |
+
|
172 |
+
df_AJGT[LABEL_COLUMN] = df_AJGT[LABEL_COLUMN].apply(
|
173 |
+
lambda x: 0 if (x == "Negative") else x
|
174 |
+
)
|
175 |
+
df_AJGT[LABEL_COLUMN] = df_AJGT[LABEL_COLUMN].apply(
|
176 |
+
lambda x: 1 if (x == "Positive") else x
|
177 |
+
)
|
178 |
+
df_AJGT["text"] = df_AJGT["text"].progress_apply(
|
179 |
+
lambda x: arabert_prep.preprocess(
|
180 |
+
x
|
181 |
+
)
|
182 |
+
)
|
183 |
+
train_AJGT, test_AJGT = train_test_split(df_AJGT, test_size=0.2, random_state=42)
|
184 |
+
label_list_AJGT = [0, 1]
|
185 |
+
|
186 |
+
data_AJGT = Dataset("AJGT", train_AJGT, test_AJGT, label_list_AJGT)
|
187 |
+
all_datasets.append(data_AJGT)
|
188 |
+
#%%
|
189 |
+
# *************LABR-UN-Binary************
|
190 |
+
from labr import LABR
|
191 |
+
|
192 |
+
labr_helper = LABR()
|
193 |
+
|
194 |
+
(d_train, y_train, d_test, y_test) = labr_helper.get_train_test(
|
195 |
+
klass="2", balanced="unbalanced"
|
196 |
+
)
|
197 |
+
|
198 |
+
train_LABR_B_U = pd.DataFrame({"text": d_train, "label": y_train})
|
199 |
+
test_LABR_B_U = pd.DataFrame({"text": d_test, "label": y_test})
|
200 |
+
|
201 |
+
train_LABR_B_U["text"] = train_LABR_B_U["text"].progress_apply(
|
202 |
+
lambda x: arabert_prep.preprocess(
|
203 |
+
x
|
204 |
+
)
|
205 |
+
)
|
206 |
+
test_LABR_B_U["text"] = test_LABR_B_U["text"].progress_apply(
|
207 |
+
lambda x: arabert_prep.preprocess(
|
208 |
+
x
|
209 |
+
)
|
210 |
+
)
|
211 |
+
label_list_LABR_B_U = [0, 1]
|
212 |
+
|
213 |
+
data_LABR_B_U = Dataset(
|
214 |
+
"LABR-UN-Binary", train_LABR_B_U, test_LABR_B_U, label_list_LABR_B_U
|
215 |
+
)
|
216 |
+
# all_datasets.append(data_LABR_B_U)
|
217 |
+
|
218 |
+
#%%
|
219 |
+
for data in tqdm(all_datasets):
|
220 |
+
# Use the InputExample class from BERT's run_classifier code to create examples from the data
|
221 |
+
data.train_InputExamples = data.train.apply(
|
222 |
+
lambda x: run_classifier.InputExample(
|
223 |
+
guid=None, # Globally unique ID for bookkeeping, unused in this example
|
224 |
+
text_a=x[DATA_COLUMN],
|
225 |
+
text_b=None,
|
226 |
+
label=x[LABEL_COLUMN],
|
227 |
+
),
|
228 |
+
axis=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
data.test_InputExamples = data.test.apply(
|
232 |
+
lambda x: run_classifier.InputExample(
|
233 |
+
guid=None, text_a=x[DATA_COLUMN], text_b=None, label=x[LABEL_COLUMN]
|
234 |
+
),
|
235 |
+
axis=1,
|
236 |
+
)
|
237 |
+
#%%
|
238 |
+
# We'll set sequences to be at most 128 tokens long.
|
239 |
+
MAX_SEQ_LENGTH = 256
|
240 |
+
|
241 |
+
VOC_FNAME = "./64000_vocab_sp_70m.txt"
|
242 |
+
tokenizer = FullTokenizer(VOC_FNAME)
|
243 |
+
|
244 |
+
for data in tqdm(all_datasets):
|
245 |
+
# Convert our train and test features to InputFeatures that BERT understands.
|
246 |
+
data.train_features = run_classifier.convert_examples_to_features(
|
247 |
+
data.train_InputExamples, data.label_list, MAX_SEQ_LENGTH, tokenizer
|
248 |
+
)
|
249 |
+
data.test_features = run_classifier.convert_examples_to_features(
|
250 |
+
data.test_InputExamples, data.label_list, MAX_SEQ_LENGTH, tokenizer
|
251 |
+
)
|
252 |
+
|
253 |
+
# %%
|
254 |
+
import pickle
|
255 |
+
|
256 |
+
with open("all_datasets_64k_farasa_256.pickle", "wb") as fp: # Pickling
|
257 |
+
pickle.dump(all_datasets, fp)
|
258 |
+
|
259 |
+
|
260 |
+
# %%
|
arabert/arabert/create_pretraining_data.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import collections
|
22 |
+
import random
|
23 |
+
import tokenization
|
24 |
+
import tensorflow as tf
|
25 |
+
|
26 |
+
flags = tf.flags
|
27 |
+
|
28 |
+
FLAGS = flags.FLAGS
|
29 |
+
|
30 |
+
flags.DEFINE_string(
|
31 |
+
"input_file", None, "Input raw text file (or comma-separated list of files)."
|
32 |
+
)
|
33 |
+
|
34 |
+
flags.DEFINE_string(
|
35 |
+
"output_file", None, "Output TF example file (or comma-separated list of files)."
|
36 |
+
)
|
37 |
+
|
38 |
+
flags.DEFINE_string(
|
39 |
+
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
40 |
+
)
|
41 |
+
|
42 |
+
flags.DEFINE_bool(
|
43 |
+
"do_lower_case",
|
44 |
+
True,
|
45 |
+
"Whether to lower case the input text. Should be True for uncased "
|
46 |
+
"models and False for cased models.",
|
47 |
+
)
|
48 |
+
|
49 |
+
flags.DEFINE_bool(
|
50 |
+
"do_whole_word_mask",
|
51 |
+
False,
|
52 |
+
"Whether to use whole word masking rather than per-WordPiece masking.",
|
53 |
+
)
|
54 |
+
|
55 |
+
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
|
56 |
+
|
57 |
+
flags.DEFINE_integer(
|
58 |
+
"max_predictions_per_seq",
|
59 |
+
20,
|
60 |
+
"Maximum number of masked LM predictions per sequence.",
|
61 |
+
)
|
62 |
+
|
63 |
+
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
|
64 |
+
|
65 |
+
flags.DEFINE_integer(
|
66 |
+
"dupe_factor",
|
67 |
+
10,
|
68 |
+
"Number of times to duplicate the input data (with different masks).",
|
69 |
+
)
|
70 |
+
|
71 |
+
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
|
72 |
+
|
73 |
+
flags.DEFINE_float(
|
74 |
+
"short_seq_prob",
|
75 |
+
0.1,
|
76 |
+
"Probability of creating sequences which are shorter than the " "maximum length.",
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
class TrainingInstance(object):
|
81 |
+
"""A single training instance (sentence pair)."""
|
82 |
+
|
83 |
+
def __init__(
|
84 |
+
self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next
|
85 |
+
):
|
86 |
+
self.tokens = tokens
|
87 |
+
self.segment_ids = segment_ids
|
88 |
+
self.is_random_next = is_random_next
|
89 |
+
self.masked_lm_positions = masked_lm_positions
|
90 |
+
self.masked_lm_labels = masked_lm_labels
|
91 |
+
|
92 |
+
def __str__(self):
|
93 |
+
s = ""
|
94 |
+
s += "tokens: %s\n" % (
|
95 |
+
" ".join([tokenization.printable_text(x) for x in self.tokens])
|
96 |
+
)
|
97 |
+
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
|
98 |
+
s += "is_random_next: %s\n" % self.is_random_next
|
99 |
+
s += "masked_lm_positions: %s\n" % (
|
100 |
+
" ".join([str(x) for x in self.masked_lm_positions])
|
101 |
+
)
|
102 |
+
s += "masked_lm_labels: %s\n" % (
|
103 |
+
" ".join([tokenization.printable_text(x) for x in self.masked_lm_labels])
|
104 |
+
)
|
105 |
+
s += "\n"
|
106 |
+
return s
|
107 |
+
|
108 |
+
def __repr__(self):
|
109 |
+
return self.__str__()
|
110 |
+
|
111 |
+
|
112 |
+
def write_instance_to_example_files(
|
113 |
+
instances, tokenizer, max_seq_length, max_predictions_per_seq, output_files
|
114 |
+
):
|
115 |
+
"""Create TF example files from `TrainingInstance`s."""
|
116 |
+
writers = []
|
117 |
+
for output_file in output_files:
|
118 |
+
writers.append(tf.python_io.TFRecordWriter(output_file))
|
119 |
+
|
120 |
+
writer_index = 0
|
121 |
+
|
122 |
+
total_written = 0
|
123 |
+
for (inst_index, instance) in enumerate(instances):
|
124 |
+
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
|
125 |
+
input_mask = [1] * len(input_ids)
|
126 |
+
segment_ids = list(instance.segment_ids)
|
127 |
+
assert len(input_ids) <= max_seq_length
|
128 |
+
|
129 |
+
while len(input_ids) < max_seq_length:
|
130 |
+
input_ids.append(0)
|
131 |
+
input_mask.append(0)
|
132 |
+
segment_ids.append(0)
|
133 |
+
|
134 |
+
assert len(input_ids) == max_seq_length
|
135 |
+
assert len(input_mask) == max_seq_length
|
136 |
+
assert len(segment_ids) == max_seq_length
|
137 |
+
|
138 |
+
masked_lm_positions = list(instance.masked_lm_positions)
|
139 |
+
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
|
140 |
+
masked_lm_weights = [1.0] * len(masked_lm_ids)
|
141 |
+
|
142 |
+
while len(masked_lm_positions) < max_predictions_per_seq:
|
143 |
+
masked_lm_positions.append(0)
|
144 |
+
masked_lm_ids.append(0)
|
145 |
+
masked_lm_weights.append(0.0)
|
146 |
+
|
147 |
+
next_sentence_label = 1 if instance.is_random_next else 0
|
148 |
+
|
149 |
+
features = collections.OrderedDict()
|
150 |
+
features["input_ids"] = create_int_feature(input_ids)
|
151 |
+
features["input_mask"] = create_int_feature(input_mask)
|
152 |
+
features["segment_ids"] = create_int_feature(segment_ids)
|
153 |
+
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
|
154 |
+
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
|
155 |
+
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
|
156 |
+
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
|
157 |
+
|
158 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
159 |
+
|
160 |
+
writers[writer_index].write(tf_example.SerializeToString())
|
161 |
+
writer_index = (writer_index + 1) % len(writers)
|
162 |
+
|
163 |
+
total_written += 1
|
164 |
+
|
165 |
+
if inst_index < 20:
|
166 |
+
tf.logging.info("*** Example ***")
|
167 |
+
tf.logging.info(
|
168 |
+
"tokens: %s"
|
169 |
+
% " ".join([tokenization.printable_text(x) for x in instance.tokens])
|
170 |
+
)
|
171 |
+
|
172 |
+
for feature_name in features.keys():
|
173 |
+
feature = features[feature_name]
|
174 |
+
values = []
|
175 |
+
if feature.int64_list.value:
|
176 |
+
values = feature.int64_list.value
|
177 |
+
elif feature.float_list.value:
|
178 |
+
values = feature.float_list.value
|
179 |
+
tf.logging.info(
|
180 |
+
"%s: %s" % (feature_name, " ".join([str(x) for x in values]))
|
181 |
+
)
|
182 |
+
|
183 |
+
for writer in writers:
|
184 |
+
writer.close()
|
185 |
+
|
186 |
+
tf.logging.info("Wrote %d total instances", total_written)
|
187 |
+
|
188 |
+
|
189 |
+
def create_int_feature(values):
|
190 |
+
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
191 |
+
return feature
|
192 |
+
|
193 |
+
|
194 |
+
def create_float_feature(values):
|
195 |
+
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
|
196 |
+
return feature
|
197 |
+
|
198 |
+
|
199 |
+
def create_training_instances(
|
200 |
+
input_files,
|
201 |
+
tokenizer,
|
202 |
+
max_seq_length,
|
203 |
+
dupe_factor,
|
204 |
+
short_seq_prob,
|
205 |
+
masked_lm_prob,
|
206 |
+
max_predictions_per_seq,
|
207 |
+
rng,
|
208 |
+
):
|
209 |
+
"""Create `TrainingInstance`s from raw text."""
|
210 |
+
all_documents = [[]]
|
211 |
+
|
212 |
+
# Input file format:
|
213 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
214 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
215 |
+
# sentence boundaries for the "next sentence prediction" task).
|
216 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
217 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
218 |
+
for input_file in input_files:
|
219 |
+
with tf.gfile.GFile(input_file, "r") as reader:
|
220 |
+
while True:
|
221 |
+
line = tokenization.convert_to_unicode(reader.readline())
|
222 |
+
if not line:
|
223 |
+
break
|
224 |
+
line = line.strip()
|
225 |
+
|
226 |
+
# Empty lines are used as document delimiters
|
227 |
+
if not line:
|
228 |
+
all_documents.append([])
|
229 |
+
tokens = tokenizer.tokenize(line)
|
230 |
+
if tokens:
|
231 |
+
all_documents[-1].append(tokens)
|
232 |
+
|
233 |
+
# Remove empty documents
|
234 |
+
all_documents = [x for x in all_documents if x]
|
235 |
+
rng.shuffle(all_documents)
|
236 |
+
|
237 |
+
vocab_words = list(tokenizer.vocab.keys())
|
238 |
+
instances = []
|
239 |
+
for _ in range(dupe_factor):
|
240 |
+
for document_index in range(len(all_documents)):
|
241 |
+
instances.extend(
|
242 |
+
create_instances_from_document(
|
243 |
+
all_documents,
|
244 |
+
document_index,
|
245 |
+
max_seq_length,
|
246 |
+
short_seq_prob,
|
247 |
+
masked_lm_prob,
|
248 |
+
max_predictions_per_seq,
|
249 |
+
vocab_words,
|
250 |
+
rng,
|
251 |
+
)
|
252 |
+
)
|
253 |
+
|
254 |
+
rng.shuffle(instances)
|
255 |
+
return instances
|
256 |
+
|
257 |
+
|
258 |
+
def create_instances_from_document(
|
259 |
+
all_documents,
|
260 |
+
document_index,
|
261 |
+
max_seq_length,
|
262 |
+
short_seq_prob,
|
263 |
+
masked_lm_prob,
|
264 |
+
max_predictions_per_seq,
|
265 |
+
vocab_words,
|
266 |
+
rng,
|
267 |
+
):
|
268 |
+
"""Creates `TrainingInstance`s for a single document."""
|
269 |
+
document = all_documents[document_index]
|
270 |
+
|
271 |
+
# Account for [CLS], [SEP], [SEP]
|
272 |
+
max_num_tokens = max_seq_length - 3
|
273 |
+
|
274 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
275 |
+
# to `max_seq_length` anyways, so short sequences are generally wasted
|
276 |
+
# computation. However, we *sometimes*
|
277 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
278 |
+
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
279 |
+
# The `target_seq_length` is just a rough target however, whereas
|
280 |
+
# `max_seq_length` is a hard limit.
|
281 |
+
target_seq_length = max_num_tokens
|
282 |
+
if rng.random() < short_seq_prob:
|
283 |
+
target_seq_length = rng.randint(2, max_num_tokens)
|
284 |
+
|
285 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
286 |
+
# sequence and choose an arbitrary split point because this would make the
|
287 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
288 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
289 |
+
# input.
|
290 |
+
instances = []
|
291 |
+
current_chunk = []
|
292 |
+
current_length = 0
|
293 |
+
i = 0
|
294 |
+
while i < len(document):
|
295 |
+
segment = document[i]
|
296 |
+
current_chunk.append(segment)
|
297 |
+
current_length += len(segment)
|
298 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
299 |
+
if current_chunk:
|
300 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
301 |
+
# (first) sentence.
|
302 |
+
a_end = 1
|
303 |
+
if len(current_chunk) >= 2:
|
304 |
+
a_end = rng.randint(1, len(current_chunk) - 1)
|
305 |
+
|
306 |
+
tokens_a = []
|
307 |
+
for j in range(a_end):
|
308 |
+
tokens_a.extend(current_chunk[j])
|
309 |
+
|
310 |
+
tokens_b = []
|
311 |
+
# Random next
|
312 |
+
is_random_next = False
|
313 |
+
if len(current_chunk) == 1 or rng.random() < 0.5:
|
314 |
+
is_random_next = True
|
315 |
+
target_b_length = target_seq_length - len(tokens_a)
|
316 |
+
|
317 |
+
# This should rarely go for more than one iteration for large
|
318 |
+
# corpora. However, just to be careful, we try to make sure that
|
319 |
+
# the random document is not the same as the document
|
320 |
+
# we're processing.
|
321 |
+
for _ in range(10):
|
322 |
+
random_document_index = rng.randint(0, len(all_documents) - 1)
|
323 |
+
if random_document_index != document_index:
|
324 |
+
break
|
325 |
+
|
326 |
+
random_document = all_documents[random_document_index]
|
327 |
+
random_start = rng.randint(0, len(random_document) - 1)
|
328 |
+
for j in range(random_start, len(random_document)):
|
329 |
+
tokens_b.extend(random_document[j])
|
330 |
+
if len(tokens_b) >= target_b_length:
|
331 |
+
break
|
332 |
+
# We didn't actually use these segments so we "put them back" so
|
333 |
+
# they don't go to waste.
|
334 |
+
num_unused_segments = len(current_chunk) - a_end
|
335 |
+
i -= num_unused_segments
|
336 |
+
# Actual next
|
337 |
+
else:
|
338 |
+
is_random_next = False
|
339 |
+
for j in range(a_end, len(current_chunk)):
|
340 |
+
tokens_b.extend(current_chunk[j])
|
341 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
|
342 |
+
|
343 |
+
assert len(tokens_a) >= 1
|
344 |
+
assert len(tokens_b) >= 1
|
345 |
+
|
346 |
+
tokens = []
|
347 |
+
segment_ids = []
|
348 |
+
tokens.append("[CLS]")
|
349 |
+
segment_ids.append(0)
|
350 |
+
for token in tokens_a:
|
351 |
+
tokens.append(token)
|
352 |
+
segment_ids.append(0)
|
353 |
+
|
354 |
+
tokens.append("[SEP]")
|
355 |
+
segment_ids.append(0)
|
356 |
+
|
357 |
+
for token in tokens_b:
|
358 |
+
tokens.append(token)
|
359 |
+
segment_ids.append(1)
|
360 |
+
tokens.append("[SEP]")
|
361 |
+
segment_ids.append(1)
|
362 |
+
|
363 |
+
(
|
364 |
+
tokens,
|
365 |
+
masked_lm_positions,
|
366 |
+
masked_lm_labels,
|
367 |
+
) = create_masked_lm_predictions(
|
368 |
+
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
|
369 |
+
)
|
370 |
+
instance = TrainingInstance(
|
371 |
+
tokens=tokens,
|
372 |
+
segment_ids=segment_ids,
|
373 |
+
is_random_next=is_random_next,
|
374 |
+
masked_lm_positions=masked_lm_positions,
|
375 |
+
masked_lm_labels=masked_lm_labels,
|
376 |
+
)
|
377 |
+
instances.append(instance)
|
378 |
+
current_chunk = []
|
379 |
+
current_length = 0
|
380 |
+
i += 1
|
381 |
+
|
382 |
+
return instances
|
383 |
+
|
384 |
+
|
385 |
+
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
386 |
+
|
387 |
+
|
388 |
+
def create_masked_lm_predictions(
|
389 |
+
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng
|
390 |
+
):
|
391 |
+
"""Creates the predictions for the masked LM objective."""
|
392 |
+
|
393 |
+
cand_indexes = []
|
394 |
+
for (i, token) in enumerate(tokens):
|
395 |
+
if token == "[CLS]" or token == "[SEP]":
|
396 |
+
continue
|
397 |
+
# Whole Word Masking means that if we mask all of the wordpieces
|
398 |
+
# corresponding to an original word. When a word has been split into
|
399 |
+
# WordPieces, the first token does not have any marker and any subsequence
|
400 |
+
# tokens are prefixed with ##. So whenever we see the ## token, we
|
401 |
+
# append it to the previous set of word indexes.
|
402 |
+
#
|
403 |
+
# Note that Whole Word Masking does *not* change the training code
|
404 |
+
# at all -- we still predict each WordPiece independently, softmaxed
|
405 |
+
# over the entire vocabulary.
|
406 |
+
if (
|
407 |
+
FLAGS.do_whole_word_mask
|
408 |
+
and len(cand_indexes) >= 1
|
409 |
+
and token.startswith("##")
|
410 |
+
):
|
411 |
+
cand_indexes[-1].append(i)
|
412 |
+
else:
|
413 |
+
cand_indexes.append([i])
|
414 |
+
|
415 |
+
rng.shuffle(cand_indexes)
|
416 |
+
|
417 |
+
output_tokens = list(tokens)
|
418 |
+
|
419 |
+
num_to_predict = min(
|
420 |
+
max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))
|
421 |
+
)
|
422 |
+
|
423 |
+
masked_lms = []
|
424 |
+
covered_indexes = set()
|
425 |
+
for index_set in cand_indexes:
|
426 |
+
if len(masked_lms) >= num_to_predict:
|
427 |
+
break
|
428 |
+
# If adding a whole-word mask would exceed the maximum number of
|
429 |
+
# predictions, then just skip this candidate.
|
430 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
431 |
+
continue
|
432 |
+
is_any_index_covered = False
|
433 |
+
for index in index_set:
|
434 |
+
if index in covered_indexes:
|
435 |
+
is_any_index_covered = True
|
436 |
+
break
|
437 |
+
if is_any_index_covered:
|
438 |
+
continue
|
439 |
+
for index in index_set:
|
440 |
+
covered_indexes.add(index)
|
441 |
+
|
442 |
+
masked_token = None
|
443 |
+
# 80% of the time, replace with [MASK]
|
444 |
+
if rng.random() < 0.8:
|
445 |
+
masked_token = "[MASK]"
|
446 |
+
else:
|
447 |
+
# 10% of the time, keep original
|
448 |
+
if rng.random() < 0.5:
|
449 |
+
masked_token = tokens[index]
|
450 |
+
# 10% of the time, replace with random word
|
451 |
+
else:
|
452 |
+
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
|
453 |
+
|
454 |
+
output_tokens[index] = masked_token
|
455 |
+
|
456 |
+
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
|
457 |
+
assert len(masked_lms) <= num_to_predict
|
458 |
+
masked_lms = sorted(masked_lms, key=lambda x: x.index)
|
459 |
+
|
460 |
+
masked_lm_positions = []
|
461 |
+
masked_lm_labels = []
|
462 |
+
for p in masked_lms:
|
463 |
+
masked_lm_positions.append(p.index)
|
464 |
+
masked_lm_labels.append(p.label)
|
465 |
+
|
466 |
+
return (output_tokens, masked_lm_positions, masked_lm_labels)
|
467 |
+
|
468 |
+
|
469 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
|
470 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
471 |
+
while True:
|
472 |
+
total_length = len(tokens_a) + len(tokens_b)
|
473 |
+
if total_length <= max_num_tokens:
|
474 |
+
break
|
475 |
+
|
476 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
477 |
+
assert len(trunc_tokens) >= 1
|
478 |
+
|
479 |
+
# We want to sometimes truncate from the front and sometimes from the
|
480 |
+
# back to add more randomness and avoid biases.
|
481 |
+
if rng.random() < 0.5:
|
482 |
+
del trunc_tokens[0]
|
483 |
+
else:
|
484 |
+
trunc_tokens.pop()
|
485 |
+
|
486 |
+
|
487 |
+
def main(_):
|
488 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
489 |
+
logger = tf.get_logger()
|
490 |
+
logger.propagate = False
|
491 |
+
|
492 |
+
tokenizer = tokenization.FullTokenizer(
|
493 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
494 |
+
)
|
495 |
+
|
496 |
+
input_files = []
|
497 |
+
for input_pattern in FLAGS.input_file.split(","):
|
498 |
+
input_files.extend(tf.gfile.Glob(input_pattern))
|
499 |
+
|
500 |
+
tf.logging.info("*** Reading from input files ***")
|
501 |
+
for input_file in input_files:
|
502 |
+
tf.logging.info(" %s", input_file)
|
503 |
+
|
504 |
+
rng = random.Random(FLAGS.random_seed)
|
505 |
+
instances = create_training_instances(
|
506 |
+
input_files,
|
507 |
+
tokenizer,
|
508 |
+
FLAGS.max_seq_length,
|
509 |
+
FLAGS.dupe_factor,
|
510 |
+
FLAGS.short_seq_prob,
|
511 |
+
FLAGS.masked_lm_prob,
|
512 |
+
FLAGS.max_predictions_per_seq,
|
513 |
+
rng,
|
514 |
+
)
|
515 |
+
|
516 |
+
output_files = FLAGS.output_file.split(",")
|
517 |
+
tf.logging.info("*** Writing to output files ***")
|
518 |
+
for output_file in output_files:
|
519 |
+
tf.logging.info(" %s", output_file)
|
520 |
+
|
521 |
+
write_instance_to_example_files(
|
522 |
+
instances,
|
523 |
+
tokenizer,
|
524 |
+
FLAGS.max_seq_length,
|
525 |
+
FLAGS.max_predictions_per_seq,
|
526 |
+
output_files,
|
527 |
+
)
|
528 |
+
|
529 |
+
|
530 |
+
if __name__ == "__main__":
|
531 |
+
flags.mark_flag_as_required("input_file")
|
532 |
+
flags.mark_flag_as_required("output_file")
|
533 |
+
flags.mark_flag_as_required("vocab_file")
|
534 |
+
tf.app.run()
|
arabert/arabert/extract_features.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Extract pre-computed feature vectors from BERT."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import codecs
|
22 |
+
import collections
|
23 |
+
import json
|
24 |
+
import re
|
25 |
+
|
26 |
+
import modeling
|
27 |
+
import tokenization
|
28 |
+
import tensorflow as tf
|
29 |
+
|
30 |
+
flags = tf.flags
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
|
34 |
+
flags.DEFINE_string("input_file", None, "")
|
35 |
+
|
36 |
+
flags.DEFINE_string("output_file", None, "")
|
37 |
+
|
38 |
+
flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
|
39 |
+
|
40 |
+
flags.DEFINE_string(
|
41 |
+
"bert_config_file",
|
42 |
+
None,
|
43 |
+
"The config json file corresponding to the pre-trained BERT model. "
|
44 |
+
"This specifies the model architecture.",
|
45 |
+
)
|
46 |
+
|
47 |
+
flags.DEFINE_integer(
|
48 |
+
"max_seq_length",
|
49 |
+
128,
|
50 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
51 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
52 |
+
"than this will be padded.",
|
53 |
+
)
|
54 |
+
|
55 |
+
flags.DEFINE_string(
|
56 |
+
"init_checkpoint",
|
57 |
+
None,
|
58 |
+
"Initial checkpoint (usually from a pre-trained BERT model).",
|
59 |
+
)
|
60 |
+
|
61 |
+
flags.DEFINE_string(
|
62 |
+
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
63 |
+
)
|
64 |
+
|
65 |
+
flags.DEFINE_bool(
|
66 |
+
"do_lower_case",
|
67 |
+
True,
|
68 |
+
"Whether to lower case the input text. Should be True for uncased "
|
69 |
+
"models and False for cased models.",
|
70 |
+
)
|
71 |
+
|
72 |
+
flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
|
73 |
+
|
74 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
75 |
+
|
76 |
+
flags.DEFINE_string("master", None, "If using a TPU, the address of the master.")
|
77 |
+
|
78 |
+
flags.DEFINE_integer(
|
79 |
+
"num_tpu_cores",
|
80 |
+
8,
|
81 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
82 |
+
)
|
83 |
+
|
84 |
+
flags.DEFINE_bool(
|
85 |
+
"use_one_hot_embeddings",
|
86 |
+
False,
|
87 |
+
"If True, tf.one_hot will be used for embedding lookups, otherwise "
|
88 |
+
"tf.nn.embedding_lookup will be used. On TPUs, this should be True "
|
89 |
+
"since it is much faster.",
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
class InputExample(object):
|
94 |
+
def __init__(self, unique_id, text_a, text_b):
|
95 |
+
self.unique_id = unique_id
|
96 |
+
self.text_a = text_a
|
97 |
+
self.text_b = text_b
|
98 |
+
|
99 |
+
|
100 |
+
class InputFeatures(object):
|
101 |
+
"""A single set of features of data."""
|
102 |
+
|
103 |
+
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
|
104 |
+
self.unique_id = unique_id
|
105 |
+
self.tokens = tokens
|
106 |
+
self.input_ids = input_ids
|
107 |
+
self.input_mask = input_mask
|
108 |
+
self.input_type_ids = input_type_ids
|
109 |
+
|
110 |
+
|
111 |
+
def input_fn_builder(features, seq_length):
|
112 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
113 |
+
|
114 |
+
all_unique_ids = []
|
115 |
+
all_input_ids = []
|
116 |
+
all_input_mask = []
|
117 |
+
all_input_type_ids = []
|
118 |
+
|
119 |
+
for feature in features:
|
120 |
+
all_unique_ids.append(feature.unique_id)
|
121 |
+
all_input_ids.append(feature.input_ids)
|
122 |
+
all_input_mask.append(feature.input_mask)
|
123 |
+
all_input_type_ids.append(feature.input_type_ids)
|
124 |
+
|
125 |
+
def input_fn(params):
|
126 |
+
"""The actual input function."""
|
127 |
+
batch_size = params["batch_size"]
|
128 |
+
|
129 |
+
num_examples = len(features)
|
130 |
+
|
131 |
+
# This is for demo purposes and does NOT scale to large data sets. We do
|
132 |
+
# not use Dataset.from_generator() because that uses tf.py_func which is
|
133 |
+
# not TPU compatible. The right way to load data is with TFRecordReader.
|
134 |
+
d = tf.data.Dataset.from_tensor_slices(
|
135 |
+
{
|
136 |
+
"unique_ids": tf.constant(
|
137 |
+
all_unique_ids, shape=[num_examples], dtype=tf.int32
|
138 |
+
),
|
139 |
+
"input_ids": tf.constant(
|
140 |
+
all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32
|
141 |
+
),
|
142 |
+
"input_mask": tf.constant(
|
143 |
+
all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32
|
144 |
+
),
|
145 |
+
"input_type_ids": tf.constant(
|
146 |
+
all_input_type_ids, shape=[num_examples, seq_length], dtype=tf.int32
|
147 |
+
),
|
148 |
+
}
|
149 |
+
)
|
150 |
+
|
151 |
+
d = d.batch(batch_size=batch_size, drop_remainder=False)
|
152 |
+
return d
|
153 |
+
|
154 |
+
return input_fn
|
155 |
+
|
156 |
+
|
157 |
+
def model_fn_builder(
|
158 |
+
bert_config, init_checkpoint, layer_indexes, use_tpu, use_one_hot_embeddings
|
159 |
+
):
|
160 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
161 |
+
|
162 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
163 |
+
"""The `model_fn` for TPUEstimator."""
|
164 |
+
|
165 |
+
unique_ids = features["unique_ids"]
|
166 |
+
input_ids = features["input_ids"]
|
167 |
+
input_mask = features["input_mask"]
|
168 |
+
input_type_ids = features["input_type_ids"]
|
169 |
+
|
170 |
+
model = modeling.BertModel(
|
171 |
+
config=bert_config,
|
172 |
+
is_training=False,
|
173 |
+
input_ids=input_ids,
|
174 |
+
input_mask=input_mask,
|
175 |
+
token_type_ids=input_type_ids,
|
176 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
177 |
+
)
|
178 |
+
|
179 |
+
if mode != tf.estimator.ModeKeys.PREDICT:
|
180 |
+
raise ValueError("Only PREDICT modes are supported: %s" % (mode))
|
181 |
+
|
182 |
+
tvars = tf.trainable_variables()
|
183 |
+
scaffold_fn = None
|
184 |
+
(
|
185 |
+
assignment_map,
|
186 |
+
initialized_variable_names,
|
187 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
188 |
+
if use_tpu:
|
189 |
+
|
190 |
+
def tpu_scaffold():
|
191 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
192 |
+
return tf.train.Scaffold()
|
193 |
+
|
194 |
+
scaffold_fn = tpu_scaffold
|
195 |
+
else:
|
196 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
197 |
+
|
198 |
+
tf.logging.info("**** Trainable Variables ****")
|
199 |
+
for var in tvars:
|
200 |
+
init_string = ""
|
201 |
+
if var.name in initialized_variable_names:
|
202 |
+
init_string = ", *INIT_FROM_CKPT*"
|
203 |
+
tf.logging.info(
|
204 |
+
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
205 |
+
)
|
206 |
+
|
207 |
+
all_layers = model.get_all_encoder_layers()
|
208 |
+
|
209 |
+
predictions = {
|
210 |
+
"unique_id": unique_ids,
|
211 |
+
}
|
212 |
+
|
213 |
+
for (i, layer_index) in enumerate(layer_indexes):
|
214 |
+
predictions["layer_output_%d" % i] = all_layers[layer_index]
|
215 |
+
|
216 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
217 |
+
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
|
218 |
+
)
|
219 |
+
return output_spec
|
220 |
+
|
221 |
+
return model_fn
|
222 |
+
|
223 |
+
|
224 |
+
def convert_examples_to_features(examples, seq_length, tokenizer):
|
225 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
226 |
+
|
227 |
+
features = []
|
228 |
+
for (ex_index, example) in enumerate(examples):
|
229 |
+
tokens_a = tokenizer.tokenize(example.text_a)
|
230 |
+
|
231 |
+
tokens_b = None
|
232 |
+
if example.text_b:
|
233 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
234 |
+
|
235 |
+
if tokens_b:
|
236 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
237 |
+
# length is less than the specified length.
|
238 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
239 |
+
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
|
240 |
+
else:
|
241 |
+
# Account for [CLS] and [SEP] with "- 2"
|
242 |
+
if len(tokens_a) > seq_length - 2:
|
243 |
+
tokens_a = tokens_a[0 : (seq_length - 2)]
|
244 |
+
|
245 |
+
# The convention in BERT is:
|
246 |
+
# (a) For sequence pairs:
|
247 |
+
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
248 |
+
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
249 |
+
# (b) For single sequences:
|
250 |
+
# tokens: [CLS] the dog is hairy . [SEP]
|
251 |
+
# type_ids: 0 0 0 0 0 0 0
|
252 |
+
#
|
253 |
+
# Where "type_ids" are used to indicate whether this is the first
|
254 |
+
# sequence or the second sequence. The embedding vectors for `type=0` and
|
255 |
+
# `type=1` were learned during pre-training and are added to the wordpiece
|
256 |
+
# embedding vector (and position vector). This is not *strictly* necessary
|
257 |
+
# since the [SEP] token unambiguously separates the sequences, but it makes
|
258 |
+
# it easier for the model to learn the concept of sequences.
|
259 |
+
#
|
260 |
+
# For classification tasks, the first vector (corresponding to [CLS]) is
|
261 |
+
# used as as the "sentence vector". Note that this only makes sense because
|
262 |
+
# the entire model is fine-tuned.
|
263 |
+
tokens = []
|
264 |
+
input_type_ids = []
|
265 |
+
tokens.append("[CLS]")
|
266 |
+
input_type_ids.append(0)
|
267 |
+
for token in tokens_a:
|
268 |
+
tokens.append(token)
|
269 |
+
input_type_ids.append(0)
|
270 |
+
tokens.append("[SEP]")
|
271 |
+
input_type_ids.append(0)
|
272 |
+
|
273 |
+
if tokens_b:
|
274 |
+
for token in tokens_b:
|
275 |
+
tokens.append(token)
|
276 |
+
input_type_ids.append(1)
|
277 |
+
tokens.append("[SEP]")
|
278 |
+
input_type_ids.append(1)
|
279 |
+
|
280 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
281 |
+
|
282 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
283 |
+
# tokens are attended to.
|
284 |
+
input_mask = [1] * len(input_ids)
|
285 |
+
|
286 |
+
# Zero-pad up to the sequence length.
|
287 |
+
while len(input_ids) < seq_length:
|
288 |
+
input_ids.append(0)
|
289 |
+
input_mask.append(0)
|
290 |
+
input_type_ids.append(0)
|
291 |
+
|
292 |
+
assert len(input_ids) == seq_length
|
293 |
+
assert len(input_mask) == seq_length
|
294 |
+
assert len(input_type_ids) == seq_length
|
295 |
+
|
296 |
+
if ex_index < 5:
|
297 |
+
tf.logging.info("*** Example ***")
|
298 |
+
tf.logging.info("unique_id: %s" % (example.unique_id))
|
299 |
+
tf.logging.info(
|
300 |
+
"tokens: %s"
|
301 |
+
% " ".join([tokenization.printable_text(x) for x in tokens])
|
302 |
+
)
|
303 |
+
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
304 |
+
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
305 |
+
tf.logging.info(
|
306 |
+
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])
|
307 |
+
)
|
308 |
+
|
309 |
+
features.append(
|
310 |
+
InputFeatures(
|
311 |
+
unique_id=example.unique_id,
|
312 |
+
tokens=tokens,
|
313 |
+
input_ids=input_ids,
|
314 |
+
input_mask=input_mask,
|
315 |
+
input_type_ids=input_type_ids,
|
316 |
+
)
|
317 |
+
)
|
318 |
+
return features
|
319 |
+
|
320 |
+
|
321 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
322 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
323 |
+
|
324 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
325 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
326 |
+
# of tokens from each, since if one sequence is very short then each token
|
327 |
+
# that's truncated likely contains more information than a longer sequence.
|
328 |
+
while True:
|
329 |
+
total_length = len(tokens_a) + len(tokens_b)
|
330 |
+
if total_length <= max_length:
|
331 |
+
break
|
332 |
+
if len(tokens_a) > len(tokens_b):
|
333 |
+
tokens_a.pop()
|
334 |
+
else:
|
335 |
+
tokens_b.pop()
|
336 |
+
|
337 |
+
|
338 |
+
def read_examples(input_file):
|
339 |
+
"""Read a list of `InputExample`s from an input file."""
|
340 |
+
examples = []
|
341 |
+
unique_id = 0
|
342 |
+
with tf.gfile.GFile(input_file, "r") as reader:
|
343 |
+
while True:
|
344 |
+
line = tokenization.convert_to_unicode(reader.readline())
|
345 |
+
if not line:
|
346 |
+
break
|
347 |
+
line = line.strip()
|
348 |
+
text_a = None
|
349 |
+
text_b = None
|
350 |
+
m = re.match(r"^(.*) \|\|\| (.*)$", line)
|
351 |
+
if m is None:
|
352 |
+
text_a = line
|
353 |
+
else:
|
354 |
+
text_a = m.group(1)
|
355 |
+
text_b = m.group(2)
|
356 |
+
examples.append(
|
357 |
+
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)
|
358 |
+
)
|
359 |
+
unique_id += 1
|
360 |
+
return examples
|
361 |
+
|
362 |
+
|
363 |
+
def main(_):
|
364 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
365 |
+
logger = tf.get_logger()
|
366 |
+
logger.propagate = False
|
367 |
+
|
368 |
+
layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
|
369 |
+
|
370 |
+
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
371 |
+
|
372 |
+
tokenizer = tokenization.FullTokenizer(
|
373 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
374 |
+
)
|
375 |
+
|
376 |
+
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
377 |
+
run_config = tf.contrib.tpu.RunConfig(
|
378 |
+
master=FLAGS.master,
|
379 |
+
tpu_config=tf.contrib.tpu.TPUConfig(
|
380 |
+
num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host
|
381 |
+
),
|
382 |
+
)
|
383 |
+
|
384 |
+
examples = read_examples(FLAGS.input_file)
|
385 |
+
|
386 |
+
features = convert_examples_to_features(
|
387 |
+
examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer
|
388 |
+
)
|
389 |
+
|
390 |
+
unique_id_to_feature = {}
|
391 |
+
for feature in features:
|
392 |
+
unique_id_to_feature[feature.unique_id] = feature
|
393 |
+
|
394 |
+
model_fn = model_fn_builder(
|
395 |
+
bert_config=bert_config,
|
396 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
397 |
+
layer_indexes=layer_indexes,
|
398 |
+
use_tpu=FLAGS.use_tpu,
|
399 |
+
use_one_hot_embeddings=FLAGS.use_one_hot_embeddings,
|
400 |
+
)
|
401 |
+
|
402 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
403 |
+
# or GPU.
|
404 |
+
estimator = tf.contrib.tpu.TPUEstimator(
|
405 |
+
use_tpu=FLAGS.use_tpu,
|
406 |
+
model_fn=model_fn,
|
407 |
+
config=run_config,
|
408 |
+
predict_batch_size=FLAGS.batch_size,
|
409 |
+
)
|
410 |
+
|
411 |
+
input_fn = input_fn_builder(features=features, seq_length=FLAGS.max_seq_length)
|
412 |
+
|
413 |
+
with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, "w")) as writer:
|
414 |
+
for result in estimator.predict(input_fn, yield_single_examples=True):
|
415 |
+
unique_id = int(result["unique_id"])
|
416 |
+
feature = unique_id_to_feature[unique_id]
|
417 |
+
output_json = collections.OrderedDict()
|
418 |
+
output_json["linex_index"] = unique_id
|
419 |
+
all_features = []
|
420 |
+
for (i, token) in enumerate(feature.tokens):
|
421 |
+
all_layers = []
|
422 |
+
for (j, layer_index) in enumerate(layer_indexes):
|
423 |
+
layer_output = result["layer_output_%d" % j]
|
424 |
+
layers = collections.OrderedDict()
|
425 |
+
layers["index"] = layer_index
|
426 |
+
layers["values"] = [
|
427 |
+
round(float(x), 6) for x in layer_output[i : (i + 1)].flat
|
428 |
+
]
|
429 |
+
all_layers.append(layers)
|
430 |
+
features = collections.OrderedDict()
|
431 |
+
features["token"] = token
|
432 |
+
features["layers"] = all_layers
|
433 |
+
all_features.append(features)
|
434 |
+
output_json["features"] = all_features
|
435 |
+
writer.write(json.dumps(output_json) + "\n")
|
436 |
+
|
437 |
+
|
438 |
+
if __name__ == "__main__":
|
439 |
+
flags.mark_flag_as_required("input_file")
|
440 |
+
flags.mark_flag_as_required("vocab_file")
|
441 |
+
flags.mark_flag_as_required("bert_config_file")
|
442 |
+
flags.mark_flag_as_required("init_checkpoint")
|
443 |
+
flags.mark_flag_as_required("output_file")
|
444 |
+
tf.app.run()
|
arabert/arabert/lamb_optimizer.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python2, python3
|
17 |
+
"""Functions and classes related to optimization (weight updates)."""
|
18 |
+
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
from __future__ import print_function
|
22 |
+
|
23 |
+
import re
|
24 |
+
import six
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
# pylint: disable=g-direct-tensorflow-import
|
28 |
+
from tensorflow.python.ops import array_ops
|
29 |
+
from tensorflow.python.ops import linalg_ops
|
30 |
+
from tensorflow.python.ops import math_ops
|
31 |
+
|
32 |
+
# pylint: enable=g-direct-tensorflow-import
|
33 |
+
|
34 |
+
|
35 |
+
class LAMBOptimizer(tf.train.Optimizer):
|
36 |
+
"""LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
|
37 |
+
|
38 |
+
# A new optimizer that includes correct L2 weight decay, adaptive
|
39 |
+
# element-wise updating, and layer-wise justification. The LAMB optimizer
|
40 |
+
# was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
|
41 |
+
# James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
|
42 |
+
# Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
learning_rate,
|
47 |
+
weight_decay_rate=0.0,
|
48 |
+
beta_1=0.9,
|
49 |
+
beta_2=0.999,
|
50 |
+
epsilon=1e-6,
|
51 |
+
exclude_from_weight_decay=None,
|
52 |
+
exclude_from_layer_adaptation=None,
|
53 |
+
name="LAMBOptimizer",
|
54 |
+
):
|
55 |
+
"""Constructs a LAMBOptimizer."""
|
56 |
+
super(LAMBOptimizer, self).__init__(False, name)
|
57 |
+
|
58 |
+
self.learning_rate = learning_rate
|
59 |
+
self.weight_decay_rate = weight_decay_rate
|
60 |
+
self.beta_1 = beta_1
|
61 |
+
self.beta_2 = beta_2
|
62 |
+
self.epsilon = epsilon
|
63 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
64 |
+
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
|
65 |
+
# arg is None.
|
66 |
+
# TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
|
67 |
+
if exclude_from_layer_adaptation:
|
68 |
+
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
|
69 |
+
else:
|
70 |
+
self.exclude_from_layer_adaptation = exclude_from_weight_decay
|
71 |
+
|
72 |
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
73 |
+
"""See base class."""
|
74 |
+
assignments = []
|
75 |
+
for (grad, param) in grads_and_vars:
|
76 |
+
if grad is None or param is None:
|
77 |
+
continue
|
78 |
+
|
79 |
+
param_name = self._get_variable_name(param.name)
|
80 |
+
|
81 |
+
m = tf.get_variable(
|
82 |
+
name=six.ensure_str(param_name) + "/adam_m",
|
83 |
+
shape=param.shape.as_list(),
|
84 |
+
dtype=tf.float32,
|
85 |
+
trainable=False,
|
86 |
+
initializer=tf.zeros_initializer(),
|
87 |
+
)
|
88 |
+
v = tf.get_variable(
|
89 |
+
name=six.ensure_str(param_name) + "/adam_v",
|
90 |
+
shape=param.shape.as_list(),
|
91 |
+
dtype=tf.float32,
|
92 |
+
trainable=False,
|
93 |
+
initializer=tf.zeros_initializer(),
|
94 |
+
)
|
95 |
+
|
96 |
+
# Standard Adam update.
|
97 |
+
next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)
|
98 |
+
next_v = tf.multiply(self.beta_2, v) + tf.multiply(
|
99 |
+
1.0 - self.beta_2, tf.square(grad)
|
100 |
+
)
|
101 |
+
|
102 |
+
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
103 |
+
|
104 |
+
# Just adding the square of the weights to the loss function is *not*
|
105 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
106 |
+
# since that will interact with the m and v parameters in strange ways.
|
107 |
+
#
|
108 |
+
# Instead we want ot decay the weights in a manner that doesn't interact
|
109 |
+
# with the m/v parameters. This is equivalent to adding the square
|
110 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
111 |
+
if self._do_use_weight_decay(param_name):
|
112 |
+
update += self.weight_decay_rate * param
|
113 |
+
|
114 |
+
ratio = 1.0
|
115 |
+
if self._do_layer_adaptation(param_name):
|
116 |
+
w_norm = linalg_ops.norm(param, ord=2)
|
117 |
+
g_norm = linalg_ops.norm(update, ord=2)
|
118 |
+
ratio = array_ops.where(
|
119 |
+
math_ops.greater(w_norm, 0),
|
120 |
+
array_ops.where(
|
121 |
+
math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0
|
122 |
+
),
|
123 |
+
1.0,
|
124 |
+
)
|
125 |
+
|
126 |
+
update_with_lr = ratio * self.learning_rate * update
|
127 |
+
|
128 |
+
next_param = param - update_with_lr
|
129 |
+
|
130 |
+
assignments.extend(
|
131 |
+
[param.assign(next_param), m.assign(next_m), v.assign(next_v)]
|
132 |
+
)
|
133 |
+
return tf.group(*assignments, name=name)
|
134 |
+
|
135 |
+
def _do_use_weight_decay(self, param_name):
|
136 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
137 |
+
if not self.weight_decay_rate:
|
138 |
+
return False
|
139 |
+
if self.exclude_from_weight_decay:
|
140 |
+
for r in self.exclude_from_weight_decay:
|
141 |
+
if re.search(r, param_name) is not None:
|
142 |
+
return False
|
143 |
+
return True
|
144 |
+
|
145 |
+
def _do_layer_adaptation(self, param_name):
|
146 |
+
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
|
147 |
+
if self.exclude_from_layer_adaptation:
|
148 |
+
for r in self.exclude_from_layer_adaptation:
|
149 |
+
if re.search(r, param_name) is not None:
|
150 |
+
return False
|
151 |
+
return True
|
152 |
+
|
153 |
+
def _get_variable_name(self, param_name):
|
154 |
+
"""Get the variable name from the tensor name."""
|
155 |
+
m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
|
156 |
+
if m is not None:
|
157 |
+
param_name = m.group(1)
|
158 |
+
return param_name
|
arabert/arabert/modeling.py
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""The main BERT model and related functions."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import collections
|
22 |
+
import copy
|
23 |
+
import json
|
24 |
+
import math
|
25 |
+
import re
|
26 |
+
import numpy as np
|
27 |
+
import six
|
28 |
+
import tensorflow as tf
|
29 |
+
|
30 |
+
|
31 |
+
class BertConfig(object):
|
32 |
+
"""Configuration for `BertModel`."""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vocab_size,
|
37 |
+
hidden_size=768,
|
38 |
+
num_hidden_layers=12,
|
39 |
+
num_attention_heads=12,
|
40 |
+
intermediate_size=3072,
|
41 |
+
hidden_act="gelu",
|
42 |
+
hidden_dropout_prob=0.1,
|
43 |
+
attention_probs_dropout_prob=0.1,
|
44 |
+
max_position_embeddings=512,
|
45 |
+
type_vocab_size=16,
|
46 |
+
initializer_range=0.02,
|
47 |
+
):
|
48 |
+
"""Constructs BertConfig.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
52 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
53 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
54 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
55 |
+
the Transformer encoder.
|
56 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
57 |
+
layer in the Transformer encoder.
|
58 |
+
hidden_act: The non-linear activation function (function or string) in the
|
59 |
+
encoder and pooler.
|
60 |
+
hidden_dropout_prob: The dropout probability for all fully connected
|
61 |
+
layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
63 |
+
probabilities.
|
64 |
+
max_position_embeddings: The maximum sequence length that this model might
|
65 |
+
ever be used with. Typically set this to something large just in case
|
66 |
+
(e.g., 512 or 1024 or 2048).
|
67 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
68 |
+
`BertModel`.
|
69 |
+
initializer_range: The stdev of the truncated_normal_initializer for
|
70 |
+
initializing all weight matrices.
|
71 |
+
"""
|
72 |
+
self.vocab_size = vocab_size
|
73 |
+
self.hidden_size = hidden_size
|
74 |
+
self.num_hidden_layers = num_hidden_layers
|
75 |
+
self.num_attention_heads = num_attention_heads
|
76 |
+
self.hidden_act = hidden_act
|
77 |
+
self.intermediate_size = intermediate_size
|
78 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
79 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
80 |
+
self.max_position_embeddings = max_position_embeddings
|
81 |
+
self.type_vocab_size = type_vocab_size
|
82 |
+
self.initializer_range = initializer_range
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def from_dict(cls, json_object):
|
86 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
87 |
+
config = BertConfig(vocab_size=None)
|
88 |
+
for (key, value) in six.iteritems(json_object):
|
89 |
+
config.__dict__[key] = value
|
90 |
+
return config
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_json_file(cls, json_file):
|
94 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
95 |
+
with tf.gfile.GFile(json_file, "r") as reader:
|
96 |
+
text = reader.read()
|
97 |
+
return cls.from_dict(json.loads(text))
|
98 |
+
|
99 |
+
def to_dict(self):
|
100 |
+
"""Serializes this instance to a Python dictionary."""
|
101 |
+
output = copy.deepcopy(self.__dict__)
|
102 |
+
return output
|
103 |
+
|
104 |
+
def to_json_string(self):
|
105 |
+
"""Serializes this instance to a JSON string."""
|
106 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
107 |
+
|
108 |
+
|
109 |
+
class BertModel(object):
|
110 |
+
"""BERT model ("Bidirectional Encoder Representations from Transformers").
|
111 |
+
|
112 |
+
Example usage:
|
113 |
+
|
114 |
+
```python
|
115 |
+
# Already been converted into WordPiece token ids
|
116 |
+
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
|
117 |
+
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
|
118 |
+
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
|
119 |
+
|
120 |
+
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
121 |
+
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
122 |
+
|
123 |
+
model = modeling.BertModel(config=config, is_training=True,
|
124 |
+
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
|
125 |
+
|
126 |
+
label_embeddings = tf.get_variable(...)
|
127 |
+
pooled_output = model.get_pooled_output()
|
128 |
+
logits = tf.matmul(pooled_output, label_embeddings)
|
129 |
+
...
|
130 |
+
```
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
config,
|
136 |
+
is_training,
|
137 |
+
input_ids,
|
138 |
+
input_mask=None,
|
139 |
+
token_type_ids=None,
|
140 |
+
use_one_hot_embeddings=False,
|
141 |
+
scope=None,
|
142 |
+
):
|
143 |
+
"""Constructor for BertModel.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
config: `BertConfig` instance.
|
147 |
+
is_training: bool. true for training model, false for eval model. Controls
|
148 |
+
whether dropout will be applied.
|
149 |
+
input_ids: int32 Tensor of shape [batch_size, seq_length].
|
150 |
+
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
|
151 |
+
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
152 |
+
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
|
153 |
+
embeddings or tf.embedding_lookup() for the word embeddings.
|
154 |
+
scope: (optional) variable scope. Defaults to "bert".
|
155 |
+
|
156 |
+
Raises:
|
157 |
+
ValueError: The config is invalid or one of the input tensor shapes
|
158 |
+
is invalid.
|
159 |
+
"""
|
160 |
+
config = copy.deepcopy(config)
|
161 |
+
if not is_training:
|
162 |
+
config.hidden_dropout_prob = 0.0
|
163 |
+
config.attention_probs_dropout_prob = 0.0
|
164 |
+
|
165 |
+
input_shape = get_shape_list(input_ids, expected_rank=2)
|
166 |
+
batch_size = input_shape[0]
|
167 |
+
seq_length = input_shape[1]
|
168 |
+
|
169 |
+
if input_mask is None:
|
170 |
+
input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
|
171 |
+
|
172 |
+
if token_type_ids is None:
|
173 |
+
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
|
174 |
+
|
175 |
+
with tf.variable_scope(scope, default_name="bert"):
|
176 |
+
with tf.variable_scope("embeddings"):
|
177 |
+
# Perform embedding lookup on the word ids.
|
178 |
+
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
179 |
+
input_ids=input_ids,
|
180 |
+
vocab_size=config.vocab_size,
|
181 |
+
embedding_size=config.hidden_size,
|
182 |
+
initializer_range=config.initializer_range,
|
183 |
+
word_embedding_name="word_embeddings",
|
184 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Add positional embeddings and token type embeddings, then layer
|
188 |
+
# normalize and perform dropout.
|
189 |
+
self.embedding_output = embedding_postprocessor(
|
190 |
+
input_tensor=self.embedding_output,
|
191 |
+
use_token_type=True,
|
192 |
+
token_type_ids=token_type_ids,
|
193 |
+
token_type_vocab_size=config.type_vocab_size,
|
194 |
+
token_type_embedding_name="token_type_embeddings",
|
195 |
+
use_position_embeddings=True,
|
196 |
+
position_embedding_name="position_embeddings",
|
197 |
+
initializer_range=config.initializer_range,
|
198 |
+
max_position_embeddings=config.max_position_embeddings,
|
199 |
+
dropout_prob=config.hidden_dropout_prob,
|
200 |
+
)
|
201 |
+
|
202 |
+
with tf.variable_scope("encoder"):
|
203 |
+
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
|
204 |
+
# mask of shape [batch_size, seq_length, seq_length] which is used
|
205 |
+
# for the attention scores.
|
206 |
+
attention_mask = create_attention_mask_from_input_mask(
|
207 |
+
input_ids, input_mask
|
208 |
+
)
|
209 |
+
|
210 |
+
# Run the stacked transformer.
|
211 |
+
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
|
212 |
+
self.all_encoder_layers = transformer_model(
|
213 |
+
input_tensor=self.embedding_output,
|
214 |
+
attention_mask=attention_mask,
|
215 |
+
hidden_size=config.hidden_size,
|
216 |
+
num_hidden_layers=config.num_hidden_layers,
|
217 |
+
num_attention_heads=config.num_attention_heads,
|
218 |
+
intermediate_size=config.intermediate_size,
|
219 |
+
intermediate_act_fn=get_activation(config.hidden_act),
|
220 |
+
hidden_dropout_prob=config.hidden_dropout_prob,
|
221 |
+
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
|
222 |
+
initializer_range=config.initializer_range,
|
223 |
+
do_return_all_layers=True,
|
224 |
+
)
|
225 |
+
|
226 |
+
self.sequence_output = self.all_encoder_layers[-1]
|
227 |
+
# The "pooler" converts the encoded sequence tensor of shape
|
228 |
+
# [batch_size, seq_length, hidden_size] to a tensor of shape
|
229 |
+
# [batch_size, hidden_size]. This is necessary for segment-level
|
230 |
+
# (or segment-pair-level) classification tasks where we need a fixed
|
231 |
+
# dimensional representation of the segment.
|
232 |
+
with tf.variable_scope("pooler"):
|
233 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
234 |
+
# to the first token. We assume that this has been pre-trained
|
235 |
+
first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
|
236 |
+
self.pooled_output = tf.layers.dense(
|
237 |
+
first_token_tensor,
|
238 |
+
config.hidden_size,
|
239 |
+
activation=tf.tanh,
|
240 |
+
kernel_initializer=create_initializer(config.initializer_range),
|
241 |
+
)
|
242 |
+
|
243 |
+
def get_pooled_output(self):
|
244 |
+
return self.pooled_output
|
245 |
+
|
246 |
+
def get_sequence_output(self):
|
247 |
+
"""Gets final hidden layer of encoder.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
|
251 |
+
to the final hidden of the transformer encoder.
|
252 |
+
"""
|
253 |
+
return self.sequence_output
|
254 |
+
|
255 |
+
def get_all_encoder_layers(self):
|
256 |
+
return self.all_encoder_layers
|
257 |
+
|
258 |
+
def get_embedding_output(self):
|
259 |
+
"""Gets output of the embedding lookup (i.e., input to the transformer).
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
|
263 |
+
to the output of the embedding layer, after summing the word
|
264 |
+
embeddings with the positional embeddings and the token type embeddings,
|
265 |
+
then performing layer normalization. This is the input to the transformer.
|
266 |
+
"""
|
267 |
+
return self.embedding_output
|
268 |
+
|
269 |
+
def get_embedding_table(self):
|
270 |
+
return self.embedding_table
|
271 |
+
|
272 |
+
|
273 |
+
def gelu(x):
|
274 |
+
"""Gaussian Error Linear Unit.
|
275 |
+
|
276 |
+
This is a smoother version of the RELU.
|
277 |
+
Original paper: https://arxiv.org/abs/1606.08415
|
278 |
+
Args:
|
279 |
+
x: float Tensor to perform activation.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
`x` with the GELU activation applied.
|
283 |
+
"""
|
284 |
+
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
285 |
+
return x * cdf
|
286 |
+
|
287 |
+
|
288 |
+
def get_activation(activation_string):
|
289 |
+
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
activation_string: String name of the activation function.
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
A Python function corresponding to the activation function. If
|
296 |
+
`activation_string` is None, empty, or "linear", this will return None.
|
297 |
+
If `activation_string` is not a string, it will return `activation_string`.
|
298 |
+
|
299 |
+
Raises:
|
300 |
+
ValueError: The `activation_string` does not correspond to a known
|
301 |
+
activation.
|
302 |
+
"""
|
303 |
+
|
304 |
+
# We assume that anything that"s not a string is already an activation
|
305 |
+
# function, so we just return it.
|
306 |
+
if not isinstance(activation_string, six.string_types):
|
307 |
+
return activation_string
|
308 |
+
|
309 |
+
if not activation_string:
|
310 |
+
return None
|
311 |
+
|
312 |
+
act = activation_string.lower()
|
313 |
+
if act == "linear":
|
314 |
+
return None
|
315 |
+
elif act == "relu":
|
316 |
+
return tf.nn.relu
|
317 |
+
elif act == "gelu":
|
318 |
+
return gelu
|
319 |
+
elif act == "tanh":
|
320 |
+
return tf.tanh
|
321 |
+
else:
|
322 |
+
raise ValueError("Unsupported activation: %s" % act)
|
323 |
+
|
324 |
+
|
325 |
+
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
|
326 |
+
"""Compute the union of the current variables and checkpoint variables."""
|
327 |
+
assignment_map = {}
|
328 |
+
initialized_variable_names = {}
|
329 |
+
|
330 |
+
name_to_variable = collections.OrderedDict()
|
331 |
+
for var in tvars:
|
332 |
+
name = var.name
|
333 |
+
m = re.match("^(.*):\\d+$", name)
|
334 |
+
if m is not None:
|
335 |
+
name = m.group(1)
|
336 |
+
name_to_variable[name] = var
|
337 |
+
|
338 |
+
init_vars = tf.train.list_variables(init_checkpoint)
|
339 |
+
|
340 |
+
assignment_map = collections.OrderedDict()
|
341 |
+
for x in init_vars:
|
342 |
+
(name, var) = (x[0], x[1])
|
343 |
+
if name not in name_to_variable:
|
344 |
+
continue
|
345 |
+
assignment_map[name] = name
|
346 |
+
initialized_variable_names[name] = 1
|
347 |
+
initialized_variable_names[name + ":0"] = 1
|
348 |
+
|
349 |
+
return (assignment_map, initialized_variable_names)
|
350 |
+
|
351 |
+
|
352 |
+
def dropout(input_tensor, dropout_prob):
|
353 |
+
"""Perform dropout.
|
354 |
+
|
355 |
+
Args:
|
356 |
+
input_tensor: float Tensor.
|
357 |
+
dropout_prob: Python float. The probability of dropping out a value (NOT of
|
358 |
+
*keeping* a dimension as in `tf.nn.dropout`).
|
359 |
+
|
360 |
+
Returns:
|
361 |
+
A version of `input_tensor` with dropout applied.
|
362 |
+
"""
|
363 |
+
if dropout_prob is None or dropout_prob == 0.0:
|
364 |
+
return input_tensor
|
365 |
+
|
366 |
+
output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
|
367 |
+
return output
|
368 |
+
|
369 |
+
|
370 |
+
def layer_norm(input_tensor, name=None):
|
371 |
+
"""Run layer normalization on the last dimension of the tensor."""
|
372 |
+
return tf.contrib.layers.layer_norm(
|
373 |
+
inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name
|
374 |
+
)
|
375 |
+
|
376 |
+
|
377 |
+
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
|
378 |
+
"""Runs layer normalization followed by dropout."""
|
379 |
+
output_tensor = layer_norm(input_tensor, name)
|
380 |
+
output_tensor = dropout(output_tensor, dropout_prob)
|
381 |
+
return output_tensor
|
382 |
+
|
383 |
+
|
384 |
+
def create_initializer(initializer_range=0.02):
|
385 |
+
"""Creates a `truncated_normal_initializer` with the given range."""
|
386 |
+
return tf.truncated_normal_initializer(stddev=initializer_range)
|
387 |
+
|
388 |
+
|
389 |
+
def embedding_lookup(
|
390 |
+
input_ids,
|
391 |
+
vocab_size,
|
392 |
+
embedding_size=128,
|
393 |
+
initializer_range=0.02,
|
394 |
+
word_embedding_name="word_embeddings",
|
395 |
+
use_one_hot_embeddings=False,
|
396 |
+
):
|
397 |
+
"""Looks up words embeddings for id tensor.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
|
401 |
+
ids.
|
402 |
+
vocab_size: int. Size of the embedding vocabulary.
|
403 |
+
embedding_size: int. Width of the word embeddings.
|
404 |
+
initializer_range: float. Embedding initialization range.
|
405 |
+
word_embedding_name: string. Name of the embedding table.
|
406 |
+
use_one_hot_embeddings: bool. If True, use one-hot method for word
|
407 |
+
embeddings. If False, use `tf.gather()`.
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
float Tensor of shape [batch_size, seq_length, embedding_size].
|
411 |
+
"""
|
412 |
+
# This function assumes that the input is of shape [batch_size, seq_length,
|
413 |
+
# num_inputs].
|
414 |
+
#
|
415 |
+
# If the input is a 2D tensor of shape [batch_size, seq_length], we
|
416 |
+
# reshape to [batch_size, seq_length, 1].
|
417 |
+
if input_ids.shape.ndims == 2:
|
418 |
+
input_ids = tf.expand_dims(input_ids, axis=[-1])
|
419 |
+
|
420 |
+
embedding_table = tf.get_variable(
|
421 |
+
name=word_embedding_name,
|
422 |
+
shape=[vocab_size, embedding_size],
|
423 |
+
initializer=create_initializer(initializer_range),
|
424 |
+
)
|
425 |
+
|
426 |
+
flat_input_ids = tf.reshape(input_ids, [-1])
|
427 |
+
if use_one_hot_embeddings:
|
428 |
+
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
|
429 |
+
output = tf.matmul(one_hot_input_ids, embedding_table)
|
430 |
+
else:
|
431 |
+
output = tf.gather(embedding_table, flat_input_ids)
|
432 |
+
|
433 |
+
input_shape = get_shape_list(input_ids)
|
434 |
+
|
435 |
+
output = tf.reshape(output, input_shape[0:-1] + [input_shape[-1] * embedding_size])
|
436 |
+
return (output, embedding_table)
|
437 |
+
|
438 |
+
|
439 |
+
def embedding_postprocessor(
|
440 |
+
input_tensor,
|
441 |
+
use_token_type=False,
|
442 |
+
token_type_ids=None,
|
443 |
+
token_type_vocab_size=16,
|
444 |
+
token_type_embedding_name="token_type_embeddings",
|
445 |
+
use_position_embeddings=True,
|
446 |
+
position_embedding_name="position_embeddings",
|
447 |
+
initializer_range=0.02,
|
448 |
+
max_position_embeddings=512,
|
449 |
+
dropout_prob=0.1,
|
450 |
+
):
|
451 |
+
"""Performs various post-processing on a word embedding tensor.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
input_tensor: float Tensor of shape [batch_size, seq_length,
|
455 |
+
embedding_size].
|
456 |
+
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
|
457 |
+
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
|
458 |
+
Must be specified if `use_token_type` is True.
|
459 |
+
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
|
460 |
+
token_type_embedding_name: string. The name of the embedding table variable
|
461 |
+
for token type ids.
|
462 |
+
use_position_embeddings: bool. Whether to add position embeddings for the
|
463 |
+
position of each token in the sequence.
|
464 |
+
position_embedding_name: string. The name of the embedding table variable
|
465 |
+
for positional embeddings.
|
466 |
+
initializer_range: float. Range of the weight initialization.
|
467 |
+
max_position_embeddings: int. Maximum sequence length that might ever be
|
468 |
+
used with this model. This can be longer than the sequence length of
|
469 |
+
input_tensor, but cannot be shorter.
|
470 |
+
dropout_prob: float. Dropout probability applied to the final output tensor.
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
float tensor with same shape as `input_tensor`.
|
474 |
+
|
475 |
+
Raises:
|
476 |
+
ValueError: One of the tensor shapes or input values is invalid.
|
477 |
+
"""
|
478 |
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
479 |
+
batch_size = input_shape[0]
|
480 |
+
seq_length = input_shape[1]
|
481 |
+
width = input_shape[2]
|
482 |
+
|
483 |
+
output = input_tensor
|
484 |
+
|
485 |
+
if use_token_type:
|
486 |
+
if token_type_ids is None:
|
487 |
+
raise ValueError(
|
488 |
+
"`token_type_ids` must be specified if" "`use_token_type` is True."
|
489 |
+
)
|
490 |
+
token_type_table = tf.get_variable(
|
491 |
+
name=token_type_embedding_name,
|
492 |
+
shape=[token_type_vocab_size, width],
|
493 |
+
initializer=create_initializer(initializer_range),
|
494 |
+
)
|
495 |
+
# This vocab will be small so we always do one-hot here, since it is always
|
496 |
+
# faster for a small vocabulary.
|
497 |
+
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
|
498 |
+
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
|
499 |
+
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
|
500 |
+
token_type_embeddings = tf.reshape(
|
501 |
+
token_type_embeddings, [batch_size, seq_length, width]
|
502 |
+
)
|
503 |
+
output += token_type_embeddings
|
504 |
+
|
505 |
+
if use_position_embeddings:
|
506 |
+
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
|
507 |
+
with tf.control_dependencies([assert_op]):
|
508 |
+
full_position_embeddings = tf.get_variable(
|
509 |
+
name=position_embedding_name,
|
510 |
+
shape=[max_position_embeddings, width],
|
511 |
+
initializer=create_initializer(initializer_range),
|
512 |
+
)
|
513 |
+
# Since the position embedding table is a learned variable, we create it
|
514 |
+
# using a (long) sequence length `max_position_embeddings`. The actual
|
515 |
+
# sequence length might be shorter than this, for faster training of
|
516 |
+
# tasks that do not have long sequences.
|
517 |
+
#
|
518 |
+
# So `full_position_embeddings` is effectively an embedding table
|
519 |
+
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
520 |
+
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
521 |
+
# perform a slice.
|
522 |
+
position_embeddings = tf.slice(
|
523 |
+
full_position_embeddings, [0, 0], [seq_length, -1]
|
524 |
+
)
|
525 |
+
num_dims = len(output.shape.as_list())
|
526 |
+
|
527 |
+
# Only the last two dimensions are relevant (`seq_length` and `width`), so
|
528 |
+
# we broadcast among the first dimensions, which is typically just
|
529 |
+
# the batch size.
|
530 |
+
position_broadcast_shape = []
|
531 |
+
for _ in range(num_dims - 2):
|
532 |
+
position_broadcast_shape.append(1)
|
533 |
+
position_broadcast_shape.extend([seq_length, width])
|
534 |
+
position_embeddings = tf.reshape(
|
535 |
+
position_embeddings, position_broadcast_shape
|
536 |
+
)
|
537 |
+
output += position_embeddings
|
538 |
+
|
539 |
+
output = layer_norm_and_dropout(output, dropout_prob)
|
540 |
+
return output
|
541 |
+
|
542 |
+
|
543 |
+
def create_attention_mask_from_input_mask(from_tensor, to_mask):
|
544 |
+
"""Create 3D attention mask from a 2D tensor mask.
|
545 |
+
|
546 |
+
Args:
|
547 |
+
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
|
548 |
+
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
|
549 |
+
|
550 |
+
Returns:
|
551 |
+
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
|
552 |
+
"""
|
553 |
+
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
554 |
+
batch_size = from_shape[0]
|
555 |
+
from_seq_length = from_shape[1]
|
556 |
+
|
557 |
+
to_shape = get_shape_list(to_mask, expected_rank=2)
|
558 |
+
to_seq_length = to_shape[1]
|
559 |
+
|
560 |
+
to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
|
561 |
+
|
562 |
+
# We don't assume that `from_tensor` is a mask (although it could be). We
|
563 |
+
# don't actually care if we attend *from* padding tokens (only *to* padding)
|
564 |
+
# tokens so we create a tensor of all ones.
|
565 |
+
#
|
566 |
+
# `broadcast_ones` = [batch_size, from_seq_length, 1]
|
567 |
+
broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
|
568 |
+
|
569 |
+
# Here we broadcast along two dimensions to create the mask.
|
570 |
+
mask = broadcast_ones * to_mask
|
571 |
+
|
572 |
+
return mask
|
573 |
+
|
574 |
+
|
575 |
+
def attention_layer(
|
576 |
+
from_tensor,
|
577 |
+
to_tensor,
|
578 |
+
attention_mask=None,
|
579 |
+
num_attention_heads=1,
|
580 |
+
size_per_head=512,
|
581 |
+
query_act=None,
|
582 |
+
key_act=None,
|
583 |
+
value_act=None,
|
584 |
+
attention_probs_dropout_prob=0.0,
|
585 |
+
initializer_range=0.02,
|
586 |
+
do_return_2d_tensor=False,
|
587 |
+
batch_size=None,
|
588 |
+
from_seq_length=None,
|
589 |
+
to_seq_length=None,
|
590 |
+
):
|
591 |
+
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
|
592 |
+
|
593 |
+
This is an implementation of multi-headed attention based on "Attention
|
594 |
+
is all you Need". If `from_tensor` and `to_tensor` are the same, then
|
595 |
+
this is self-attention. Each timestep in `from_tensor` attends to the
|
596 |
+
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
|
597 |
+
|
598 |
+
This function first projects `from_tensor` into a "query" tensor and
|
599 |
+
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
|
600 |
+
of tensors of length `num_attention_heads`, where each tensor is of shape
|
601 |
+
[batch_size, seq_length, size_per_head].
|
602 |
+
|
603 |
+
Then, the query and key tensors are dot-producted and scaled. These are
|
604 |
+
softmaxed to obtain attention probabilities. The value tensors are then
|
605 |
+
interpolated by these probabilities, then concatenated back to a single
|
606 |
+
tensor and returned.
|
607 |
+
|
608 |
+
In practice, the multi-headed attention are done with transposes and
|
609 |
+
reshapes rather than actual separate tensors.
|
610 |
+
|
611 |
+
Args:
|
612 |
+
from_tensor: float Tensor of shape [batch_size, from_seq_length,
|
613 |
+
from_width].
|
614 |
+
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
|
615 |
+
attention_mask: (optional) int32 Tensor of shape [batch_size,
|
616 |
+
from_seq_length, to_seq_length]. The values should be 1 or 0. The
|
617 |
+
attention scores will effectively be set to -infinity for any positions in
|
618 |
+
the mask that are 0, and will be unchanged for positions that are 1.
|
619 |
+
num_attention_heads: int. Number of attention heads.
|
620 |
+
size_per_head: int. Size of each attention head.
|
621 |
+
query_act: (optional) Activation function for the query transform.
|
622 |
+
key_act: (optional) Activation function for the key transform.
|
623 |
+
value_act: (optional) Activation function for the value transform.
|
624 |
+
attention_probs_dropout_prob: (optional) float. Dropout probability of the
|
625 |
+
attention probabilities.
|
626 |
+
initializer_range: float. Range of the weight initializer.
|
627 |
+
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
|
628 |
+
* from_seq_length, num_attention_heads * size_per_head]. If False, the
|
629 |
+
output will be of shape [batch_size, from_seq_length, num_attention_heads
|
630 |
+
* size_per_head].
|
631 |
+
batch_size: (Optional) int. If the input is 2D, this might be the batch size
|
632 |
+
of the 3D version of the `from_tensor` and `to_tensor`.
|
633 |
+
from_seq_length: (Optional) If the input is 2D, this might be the seq length
|
634 |
+
of the 3D version of the `from_tensor`.
|
635 |
+
to_seq_length: (Optional) If the input is 2D, this might be the seq length
|
636 |
+
of the 3D version of the `to_tensor`.
|
637 |
+
|
638 |
+
Returns:
|
639 |
+
float Tensor of shape [batch_size, from_seq_length,
|
640 |
+
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
|
641 |
+
true, this will be of shape [batch_size * from_seq_length,
|
642 |
+
num_attention_heads * size_per_head]).
|
643 |
+
|
644 |
+
Raises:
|
645 |
+
ValueError: Any of the arguments or tensor shapes are invalid.
|
646 |
+
"""
|
647 |
+
|
648 |
+
def transpose_for_scores(
|
649 |
+
input_tensor, batch_size, num_attention_heads, seq_length, width
|
650 |
+
):
|
651 |
+
output_tensor = tf.reshape(
|
652 |
+
input_tensor, [batch_size, seq_length, num_attention_heads, width]
|
653 |
+
)
|
654 |
+
|
655 |
+
output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
|
656 |
+
return output_tensor
|
657 |
+
|
658 |
+
from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
|
659 |
+
to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
|
660 |
+
|
661 |
+
if len(from_shape) != len(to_shape):
|
662 |
+
raise ValueError(
|
663 |
+
"The rank of `from_tensor` must match the rank of `to_tensor`."
|
664 |
+
)
|
665 |
+
|
666 |
+
if len(from_shape) == 3:
|
667 |
+
batch_size = from_shape[0]
|
668 |
+
from_seq_length = from_shape[1]
|
669 |
+
to_seq_length = to_shape[1]
|
670 |
+
elif len(from_shape) == 2:
|
671 |
+
if batch_size is None or from_seq_length is None or to_seq_length is None:
|
672 |
+
raise ValueError(
|
673 |
+
"When passing in rank 2 tensors to attention_layer, the values "
|
674 |
+
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
|
675 |
+
"must all be specified."
|
676 |
+
)
|
677 |
+
|
678 |
+
# Scalar dimensions referenced here:
|
679 |
+
# B = batch size (number of sequences)
|
680 |
+
# F = `from_tensor` sequence length
|
681 |
+
# T = `to_tensor` sequence length
|
682 |
+
# N = `num_attention_heads`
|
683 |
+
# H = `size_per_head`
|
684 |
+
|
685 |
+
from_tensor_2d = reshape_to_matrix(from_tensor)
|
686 |
+
to_tensor_2d = reshape_to_matrix(to_tensor)
|
687 |
+
|
688 |
+
# `query_layer` = [B*F, N*H]
|
689 |
+
query_layer = tf.layers.dense(
|
690 |
+
from_tensor_2d,
|
691 |
+
num_attention_heads * size_per_head,
|
692 |
+
activation=query_act,
|
693 |
+
name="query",
|
694 |
+
kernel_initializer=create_initializer(initializer_range),
|
695 |
+
)
|
696 |
+
|
697 |
+
# `key_layer` = [B*T, N*H]
|
698 |
+
key_layer = tf.layers.dense(
|
699 |
+
to_tensor_2d,
|
700 |
+
num_attention_heads * size_per_head,
|
701 |
+
activation=key_act,
|
702 |
+
name="key",
|
703 |
+
kernel_initializer=create_initializer(initializer_range),
|
704 |
+
)
|
705 |
+
|
706 |
+
# `value_layer` = [B*T, N*H]
|
707 |
+
value_layer = tf.layers.dense(
|
708 |
+
to_tensor_2d,
|
709 |
+
num_attention_heads * size_per_head,
|
710 |
+
activation=value_act,
|
711 |
+
name="value",
|
712 |
+
kernel_initializer=create_initializer(initializer_range),
|
713 |
+
)
|
714 |
+
|
715 |
+
# `query_layer` = [B, N, F, H]
|
716 |
+
query_layer = transpose_for_scores(
|
717 |
+
query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head
|
718 |
+
)
|
719 |
+
|
720 |
+
# `key_layer` = [B, N, T, H]
|
721 |
+
key_layer = transpose_for_scores(
|
722 |
+
key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head
|
723 |
+
)
|
724 |
+
|
725 |
+
# Take the dot product between "query" and "key" to get the raw
|
726 |
+
# attention scores.
|
727 |
+
# `attention_scores` = [B, N, F, T]
|
728 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
729 |
+
attention_scores = tf.multiply(
|
730 |
+
attention_scores, 1.0 / math.sqrt(float(size_per_head))
|
731 |
+
)
|
732 |
+
|
733 |
+
if attention_mask is not None:
|
734 |
+
# `attention_mask` = [B, 1, F, T]
|
735 |
+
attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
736 |
+
|
737 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
738 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
739 |
+
# positions we want to attend and -10000.0 for masked positions.
|
740 |
+
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
|
741 |
+
|
742 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
743 |
+
# effectively the same as removing these entirely.
|
744 |
+
attention_scores += adder
|
745 |
+
|
746 |
+
# Normalize the attention scores to probabilities.
|
747 |
+
# `attention_probs` = [B, N, F, T]
|
748 |
+
attention_probs = tf.nn.softmax(attention_scores)
|
749 |
+
|
750 |
+
# This is actually dropping out entire tokens to attend to, which might
|
751 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
752 |
+
attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
|
753 |
+
|
754 |
+
# `value_layer` = [B, T, N, H]
|
755 |
+
value_layer = tf.reshape(
|
756 |
+
value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]
|
757 |
+
)
|
758 |
+
|
759 |
+
# `value_layer` = [B, N, T, H]
|
760 |
+
value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
|
761 |
+
|
762 |
+
# `context_layer` = [B, N, F, H]
|
763 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
764 |
+
|
765 |
+
# `context_layer` = [B, F, N, H]
|
766 |
+
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
767 |
+
|
768 |
+
if do_return_2d_tensor:
|
769 |
+
# `context_layer` = [B*F, N*H]
|
770 |
+
context_layer = tf.reshape(
|
771 |
+
context_layer,
|
772 |
+
[batch_size * from_seq_length, num_attention_heads * size_per_head],
|
773 |
+
)
|
774 |
+
else:
|
775 |
+
# `context_layer` = [B, F, N*H]
|
776 |
+
context_layer = tf.reshape(
|
777 |
+
context_layer,
|
778 |
+
[batch_size, from_seq_length, num_attention_heads * size_per_head],
|
779 |
+
)
|
780 |
+
|
781 |
+
return context_layer
|
782 |
+
|
783 |
+
|
784 |
+
def transformer_model(
|
785 |
+
input_tensor,
|
786 |
+
attention_mask=None,
|
787 |
+
hidden_size=768,
|
788 |
+
num_hidden_layers=12,
|
789 |
+
num_attention_heads=12,
|
790 |
+
intermediate_size=3072,
|
791 |
+
intermediate_act_fn=gelu,
|
792 |
+
hidden_dropout_prob=0.1,
|
793 |
+
attention_probs_dropout_prob=0.1,
|
794 |
+
initializer_range=0.02,
|
795 |
+
do_return_all_layers=False,
|
796 |
+
):
|
797 |
+
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
|
798 |
+
|
799 |
+
This is almost an exact implementation of the original Transformer encoder.
|
800 |
+
|
801 |
+
See the original paper:
|
802 |
+
https://arxiv.org/abs/1706.03762
|
803 |
+
|
804 |
+
Also see:
|
805 |
+
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
|
806 |
+
|
807 |
+
Args:
|
808 |
+
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
|
809 |
+
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
|
810 |
+
seq_length], with 1 for positions that can be attended to and 0 in
|
811 |
+
positions that should not be.
|
812 |
+
hidden_size: int. Hidden size of the Transformer.
|
813 |
+
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
|
814 |
+
num_attention_heads: int. Number of attention heads in the Transformer.
|
815 |
+
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
|
816 |
+
forward) layer.
|
817 |
+
intermediate_act_fn: function. The non-linear activation function to apply
|
818 |
+
to the output of the intermediate/feed-forward layer.
|
819 |
+
hidden_dropout_prob: float. Dropout probability for the hidden layers.
|
820 |
+
attention_probs_dropout_prob: float. Dropout probability of the attention
|
821 |
+
probabilities.
|
822 |
+
initializer_range: float. Range of the initializer (stddev of truncated
|
823 |
+
normal).
|
824 |
+
do_return_all_layers: Whether to also return all layers or just the final
|
825 |
+
layer.
|
826 |
+
|
827 |
+
Returns:
|
828 |
+
float Tensor of shape [batch_size, seq_length, hidden_size], the final
|
829 |
+
hidden layer of the Transformer.
|
830 |
+
|
831 |
+
Raises:
|
832 |
+
ValueError: A Tensor shape or parameter is invalid.
|
833 |
+
"""
|
834 |
+
if hidden_size % num_attention_heads != 0:
|
835 |
+
raise ValueError(
|
836 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
837 |
+
"heads (%d)" % (hidden_size, num_attention_heads)
|
838 |
+
)
|
839 |
+
|
840 |
+
attention_head_size = int(hidden_size / num_attention_heads)
|
841 |
+
input_shape = get_shape_list(input_tensor, expected_rank=3)
|
842 |
+
batch_size = input_shape[0]
|
843 |
+
seq_length = input_shape[1]
|
844 |
+
input_width = input_shape[2]
|
845 |
+
|
846 |
+
# The Transformer performs sum residuals on all layers so the input needs
|
847 |
+
# to be the same as the hidden size.
|
848 |
+
if input_width != hidden_size:
|
849 |
+
raise ValueError(
|
850 |
+
"The width of the input tensor (%d) != hidden size (%d)"
|
851 |
+
% (input_width, hidden_size)
|
852 |
+
)
|
853 |
+
|
854 |
+
# We keep the representation as a 2D tensor to avoid re-shaping it back and
|
855 |
+
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
|
856 |
+
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
|
857 |
+
# help the optimizer.
|
858 |
+
prev_output = reshape_to_matrix(input_tensor)
|
859 |
+
|
860 |
+
all_layer_outputs = []
|
861 |
+
for layer_idx in range(num_hidden_layers):
|
862 |
+
with tf.variable_scope("layer_%d" % layer_idx):
|
863 |
+
layer_input = prev_output
|
864 |
+
|
865 |
+
with tf.variable_scope("attention"):
|
866 |
+
attention_heads = []
|
867 |
+
with tf.variable_scope("self"):
|
868 |
+
attention_head = attention_layer(
|
869 |
+
from_tensor=layer_input,
|
870 |
+
to_tensor=layer_input,
|
871 |
+
attention_mask=attention_mask,
|
872 |
+
num_attention_heads=num_attention_heads,
|
873 |
+
size_per_head=attention_head_size,
|
874 |
+
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
875 |
+
initializer_range=initializer_range,
|
876 |
+
do_return_2d_tensor=True,
|
877 |
+
batch_size=batch_size,
|
878 |
+
from_seq_length=seq_length,
|
879 |
+
to_seq_length=seq_length,
|
880 |
+
)
|
881 |
+
attention_heads.append(attention_head)
|
882 |
+
|
883 |
+
attention_output = None
|
884 |
+
if len(attention_heads) == 1:
|
885 |
+
attention_output = attention_heads[0]
|
886 |
+
else:
|
887 |
+
# In the case where we have other sequences, we just concatenate
|
888 |
+
# them to the self-attention head before the projection.
|
889 |
+
attention_output = tf.concat(attention_heads, axis=-1)
|
890 |
+
|
891 |
+
# Run a linear projection of `hidden_size` then add a residual
|
892 |
+
# with `layer_input`.
|
893 |
+
with tf.variable_scope("output"):
|
894 |
+
attention_output = tf.layers.dense(
|
895 |
+
attention_output,
|
896 |
+
hidden_size,
|
897 |
+
kernel_initializer=create_initializer(initializer_range),
|
898 |
+
)
|
899 |
+
attention_output = dropout(attention_output, hidden_dropout_prob)
|
900 |
+
attention_output = layer_norm(attention_output + layer_input)
|
901 |
+
|
902 |
+
# The activation is only applied to the "intermediate" hidden layer.
|
903 |
+
with tf.variable_scope("intermediate"):
|
904 |
+
intermediate_output = tf.layers.dense(
|
905 |
+
attention_output,
|
906 |
+
intermediate_size,
|
907 |
+
activation=intermediate_act_fn,
|
908 |
+
kernel_initializer=create_initializer(initializer_range),
|
909 |
+
)
|
910 |
+
|
911 |
+
# Down-project back to `hidden_size` then add the residual.
|
912 |
+
with tf.variable_scope("output"):
|
913 |
+
layer_output = tf.layers.dense(
|
914 |
+
intermediate_output,
|
915 |
+
hidden_size,
|
916 |
+
kernel_initializer=create_initializer(initializer_range),
|
917 |
+
)
|
918 |
+
layer_output = dropout(layer_output, hidden_dropout_prob)
|
919 |
+
layer_output = layer_norm(layer_output + attention_output)
|
920 |
+
prev_output = layer_output
|
921 |
+
all_layer_outputs.append(layer_output)
|
922 |
+
|
923 |
+
if do_return_all_layers:
|
924 |
+
final_outputs = []
|
925 |
+
for layer_output in all_layer_outputs:
|
926 |
+
final_output = reshape_from_matrix(layer_output, input_shape)
|
927 |
+
final_outputs.append(final_output)
|
928 |
+
return final_outputs
|
929 |
+
else:
|
930 |
+
final_output = reshape_from_matrix(prev_output, input_shape)
|
931 |
+
return final_output
|
932 |
+
|
933 |
+
|
934 |
+
def get_shape_list(tensor, expected_rank=None, name=None):
|
935 |
+
"""Returns a list of the shape of tensor, preferring static dimensions.
|
936 |
+
|
937 |
+
Args:
|
938 |
+
tensor: A tf.Tensor object to find the shape of.
|
939 |
+
expected_rank: (optional) int. The expected rank of `tensor`. If this is
|
940 |
+
specified and the `tensor` has a different rank, and exception will be
|
941 |
+
thrown.
|
942 |
+
name: Optional name of the tensor for the error message.
|
943 |
+
|
944 |
+
Returns:
|
945 |
+
A list of dimensions of the shape of tensor. All static dimensions will
|
946 |
+
be returned as python integers, and dynamic dimensions will be returned
|
947 |
+
as tf.Tensor scalars.
|
948 |
+
"""
|
949 |
+
if name is None:
|
950 |
+
name = tensor.name
|
951 |
+
|
952 |
+
if expected_rank is not None:
|
953 |
+
assert_rank(tensor, expected_rank, name)
|
954 |
+
|
955 |
+
shape = tensor.shape.as_list()
|
956 |
+
|
957 |
+
non_static_indexes = []
|
958 |
+
for (index, dim) in enumerate(shape):
|
959 |
+
if dim is None:
|
960 |
+
non_static_indexes.append(index)
|
961 |
+
|
962 |
+
if not non_static_indexes:
|
963 |
+
return shape
|
964 |
+
|
965 |
+
dyn_shape = tf.shape(tensor)
|
966 |
+
for index in non_static_indexes:
|
967 |
+
shape[index] = dyn_shape[index]
|
968 |
+
return shape
|
969 |
+
|
970 |
+
|
971 |
+
def reshape_to_matrix(input_tensor):
|
972 |
+
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
|
973 |
+
ndims = input_tensor.shape.ndims
|
974 |
+
if ndims < 2:
|
975 |
+
raise ValueError(
|
976 |
+
"Input tensor must have at least rank 2. Shape = %s" % (input_tensor.shape)
|
977 |
+
)
|
978 |
+
if ndims == 2:
|
979 |
+
return input_tensor
|
980 |
+
|
981 |
+
width = input_tensor.shape[-1]
|
982 |
+
output_tensor = tf.reshape(input_tensor, [-1, width])
|
983 |
+
return output_tensor
|
984 |
+
|
985 |
+
|
986 |
+
def reshape_from_matrix(output_tensor, orig_shape_list):
|
987 |
+
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
|
988 |
+
if len(orig_shape_list) == 2:
|
989 |
+
return output_tensor
|
990 |
+
|
991 |
+
output_shape = get_shape_list(output_tensor)
|
992 |
+
|
993 |
+
orig_dims = orig_shape_list[0:-1]
|
994 |
+
width = output_shape[-1]
|
995 |
+
|
996 |
+
return tf.reshape(output_tensor, orig_dims + [width])
|
997 |
+
|
998 |
+
|
999 |
+
def assert_rank(tensor, expected_rank, name=None):
|
1000 |
+
"""Raises an exception if the tensor rank is not of the expected rank.
|
1001 |
+
|
1002 |
+
Args:
|
1003 |
+
tensor: A tf.Tensor to check the rank of.
|
1004 |
+
expected_rank: Python integer or list of integers, expected rank.
|
1005 |
+
name: Optional name of the tensor for the error message.
|
1006 |
+
|
1007 |
+
Raises:
|
1008 |
+
ValueError: If the expected shape doesn't match the actual shape.
|
1009 |
+
"""
|
1010 |
+
if name is None:
|
1011 |
+
name = tensor.name
|
1012 |
+
|
1013 |
+
expected_rank_dict = {}
|
1014 |
+
if isinstance(expected_rank, six.integer_types):
|
1015 |
+
expected_rank_dict[expected_rank] = True
|
1016 |
+
else:
|
1017 |
+
for x in expected_rank:
|
1018 |
+
expected_rank_dict[x] = True
|
1019 |
+
|
1020 |
+
actual_rank = tensor.shape.ndims
|
1021 |
+
if actual_rank not in expected_rank_dict:
|
1022 |
+
scope_name = tf.get_variable_scope().name
|
1023 |
+
raise ValueError(
|
1024 |
+
"For the tensor `%s` in scope `%s`, the actual rank "
|
1025 |
+
"`%d` (shape = %s) is not equal to the expected rank `%s`"
|
1026 |
+
% (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))
|
1027 |
+
)
|
arabert/arabert/optimization.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Functions and classes related to optimization (weight updates)."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import re
|
22 |
+
import tensorflow as tf
|
23 |
+
import lamb_optimizer
|
24 |
+
|
25 |
+
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu,
|
26 |
+
optimizer="adamw", poly_power=1.0, start_warmup_step=0,
|
27 |
+
colocate_gradients_with_ops=False):
|
28 |
+
"""Creates an optimizer training op."""
|
29 |
+
global_step = tf.train.get_or_create_global_step()
|
30 |
+
|
31 |
+
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
32 |
+
|
33 |
+
# Implements linear decay of the learning rate.
|
34 |
+
learning_rate = tf.train.polynomial_decay(
|
35 |
+
learning_rate,
|
36 |
+
global_step,
|
37 |
+
num_train_steps,
|
38 |
+
end_learning_rate=0.0,
|
39 |
+
power=poly_power,
|
40 |
+
cycle=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Implements linear warmup. I.e., if global_step - start_warmup_step <
|
44 |
+
# num_warmup_steps, the learning rate will be
|
45 |
+
# `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
|
46 |
+
if num_warmup_steps:
|
47 |
+
tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step)
|
48 |
+
+ ", for " + str(num_warmup_steps) + " steps ++++++")
|
49 |
+
global_steps_int = tf.cast(global_step, tf.int32)
|
50 |
+
start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
|
51 |
+
global_steps_int = global_steps_int - start_warm_int
|
52 |
+
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
53 |
+
|
54 |
+
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
55 |
+
warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
|
56 |
+
|
57 |
+
warmup_percent_done = global_steps_float / warmup_steps_float
|
58 |
+
warmup_learning_rate = init_lr * warmup_percent_done
|
59 |
+
|
60 |
+
is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
|
61 |
+
learning_rate = (
|
62 |
+
1.0 - is_warmup
|
63 |
+
) * learning_rate + is_warmup * warmup_learning_rate
|
64 |
+
|
65 |
+
# It is OK that you use this optimizer for finetuning, since this
|
66 |
+
# is how the model was trained (note that the Adam m/v variables are NOT
|
67 |
+
# loaded from init_checkpoint.)
|
68 |
+
# It is OK to use AdamW in the finetuning even the model is trained by LAMB.
|
69 |
+
# As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
|
70 |
+
# is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
|
71 |
+
# batch size of 64 in the finetune.
|
72 |
+
if optimizer == "adamw":
|
73 |
+
tf.logging.info("using adamw")
|
74 |
+
optimizer = AdamWeightDecayOptimizer(
|
75 |
+
learning_rate=learning_rate,
|
76 |
+
weight_decay_rate=0.01,
|
77 |
+
beta_1=0.9,
|
78 |
+
beta_2=0.999,
|
79 |
+
epsilon=1e-6,
|
80 |
+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
81 |
+
elif optimizer == "lamb":
|
82 |
+
tf.logging.info("using lamb")
|
83 |
+
optimizer = lamb_optimizer.LAMBOptimizer(
|
84 |
+
learning_rate=learning_rate,
|
85 |
+
weight_decay_rate=0.01,
|
86 |
+
beta_1=0.9,
|
87 |
+
beta_2=0.999,
|
88 |
+
epsilon=1e-6,
|
89 |
+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
90 |
+
else:
|
91 |
+
raise ValueError("Not supported optimizer: ", optimizer)
|
92 |
+
|
93 |
+
if use_tpu:
|
94 |
+
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
95 |
+
|
96 |
+
tvars = tf.trainable_variables()
|
97 |
+
grads = tf.gradients(loss, tvars)
|
98 |
+
|
99 |
+
# This is how the model was pre-trained.
|
100 |
+
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
101 |
+
|
102 |
+
train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)
|
103 |
+
|
104 |
+
# Normally the global step update is done inside of `apply_gradients`.
|
105 |
+
# However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this.
|
106 |
+
# But if you use a different optimizer, you should probably take this line
|
107 |
+
# out.
|
108 |
+
new_global_step = global_step + 1
|
109 |
+
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
110 |
+
return train_op
|
111 |
+
|
112 |
+
|
113 |
+
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
114 |
+
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
learning_rate,
|
119 |
+
weight_decay_rate=0.0,
|
120 |
+
beta_1=0.9,
|
121 |
+
beta_2=0.999,
|
122 |
+
epsilon=1e-6,
|
123 |
+
exclude_from_weight_decay=None,
|
124 |
+
name="AdamWeightDecayOptimizer",
|
125 |
+
):
|
126 |
+
"""Constructs a AdamWeightDecayOptimizer."""
|
127 |
+
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
128 |
+
|
129 |
+
self.learning_rate = learning_rate
|
130 |
+
self.weight_decay_rate = weight_decay_rate
|
131 |
+
self.beta_1 = beta_1
|
132 |
+
self.beta_2 = beta_2
|
133 |
+
self.epsilon = epsilon
|
134 |
+
self.exclude_from_weight_decay = exclude_from_weight_decay
|
135 |
+
|
136 |
+
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
137 |
+
"""See base class."""
|
138 |
+
assignments = []
|
139 |
+
for (grad, param) in grads_and_vars:
|
140 |
+
if grad is None or param is None:
|
141 |
+
continue
|
142 |
+
|
143 |
+
param_name = self._get_variable_name(param.name)
|
144 |
+
|
145 |
+
m = tf.get_variable(
|
146 |
+
name=param_name + "/adam_m",
|
147 |
+
shape=param.shape.as_list(),
|
148 |
+
dtype=tf.float32,
|
149 |
+
trainable=False,
|
150 |
+
initializer=tf.zeros_initializer(),
|
151 |
+
)
|
152 |
+
v = tf.get_variable(
|
153 |
+
name=param_name + "/adam_v",
|
154 |
+
shape=param.shape.as_list(),
|
155 |
+
dtype=tf.float32,
|
156 |
+
trainable=False,
|
157 |
+
initializer=tf.zeros_initializer(),
|
158 |
+
)
|
159 |
+
|
160 |
+
# Standard Adam update.
|
161 |
+
next_m = tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)
|
162 |
+
next_v = tf.multiply(self.beta_2, v) + tf.multiply(
|
163 |
+
1.0 - self.beta_2, tf.square(grad)
|
164 |
+
)
|
165 |
+
|
166 |
+
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
167 |
+
|
168 |
+
# Just adding the square of the weights to the loss function is *not*
|
169 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
170 |
+
# since that will interact with the m and v parameters in strange ways.
|
171 |
+
#
|
172 |
+
# Instead we want ot decay the weights in a manner that doesn't interact
|
173 |
+
# with the m/v parameters. This is equivalent to adding the square
|
174 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
175 |
+
if self._do_use_weight_decay(param_name):
|
176 |
+
update += self.weight_decay_rate * param
|
177 |
+
|
178 |
+
update_with_lr = self.learning_rate * update
|
179 |
+
|
180 |
+
next_param = param - update_with_lr
|
181 |
+
|
182 |
+
assignments.extend(
|
183 |
+
[param.assign(next_param), m.assign(next_m), v.assign(next_v)]
|
184 |
+
)
|
185 |
+
return tf.group(*assignments, name=name)
|
186 |
+
|
187 |
+
def _do_use_weight_decay(self, param_name):
|
188 |
+
"""Whether to use L2 weight decay for `param_name`."""
|
189 |
+
if not self.weight_decay_rate:
|
190 |
+
return False
|
191 |
+
if self.exclude_from_weight_decay:
|
192 |
+
for r in self.exclude_from_weight_decay:
|
193 |
+
if re.search(r, param_name) is not None:
|
194 |
+
return False
|
195 |
+
return True
|
196 |
+
|
197 |
+
def _get_variable_name(self, param_name):
|
198 |
+
"""Get the variable name from the tensor name."""
|
199 |
+
m = re.match("^(.*):\\d+$", param_name)
|
200 |
+
if m is not None:
|
201 |
+
param_name = m.group(1)
|
202 |
+
return param_name
|
arabert/arabert/run_classifier.py
ADDED
@@ -0,0 +1,1078 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""BERT finetuning runner."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import collections
|
22 |
+
import csv
|
23 |
+
import os
|
24 |
+
import modeling
|
25 |
+
import optimization
|
26 |
+
import tokenization
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
flags = tf.flags
|
30 |
+
|
31 |
+
FLAGS = flags.FLAGS
|
32 |
+
|
33 |
+
## Required parameters
|
34 |
+
flags.DEFINE_string(
|
35 |
+
"data_dir",
|
36 |
+
None,
|
37 |
+
"The input data dir. Should contain the .tsv files (or other data files) "
|
38 |
+
"for the task.",
|
39 |
+
)
|
40 |
+
|
41 |
+
flags.DEFINE_string(
|
42 |
+
"bert_config_file",
|
43 |
+
None,
|
44 |
+
"The config json file corresponding to the pre-trained BERT model. "
|
45 |
+
"This specifies the model architecture.",
|
46 |
+
)
|
47 |
+
|
48 |
+
flags.DEFINE_string("task_name", None, "The name of the task to train.")
|
49 |
+
|
50 |
+
flags.DEFINE_string(
|
51 |
+
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
52 |
+
)
|
53 |
+
|
54 |
+
flags.DEFINE_string(
|
55 |
+
"output_dir",
|
56 |
+
None,
|
57 |
+
"The output directory where the model checkpoints will be written.",
|
58 |
+
)
|
59 |
+
|
60 |
+
## Other parameters
|
61 |
+
|
62 |
+
flags.DEFINE_string(
|
63 |
+
"init_checkpoint",
|
64 |
+
None,
|
65 |
+
"Initial checkpoint (usually from a pre-trained BERT model).",
|
66 |
+
)
|
67 |
+
|
68 |
+
flags.DEFINE_bool(
|
69 |
+
"do_lower_case",
|
70 |
+
True,
|
71 |
+
"Whether to lower case the input text. Should be True for uncased "
|
72 |
+
"models and False for cased models.",
|
73 |
+
)
|
74 |
+
|
75 |
+
flags.DEFINE_integer(
|
76 |
+
"max_seq_length",
|
77 |
+
128,
|
78 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
79 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
80 |
+
"than this will be padded.",
|
81 |
+
)
|
82 |
+
|
83 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
84 |
+
|
85 |
+
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
86 |
+
|
87 |
+
flags.DEFINE_bool(
|
88 |
+
"do_predict", False, "Whether to run the model in inference mode on the test set."
|
89 |
+
)
|
90 |
+
|
91 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
92 |
+
|
93 |
+
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
94 |
+
|
95 |
+
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
|
96 |
+
|
97 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
98 |
+
|
99 |
+
flags.DEFINE_float(
|
100 |
+
"num_train_epochs", 3.0, "Total number of training epochs to perform."
|
101 |
+
)
|
102 |
+
|
103 |
+
flags.DEFINE_float(
|
104 |
+
"warmup_proportion",
|
105 |
+
0.1,
|
106 |
+
"Proportion of training to perform linear learning rate warmup for. "
|
107 |
+
"E.g., 0.1 = 10% of training.",
|
108 |
+
)
|
109 |
+
|
110 |
+
flags.DEFINE_integer(
|
111 |
+
"save_checkpoints_steps", 1000, "How often to save the model checkpoint."
|
112 |
+
)
|
113 |
+
|
114 |
+
flags.DEFINE_integer(
|
115 |
+
"iterations_per_loop", 1000, "How many steps to make in each estimator call."
|
116 |
+
)
|
117 |
+
|
118 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
119 |
+
|
120 |
+
tf.flags.DEFINE_string(
|
121 |
+
"tpu_name",
|
122 |
+
None,
|
123 |
+
"The Cloud TPU to use for training. This should be either the name "
|
124 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
125 |
+
"url.",
|
126 |
+
)
|
127 |
+
|
128 |
+
tf.flags.DEFINE_string(
|
129 |
+
"tpu_zone",
|
130 |
+
None,
|
131 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
132 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
133 |
+
"metadata.",
|
134 |
+
)
|
135 |
+
|
136 |
+
tf.flags.DEFINE_string(
|
137 |
+
"gcp_project",
|
138 |
+
None,
|
139 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
140 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
141 |
+
"metadata.",
|
142 |
+
)
|
143 |
+
|
144 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
145 |
+
|
146 |
+
flags.DEFINE_integer(
|
147 |
+
"num_tpu_cores",
|
148 |
+
8,
|
149 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
class InputExample(object):
|
154 |
+
"""A single training/test example for simple sequence classification."""
|
155 |
+
|
156 |
+
def __init__(self, guid, text_a, text_b=None, label=None):
|
157 |
+
"""Constructs a InputExample.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
guid: Unique id for the example.
|
161 |
+
text_a: string. The untokenized text of the first sequence. For single
|
162 |
+
sequence tasks, only this sequence must be specified.
|
163 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
164 |
+
Only must be specified for sequence pair tasks.
|
165 |
+
label: (Optional) string. The label of the example. This should be
|
166 |
+
specified for train and dev examples, but not for test examples.
|
167 |
+
"""
|
168 |
+
self.guid = guid
|
169 |
+
self.text_a = text_a
|
170 |
+
self.text_b = text_b
|
171 |
+
self.label = label
|
172 |
+
|
173 |
+
|
174 |
+
class PaddingInputExample(object):
|
175 |
+
"""Fake example so the num input examples is a multiple of the batch size.
|
176 |
+
|
177 |
+
When running eval/predict on the TPU, we need to pad the number of examples
|
178 |
+
to be a multiple of the batch size, because the TPU requires a fixed batch
|
179 |
+
size. The alternative is to drop the last batch, which is bad because it means
|
180 |
+
the entire output data won't be generated.
|
181 |
+
|
182 |
+
We use this class instead of `None` because treating `None` as padding
|
183 |
+
battches could cause silent errors.
|
184 |
+
"""
|
185 |
+
|
186 |
+
|
187 |
+
class InputFeatures(object):
|
188 |
+
"""A single set of features of data."""
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self, input_ids, input_mask, segment_ids, label_id, is_real_example=True
|
192 |
+
):
|
193 |
+
self.input_ids = input_ids
|
194 |
+
self.input_mask = input_mask
|
195 |
+
self.segment_ids = segment_ids
|
196 |
+
self.label_id = label_id
|
197 |
+
self.is_real_example = is_real_example
|
198 |
+
|
199 |
+
|
200 |
+
class DataProcessor(object):
|
201 |
+
"""Base class for data converters for sequence classification data sets."""
|
202 |
+
|
203 |
+
def get_train_examples(self, data_dir):
|
204 |
+
"""Gets a collection of `InputExample`s for the train set."""
|
205 |
+
raise NotImplementedError()
|
206 |
+
|
207 |
+
def get_dev_examples(self, data_dir):
|
208 |
+
"""Gets a collection of `InputExample`s for the dev set."""
|
209 |
+
raise NotImplementedError()
|
210 |
+
|
211 |
+
def get_test_examples(self, data_dir):
|
212 |
+
"""Gets a collection of `InputExample`s for prediction."""
|
213 |
+
raise NotImplementedError()
|
214 |
+
|
215 |
+
def get_labels(self):
|
216 |
+
"""Gets the list of labels for this data set."""
|
217 |
+
raise NotImplementedError()
|
218 |
+
|
219 |
+
@classmethod
|
220 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
221 |
+
"""Reads a tab separated value file."""
|
222 |
+
with tf.gfile.Open(input_file, "r") as f:
|
223 |
+
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
224 |
+
lines = []
|
225 |
+
for line in reader:
|
226 |
+
lines.append(line)
|
227 |
+
return lines
|
228 |
+
|
229 |
+
|
230 |
+
class XnliProcessor(DataProcessor):
|
231 |
+
"""Processor for the XNLI data set."""
|
232 |
+
|
233 |
+
def __init__(self):
|
234 |
+
self.language = "ar"
|
235 |
+
|
236 |
+
def get_train_examples(self, data_dir):
|
237 |
+
"""See base class."""
|
238 |
+
lines = self._read_tsv(
|
239 |
+
os.path.join(data_dir, "multinli", "multinli.train.%s.tsv" % self.language)
|
240 |
+
)
|
241 |
+
examples = []
|
242 |
+
for (i, line) in enumerate(lines):
|
243 |
+
if i == 0:
|
244 |
+
continue
|
245 |
+
guid = "train-%d" % (i)
|
246 |
+
text_a = tokenization.convert_to_unicode(line[0])
|
247 |
+
text_b = tokenization.convert_to_unicode(line[1])
|
248 |
+
label = tokenization.convert_to_unicode(line[2])
|
249 |
+
if label == tokenization.convert_to_unicode("contradictory"):
|
250 |
+
label = tokenization.convert_to_unicode("contradiction")
|
251 |
+
examples.append(
|
252 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
253 |
+
)
|
254 |
+
return examples
|
255 |
+
|
256 |
+
def get_dev_examples(self, data_dir):
|
257 |
+
"""See base class."""
|
258 |
+
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
|
259 |
+
examples = []
|
260 |
+
for (i, line) in enumerate(lines):
|
261 |
+
if i == 0:
|
262 |
+
continue
|
263 |
+
guid = "dev-%d" % (i)
|
264 |
+
language = tokenization.convert_to_unicode(line[0])
|
265 |
+
if language != tokenization.convert_to_unicode(self.language):
|
266 |
+
continue
|
267 |
+
text_a = tokenization.convert_to_unicode(line[6])
|
268 |
+
text_b = tokenization.convert_to_unicode(line[7])
|
269 |
+
label = tokenization.convert_to_unicode(line[1])
|
270 |
+
examples.append(
|
271 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
272 |
+
)
|
273 |
+
return examples
|
274 |
+
|
275 |
+
def get_labels(self):
|
276 |
+
"""See base class."""
|
277 |
+
return ["contradiction", "entailment", "neutral"]
|
278 |
+
|
279 |
+
|
280 |
+
class MnliProcessor(DataProcessor):
|
281 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
282 |
+
|
283 |
+
def get_train_examples(self, data_dir):
|
284 |
+
"""See base class."""
|
285 |
+
return self._create_examples(
|
286 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
|
287 |
+
)
|
288 |
+
|
289 |
+
def get_dev_examples(self, data_dir):
|
290 |
+
"""See base class."""
|
291 |
+
return self._create_examples(
|
292 |
+
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched"
|
293 |
+
)
|
294 |
+
|
295 |
+
def get_test_examples(self, data_dir):
|
296 |
+
"""See base class."""
|
297 |
+
return self._create_examples(
|
298 |
+
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test"
|
299 |
+
)
|
300 |
+
|
301 |
+
def get_labels(self):
|
302 |
+
"""See base class."""
|
303 |
+
return ["contradiction", "entailment", "neutral"]
|
304 |
+
|
305 |
+
def _create_examples(self, lines, set_type):
|
306 |
+
"""Creates examples for the training and dev sets."""
|
307 |
+
examples = []
|
308 |
+
for (i, line) in enumerate(lines):
|
309 |
+
if i == 0:
|
310 |
+
continue
|
311 |
+
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
|
312 |
+
text_a = tokenization.convert_to_unicode(line[8])
|
313 |
+
text_b = tokenization.convert_to_unicode(line[9])
|
314 |
+
if set_type == "test":
|
315 |
+
label = "contradiction"
|
316 |
+
else:
|
317 |
+
label = tokenization.convert_to_unicode(line[-1])
|
318 |
+
examples.append(
|
319 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
320 |
+
)
|
321 |
+
return examples
|
322 |
+
|
323 |
+
|
324 |
+
class MrpcProcessor(DataProcessor):
|
325 |
+
"""Processor for the MRPC data set (GLUE version)."""
|
326 |
+
|
327 |
+
def get_train_examples(self, data_dir):
|
328 |
+
"""See base class."""
|
329 |
+
return self._create_examples(
|
330 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
|
331 |
+
)
|
332 |
+
|
333 |
+
def get_dev_examples(self, data_dir):
|
334 |
+
"""See base class."""
|
335 |
+
return self._create_examples(
|
336 |
+
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev"
|
337 |
+
)
|
338 |
+
|
339 |
+
def get_test_examples(self, data_dir):
|
340 |
+
"""See base class."""
|
341 |
+
return self._create_examples(
|
342 |
+
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test"
|
343 |
+
)
|
344 |
+
|
345 |
+
def get_labels(self):
|
346 |
+
"""See base class."""
|
347 |
+
return ["0", "1"]
|
348 |
+
|
349 |
+
def _create_examples(self, lines, set_type):
|
350 |
+
"""Creates examples for the training and dev sets."""
|
351 |
+
examples = []
|
352 |
+
for (i, line) in enumerate(lines):
|
353 |
+
if i == 0:
|
354 |
+
continue
|
355 |
+
guid = "%s-%s" % (set_type, i)
|
356 |
+
text_a = tokenization.convert_to_unicode(line[3])
|
357 |
+
text_b = tokenization.convert_to_unicode(line[4])
|
358 |
+
if set_type == "test":
|
359 |
+
label = "0"
|
360 |
+
else:
|
361 |
+
label = tokenization.convert_to_unicode(line[0])
|
362 |
+
examples.append(
|
363 |
+
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
|
364 |
+
)
|
365 |
+
return examples
|
366 |
+
|
367 |
+
|
368 |
+
class ColaProcessor(DataProcessor):
|
369 |
+
"""Processor for the CoLA data set (GLUE version)."""
|
370 |
+
|
371 |
+
def get_train_examples(self, data_dir):
|
372 |
+
"""See base class."""
|
373 |
+
return self._create_examples(
|
374 |
+
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train"
|
375 |
+
)
|
376 |
+
|
377 |
+
def get_dev_examples(self, data_dir):
|
378 |
+
"""See base class."""
|
379 |
+
return self._create_examples(
|
380 |
+
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev"
|
381 |
+
)
|
382 |
+
|
383 |
+
def get_test_examples(self, data_dir):
|
384 |
+
"""See base class."""
|
385 |
+
return self._create_examples(
|
386 |
+
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test"
|
387 |
+
)
|
388 |
+
|
389 |
+
def get_labels(self):
|
390 |
+
"""See base class."""
|
391 |
+
return ["0", "1"]
|
392 |
+
|
393 |
+
def _create_examples(self, lines, set_type):
|
394 |
+
"""Creates examples for the training and dev sets."""
|
395 |
+
examples = []
|
396 |
+
for (i, line) in enumerate(lines):
|
397 |
+
# Only the test set has a header
|
398 |
+
if set_type == "test" and i == 0:
|
399 |
+
continue
|
400 |
+
guid = "%s-%s" % (set_type, i)
|
401 |
+
if set_type == "test":
|
402 |
+
text_a = tokenization.convert_to_unicode(line[1])
|
403 |
+
label = "0"
|
404 |
+
else:
|
405 |
+
text_a = tokenization.convert_to_unicode(line[3])
|
406 |
+
label = tokenization.convert_to_unicode(line[1])
|
407 |
+
examples.append(
|
408 |
+
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)
|
409 |
+
)
|
410 |
+
return examples
|
411 |
+
|
412 |
+
|
413 |
+
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
|
414 |
+
"""Converts a single `InputExample` into a single `InputFeatures`."""
|
415 |
+
|
416 |
+
if isinstance(example, PaddingInputExample):
|
417 |
+
return InputFeatures(
|
418 |
+
input_ids=[0] * max_seq_length,
|
419 |
+
input_mask=[0] * max_seq_length,
|
420 |
+
segment_ids=[0] * max_seq_length,
|
421 |
+
label_id=0,
|
422 |
+
is_real_example=False,
|
423 |
+
)
|
424 |
+
|
425 |
+
label_map = {}
|
426 |
+
for (i, label) in enumerate(label_list):
|
427 |
+
label_map[label] = i
|
428 |
+
|
429 |
+
tokens_a = tokenizer.tokenize(example.text_a)
|
430 |
+
tokens_b = None
|
431 |
+
if example.text_b:
|
432 |
+
tokens_b = tokenizer.tokenize(example.text_b)
|
433 |
+
|
434 |
+
if tokens_b:
|
435 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
436 |
+
# length is less than the specified length.
|
437 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
438 |
+
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
|
439 |
+
else:
|
440 |
+
# Account for [CLS] and [SEP] with "- 2"
|
441 |
+
if len(tokens_a) > max_seq_length - 2:
|
442 |
+
tokens_a = tokens_a[0 : (max_seq_length - 2)]
|
443 |
+
|
444 |
+
# The convention in BERT is:
|
445 |
+
# (a) For sequence pairs:
|
446 |
+
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
447 |
+
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
448 |
+
# (b) For single sequences:
|
449 |
+
# tokens: [CLS] the dog is hairy . [SEP]
|
450 |
+
# type_ids: 0 0 0 0 0 0 0
|
451 |
+
#
|
452 |
+
# Where "type_ids" are used to indicate whether this is the first
|
453 |
+
# sequence or the second sequence. The embedding vectors for `type=0` and
|
454 |
+
# `type=1` were learned during pre-training and are added to the wordpiece
|
455 |
+
# embedding vector (and position vector). This is not *strictly* necessary
|
456 |
+
# since the [SEP] token unambiguously separates the sequences, but it makes
|
457 |
+
# it easier for the model to learn the concept of sequences.
|
458 |
+
#
|
459 |
+
# For classification tasks, the first vector (corresponding to [CLS]) is
|
460 |
+
# used as the "sentence vector". Note that this only makes sense because
|
461 |
+
# the entire model is fine-tuned.
|
462 |
+
tokens = []
|
463 |
+
segment_ids = []
|
464 |
+
tokens.append("[CLS]")
|
465 |
+
segment_ids.append(0)
|
466 |
+
for token in tokens_a:
|
467 |
+
tokens.append(token)
|
468 |
+
segment_ids.append(0)
|
469 |
+
tokens.append("[SEP]")
|
470 |
+
segment_ids.append(0)
|
471 |
+
|
472 |
+
if tokens_b:
|
473 |
+
for token in tokens_b:
|
474 |
+
tokens.append(token)
|
475 |
+
segment_ids.append(1)
|
476 |
+
tokens.append("[SEP]")
|
477 |
+
segment_ids.append(1)
|
478 |
+
|
479 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
480 |
+
|
481 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
482 |
+
# tokens are attended to.
|
483 |
+
input_mask = [1] * len(input_ids)
|
484 |
+
|
485 |
+
# Zero-pad up to the sequence length.
|
486 |
+
while len(input_ids) < max_seq_length:
|
487 |
+
input_ids.append(0)
|
488 |
+
input_mask.append(0)
|
489 |
+
segment_ids.append(0)
|
490 |
+
|
491 |
+
assert len(input_ids) == max_seq_length
|
492 |
+
assert len(input_mask) == max_seq_length
|
493 |
+
assert len(segment_ids) == max_seq_length
|
494 |
+
|
495 |
+
label_id = label_map[example.label]
|
496 |
+
if ex_index < 5:
|
497 |
+
tf.logging.info("*** Example ***")
|
498 |
+
tf.logging.info("guid: %s" % (example.guid))
|
499 |
+
tf.logging.info(
|
500 |
+
"tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])
|
501 |
+
)
|
502 |
+
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
503 |
+
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
504 |
+
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
505 |
+
tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
|
506 |
+
|
507 |
+
feature = InputFeatures(
|
508 |
+
input_ids=input_ids,
|
509 |
+
input_mask=input_mask,
|
510 |
+
segment_ids=segment_ids,
|
511 |
+
label_id=label_id,
|
512 |
+
is_real_example=True,
|
513 |
+
)
|
514 |
+
return feature
|
515 |
+
|
516 |
+
|
517 |
+
def file_based_convert_examples_to_features(
|
518 |
+
examples, label_list, max_seq_length, tokenizer, output_file
|
519 |
+
):
|
520 |
+
"""Convert a set of `InputExample`s to a TFRecord file."""
|
521 |
+
|
522 |
+
writer = tf.python_io.TFRecordWriter(output_file)
|
523 |
+
|
524 |
+
for (ex_index, example) in enumerate(examples):
|
525 |
+
if ex_index % 10000 == 0:
|
526 |
+
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
|
527 |
+
|
528 |
+
feature = convert_single_example(
|
529 |
+
ex_index, example, label_list, max_seq_length, tokenizer
|
530 |
+
)
|
531 |
+
|
532 |
+
def create_int_feature(values):
|
533 |
+
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
534 |
+
return f
|
535 |
+
|
536 |
+
features = collections.OrderedDict()
|
537 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
538 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
539 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
540 |
+
features["label_ids"] = create_int_feature([feature.label_id])
|
541 |
+
features["is_real_example"] = create_int_feature([int(feature.is_real_example)])
|
542 |
+
|
543 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
544 |
+
writer.write(tf_example.SerializeToString())
|
545 |
+
writer.close()
|
546 |
+
|
547 |
+
|
548 |
+
def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
|
549 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
550 |
+
|
551 |
+
name_to_features = {
|
552 |
+
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
553 |
+
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
|
554 |
+
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
555 |
+
"label_ids": tf.FixedLenFeature([], tf.int64),
|
556 |
+
"is_real_example": tf.FixedLenFeature([], tf.int64),
|
557 |
+
}
|
558 |
+
|
559 |
+
def _decode_record(record, name_to_features):
|
560 |
+
"""Decodes a record to a TensorFlow example."""
|
561 |
+
example = tf.parse_single_example(record, name_to_features)
|
562 |
+
|
563 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
564 |
+
# So cast all int64 to int32.
|
565 |
+
for name in list(example.keys()):
|
566 |
+
t = example[name]
|
567 |
+
if t.dtype == tf.int64:
|
568 |
+
t = tf.to_int32(t)
|
569 |
+
example[name] = t
|
570 |
+
|
571 |
+
return example
|
572 |
+
|
573 |
+
def input_fn(params):
|
574 |
+
"""The actual input function."""
|
575 |
+
batch_size = params["batch_size"]
|
576 |
+
|
577 |
+
# For training, we want a lot of parallel reading and shuffling.
|
578 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
579 |
+
d = tf.data.TFRecordDataset(input_file)
|
580 |
+
if is_training:
|
581 |
+
d = d.repeat()
|
582 |
+
d = d.shuffle(buffer_size=100)
|
583 |
+
|
584 |
+
d = d.apply(
|
585 |
+
tf.contrib.data.map_and_batch(
|
586 |
+
lambda record: _decode_record(record, name_to_features),
|
587 |
+
batch_size=batch_size,
|
588 |
+
drop_remainder=drop_remainder,
|
589 |
+
)
|
590 |
+
)
|
591 |
+
|
592 |
+
return d
|
593 |
+
|
594 |
+
return input_fn
|
595 |
+
|
596 |
+
|
597 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
598 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
599 |
+
|
600 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
601 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
602 |
+
# of tokens from each, since if one sequence is very short then each token
|
603 |
+
# that's truncated likely contains more information than a longer sequence.
|
604 |
+
while True:
|
605 |
+
total_length = len(tokens_a) + len(tokens_b)
|
606 |
+
if total_length <= max_length:
|
607 |
+
break
|
608 |
+
if len(tokens_a) > len(tokens_b):
|
609 |
+
tokens_a.pop()
|
610 |
+
else:
|
611 |
+
tokens_b.pop()
|
612 |
+
|
613 |
+
|
614 |
+
def create_model(
|
615 |
+
bert_config,
|
616 |
+
is_training,
|
617 |
+
input_ids,
|
618 |
+
input_mask,
|
619 |
+
segment_ids,
|
620 |
+
labels,
|
621 |
+
num_labels,
|
622 |
+
use_one_hot_embeddings,
|
623 |
+
):
|
624 |
+
"""Creates a classification model."""
|
625 |
+
model = modeling.BertModel(
|
626 |
+
config=bert_config,
|
627 |
+
is_training=is_training,
|
628 |
+
input_ids=input_ids,
|
629 |
+
input_mask=input_mask,
|
630 |
+
token_type_ids=segment_ids,
|
631 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
632 |
+
)
|
633 |
+
|
634 |
+
# In the demo, we are doing a simple classification task on the entire
|
635 |
+
# segment.
|
636 |
+
#
|
637 |
+
# If you want to use the token-level output, use model.get_sequence_output()
|
638 |
+
# instead.
|
639 |
+
output_layer = model.get_pooled_output()
|
640 |
+
|
641 |
+
hidden_size = output_layer.shape[-1].value
|
642 |
+
|
643 |
+
output_weights = tf.get_variable(
|
644 |
+
"output_weights",
|
645 |
+
[num_labels, hidden_size],
|
646 |
+
initializer=tf.truncated_normal_initializer(stddev=0.02),
|
647 |
+
)
|
648 |
+
|
649 |
+
output_bias = tf.get_variable(
|
650 |
+
"output_bias", [num_labels], initializer=tf.zeros_initializer()
|
651 |
+
)
|
652 |
+
|
653 |
+
with tf.variable_scope("loss"):
|
654 |
+
if is_training:
|
655 |
+
# I.e., 0.1 dropout
|
656 |
+
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
|
657 |
+
|
658 |
+
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
|
659 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
660 |
+
probabilities = tf.nn.softmax(logits, axis=-1)
|
661 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
662 |
+
|
663 |
+
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
|
664 |
+
|
665 |
+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
666 |
+
loss = tf.reduce_mean(per_example_loss)
|
667 |
+
|
668 |
+
return (loss, per_example_loss, logits, probabilities)
|
669 |
+
|
670 |
+
|
671 |
+
def model_fn_builder(
|
672 |
+
bert_config,
|
673 |
+
num_labels,
|
674 |
+
init_checkpoint,
|
675 |
+
learning_rate,
|
676 |
+
num_train_steps,
|
677 |
+
num_warmup_steps,
|
678 |
+
use_tpu,
|
679 |
+
use_one_hot_embeddings,
|
680 |
+
):
|
681 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
682 |
+
|
683 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
684 |
+
"""The `model_fn` for TPUEstimator."""
|
685 |
+
|
686 |
+
tf.logging.info("*** Features ***")
|
687 |
+
for name in sorted(features.keys()):
|
688 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
689 |
+
|
690 |
+
input_ids = features["input_ids"]
|
691 |
+
input_mask = features["input_mask"]
|
692 |
+
segment_ids = features["segment_ids"]
|
693 |
+
label_ids = features["label_ids"]
|
694 |
+
is_real_example = None
|
695 |
+
if "is_real_example" in features:
|
696 |
+
is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
|
697 |
+
else:
|
698 |
+
is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
|
699 |
+
|
700 |
+
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
701 |
+
|
702 |
+
(total_loss, per_example_loss, logits, probabilities) = create_model(
|
703 |
+
bert_config,
|
704 |
+
is_training,
|
705 |
+
input_ids,
|
706 |
+
input_mask,
|
707 |
+
segment_ids,
|
708 |
+
label_ids,
|
709 |
+
num_labels,
|
710 |
+
use_one_hot_embeddings,
|
711 |
+
)
|
712 |
+
|
713 |
+
tvars = tf.trainable_variables()
|
714 |
+
initialized_variable_names = {}
|
715 |
+
scaffold_fn = None
|
716 |
+
if init_checkpoint:
|
717 |
+
(
|
718 |
+
assignment_map,
|
719 |
+
initialized_variable_names,
|
720 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
721 |
+
if use_tpu:
|
722 |
+
|
723 |
+
def tpu_scaffold():
|
724 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
725 |
+
return tf.train.Scaffold()
|
726 |
+
|
727 |
+
scaffold_fn = tpu_scaffold
|
728 |
+
else:
|
729 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
730 |
+
|
731 |
+
tf.logging.info("**** Trainable Variables ****")
|
732 |
+
for var in tvars:
|
733 |
+
init_string = ""
|
734 |
+
if var.name in initialized_variable_names:
|
735 |
+
init_string = ", *INIT_FROM_CKPT*"
|
736 |
+
tf.logging.info(
|
737 |
+
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
738 |
+
)
|
739 |
+
|
740 |
+
output_spec = None
|
741 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
742 |
+
|
743 |
+
train_op = optimization.create_optimizer(
|
744 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
|
745 |
+
)
|
746 |
+
|
747 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
748 |
+
mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
|
749 |
+
)
|
750 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
751 |
+
|
752 |
+
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
753 |
+
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
754 |
+
accuracy = tf.metrics.accuracy(
|
755 |
+
labels=label_ids, predictions=predictions, weights=is_real_example
|
756 |
+
)
|
757 |
+
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
|
758 |
+
return {
|
759 |
+
"eval_accuracy": accuracy,
|
760 |
+
"eval_loss": loss,
|
761 |
+
}
|
762 |
+
|
763 |
+
eval_metrics = (
|
764 |
+
metric_fn,
|
765 |
+
[per_example_loss, label_ids, logits, is_real_example],
|
766 |
+
)
|
767 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
768 |
+
mode=mode,
|
769 |
+
loss=total_loss,
|
770 |
+
eval_metrics=eval_metrics,
|
771 |
+
scaffold_fn=scaffold_fn,
|
772 |
+
)
|
773 |
+
else:
|
774 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
775 |
+
mode=mode,
|
776 |
+
predictions={"probabilities": probabilities},
|
777 |
+
scaffold_fn=scaffold_fn,
|
778 |
+
)
|
779 |
+
return output_spec
|
780 |
+
|
781 |
+
return model_fn
|
782 |
+
|
783 |
+
|
784 |
+
# This function is not used by this file but is still used by the Colab and
|
785 |
+
# people who depend on it.
|
786 |
+
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
787 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
788 |
+
|
789 |
+
all_input_ids = []
|
790 |
+
all_input_mask = []
|
791 |
+
all_segment_ids = []
|
792 |
+
all_label_ids = []
|
793 |
+
|
794 |
+
for feature in features:
|
795 |
+
all_input_ids.append(feature.input_ids)
|
796 |
+
all_input_mask.append(feature.input_mask)
|
797 |
+
all_segment_ids.append(feature.segment_ids)
|
798 |
+
all_label_ids.append(feature.label_id)
|
799 |
+
|
800 |
+
def input_fn(params):
|
801 |
+
"""The actual input function."""
|
802 |
+
batch_size = params["batch_size"]
|
803 |
+
|
804 |
+
num_examples = len(features)
|
805 |
+
|
806 |
+
# This is for demo purposes and does NOT scale to large data sets. We do
|
807 |
+
# not use Dataset.from_generator() because that uses tf.py_func which is
|
808 |
+
# not TPU compatible. The right way to load data is with TFRecordReader.
|
809 |
+
d = tf.data.Dataset.from_tensor_slices(
|
810 |
+
{
|
811 |
+
"input_ids": tf.constant(
|
812 |
+
all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32
|
813 |
+
),
|
814 |
+
"input_mask": tf.constant(
|
815 |
+
all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32
|
816 |
+
),
|
817 |
+
"segment_ids": tf.constant(
|
818 |
+
all_segment_ids, shape=[num_examples, seq_length], dtype=tf.int32
|
819 |
+
),
|
820 |
+
"label_ids": tf.constant(
|
821 |
+
all_label_ids, shape=[num_examples], dtype=tf.int32
|
822 |
+
),
|
823 |
+
}
|
824 |
+
)
|
825 |
+
|
826 |
+
if is_training:
|
827 |
+
d = d.repeat()
|
828 |
+
d = d.shuffle(buffer_size=100)
|
829 |
+
|
830 |
+
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
|
831 |
+
return d
|
832 |
+
|
833 |
+
return input_fn
|
834 |
+
|
835 |
+
|
836 |
+
# This function is not used by this file but is still used by the Colab and
|
837 |
+
# people who depend on it.
|
838 |
+
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
839 |
+
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
|
840 |
+
|
841 |
+
features = []
|
842 |
+
for (ex_index, example) in enumerate(examples):
|
843 |
+
if ex_index % 10000 == 0:
|
844 |
+
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
|
845 |
+
|
846 |
+
feature = convert_single_example(
|
847 |
+
ex_index, example, label_list, max_seq_length, tokenizer
|
848 |
+
)
|
849 |
+
|
850 |
+
features.append(feature)
|
851 |
+
return features
|
852 |
+
|
853 |
+
|
854 |
+
def main(_):
|
855 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
856 |
+
logger = tf.get_logger()
|
857 |
+
logger.propagate = False
|
858 |
+
|
859 |
+
processors = {
|
860 |
+
"cola": ColaProcessor,
|
861 |
+
"mnli": MnliProcessor,
|
862 |
+
"mrpc": MrpcProcessor,
|
863 |
+
"xnli": XnliProcessor,
|
864 |
+
}
|
865 |
+
|
866 |
+
tokenization.validate_case_matches_checkpoint(
|
867 |
+
FLAGS.do_lower_case, FLAGS.init_checkpoint
|
868 |
+
)
|
869 |
+
|
870 |
+
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
|
871 |
+
raise ValueError(
|
872 |
+
"At least one of `do_train`, `do_eval` or `do_predict' must be True."
|
873 |
+
)
|
874 |
+
|
875 |
+
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
876 |
+
|
877 |
+
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
|
878 |
+
raise ValueError(
|
879 |
+
"Cannot use sequence length %d because the BERT model "
|
880 |
+
"was only trained up to sequence length %d"
|
881 |
+
% (FLAGS.max_seq_length, bert_config.max_position_embeddings)
|
882 |
+
)
|
883 |
+
|
884 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
885 |
+
|
886 |
+
task_name = FLAGS.task_name.lower()
|
887 |
+
|
888 |
+
if task_name not in processors:
|
889 |
+
raise ValueError("Task not found: %s" % (task_name))
|
890 |
+
|
891 |
+
processor = processors[task_name]()
|
892 |
+
|
893 |
+
label_list = processor.get_labels()
|
894 |
+
|
895 |
+
tokenizer = tokenization.FullTokenizer(
|
896 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
897 |
+
)
|
898 |
+
|
899 |
+
tpu_cluster_resolver = None
|
900 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
901 |
+
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
902 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
|
903 |
+
)
|
904 |
+
|
905 |
+
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
906 |
+
run_config = tf.contrib.tpu.RunConfig(
|
907 |
+
cluster=tpu_cluster_resolver,
|
908 |
+
master=FLAGS.master,
|
909 |
+
model_dir=FLAGS.output_dir,
|
910 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
911 |
+
tpu_config=tf.contrib.tpu.TPUConfig(
|
912 |
+
iterations_per_loop=FLAGS.iterations_per_loop,
|
913 |
+
num_shards=FLAGS.num_tpu_cores,
|
914 |
+
per_host_input_for_training=is_per_host,
|
915 |
+
),
|
916 |
+
)
|
917 |
+
|
918 |
+
train_examples = None
|
919 |
+
num_train_steps = None
|
920 |
+
num_warmup_steps = None
|
921 |
+
if FLAGS.do_train:
|
922 |
+
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
923 |
+
num_train_steps = int(
|
924 |
+
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
|
925 |
+
)
|
926 |
+
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
927 |
+
|
928 |
+
model_fn = model_fn_builder(
|
929 |
+
bert_config=bert_config,
|
930 |
+
num_labels=len(label_list),
|
931 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
932 |
+
learning_rate=FLAGS.learning_rate,
|
933 |
+
num_train_steps=num_train_steps,
|
934 |
+
num_warmup_steps=num_warmup_steps,
|
935 |
+
use_tpu=FLAGS.use_tpu,
|
936 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
937 |
+
)
|
938 |
+
|
939 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
940 |
+
# or GPU.
|
941 |
+
estimator = tf.contrib.tpu.TPUEstimator(
|
942 |
+
use_tpu=FLAGS.use_tpu,
|
943 |
+
model_fn=model_fn,
|
944 |
+
config=run_config,
|
945 |
+
train_batch_size=FLAGS.train_batch_size,
|
946 |
+
eval_batch_size=FLAGS.eval_batch_size,
|
947 |
+
predict_batch_size=FLAGS.predict_batch_size,
|
948 |
+
)
|
949 |
+
|
950 |
+
if FLAGS.do_train:
|
951 |
+
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
|
952 |
+
file_based_convert_examples_to_features(
|
953 |
+
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file
|
954 |
+
)
|
955 |
+
tf.logging.info("***** Running training *****")
|
956 |
+
tf.logging.info(" Num examples = %d", len(train_examples))
|
957 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
958 |
+
tf.logging.info(" Num steps = %d", num_train_steps)
|
959 |
+
train_input_fn = file_based_input_fn_builder(
|
960 |
+
input_file=train_file,
|
961 |
+
seq_length=FLAGS.max_seq_length,
|
962 |
+
is_training=True,
|
963 |
+
drop_remainder=True,
|
964 |
+
)
|
965 |
+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
966 |
+
|
967 |
+
if FLAGS.do_eval:
|
968 |
+
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
|
969 |
+
num_actual_eval_examples = len(eval_examples)
|
970 |
+
if FLAGS.use_tpu:
|
971 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
972 |
+
# of examples must be a multiple of the batch size, or else examples
|
973 |
+
# will get dropped. So we pad with fake examples which are ignored
|
974 |
+
# later on. These do NOT count towards the metric (all tf.metrics
|
975 |
+
# support a per-instance weight, and these get a weight of 0.0).
|
976 |
+
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
977 |
+
eval_examples.append(PaddingInputExample())
|
978 |
+
|
979 |
+
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
|
980 |
+
file_based_convert_examples_to_features(
|
981 |
+
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file
|
982 |
+
)
|
983 |
+
|
984 |
+
tf.logging.info("***** Running evaluation *****")
|
985 |
+
tf.logging.info(
|
986 |
+
" Num examples = %d (%d actual, %d padding)",
|
987 |
+
len(eval_examples),
|
988 |
+
num_actual_eval_examples,
|
989 |
+
len(eval_examples) - num_actual_eval_examples,
|
990 |
+
)
|
991 |
+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
992 |
+
|
993 |
+
# This tells the estimator to run through the entire set.
|
994 |
+
eval_steps = None
|
995 |
+
# However, if running eval on the TPU, you will need to specify the
|
996 |
+
# number of steps.
|
997 |
+
if FLAGS.use_tpu:
|
998 |
+
assert len(eval_examples) % FLAGS.eval_batch_size == 0
|
999 |
+
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
|
1000 |
+
|
1001 |
+
eval_drop_remainder = True if FLAGS.use_tpu else False
|
1002 |
+
eval_input_fn = file_based_input_fn_builder(
|
1003 |
+
input_file=eval_file,
|
1004 |
+
seq_length=FLAGS.max_seq_length,
|
1005 |
+
is_training=False,
|
1006 |
+
drop_remainder=eval_drop_remainder,
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
1010 |
+
|
1011 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
1012 |
+
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
1013 |
+
tf.logging.info("***** Eval results *****")
|
1014 |
+
for key in sorted(result.keys()):
|
1015 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
1016 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
1017 |
+
|
1018 |
+
if FLAGS.do_predict:
|
1019 |
+
predict_examples = processor.get_test_examples(FLAGS.data_dir)
|
1020 |
+
num_actual_predict_examples = len(predict_examples)
|
1021 |
+
if FLAGS.use_tpu:
|
1022 |
+
# TPU requires a fixed batch size for all batches, therefore the number
|
1023 |
+
# of examples must be a multiple of the batch size, or else examples
|
1024 |
+
# will get dropped. So we pad with fake examples which are ignored
|
1025 |
+
# later on.
|
1026 |
+
while len(predict_examples) % FLAGS.predict_batch_size != 0:
|
1027 |
+
predict_examples.append(PaddingInputExample())
|
1028 |
+
|
1029 |
+
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
1030 |
+
file_based_convert_examples_to_features(
|
1031 |
+
predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
tf.logging.info("***** Running prediction*****")
|
1035 |
+
tf.logging.info(
|
1036 |
+
" Num examples = %d (%d actual, %d padding)",
|
1037 |
+
len(predict_examples),
|
1038 |
+
num_actual_predict_examples,
|
1039 |
+
len(predict_examples) - num_actual_predict_examples,
|
1040 |
+
)
|
1041 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
1042 |
+
|
1043 |
+
predict_drop_remainder = True if FLAGS.use_tpu else False
|
1044 |
+
predict_input_fn = file_based_input_fn_builder(
|
1045 |
+
input_file=predict_file,
|
1046 |
+
seq_length=FLAGS.max_seq_length,
|
1047 |
+
is_training=False,
|
1048 |
+
drop_remainder=predict_drop_remainder,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
result = estimator.predict(input_fn=predict_input_fn)
|
1052 |
+
|
1053 |
+
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
|
1054 |
+
with tf.gfile.GFile(output_predict_file, "w") as writer:
|
1055 |
+
num_written_lines = 0
|
1056 |
+
tf.logging.info("***** Predict results *****")
|
1057 |
+
for (i, prediction) in enumerate(result):
|
1058 |
+
probabilities = prediction["probabilities"]
|
1059 |
+
if i >= num_actual_predict_examples:
|
1060 |
+
break
|
1061 |
+
output_line = (
|
1062 |
+
"\t".join(
|
1063 |
+
str(class_probability) for class_probability in probabilities
|
1064 |
+
)
|
1065 |
+
+ "\n"
|
1066 |
+
)
|
1067 |
+
writer.write(output_line)
|
1068 |
+
num_written_lines += 1
|
1069 |
+
assert num_written_lines == num_actual_predict_examples
|
1070 |
+
|
1071 |
+
|
1072 |
+
if __name__ == "__main__":
|
1073 |
+
flags.mark_flag_as_required("data_dir")
|
1074 |
+
flags.mark_flag_as_required("task_name")
|
1075 |
+
flags.mark_flag_as_required("vocab_file")
|
1076 |
+
flags.mark_flag_as_required("bert_config_file")
|
1077 |
+
flags.mark_flag_as_required("output_dir")
|
1078 |
+
tf.app.run()
|
arabert/arabert/run_pretraining.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Run masked LM/next sentence masked_lm pre-training for BERT."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import os
|
22 |
+
import modeling
|
23 |
+
import optimization
|
24 |
+
import tensorflow as tf
|
25 |
+
|
26 |
+
flags = tf.flags
|
27 |
+
|
28 |
+
FLAGS = flags.FLAGS
|
29 |
+
|
30 |
+
## Required parameters
|
31 |
+
flags.DEFINE_string(
|
32 |
+
"bert_config_file",
|
33 |
+
None,
|
34 |
+
"The config json file corresponding to the pre-trained BERT model. "
|
35 |
+
"This specifies the model architecture.",
|
36 |
+
)
|
37 |
+
|
38 |
+
flags.DEFINE_string(
|
39 |
+
"input_file", None, "Input TF example files (can be a glob or comma separated)."
|
40 |
+
)
|
41 |
+
|
42 |
+
flags.DEFINE_string(
|
43 |
+
"output_dir",
|
44 |
+
None,
|
45 |
+
"The output directory where the model checkpoints will be written.",
|
46 |
+
)
|
47 |
+
|
48 |
+
## Other parameters
|
49 |
+
flags.DEFINE_string(
|
50 |
+
"init_checkpoint",
|
51 |
+
None,
|
52 |
+
"Initial checkpoint (usually from a pre-trained BERT model).",
|
53 |
+
)
|
54 |
+
|
55 |
+
flags.DEFINE_integer(
|
56 |
+
"max_seq_length",
|
57 |
+
128,
|
58 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
59 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
60 |
+
"than this will be padded. Must match data generation.",
|
61 |
+
)
|
62 |
+
|
63 |
+
flags.DEFINE_integer(
|
64 |
+
"max_predictions_per_seq",
|
65 |
+
20,
|
66 |
+
"Maximum number of masked LM predictions per sequence. "
|
67 |
+
"Must match data generation.",
|
68 |
+
)
|
69 |
+
|
70 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
71 |
+
|
72 |
+
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
73 |
+
|
74 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
75 |
+
|
76 |
+
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
77 |
+
|
78 |
+
flags.DEFINE_float("poly_power", 1.0, "The power of poly decay.")
|
79 |
+
|
80 |
+
flags.DEFINE_enum("optimizer", "lamb", ["adamw", "lamb"],
|
81 |
+
"The optimizer for training.")
|
82 |
+
|
83 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
84 |
+
|
85 |
+
flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
|
86 |
+
|
87 |
+
flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
|
88 |
+
|
89 |
+
flags.DEFINE_integer("start_warmup_step", 0, "The starting step of warmup.")
|
90 |
+
|
91 |
+
flags.DEFINE_integer(
|
92 |
+
"save_checkpoints_steps", 1000, "How often to save the model checkpoint."
|
93 |
+
)
|
94 |
+
|
95 |
+
flags.DEFINE_integer(
|
96 |
+
"iterations_per_loop", 1000, "How many steps to make in each estimator call."
|
97 |
+
)
|
98 |
+
|
99 |
+
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
100 |
+
|
101 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
102 |
+
|
103 |
+
tf.flags.DEFINE_string(
|
104 |
+
"tpu_name",
|
105 |
+
None,
|
106 |
+
"The Cloud TPU to use for training. This should be either the name "
|
107 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
108 |
+
"url.",
|
109 |
+
)
|
110 |
+
|
111 |
+
tf.flags.DEFINE_string(
|
112 |
+
"tpu_zone",
|
113 |
+
None,
|
114 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
115 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
116 |
+
"metadata.",
|
117 |
+
)
|
118 |
+
|
119 |
+
tf.flags.DEFINE_string(
|
120 |
+
"gcp_project",
|
121 |
+
None,
|
122 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
123 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
124 |
+
"metadata.",
|
125 |
+
)
|
126 |
+
|
127 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
128 |
+
|
129 |
+
flags.DEFINE_integer(
|
130 |
+
"num_tpu_cores",
|
131 |
+
8,
|
132 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
133 |
+
)
|
134 |
+
|
135 |
+
flags.DEFINE_integer("keep_checkpoint_max", 10,
|
136 |
+
"How many checkpoints to keep.")
|
137 |
+
|
138 |
+
|
139 |
+
def model_fn_builder(
|
140 |
+
bert_config,
|
141 |
+
init_checkpoint,
|
142 |
+
learning_rate,
|
143 |
+
num_train_steps,
|
144 |
+
num_warmup_steps,
|
145 |
+
use_tpu,
|
146 |
+
use_one_hot_embeddings,
|
147 |
+
optimizer,
|
148 |
+
poly_power,
|
149 |
+
start_warmup_step,
|
150 |
+
):
|
151 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
152 |
+
|
153 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
154 |
+
"""The `model_fn` for TPUEstimator."""
|
155 |
+
|
156 |
+
tf.logging.info("*** Features ***")
|
157 |
+
for name in sorted(features.keys()):
|
158 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
159 |
+
|
160 |
+
input_ids = features["input_ids"]
|
161 |
+
input_mask = features["input_mask"]
|
162 |
+
segment_ids = features["segment_ids"]
|
163 |
+
masked_lm_positions = features["masked_lm_positions"]
|
164 |
+
masked_lm_ids = features["masked_lm_ids"]
|
165 |
+
masked_lm_weights = features["masked_lm_weights"]
|
166 |
+
next_sentence_labels = features["next_sentence_labels"]
|
167 |
+
|
168 |
+
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
169 |
+
|
170 |
+
model = modeling.BertModel(
|
171 |
+
config=bert_config,
|
172 |
+
is_training=is_training,
|
173 |
+
input_ids=input_ids,
|
174 |
+
input_mask=input_mask,
|
175 |
+
token_type_ids=segment_ids,
|
176 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
177 |
+
)
|
178 |
+
|
179 |
+
(
|
180 |
+
masked_lm_loss,
|
181 |
+
masked_lm_example_loss,
|
182 |
+
masked_lm_log_probs,
|
183 |
+
) = get_masked_lm_output(
|
184 |
+
bert_config,
|
185 |
+
model.get_sequence_output(),
|
186 |
+
model.get_embedding_table(),
|
187 |
+
masked_lm_positions,
|
188 |
+
masked_lm_ids,
|
189 |
+
masked_lm_weights,
|
190 |
+
)
|
191 |
+
|
192 |
+
(
|
193 |
+
next_sentence_loss,
|
194 |
+
next_sentence_example_loss,
|
195 |
+
next_sentence_log_probs,
|
196 |
+
) = get_next_sentence_output(
|
197 |
+
bert_config, model.get_pooled_output(), next_sentence_labels
|
198 |
+
)
|
199 |
+
|
200 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
201 |
+
|
202 |
+
tvars = tf.trainable_variables()
|
203 |
+
|
204 |
+
initialized_variable_names = {}
|
205 |
+
scaffold_fn = None
|
206 |
+
if init_checkpoint:
|
207 |
+
(
|
208 |
+
assignment_map,
|
209 |
+
initialized_variable_names,
|
210 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
211 |
+
if use_tpu:
|
212 |
+
|
213 |
+
def tpu_scaffold():
|
214 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
215 |
+
return tf.train.Scaffold()
|
216 |
+
|
217 |
+
scaffold_fn = tpu_scaffold
|
218 |
+
else:
|
219 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
220 |
+
|
221 |
+
tf.logging.info("**** Trainable Variables ****")
|
222 |
+
for var in tvars:
|
223 |
+
init_string = ""
|
224 |
+
if var.name in initialized_variable_names:
|
225 |
+
init_string = ", *INIT_FROM_CKPT*"
|
226 |
+
tf.logging.info(
|
227 |
+
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
228 |
+
)
|
229 |
+
|
230 |
+
output_spec = None
|
231 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
232 |
+
train_op = optimization.create_optimizer(
|
233 |
+
total_loss,
|
234 |
+
learning_rate,
|
235 |
+
num_train_steps,
|
236 |
+
num_warmup_steps,
|
237 |
+
use_tpu,
|
238 |
+
optimizer,
|
239 |
+
poly_power,
|
240 |
+
start_warmup_step,
|
241 |
+
)
|
242 |
+
|
243 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
244 |
+
mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
|
245 |
+
)
|
246 |
+
elif mode == tf.estimator.ModeKeys.EVAL:
|
247 |
+
|
248 |
+
def metric_fn(
|
249 |
+
masked_lm_example_loss,
|
250 |
+
masked_lm_log_probs,
|
251 |
+
masked_lm_ids,
|
252 |
+
masked_lm_weights,
|
253 |
+
next_sentence_example_loss,
|
254 |
+
next_sentence_log_probs,
|
255 |
+
next_sentence_labels,
|
256 |
+
):
|
257 |
+
"""Computes the loss and accuracy of the model."""
|
258 |
+
masked_lm_log_probs = tf.reshape(
|
259 |
+
masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]
|
260 |
+
)
|
261 |
+
masked_lm_predictions = tf.argmax(
|
262 |
+
masked_lm_log_probs, axis=-1, output_type=tf.int32
|
263 |
+
)
|
264 |
+
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
|
265 |
+
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
|
266 |
+
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
|
267 |
+
masked_lm_accuracy = tf.metrics.accuracy(
|
268 |
+
labels=masked_lm_ids,
|
269 |
+
predictions=masked_lm_predictions,
|
270 |
+
weights=masked_lm_weights,
|
271 |
+
)
|
272 |
+
masked_lm_mean_loss = tf.metrics.mean(
|
273 |
+
values=masked_lm_example_loss, weights=masked_lm_weights
|
274 |
+
)
|
275 |
+
|
276 |
+
next_sentence_log_probs = tf.reshape(
|
277 |
+
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]
|
278 |
+
)
|
279 |
+
next_sentence_predictions = tf.argmax(
|
280 |
+
next_sentence_log_probs, axis=-1, output_type=tf.int32
|
281 |
+
)
|
282 |
+
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
|
283 |
+
next_sentence_accuracy = tf.metrics.accuracy(
|
284 |
+
labels=next_sentence_labels, predictions=next_sentence_predictions
|
285 |
+
)
|
286 |
+
next_sentence_mean_loss = tf.metrics.mean(
|
287 |
+
values=next_sentence_example_loss
|
288 |
+
)
|
289 |
+
|
290 |
+
return {
|
291 |
+
"masked_lm_accuracy": masked_lm_accuracy,
|
292 |
+
"masked_lm_loss": masked_lm_mean_loss,
|
293 |
+
"next_sentence_accuracy": next_sentence_accuracy,
|
294 |
+
"next_sentence_loss": next_sentence_mean_loss,
|
295 |
+
}
|
296 |
+
|
297 |
+
eval_metrics = (
|
298 |
+
metric_fn,
|
299 |
+
[
|
300 |
+
masked_lm_example_loss,
|
301 |
+
masked_lm_log_probs,
|
302 |
+
masked_lm_ids,
|
303 |
+
masked_lm_weights,
|
304 |
+
next_sentence_example_loss,
|
305 |
+
next_sentence_log_probs,
|
306 |
+
next_sentence_labels,
|
307 |
+
],
|
308 |
+
)
|
309 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
310 |
+
mode=mode,
|
311 |
+
loss=total_loss,
|
312 |
+
eval_metrics=eval_metrics,
|
313 |
+
scaffold_fn=scaffold_fn,
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
317 |
+
|
318 |
+
return output_spec
|
319 |
+
|
320 |
+
return model_fn
|
321 |
+
|
322 |
+
|
323 |
+
def get_masked_lm_output(
|
324 |
+
bert_config, input_tensor, output_weights, positions, label_ids, label_weights
|
325 |
+
):
|
326 |
+
"""Get loss and log probs for the masked LM."""
|
327 |
+
input_tensor = gather_indexes(input_tensor, positions)
|
328 |
+
|
329 |
+
with tf.variable_scope("cls/predictions"):
|
330 |
+
# We apply one more non-linear transformation before the output layer.
|
331 |
+
# This matrix is not used after pre-training.
|
332 |
+
with tf.variable_scope("transform"):
|
333 |
+
input_tensor = tf.layers.dense(
|
334 |
+
input_tensor,
|
335 |
+
units=bert_config.hidden_size,
|
336 |
+
activation=modeling.get_activation(bert_config.hidden_act),
|
337 |
+
kernel_initializer=modeling.create_initializer(
|
338 |
+
bert_config.initializer_range
|
339 |
+
),
|
340 |
+
)
|
341 |
+
input_tensor = modeling.layer_norm(input_tensor)
|
342 |
+
|
343 |
+
# The output weights are the same as the input embeddings, but there is
|
344 |
+
# an output-only bias for each token.
|
345 |
+
output_bias = tf.get_variable(
|
346 |
+
"output_bias",
|
347 |
+
shape=[bert_config.vocab_size],
|
348 |
+
initializer=tf.zeros_initializer(),
|
349 |
+
)
|
350 |
+
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
351 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
352 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
353 |
+
|
354 |
+
label_ids = tf.reshape(label_ids, [-1])
|
355 |
+
label_weights = tf.reshape(label_weights, [-1])
|
356 |
+
|
357 |
+
one_hot_labels = tf.one_hot(
|
358 |
+
label_ids, depth=bert_config.vocab_size, dtype=tf.float32
|
359 |
+
)
|
360 |
+
|
361 |
+
# The `positions` tensor might be zero-padded (if the sequence is too
|
362 |
+
# short to have the maximum number of predictions). The `label_weights`
|
363 |
+
# tensor has a value of 1.0 for every real prediction and 0.0 for the
|
364 |
+
# padding predictions.
|
365 |
+
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
|
366 |
+
numerator = tf.reduce_sum(label_weights * per_example_loss)
|
367 |
+
denominator = tf.reduce_sum(label_weights) + 1e-5
|
368 |
+
loss = numerator / denominator
|
369 |
+
|
370 |
+
return (loss, per_example_loss, log_probs)
|
371 |
+
|
372 |
+
|
373 |
+
def get_next_sentence_output(bert_config, input_tensor, labels):
|
374 |
+
"""Get loss and log probs for the next sentence prediction."""
|
375 |
+
|
376 |
+
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
377 |
+
# "random sentence". This weight matrix is not used after pre-training.
|
378 |
+
with tf.variable_scope("cls/seq_relationship"):
|
379 |
+
output_weights = tf.get_variable(
|
380 |
+
"output_weights",
|
381 |
+
shape=[2, bert_config.hidden_size],
|
382 |
+
initializer=modeling.create_initializer(bert_config.initializer_range),
|
383 |
+
)
|
384 |
+
output_bias = tf.get_variable(
|
385 |
+
"output_bias", shape=[2], initializer=tf.zeros_initializer()
|
386 |
+
)
|
387 |
+
|
388 |
+
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
389 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
390 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
391 |
+
labels = tf.reshape(labels, [-1])
|
392 |
+
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
393 |
+
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
394 |
+
loss = tf.reduce_mean(per_example_loss)
|
395 |
+
return (loss, per_example_loss, log_probs)
|
396 |
+
|
397 |
+
|
398 |
+
def gather_indexes(sequence_tensor, positions):
|
399 |
+
"""Gathers the vectors at the specific positions over a minibatch."""
|
400 |
+
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
|
401 |
+
batch_size = sequence_shape[0]
|
402 |
+
seq_length = sequence_shape[1]
|
403 |
+
width = sequence_shape[2]
|
404 |
+
|
405 |
+
flat_offsets = tf.reshape(
|
406 |
+
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]
|
407 |
+
)
|
408 |
+
flat_positions = tf.reshape(positions + flat_offsets, [-1])
|
409 |
+
flat_sequence_tensor = tf.reshape(sequence_tensor, [batch_size * seq_length, width])
|
410 |
+
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
|
411 |
+
return output_tensor
|
412 |
+
|
413 |
+
|
414 |
+
def input_fn_builder(
|
415 |
+
input_files, max_seq_length, max_predictions_per_seq, is_training, num_cpu_threads=4
|
416 |
+
):
|
417 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
418 |
+
|
419 |
+
def input_fn(params):
|
420 |
+
"""The actual input function."""
|
421 |
+
batch_size = params["batch_size"]
|
422 |
+
|
423 |
+
name_to_features = {
|
424 |
+
"input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
425 |
+
"input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
|
426 |
+
"segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
|
427 |
+
"masked_lm_positions": tf.FixedLenFeature(
|
428 |
+
[max_predictions_per_seq], tf.int64
|
429 |
+
),
|
430 |
+
"masked_lm_ids": tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
|
431 |
+
"masked_lm_weights": tf.FixedLenFeature(
|
432 |
+
[max_predictions_per_seq], tf.float32
|
433 |
+
),
|
434 |
+
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64),
|
435 |
+
}
|
436 |
+
|
437 |
+
# For training, we want a lot of parallel reading and shuffling.
|
438 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
439 |
+
if is_training:
|
440 |
+
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
441 |
+
d = d.repeat()
|
442 |
+
d = d.shuffle(buffer_size=len(input_files))
|
443 |
+
|
444 |
+
# `cycle_length` is the number of parallel files that get read.
|
445 |
+
cycle_length = min(num_cpu_threads, len(input_files))
|
446 |
+
|
447 |
+
# `sloppy` mode means that the interleaving is not exact. This adds
|
448 |
+
# even more randomness to the training pipeline.
|
449 |
+
d = d.apply(
|
450 |
+
tf.contrib.data.parallel_interleave(
|
451 |
+
tf.data.TFRecordDataset,
|
452 |
+
sloppy=is_training,
|
453 |
+
cycle_length=cycle_length,
|
454 |
+
)
|
455 |
+
)
|
456 |
+
d = d.shuffle(buffer_size=100)
|
457 |
+
else:
|
458 |
+
d = tf.data.TFRecordDataset(input_files)
|
459 |
+
# Since we evaluate for a fixed number of steps we don't want to encounter
|
460 |
+
# out-of-range exceptions.
|
461 |
+
d = d.repeat()
|
462 |
+
|
463 |
+
# We must `drop_remainder` on training because the TPU requires fixed
|
464 |
+
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
465 |
+
# and we *don't* want to drop the remainder, otherwise we wont cover
|
466 |
+
# every sample.
|
467 |
+
d = d.apply(
|
468 |
+
tf.contrib.data.map_and_batch(
|
469 |
+
lambda record: _decode_record(record, name_to_features),
|
470 |
+
batch_size=batch_size,
|
471 |
+
num_parallel_batches=num_cpu_threads,
|
472 |
+
drop_remainder=True,
|
473 |
+
)
|
474 |
+
)
|
475 |
+
return d
|
476 |
+
|
477 |
+
return input_fn
|
478 |
+
|
479 |
+
|
480 |
+
def _decode_record(record, name_to_features):
|
481 |
+
"""Decodes a record to a TensorFlow example."""
|
482 |
+
example = tf.parse_single_example(record, name_to_features)
|
483 |
+
|
484 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
485 |
+
# So cast all int64 to int32.
|
486 |
+
for name in list(example.keys()):
|
487 |
+
t = example[name]
|
488 |
+
if t.dtype == tf.int64:
|
489 |
+
t = tf.to_int32(t)
|
490 |
+
example[name] = t
|
491 |
+
|
492 |
+
return example
|
493 |
+
|
494 |
+
|
495 |
+
def main(_):
|
496 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
497 |
+
logger = tf.get_logger()
|
498 |
+
logger.propagate = False
|
499 |
+
|
500 |
+
if not FLAGS.do_train and not FLAGS.do_eval:
|
501 |
+
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
502 |
+
|
503 |
+
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
504 |
+
|
505 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
506 |
+
|
507 |
+
input_files = []
|
508 |
+
for input_pattern in FLAGS.input_file.split(","):
|
509 |
+
input_files.extend(tf.gfile.Glob(input_pattern))
|
510 |
+
|
511 |
+
# tf.logging.info("*** Input Files ***")
|
512 |
+
# for input_file in input_files:
|
513 |
+
# tf.logging.info(" %s" % input_file)
|
514 |
+
|
515 |
+
tpu_cluster_resolver = None
|
516 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
517 |
+
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
518 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
|
519 |
+
)
|
520 |
+
|
521 |
+
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
522 |
+
run_config = tf.contrib.tpu.RunConfig(
|
523 |
+
cluster=tpu_cluster_resolver,
|
524 |
+
master=FLAGS.master,
|
525 |
+
model_dir=FLAGS.output_dir,
|
526 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
527 |
+
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
528 |
+
tpu_config=tf.contrib.tpu.TPUConfig(
|
529 |
+
iterations_per_loop=FLAGS.iterations_per_loop,
|
530 |
+
num_shards=FLAGS.num_tpu_cores,
|
531 |
+
per_host_input_for_training=is_per_host,
|
532 |
+
),
|
533 |
+
)
|
534 |
+
|
535 |
+
model_fn = model_fn_builder(
|
536 |
+
bert_config=bert_config,
|
537 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
538 |
+
learning_rate=FLAGS.learning_rate,
|
539 |
+
num_train_steps=FLAGS.num_train_steps,
|
540 |
+
num_warmup_steps=FLAGS.num_warmup_steps,
|
541 |
+
use_tpu=FLAGS.use_tpu,
|
542 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
543 |
+
optimizer=FLAGS.optimizer,
|
544 |
+
poly_power=FLAGS.poly_power,
|
545 |
+
start_warmup_step=FLAGS.start_warmup_step
|
546 |
+
)
|
547 |
+
|
548 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
549 |
+
# or GPU.
|
550 |
+
estimator = tf.contrib.tpu.TPUEstimator(
|
551 |
+
use_tpu=FLAGS.use_tpu,
|
552 |
+
model_fn=model_fn,
|
553 |
+
config=run_config,
|
554 |
+
train_batch_size=FLAGS.train_batch_size,
|
555 |
+
eval_batch_size=FLAGS.eval_batch_size,
|
556 |
+
)
|
557 |
+
|
558 |
+
if FLAGS.do_train:
|
559 |
+
tf.logging.info("***** Running training *****")
|
560 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
561 |
+
train_input_fn = input_fn_builder(
|
562 |
+
input_files=input_files,
|
563 |
+
max_seq_length=FLAGS.max_seq_length,
|
564 |
+
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
565 |
+
is_training=True,
|
566 |
+
)
|
567 |
+
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
568 |
+
|
569 |
+
if FLAGS.do_eval:
|
570 |
+
tf.logging.info("***** Running evaluation *****")
|
571 |
+
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
572 |
+
|
573 |
+
eval_input_fn = input_fn_builder(
|
574 |
+
input_files=input_files,
|
575 |
+
max_seq_length=FLAGS.max_seq_length,
|
576 |
+
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
577 |
+
is_training=False,
|
578 |
+
)
|
579 |
+
|
580 |
+
result = estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
581 |
+
|
582 |
+
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
|
583 |
+
with tf.gfile.GFile(output_eval_file, "w") as writer:
|
584 |
+
tf.logging.info("***** Eval results *****")
|
585 |
+
for key in sorted(result.keys()):
|
586 |
+
tf.logging.info(" %s = %s", key, str(result[key]))
|
587 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
588 |
+
|
589 |
+
if __name__ == "__main__":
|
590 |
+
flags.mark_flag_as_required("input_file")
|
591 |
+
flags.mark_flag_as_required("bert_config_file")
|
592 |
+
flags.mark_flag_as_required("output_dir")
|
593 |
+
tf.app.run()
|
arabert/arabert/run_squad.py
ADDED
@@ -0,0 +1,1440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import collections
|
22 |
+
import json
|
23 |
+
import math
|
24 |
+
import os
|
25 |
+
import random
|
26 |
+
import modeling
|
27 |
+
import optimization
|
28 |
+
import tokenization
|
29 |
+
import six
|
30 |
+
import tensorflow as tf
|
31 |
+
|
32 |
+
flags = tf.flags
|
33 |
+
|
34 |
+
FLAGS = flags.FLAGS
|
35 |
+
|
36 |
+
## Required parameters
|
37 |
+
flags.DEFINE_string(
|
38 |
+
"bert_config_file",
|
39 |
+
None,
|
40 |
+
"The config json file corresponding to the pre-trained BERT model. "
|
41 |
+
"This specifies the model architecture.",
|
42 |
+
)
|
43 |
+
|
44 |
+
flags.DEFINE_string(
|
45 |
+
"vocab_file", None, "The vocabulary file that the BERT model was trained on."
|
46 |
+
)
|
47 |
+
|
48 |
+
flags.DEFINE_string(
|
49 |
+
"output_dir",
|
50 |
+
None,
|
51 |
+
"The output directory where the model checkpoints will be written.",
|
52 |
+
)
|
53 |
+
|
54 |
+
## Other parameters
|
55 |
+
flags.DEFINE_string(
|
56 |
+
"train_file", None, "SQuAD json for training. E.g., train-v1.1.json"
|
57 |
+
)
|
58 |
+
|
59 |
+
flags.DEFINE_string(
|
60 |
+
"predict_file",
|
61 |
+
None,
|
62 |
+
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
|
63 |
+
)
|
64 |
+
|
65 |
+
flags.DEFINE_string(
|
66 |
+
"init_checkpoint",
|
67 |
+
None,
|
68 |
+
"Initial checkpoint (usually from a pre-trained BERT model).",
|
69 |
+
)
|
70 |
+
|
71 |
+
flags.DEFINE_bool(
|
72 |
+
"do_lower_case",
|
73 |
+
True,
|
74 |
+
"Whether to lower case the input text. Should be True for uncased "
|
75 |
+
"models and False for cased models.",
|
76 |
+
)
|
77 |
+
|
78 |
+
flags.DEFINE_integer(
|
79 |
+
"max_seq_length",
|
80 |
+
384,
|
81 |
+
"The maximum total input sequence length after WordPiece tokenization. "
|
82 |
+
"Sequences longer than this will be truncated, and sequences shorter "
|
83 |
+
"than this will be padded.",
|
84 |
+
)
|
85 |
+
|
86 |
+
flags.DEFINE_integer(
|
87 |
+
"doc_stride",
|
88 |
+
128,
|
89 |
+
"When splitting up a long document into chunks, how much stride to "
|
90 |
+
"take between chunks.",
|
91 |
+
)
|
92 |
+
|
93 |
+
flags.DEFINE_integer(
|
94 |
+
"max_query_length",
|
95 |
+
64,
|
96 |
+
"The maximum number of tokens for the question. Questions longer than "
|
97 |
+
"this will be truncated to this length.",
|
98 |
+
)
|
99 |
+
|
100 |
+
flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
101 |
+
|
102 |
+
flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
|
103 |
+
|
104 |
+
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
105 |
+
|
106 |
+
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predictions.")
|
107 |
+
|
108 |
+
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
|
109 |
+
|
110 |
+
flags.DEFINE_float(
|
111 |
+
"num_train_epochs", 3.0, "Total number of training epochs to perform."
|
112 |
+
)
|
113 |
+
|
114 |
+
flags.DEFINE_float(
|
115 |
+
"warmup_proportion",
|
116 |
+
0.1,
|
117 |
+
"Proportion of training to perform linear learning rate warmup for. "
|
118 |
+
"E.g., 0.1 = 10% of training.",
|
119 |
+
)
|
120 |
+
|
121 |
+
flags.DEFINE_integer(
|
122 |
+
"save_checkpoints_steps", 1000, "How often to save the model checkpoint."
|
123 |
+
)
|
124 |
+
|
125 |
+
flags.DEFINE_integer(
|
126 |
+
"iterations_per_loop", 1000, "How many steps to make in each estimator call."
|
127 |
+
)
|
128 |
+
|
129 |
+
flags.DEFINE_integer(
|
130 |
+
"n_best_size",
|
131 |
+
20,
|
132 |
+
"The total number of n-best predictions to generate in the "
|
133 |
+
"nbest_predictions.json output file.",
|
134 |
+
)
|
135 |
+
|
136 |
+
flags.DEFINE_integer(
|
137 |
+
"max_answer_length",
|
138 |
+
30,
|
139 |
+
"The maximum length of an answer that can be generated. This is needed "
|
140 |
+
"because the start and end predictions are not conditioned on one another.",
|
141 |
+
)
|
142 |
+
|
143 |
+
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
144 |
+
|
145 |
+
tf.flags.DEFINE_string(
|
146 |
+
"tpu_name",
|
147 |
+
None,
|
148 |
+
"The Cloud TPU to use for training. This should be either the name "
|
149 |
+
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
150 |
+
"url.",
|
151 |
+
)
|
152 |
+
|
153 |
+
tf.flags.DEFINE_string(
|
154 |
+
"tpu_zone",
|
155 |
+
None,
|
156 |
+
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
157 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
158 |
+
"metadata.",
|
159 |
+
)
|
160 |
+
|
161 |
+
tf.flags.DEFINE_string(
|
162 |
+
"gcp_project",
|
163 |
+
None,
|
164 |
+
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
165 |
+
"specified, we will attempt to automatically detect the GCE project from "
|
166 |
+
"metadata.",
|
167 |
+
)
|
168 |
+
|
169 |
+
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
170 |
+
|
171 |
+
flags.DEFINE_integer(
|
172 |
+
"num_tpu_cores",
|
173 |
+
8,
|
174 |
+
"Only used if `use_tpu` is True. Total number of TPU cores to use.",
|
175 |
+
)
|
176 |
+
|
177 |
+
flags.DEFINE_bool(
|
178 |
+
"verbose_logging",
|
179 |
+
False,
|
180 |
+
"If true, all of the warnings related to data processing will be printed. "
|
181 |
+
"A number of warnings are expected for a normal SQuAD evaluation.",
|
182 |
+
)
|
183 |
+
|
184 |
+
flags.DEFINE_bool(
|
185 |
+
"version_2_with_negative",
|
186 |
+
False,
|
187 |
+
"If true, the SQuAD examples contain some that do not have an answer.",
|
188 |
+
)
|
189 |
+
|
190 |
+
flags.DEFINE_float(
|
191 |
+
"null_score_diff_threshold",
|
192 |
+
0.0,
|
193 |
+
"If null_score - best_non_null is greater than the threshold predict null.",
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
class SquadExample(object):
|
198 |
+
"""A single training/test example for simple sequence classification.
|
199 |
+
|
200 |
+
For examples without an answer, the start and end position are -1.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
qas_id,
|
206 |
+
question_text,
|
207 |
+
doc_tokens,
|
208 |
+
orig_answer_text=None,
|
209 |
+
start_position=None,
|
210 |
+
end_position=None,
|
211 |
+
is_impossible=False,
|
212 |
+
):
|
213 |
+
self.qas_id = qas_id
|
214 |
+
self.question_text = question_text
|
215 |
+
self.doc_tokens = doc_tokens
|
216 |
+
self.orig_answer_text = orig_answer_text
|
217 |
+
self.start_position = start_position
|
218 |
+
self.end_position = end_position
|
219 |
+
self.is_impossible = is_impossible
|
220 |
+
|
221 |
+
def __str__(self):
|
222 |
+
return self.__repr__()
|
223 |
+
|
224 |
+
def __repr__(self):
|
225 |
+
s = ""
|
226 |
+
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
227 |
+
s += ", question_text: %s" % (tokenization.printable_text(self.question_text))
|
228 |
+
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
229 |
+
if self.start_position:
|
230 |
+
s += ", start_position: %d" % (self.start_position)
|
231 |
+
if self.start_position:
|
232 |
+
s += ", end_position: %d" % (self.end_position)
|
233 |
+
if self.start_position:
|
234 |
+
s += ", is_impossible: %r" % (self.is_impossible)
|
235 |
+
return s
|
236 |
+
|
237 |
+
|
238 |
+
class InputFeatures(object):
|
239 |
+
"""A single set of features of data."""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
unique_id,
|
244 |
+
example_index,
|
245 |
+
doc_span_index,
|
246 |
+
tokens,
|
247 |
+
token_to_orig_map,
|
248 |
+
token_is_max_context,
|
249 |
+
input_ids,
|
250 |
+
input_mask,
|
251 |
+
segment_ids,
|
252 |
+
start_position=None,
|
253 |
+
end_position=None,
|
254 |
+
is_impossible=None,
|
255 |
+
):
|
256 |
+
self.unique_id = unique_id
|
257 |
+
self.example_index = example_index
|
258 |
+
self.doc_span_index = doc_span_index
|
259 |
+
self.tokens = tokens
|
260 |
+
self.token_to_orig_map = token_to_orig_map
|
261 |
+
self.token_is_max_context = token_is_max_context
|
262 |
+
self.input_ids = input_ids
|
263 |
+
self.input_mask = input_mask
|
264 |
+
self.segment_ids = segment_ids
|
265 |
+
self.start_position = start_position
|
266 |
+
self.end_position = end_position
|
267 |
+
self.is_impossible = is_impossible
|
268 |
+
|
269 |
+
|
270 |
+
def read_squad_examples(input_file, is_training):
|
271 |
+
"""Read a SQuAD json file into a list of SquadExample."""
|
272 |
+
with tf.gfile.Open(input_file, "r") as reader:
|
273 |
+
input_data = json.load(reader)["data"]
|
274 |
+
|
275 |
+
def is_whitespace(c):
|
276 |
+
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
277 |
+
return True
|
278 |
+
return False
|
279 |
+
|
280 |
+
examples = []
|
281 |
+
for entry in input_data:
|
282 |
+
for paragraph in entry["paragraphs"]:
|
283 |
+
paragraph_text = paragraph["context"]
|
284 |
+
doc_tokens = []
|
285 |
+
char_to_word_offset = []
|
286 |
+
prev_is_whitespace = True
|
287 |
+
for c in paragraph_text:
|
288 |
+
if is_whitespace(c):
|
289 |
+
prev_is_whitespace = True
|
290 |
+
else:
|
291 |
+
if prev_is_whitespace:
|
292 |
+
doc_tokens.append(c)
|
293 |
+
else:
|
294 |
+
doc_tokens[-1] += c
|
295 |
+
prev_is_whitespace = False
|
296 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
297 |
+
|
298 |
+
for qa in paragraph["qas"]:
|
299 |
+
qas_id = qa["id"]
|
300 |
+
question_text = qa["question"]
|
301 |
+
start_position = None
|
302 |
+
end_position = None
|
303 |
+
orig_answer_text = None
|
304 |
+
is_impossible = False
|
305 |
+
if is_training:
|
306 |
+
|
307 |
+
if FLAGS.version_2_with_negative:
|
308 |
+
is_impossible = qa["is_impossible"]
|
309 |
+
if (len(qa["answers"]) != 1) and (not is_impossible):
|
310 |
+
raise ValueError(
|
311 |
+
"For training, each question should have exactly 1 answer."
|
312 |
+
)
|
313 |
+
if not is_impossible:
|
314 |
+
answer = qa["answers"][0]
|
315 |
+
orig_answer_text = answer["text"]
|
316 |
+
answer_offset = answer["answer_start"]
|
317 |
+
answer_length = len(orig_answer_text)
|
318 |
+
start_position = char_to_word_offset[answer_offset]
|
319 |
+
end_position = char_to_word_offset[
|
320 |
+
answer_offset + answer_length - 1
|
321 |
+
]
|
322 |
+
# Only add answers where the text can be exactly recovered from the
|
323 |
+
# document. If this CAN'T happen it's likely due to weird Unicode
|
324 |
+
# stuff so we will just skip the example.
|
325 |
+
#
|
326 |
+
# Note that this means for training mode, every example is NOT
|
327 |
+
# guaranteed to be preserved.
|
328 |
+
actual_text = " ".join(
|
329 |
+
doc_tokens[start_position : (end_position + 1)]
|
330 |
+
)
|
331 |
+
cleaned_answer_text = " ".join(
|
332 |
+
tokenization.whitespace_tokenize(orig_answer_text)
|
333 |
+
)
|
334 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
335 |
+
tf.logging.warning(
|
336 |
+
"Could not find answer: '%s' vs. '%s'",
|
337 |
+
actual_text,
|
338 |
+
cleaned_answer_text,
|
339 |
+
)
|
340 |
+
continue
|
341 |
+
else:
|
342 |
+
start_position = -1
|
343 |
+
end_position = -1
|
344 |
+
orig_answer_text = ""
|
345 |
+
|
346 |
+
example = SquadExample(
|
347 |
+
qas_id=qas_id,
|
348 |
+
question_text=question_text,
|
349 |
+
doc_tokens=doc_tokens,
|
350 |
+
orig_answer_text=orig_answer_text,
|
351 |
+
start_position=start_position,
|
352 |
+
end_position=end_position,
|
353 |
+
is_impossible=is_impossible,
|
354 |
+
)
|
355 |
+
examples.append(example)
|
356 |
+
|
357 |
+
return examples
|
358 |
+
|
359 |
+
|
360 |
+
def convert_examples_to_features(
|
361 |
+
examples,
|
362 |
+
tokenizer,
|
363 |
+
max_seq_length,
|
364 |
+
doc_stride,
|
365 |
+
max_query_length,
|
366 |
+
is_training,
|
367 |
+
output_fn,
|
368 |
+
):
|
369 |
+
"""Loads a data file into a list of `InputBatch`s."""
|
370 |
+
|
371 |
+
unique_id = 1000000000
|
372 |
+
|
373 |
+
for (example_index, example) in enumerate(examples):
|
374 |
+
query_tokens = tokenizer.tokenize(example.question_text)
|
375 |
+
|
376 |
+
if len(query_tokens) > max_query_length:
|
377 |
+
query_tokens = query_tokens[0:max_query_length]
|
378 |
+
|
379 |
+
tok_to_orig_index = []
|
380 |
+
orig_to_tok_index = []
|
381 |
+
all_doc_tokens = []
|
382 |
+
for (i, token) in enumerate(example.doc_tokens):
|
383 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
384 |
+
sub_tokens = tokenizer.tokenize(token)
|
385 |
+
for sub_token in sub_tokens:
|
386 |
+
tok_to_orig_index.append(i)
|
387 |
+
all_doc_tokens.append(sub_token)
|
388 |
+
|
389 |
+
tok_start_position = None
|
390 |
+
tok_end_position = None
|
391 |
+
if is_training and example.is_impossible:
|
392 |
+
tok_start_position = -1
|
393 |
+
tok_end_position = -1
|
394 |
+
if is_training and not example.is_impossible:
|
395 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
396 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
397 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
398 |
+
else:
|
399 |
+
tok_end_position = len(all_doc_tokens) - 1
|
400 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
401 |
+
all_doc_tokens,
|
402 |
+
tok_start_position,
|
403 |
+
tok_end_position,
|
404 |
+
tokenizer,
|
405 |
+
example.orig_answer_text,
|
406 |
+
)
|
407 |
+
|
408 |
+
# The -3 accounts for [CLS], [SEP] and [SEP]
|
409 |
+
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
|
410 |
+
|
411 |
+
# We can have documents that are longer than the maximum sequence length.
|
412 |
+
# To deal with this we do a sliding window approach, where we take chunks
|
413 |
+
# of the up to our max length with a stride of `doc_stride`.
|
414 |
+
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
415 |
+
"DocSpan", ["start", "length"]
|
416 |
+
)
|
417 |
+
doc_spans = []
|
418 |
+
start_offset = 0
|
419 |
+
while start_offset < len(all_doc_tokens):
|
420 |
+
length = len(all_doc_tokens) - start_offset
|
421 |
+
if length > max_tokens_for_doc:
|
422 |
+
length = max_tokens_for_doc
|
423 |
+
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
424 |
+
if start_offset + length == len(all_doc_tokens):
|
425 |
+
break
|
426 |
+
start_offset += min(length, doc_stride)
|
427 |
+
|
428 |
+
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
429 |
+
tokens = []
|
430 |
+
token_to_orig_map = {}
|
431 |
+
token_is_max_context = {}
|
432 |
+
segment_ids = []
|
433 |
+
tokens.append("[CLS]")
|
434 |
+
segment_ids.append(0)
|
435 |
+
for token in query_tokens:
|
436 |
+
tokens.append(token)
|
437 |
+
segment_ids.append(0)
|
438 |
+
tokens.append("[SEP]")
|
439 |
+
segment_ids.append(0)
|
440 |
+
|
441 |
+
for i in range(doc_span.length):
|
442 |
+
split_token_index = doc_span.start + i
|
443 |
+
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
444 |
+
|
445 |
+
is_max_context = _check_is_max_context(
|
446 |
+
doc_spans, doc_span_index, split_token_index
|
447 |
+
)
|
448 |
+
token_is_max_context[len(tokens)] = is_max_context
|
449 |
+
tokens.append(all_doc_tokens[split_token_index])
|
450 |
+
segment_ids.append(1)
|
451 |
+
tokens.append("[SEP]")
|
452 |
+
segment_ids.append(1)
|
453 |
+
|
454 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
455 |
+
|
456 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
457 |
+
# tokens are attended to.
|
458 |
+
input_mask = [1] * len(input_ids)
|
459 |
+
|
460 |
+
# Zero-pad up to the sequence length.
|
461 |
+
while len(input_ids) < max_seq_length:
|
462 |
+
input_ids.append(0)
|
463 |
+
input_mask.append(0)
|
464 |
+
segment_ids.append(0)
|
465 |
+
|
466 |
+
assert len(input_ids) == max_seq_length
|
467 |
+
assert len(input_mask) == max_seq_length
|
468 |
+
assert len(segment_ids) == max_seq_length
|
469 |
+
|
470 |
+
start_position = None
|
471 |
+
end_position = None
|
472 |
+
if is_training and not example.is_impossible:
|
473 |
+
# For training, if our document chunk does not contain an annotation
|
474 |
+
# we throw it out, since there is nothing to predict.
|
475 |
+
doc_start = doc_span.start
|
476 |
+
doc_end = doc_span.start + doc_span.length - 1
|
477 |
+
out_of_span = False
|
478 |
+
if not (
|
479 |
+
tok_start_position >= doc_start and tok_end_position <= doc_end
|
480 |
+
):
|
481 |
+
out_of_span = True
|
482 |
+
if out_of_span:
|
483 |
+
start_position = 0
|
484 |
+
end_position = 0
|
485 |
+
else:
|
486 |
+
doc_offset = len(query_tokens) + 2
|
487 |
+
start_position = tok_start_position - doc_start + doc_offset
|
488 |
+
end_position = tok_end_position - doc_start + doc_offset
|
489 |
+
|
490 |
+
if is_training and example.is_impossible:
|
491 |
+
start_position = 0
|
492 |
+
end_position = 0
|
493 |
+
|
494 |
+
if example_index < 20:
|
495 |
+
tf.logging.info("*** Example ***")
|
496 |
+
tf.logging.info("unique_id: %s" % (unique_id))
|
497 |
+
tf.logging.info("example_index: %s" % (example_index))
|
498 |
+
tf.logging.info("doc_span_index: %s" % (doc_span_index))
|
499 |
+
tf.logging.info(
|
500 |
+
"tokens: %s"
|
501 |
+
% " ".join([tokenization.printable_text(x) for x in tokens])
|
502 |
+
)
|
503 |
+
tf.logging.info(
|
504 |
+
"token_to_orig_map: %s"
|
505 |
+
% " ".join(
|
506 |
+
[
|
507 |
+
"%d:%d" % (x, y)
|
508 |
+
for (x, y) in six.iteritems(token_to_orig_map)
|
509 |
+
]
|
510 |
+
)
|
511 |
+
)
|
512 |
+
tf.logging.info(
|
513 |
+
"token_is_max_context: %s"
|
514 |
+
% " ".join(
|
515 |
+
[
|
516 |
+
"%d:%s" % (x, y)
|
517 |
+
for (x, y) in six.iteritems(token_is_max_context)
|
518 |
+
]
|
519 |
+
)
|
520 |
+
)
|
521 |
+
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
522 |
+
tf.logging.info(
|
523 |
+
"input_mask: %s" % " ".join([str(x) for x in input_mask])
|
524 |
+
)
|
525 |
+
tf.logging.info(
|
526 |
+
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])
|
527 |
+
)
|
528 |
+
if is_training and example.is_impossible:
|
529 |
+
tf.logging.info("impossible example")
|
530 |
+
if is_training and not example.is_impossible:
|
531 |
+
answer_text = " ".join(tokens[start_position : (end_position + 1)])
|
532 |
+
tf.logging.info("start_position: %d" % (start_position))
|
533 |
+
tf.logging.info("end_position: %d" % (end_position))
|
534 |
+
tf.logging.info(
|
535 |
+
"answer: %s" % (tokenization.printable_text(answer_text))
|
536 |
+
)
|
537 |
+
|
538 |
+
feature = InputFeatures(
|
539 |
+
unique_id=unique_id,
|
540 |
+
example_index=example_index,
|
541 |
+
doc_span_index=doc_span_index,
|
542 |
+
tokens=tokens,
|
543 |
+
token_to_orig_map=token_to_orig_map,
|
544 |
+
token_is_max_context=token_is_max_context,
|
545 |
+
input_ids=input_ids,
|
546 |
+
input_mask=input_mask,
|
547 |
+
segment_ids=segment_ids,
|
548 |
+
start_position=start_position,
|
549 |
+
end_position=end_position,
|
550 |
+
is_impossible=example.is_impossible,
|
551 |
+
)
|
552 |
+
|
553 |
+
# Run callback
|
554 |
+
output_fn(feature)
|
555 |
+
|
556 |
+
unique_id += 1
|
557 |
+
|
558 |
+
|
559 |
+
def _improve_answer_span(
|
560 |
+
doc_tokens, input_start, input_end, tokenizer, orig_answer_text
|
561 |
+
):
|
562 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
563 |
+
|
564 |
+
# The SQuAD annotations are character based. We first project them to
|
565 |
+
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
566 |
+
# often find a "better match". For example:
|
567 |
+
#
|
568 |
+
# Question: What year was John Smith born?
|
569 |
+
# Context: The leader was John Smith (1895-1943).
|
570 |
+
# Answer: 1895
|
571 |
+
#
|
572 |
+
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
573 |
+
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
574 |
+
# the exact answer, 1895.
|
575 |
+
#
|
576 |
+
# However, this is not always possible. Consider the following:
|
577 |
+
#
|
578 |
+
# Question: What country is the top exporter of electornics?
|
579 |
+
# Context: The Japanese electronics industry is the lagest in the world.
|
580 |
+
# Answer: Japan
|
581 |
+
#
|
582 |
+
# In this case, the annotator chose "Japan" as a character sub-span of
|
583 |
+
# the word "Japanese". Since our WordPiece tokenizer does not split
|
584 |
+
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
585 |
+
# in SQuAD, but does happen.
|
586 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
587 |
+
|
588 |
+
for new_start in range(input_start, input_end + 1):
|
589 |
+
for new_end in range(input_end, new_start - 1, -1):
|
590 |
+
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
591 |
+
if text_span == tok_answer_text:
|
592 |
+
return (new_start, new_end)
|
593 |
+
|
594 |
+
return (input_start, input_end)
|
595 |
+
|
596 |
+
|
597 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
598 |
+
"""Check if this is the 'max context' doc span for the token."""
|
599 |
+
|
600 |
+
# Because of the sliding window approach taken to scoring documents, a single
|
601 |
+
# token can appear in multiple documents. E.g.
|
602 |
+
# Doc: the man went to the store and bought a gallon of milk
|
603 |
+
# Span A: the man went to the
|
604 |
+
# Span B: to the store and bought
|
605 |
+
# Span C: and bought a gallon of
|
606 |
+
# ...
|
607 |
+
#
|
608 |
+
# Now the word 'bought' will have two scores from spans B and C. We only
|
609 |
+
# want to consider the score with "maximum context", which we define as
|
610 |
+
# the *minimum* of its left and right context (the *sum* of left and
|
611 |
+
# right context will always be the same, of course).
|
612 |
+
#
|
613 |
+
# In the example the maximum context for 'bought' would be span C since
|
614 |
+
# it has 1 left context and 3 right context, while span B has 4 left context
|
615 |
+
# and 0 right context.
|
616 |
+
best_score = None
|
617 |
+
best_span_index = None
|
618 |
+
for (span_index, doc_span) in enumerate(doc_spans):
|
619 |
+
end = doc_span.start + doc_span.length - 1
|
620 |
+
if position < doc_span.start:
|
621 |
+
continue
|
622 |
+
if position > end:
|
623 |
+
continue
|
624 |
+
num_left_context = position - doc_span.start
|
625 |
+
num_right_context = end - position
|
626 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
627 |
+
if best_score is None or score > best_score:
|
628 |
+
best_score = score
|
629 |
+
best_span_index = span_index
|
630 |
+
|
631 |
+
return cur_span_index == best_span_index
|
632 |
+
|
633 |
+
|
634 |
+
def create_model(
|
635 |
+
bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings
|
636 |
+
):
|
637 |
+
"""Creates a classification model."""
|
638 |
+
model = modeling.BertModel(
|
639 |
+
config=bert_config,
|
640 |
+
is_training=is_training,
|
641 |
+
input_ids=input_ids,
|
642 |
+
input_mask=input_mask,
|
643 |
+
token_type_ids=segment_ids,
|
644 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
645 |
+
)
|
646 |
+
|
647 |
+
final_hidden = model.get_sequence_output()
|
648 |
+
|
649 |
+
final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
|
650 |
+
batch_size = final_hidden_shape[0]
|
651 |
+
seq_length = final_hidden_shape[1]
|
652 |
+
hidden_size = final_hidden_shape[2]
|
653 |
+
|
654 |
+
output_weights = tf.get_variable(
|
655 |
+
"cls/squad/output_weights",
|
656 |
+
[2, hidden_size],
|
657 |
+
initializer=tf.truncated_normal_initializer(stddev=0.02),
|
658 |
+
)
|
659 |
+
|
660 |
+
output_bias = tf.get_variable(
|
661 |
+
"cls/squad/output_bias", [2], initializer=tf.zeros_initializer()
|
662 |
+
)
|
663 |
+
|
664 |
+
final_hidden_matrix = tf.reshape(
|
665 |
+
final_hidden, [batch_size * seq_length, hidden_size]
|
666 |
+
)
|
667 |
+
logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
|
668 |
+
logits = tf.nn.bias_add(logits, output_bias)
|
669 |
+
|
670 |
+
logits = tf.reshape(logits, [batch_size, seq_length, 2])
|
671 |
+
logits = tf.transpose(logits, [2, 0, 1])
|
672 |
+
|
673 |
+
unstacked_logits = tf.unstack(logits, axis=0)
|
674 |
+
|
675 |
+
(start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
|
676 |
+
|
677 |
+
return (start_logits, end_logits)
|
678 |
+
|
679 |
+
|
680 |
+
def model_fn_builder(
|
681 |
+
bert_config,
|
682 |
+
init_checkpoint,
|
683 |
+
learning_rate,
|
684 |
+
num_train_steps,
|
685 |
+
num_warmup_steps,
|
686 |
+
use_tpu,
|
687 |
+
use_one_hot_embeddings,
|
688 |
+
):
|
689 |
+
"""Returns `model_fn` closure for TPUEstimator."""
|
690 |
+
|
691 |
+
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
692 |
+
"""The `model_fn` for TPUEstimator."""
|
693 |
+
|
694 |
+
tf.logging.info("*** Features ***")
|
695 |
+
for name in sorted(features.keys()):
|
696 |
+
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
|
697 |
+
|
698 |
+
unique_ids = features["unique_ids"]
|
699 |
+
input_ids = features["input_ids"]
|
700 |
+
input_mask = features["input_mask"]
|
701 |
+
segment_ids = features["segment_ids"]
|
702 |
+
|
703 |
+
is_training = mode == tf.estimator.ModeKeys.TRAIN
|
704 |
+
|
705 |
+
(start_logits, end_logits) = create_model(
|
706 |
+
bert_config=bert_config,
|
707 |
+
is_training=is_training,
|
708 |
+
input_ids=input_ids,
|
709 |
+
input_mask=input_mask,
|
710 |
+
segment_ids=segment_ids,
|
711 |
+
use_one_hot_embeddings=use_one_hot_embeddings,
|
712 |
+
)
|
713 |
+
|
714 |
+
tvars = tf.trainable_variables()
|
715 |
+
|
716 |
+
initialized_variable_names = {}
|
717 |
+
scaffold_fn = None
|
718 |
+
if init_checkpoint:
|
719 |
+
(
|
720 |
+
assignment_map,
|
721 |
+
initialized_variable_names,
|
722 |
+
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
|
723 |
+
if use_tpu:
|
724 |
+
|
725 |
+
def tpu_scaffold():
|
726 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
727 |
+
return tf.train.Scaffold()
|
728 |
+
|
729 |
+
scaffold_fn = tpu_scaffold
|
730 |
+
else:
|
731 |
+
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
732 |
+
|
733 |
+
tf.logging.info("**** Trainable Variables ****")
|
734 |
+
for var in tvars:
|
735 |
+
init_string = ""
|
736 |
+
if var.name in initialized_variable_names:
|
737 |
+
init_string = ", *INIT_FROM_CKPT*"
|
738 |
+
tf.logging.info(
|
739 |
+
" name = %s, shape = %s%s", var.name, var.shape, init_string
|
740 |
+
)
|
741 |
+
|
742 |
+
output_spec = None
|
743 |
+
if mode == tf.estimator.ModeKeys.TRAIN:
|
744 |
+
seq_length = modeling.get_shape_list(input_ids)[1]
|
745 |
+
|
746 |
+
def compute_loss(logits, positions):
|
747 |
+
one_hot_positions = tf.one_hot(
|
748 |
+
positions, depth=seq_length, dtype=tf.float32
|
749 |
+
)
|
750 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
751 |
+
loss = -tf.reduce_mean(
|
752 |
+
tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
|
753 |
+
)
|
754 |
+
return loss
|
755 |
+
|
756 |
+
start_positions = features["start_positions"]
|
757 |
+
end_positions = features["end_positions"]
|
758 |
+
|
759 |
+
start_loss = compute_loss(start_logits, start_positions)
|
760 |
+
end_loss = compute_loss(end_logits, end_positions)
|
761 |
+
|
762 |
+
total_loss = (start_loss + end_loss) / 2.0
|
763 |
+
|
764 |
+
train_op = optimization.create_optimizer(
|
765 |
+
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu
|
766 |
+
)
|
767 |
+
|
768 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
769 |
+
mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn
|
770 |
+
)
|
771 |
+
elif mode == tf.estimator.ModeKeys.PREDICT:
|
772 |
+
predictions = {
|
773 |
+
"unique_ids": unique_ids,
|
774 |
+
"start_logits": start_logits,
|
775 |
+
"end_logits": end_logits,
|
776 |
+
}
|
777 |
+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
778 |
+
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
|
779 |
+
)
|
780 |
+
else:
|
781 |
+
raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode))
|
782 |
+
|
783 |
+
return output_spec
|
784 |
+
|
785 |
+
return model_fn
|
786 |
+
|
787 |
+
|
788 |
+
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
|
789 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
790 |
+
|
791 |
+
name_to_features = {
|
792 |
+
"unique_ids": tf.FixedLenFeature([], tf.int64),
|
793 |
+
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
794 |
+
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
|
795 |
+
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
796 |
+
}
|
797 |
+
|
798 |
+
if is_training:
|
799 |
+
name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
|
800 |
+
name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)
|
801 |
+
|
802 |
+
def _decode_record(record, name_to_features):
|
803 |
+
"""Decodes a record to a TensorFlow example."""
|
804 |
+
example = tf.parse_single_example(record, name_to_features)
|
805 |
+
|
806 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
807 |
+
# So cast all int64 to int32.
|
808 |
+
for name in list(example.keys()):
|
809 |
+
t = example[name]
|
810 |
+
if t.dtype == tf.int64:
|
811 |
+
t = tf.to_int32(t)
|
812 |
+
example[name] = t
|
813 |
+
|
814 |
+
return example
|
815 |
+
|
816 |
+
def input_fn(params):
|
817 |
+
"""The actual input function."""
|
818 |
+
batch_size = params["batch_size"]
|
819 |
+
|
820 |
+
# For training, we want a lot of parallel reading and shuffling.
|
821 |
+
# For eval, we want no shuffling and parallel reading doesn't matter.
|
822 |
+
d = tf.data.TFRecordDataset(input_file)
|
823 |
+
if is_training:
|
824 |
+
d = d.repeat()
|
825 |
+
d = d.shuffle(buffer_size=100)
|
826 |
+
|
827 |
+
d = d.apply(
|
828 |
+
tf.contrib.data.map_and_batch(
|
829 |
+
lambda record: _decode_record(record, name_to_features),
|
830 |
+
batch_size=batch_size,
|
831 |
+
drop_remainder=drop_remainder,
|
832 |
+
)
|
833 |
+
)
|
834 |
+
|
835 |
+
return d
|
836 |
+
|
837 |
+
return input_fn
|
838 |
+
|
839 |
+
|
840 |
+
RawResult = collections.namedtuple(
|
841 |
+
"RawResult", ["unique_id", "start_logits", "end_logits"]
|
842 |
+
)
|
843 |
+
|
844 |
+
|
845 |
+
def write_predictions(
|
846 |
+
all_examples,
|
847 |
+
all_features,
|
848 |
+
all_results,
|
849 |
+
n_best_size,
|
850 |
+
max_answer_length,
|
851 |
+
do_lower_case,
|
852 |
+
output_prediction_file,
|
853 |
+
output_nbest_file,
|
854 |
+
output_null_log_odds_file,
|
855 |
+
):
|
856 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
857 |
+
tf.logging.info("Writing predictions to: %s" % (output_prediction_file))
|
858 |
+
tf.logging.info("Writing nbest to: %s" % (output_nbest_file))
|
859 |
+
|
860 |
+
example_index_to_features = collections.defaultdict(list)
|
861 |
+
for feature in all_features:
|
862 |
+
example_index_to_features[feature.example_index].append(feature)
|
863 |
+
|
864 |
+
unique_id_to_result = {}
|
865 |
+
for result in all_results:
|
866 |
+
unique_id_to_result[result.unique_id] = result
|
867 |
+
|
868 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
869 |
+
"PrelimPrediction",
|
870 |
+
["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
|
871 |
+
)
|
872 |
+
|
873 |
+
all_predictions = collections.OrderedDict()
|
874 |
+
all_nbest_json = collections.OrderedDict()
|
875 |
+
scores_diff_json = collections.OrderedDict()
|
876 |
+
|
877 |
+
for (example_index, example) in enumerate(all_examples):
|
878 |
+
features = example_index_to_features[example_index]
|
879 |
+
|
880 |
+
prelim_predictions = []
|
881 |
+
# keep track of the minimum score of null start+end of position 0
|
882 |
+
score_null = 1000000 # large and positive
|
883 |
+
min_null_feature_index = 0 # the paragraph slice with min mull score
|
884 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
885 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
886 |
+
for (feature_index, feature) in enumerate(features):
|
887 |
+
result = unique_id_to_result[feature.unique_id]
|
888 |
+
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
889 |
+
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
890 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
891 |
+
if FLAGS.version_2_with_negative:
|
892 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
893 |
+
if feature_null_score < score_null:
|
894 |
+
score_null = feature_null_score
|
895 |
+
min_null_feature_index = feature_index
|
896 |
+
null_start_logit = result.start_logits[0]
|
897 |
+
null_end_logit = result.end_logits[0]
|
898 |
+
for start_index in start_indexes:
|
899 |
+
for end_index in end_indexes:
|
900 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
901 |
+
# that the start of the span is in the question. We throw out all
|
902 |
+
# invalid predictions.
|
903 |
+
if start_index >= len(feature.tokens):
|
904 |
+
continue
|
905 |
+
if end_index >= len(feature.tokens):
|
906 |
+
continue
|
907 |
+
if start_index not in feature.token_to_orig_map:
|
908 |
+
continue
|
909 |
+
if end_index not in feature.token_to_orig_map:
|
910 |
+
continue
|
911 |
+
if not feature.token_is_max_context.get(start_index, False):
|
912 |
+
continue
|
913 |
+
if end_index < start_index:
|
914 |
+
continue
|
915 |
+
length = end_index - start_index + 1
|
916 |
+
if length > max_answer_length:
|
917 |
+
continue
|
918 |
+
prelim_predictions.append(
|
919 |
+
_PrelimPrediction(
|
920 |
+
feature_index=feature_index,
|
921 |
+
start_index=start_index,
|
922 |
+
end_index=end_index,
|
923 |
+
start_logit=result.start_logits[start_index],
|
924 |
+
end_logit=result.end_logits[end_index],
|
925 |
+
)
|
926 |
+
)
|
927 |
+
|
928 |
+
if FLAGS.version_2_with_negative:
|
929 |
+
prelim_predictions.append(
|
930 |
+
_PrelimPrediction(
|
931 |
+
feature_index=min_null_feature_index,
|
932 |
+
start_index=0,
|
933 |
+
end_index=0,
|
934 |
+
start_logit=null_start_logit,
|
935 |
+
end_logit=null_end_logit,
|
936 |
+
)
|
937 |
+
)
|
938 |
+
prelim_predictions = sorted(
|
939 |
+
prelim_predictions,
|
940 |
+
key=lambda x: (x.start_logit + x.end_logit),
|
941 |
+
reverse=True,
|
942 |
+
)
|
943 |
+
|
944 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
945 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
946 |
+
)
|
947 |
+
|
948 |
+
seen_predictions = {}
|
949 |
+
nbest = []
|
950 |
+
for pred in prelim_predictions:
|
951 |
+
if len(nbest) >= n_best_size:
|
952 |
+
break
|
953 |
+
feature = features[pred.feature_index]
|
954 |
+
if pred.start_index > 0: # this is a non-null prediction
|
955 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
956 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
957 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
958 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
959 |
+
tok_text = " ".join(tok_tokens)
|
960 |
+
|
961 |
+
# De-tokenize WordPieces that have been split off.
|
962 |
+
tok_text = tok_text.replace(" ##", "")
|
963 |
+
tok_text = tok_text.replace("##", "")
|
964 |
+
|
965 |
+
# Clean whitespace
|
966 |
+
tok_text = tok_text.strip()
|
967 |
+
tok_text = " ".join(tok_text.split())
|
968 |
+
orig_text = " ".join(orig_tokens)
|
969 |
+
|
970 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case)
|
971 |
+
if final_text in seen_predictions:
|
972 |
+
continue
|
973 |
+
|
974 |
+
seen_predictions[final_text] = True
|
975 |
+
else:
|
976 |
+
final_text = ""
|
977 |
+
seen_predictions[final_text] = True
|
978 |
+
|
979 |
+
nbest.append(
|
980 |
+
_NbestPrediction(
|
981 |
+
text=final_text,
|
982 |
+
start_logit=pred.start_logit,
|
983 |
+
end_logit=pred.end_logit,
|
984 |
+
)
|
985 |
+
)
|
986 |
+
|
987 |
+
# if we didn't inlude the empty option in the n-best, inlcude it
|
988 |
+
if FLAGS.version_2_with_negative:
|
989 |
+
if "" not in seen_predictions:
|
990 |
+
nbest.append(
|
991 |
+
_NbestPrediction(
|
992 |
+
text="", start_logit=null_start_logit, end_logit=null_end_logit
|
993 |
+
)
|
994 |
+
)
|
995 |
+
# In very rare edge cases we could have no valid predictions. So we
|
996 |
+
# just create a nonce prediction in this case to avoid failure.
|
997 |
+
if not nbest:
|
998 |
+
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
999 |
+
|
1000 |
+
assert len(nbest) >= 1
|
1001 |
+
|
1002 |
+
total_scores = []
|
1003 |
+
best_non_null_entry = None
|
1004 |
+
for entry in nbest:
|
1005 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
1006 |
+
if not best_non_null_entry:
|
1007 |
+
if entry.text:
|
1008 |
+
best_non_null_entry = entry
|
1009 |
+
|
1010 |
+
probs = _compute_softmax(total_scores)
|
1011 |
+
|
1012 |
+
nbest_json = []
|
1013 |
+
for (i, entry) in enumerate(nbest):
|
1014 |
+
output = collections.OrderedDict()
|
1015 |
+
output["text"] = entry.text
|
1016 |
+
output["probability"] = probs[i]
|
1017 |
+
output["start_logit"] = entry.start_logit
|
1018 |
+
output["end_logit"] = entry.end_logit
|
1019 |
+
nbest_json.append(output)
|
1020 |
+
|
1021 |
+
assert len(nbest_json) >= 1
|
1022 |
+
|
1023 |
+
if not FLAGS.version_2_with_negative:
|
1024 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
1025 |
+
else:
|
1026 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
1027 |
+
score_diff = (
|
1028 |
+
score_null
|
1029 |
+
- best_non_null_entry.start_logit
|
1030 |
+
- (best_non_null_entry.end_logit)
|
1031 |
+
)
|
1032 |
+
scores_diff_json[example.qas_id] = score_diff
|
1033 |
+
if score_diff > FLAGS.null_score_diff_threshold:
|
1034 |
+
all_predictions[example.qas_id] = ""
|
1035 |
+
else:
|
1036 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
1037 |
+
|
1038 |
+
all_nbest_json[example.qas_id] = nbest_json
|
1039 |
+
|
1040 |
+
with tf.gfile.GFile(output_prediction_file, "w") as writer:
|
1041 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
1042 |
+
|
1043 |
+
with tf.gfile.GFile(output_nbest_file, "w") as writer:
|
1044 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
1045 |
+
|
1046 |
+
if FLAGS.version_2_with_negative:
|
1047 |
+
with tf.gfile.GFile(output_null_log_odds_file, "w") as writer:
|
1048 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
1049 |
+
|
1050 |
+
|
1051 |
+
def get_final_text(pred_text, orig_text, do_lower_case):
|
1052 |
+
"""Project the tokenized prediction back to the original text."""
|
1053 |
+
|
1054 |
+
# When we created the data, we kept track of the alignment between original
|
1055 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
1056 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
1057 |
+
# span that we predicted.
|
1058 |
+
#
|
1059 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
1060 |
+
# our prediction.
|
1061 |
+
#
|
1062 |
+
# For example, let's say:
|
1063 |
+
# pred_text = steve smith
|
1064 |
+
# orig_text = Steve Smith's
|
1065 |
+
#
|
1066 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
1067 |
+
#
|
1068 |
+
# We don't want to return `pred_text` because it's already been normalized
|
1069 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
1070 |
+
# our tokenizer does additional normalization like stripping accent
|
1071 |
+
# characters).
|
1072 |
+
#
|
1073 |
+
# What we really want to return is "Steve Smith".
|
1074 |
+
#
|
1075 |
+
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
1076 |
+
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
1077 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
1078 |
+
|
1079 |
+
def _strip_spaces(text):
|
1080 |
+
ns_chars = []
|
1081 |
+
ns_to_s_map = collections.OrderedDict()
|
1082 |
+
for (i, c) in enumerate(text):
|
1083 |
+
if c == " ":
|
1084 |
+
continue
|
1085 |
+
ns_to_s_map[len(ns_chars)] = i
|
1086 |
+
ns_chars.append(c)
|
1087 |
+
ns_text = "".join(ns_chars)
|
1088 |
+
return (ns_text, ns_to_s_map)
|
1089 |
+
|
1090 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
1091 |
+
# and `pred_text`, and check if they are the same length. If they are
|
1092 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
1093 |
+
# length, we assume the characters are one-to-one aligned.
|
1094 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
1095 |
+
|
1096 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
1097 |
+
|
1098 |
+
start_position = tok_text.find(pred_text)
|
1099 |
+
if start_position == -1:
|
1100 |
+
if FLAGS.verbose_logging:
|
1101 |
+
tf.logging.info(
|
1102 |
+
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text)
|
1103 |
+
)
|
1104 |
+
return orig_text
|
1105 |
+
end_position = start_position + len(pred_text) - 1
|
1106 |
+
|
1107 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
1108 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
1109 |
+
|
1110 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
1111 |
+
if FLAGS.verbose_logging:
|
1112 |
+
tf.logging.info(
|
1113 |
+
"Length not equal after stripping spaces: '%s' vs '%s'",
|
1114 |
+
orig_ns_text,
|
1115 |
+
tok_ns_text,
|
1116 |
+
)
|
1117 |
+
return orig_text
|
1118 |
+
|
1119 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
1120 |
+
# the character-to-character alignment.
|
1121 |
+
tok_s_to_ns_map = {}
|
1122 |
+
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
1123 |
+
tok_s_to_ns_map[tok_index] = i
|
1124 |
+
|
1125 |
+
orig_start_position = None
|
1126 |
+
if start_position in tok_s_to_ns_map:
|
1127 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
1128 |
+
if ns_start_position in orig_ns_to_s_map:
|
1129 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
1130 |
+
|
1131 |
+
if orig_start_position is None:
|
1132 |
+
if FLAGS.verbose_logging:
|
1133 |
+
tf.logging.info("Couldn't map start position")
|
1134 |
+
return orig_text
|
1135 |
+
|
1136 |
+
orig_end_position = None
|
1137 |
+
if end_position in tok_s_to_ns_map:
|
1138 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
1139 |
+
if ns_end_position in orig_ns_to_s_map:
|
1140 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
1141 |
+
|
1142 |
+
if orig_end_position is None:
|
1143 |
+
if FLAGS.verbose_logging:
|
1144 |
+
tf.logging.info("Couldn't map end position")
|
1145 |
+
return orig_text
|
1146 |
+
|
1147 |
+
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
1148 |
+
return output_text
|
1149 |
+
|
1150 |
+
|
1151 |
+
def _get_best_indexes(logits, n_best_size):
|
1152 |
+
"""Get the n-best logits from a list."""
|
1153 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
1154 |
+
|
1155 |
+
best_indexes = []
|
1156 |
+
for i in range(len(index_and_score)):
|
1157 |
+
if i >= n_best_size:
|
1158 |
+
break
|
1159 |
+
best_indexes.append(index_and_score[i][0])
|
1160 |
+
return best_indexes
|
1161 |
+
|
1162 |
+
|
1163 |
+
def _compute_softmax(scores):
|
1164 |
+
"""Compute softmax probability over raw logits."""
|
1165 |
+
if not scores:
|
1166 |
+
return []
|
1167 |
+
|
1168 |
+
max_score = None
|
1169 |
+
for score in scores:
|
1170 |
+
if max_score is None or score > max_score:
|
1171 |
+
max_score = score
|
1172 |
+
|
1173 |
+
exp_scores = []
|
1174 |
+
total_sum = 0.0
|
1175 |
+
for score in scores:
|
1176 |
+
x = math.exp(score - max_score)
|
1177 |
+
exp_scores.append(x)
|
1178 |
+
total_sum += x
|
1179 |
+
|
1180 |
+
probs = []
|
1181 |
+
for score in exp_scores:
|
1182 |
+
probs.append(score / total_sum)
|
1183 |
+
return probs
|
1184 |
+
|
1185 |
+
|
1186 |
+
class FeatureWriter(object):
|
1187 |
+
"""Writes InputFeature to TF example file."""
|
1188 |
+
|
1189 |
+
def __init__(self, filename, is_training):
|
1190 |
+
self.filename = filename
|
1191 |
+
self.is_training = is_training
|
1192 |
+
self.num_features = 0
|
1193 |
+
self._writer = tf.python_io.TFRecordWriter(filename)
|
1194 |
+
|
1195 |
+
def process_feature(self, feature):
|
1196 |
+
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
|
1197 |
+
self.num_features += 1
|
1198 |
+
|
1199 |
+
def create_int_feature(values):
|
1200 |
+
feature = tf.train.Feature(
|
1201 |
+
int64_list=tf.train.Int64List(value=list(values))
|
1202 |
+
)
|
1203 |
+
return feature
|
1204 |
+
|
1205 |
+
features = collections.OrderedDict()
|
1206 |
+
features["unique_ids"] = create_int_feature([feature.unique_id])
|
1207 |
+
features["input_ids"] = create_int_feature(feature.input_ids)
|
1208 |
+
features["input_mask"] = create_int_feature(feature.input_mask)
|
1209 |
+
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
1210 |
+
|
1211 |
+
if self.is_training:
|
1212 |
+
features["start_positions"] = create_int_feature([feature.start_position])
|
1213 |
+
features["end_positions"] = create_int_feature([feature.end_position])
|
1214 |
+
impossible = 0
|
1215 |
+
if feature.is_impossible:
|
1216 |
+
impossible = 1
|
1217 |
+
features["is_impossible"] = create_int_feature([impossible])
|
1218 |
+
|
1219 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
1220 |
+
self._writer.write(tf_example.SerializeToString())
|
1221 |
+
|
1222 |
+
def close(self):
|
1223 |
+
self._writer.close()
|
1224 |
+
|
1225 |
+
|
1226 |
+
def validate_flags_or_throw(bert_config):
|
1227 |
+
"""Validate the input FLAGS or throw an exception."""
|
1228 |
+
tokenization.validate_case_matches_checkpoint(
|
1229 |
+
FLAGS.do_lower_case, FLAGS.init_checkpoint
|
1230 |
+
)
|
1231 |
+
|
1232 |
+
if not FLAGS.do_train and not FLAGS.do_predict:
|
1233 |
+
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
|
1234 |
+
|
1235 |
+
if FLAGS.do_train:
|
1236 |
+
if not FLAGS.train_file:
|
1237 |
+
raise ValueError(
|
1238 |
+
"If `do_train` is True, then `train_file` must be specified."
|
1239 |
+
)
|
1240 |
+
if FLAGS.do_predict:
|
1241 |
+
if not FLAGS.predict_file:
|
1242 |
+
raise ValueError(
|
1243 |
+
"If `do_predict` is True, then `predict_file` must be specified."
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
|
1247 |
+
raise ValueError(
|
1248 |
+
"Cannot use sequence length %d because the BERT model "
|
1249 |
+
"was only trained up to sequence length %d"
|
1250 |
+
% (FLAGS.max_seq_length, bert_config.max_position_embeddings)
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
|
1254 |
+
raise ValueError(
|
1255 |
+
"The max_seq_length (%d) must be greater than max_query_length "
|
1256 |
+
"(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
|
1260 |
+
def main(_):
|
1261 |
+
tf.logging.set_verbosity(tf.logging.INFO)
|
1262 |
+
logger = tf.get_logger()
|
1263 |
+
logger.propagate = False
|
1264 |
+
|
1265 |
+
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
1266 |
+
|
1267 |
+
validate_flags_or_throw(bert_config)
|
1268 |
+
|
1269 |
+
tf.gfile.MakeDirs(FLAGS.output_dir)
|
1270 |
+
|
1271 |
+
tokenizer = tokenization.FullTokenizer(
|
1272 |
+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
|
1273 |
+
)
|
1274 |
+
|
1275 |
+
tpu_cluster_resolver = None
|
1276 |
+
if FLAGS.use_tpu and FLAGS.tpu_name:
|
1277 |
+
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
1278 |
+
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project
|
1279 |
+
)
|
1280 |
+
|
1281 |
+
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
1282 |
+
run_config = tf.contrib.tpu.RunConfig(
|
1283 |
+
cluster=tpu_cluster_resolver,
|
1284 |
+
master=FLAGS.master,
|
1285 |
+
model_dir=FLAGS.output_dir,
|
1286 |
+
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
1287 |
+
tpu_config=tf.contrib.tpu.TPUConfig(
|
1288 |
+
iterations_per_loop=FLAGS.iterations_per_loop,
|
1289 |
+
num_shards=FLAGS.num_tpu_cores,
|
1290 |
+
per_host_input_for_training=is_per_host,
|
1291 |
+
),
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
train_examples = None
|
1295 |
+
num_train_steps = None
|
1296 |
+
num_warmup_steps = None
|
1297 |
+
if FLAGS.do_train:
|
1298 |
+
train_examples = read_squad_examples(
|
1299 |
+
input_file=FLAGS.train_file, is_training=True
|
1300 |
+
)
|
1301 |
+
num_train_steps = int(
|
1302 |
+
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs
|
1303 |
+
)
|
1304 |
+
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
1305 |
+
|
1306 |
+
# Pre-shuffle the input to avoid having to make a very large shuffle
|
1307 |
+
# buffer in in the `input_fn`.
|
1308 |
+
rng = random.Random(12345)
|
1309 |
+
rng.shuffle(train_examples)
|
1310 |
+
|
1311 |
+
model_fn = model_fn_builder(
|
1312 |
+
bert_config=bert_config,
|
1313 |
+
init_checkpoint=FLAGS.init_checkpoint,
|
1314 |
+
learning_rate=FLAGS.learning_rate,
|
1315 |
+
num_train_steps=num_train_steps,
|
1316 |
+
num_warmup_steps=num_warmup_steps,
|
1317 |
+
use_tpu=FLAGS.use_tpu,
|
1318 |
+
use_one_hot_embeddings=FLAGS.use_tpu,
|
1319 |
+
)
|
1320 |
+
|
1321 |
+
# If TPU is not available, this will fall back to normal Estimator on CPU
|
1322 |
+
# or GPU.
|
1323 |
+
estimator = tf.contrib.tpu.TPUEstimator(
|
1324 |
+
use_tpu=FLAGS.use_tpu,
|
1325 |
+
model_fn=model_fn,
|
1326 |
+
config=run_config,
|
1327 |
+
train_batch_size=FLAGS.train_batch_size,
|
1328 |
+
predict_batch_size=FLAGS.predict_batch_size,
|
1329 |
+
)
|
1330 |
+
|
1331 |
+
if FLAGS.do_train:
|
1332 |
+
# We write to a temporary file to avoid storing very large constant tensors
|
1333 |
+
# in memory.
|
1334 |
+
train_writer = FeatureWriter(
|
1335 |
+
filename=os.path.join(FLAGS.output_dir, "train.tf_record"), is_training=True
|
1336 |
+
)
|
1337 |
+
convert_examples_to_features(
|
1338 |
+
examples=train_examples,
|
1339 |
+
tokenizer=tokenizer,
|
1340 |
+
max_seq_length=FLAGS.max_seq_length,
|
1341 |
+
doc_stride=FLAGS.doc_stride,
|
1342 |
+
max_query_length=FLAGS.max_query_length,
|
1343 |
+
is_training=True,
|
1344 |
+
output_fn=train_writer.process_feature,
|
1345 |
+
)
|
1346 |
+
train_writer.close()
|
1347 |
+
|
1348 |
+
tf.logging.info("***** Running training *****")
|
1349 |
+
tf.logging.info(" Num orig examples = %d", len(train_examples))
|
1350 |
+
tf.logging.info(" Num split examples = %d", train_writer.num_features)
|
1351 |
+
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
1352 |
+
tf.logging.info(" Num steps = %d", num_train_steps)
|
1353 |
+
del train_examples
|
1354 |
+
|
1355 |
+
train_input_fn = input_fn_builder(
|
1356 |
+
input_file=train_writer.filename,
|
1357 |
+
seq_length=FLAGS.max_seq_length,
|
1358 |
+
is_training=True,
|
1359 |
+
drop_remainder=True,
|
1360 |
+
)
|
1361 |
+
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
1362 |
+
|
1363 |
+
if FLAGS.do_predict:
|
1364 |
+
eval_examples = read_squad_examples(
|
1365 |
+
input_file=FLAGS.predict_file, is_training=False
|
1366 |
+
)
|
1367 |
+
|
1368 |
+
eval_writer = FeatureWriter(
|
1369 |
+
filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), is_training=False
|
1370 |
+
)
|
1371 |
+
eval_features = []
|
1372 |
+
|
1373 |
+
def append_feature(feature):
|
1374 |
+
eval_features.append(feature)
|
1375 |
+
eval_writer.process_feature(feature)
|
1376 |
+
|
1377 |
+
convert_examples_to_features(
|
1378 |
+
examples=eval_examples,
|
1379 |
+
tokenizer=tokenizer,
|
1380 |
+
max_seq_length=FLAGS.max_seq_length,
|
1381 |
+
doc_stride=FLAGS.doc_stride,
|
1382 |
+
max_query_length=FLAGS.max_query_length,
|
1383 |
+
is_training=False,
|
1384 |
+
output_fn=append_feature,
|
1385 |
+
)
|
1386 |
+
eval_writer.close()
|
1387 |
+
|
1388 |
+
tf.logging.info("***** Running predictions *****")
|
1389 |
+
tf.logging.info(" Num orig examples = %d", len(eval_examples))
|
1390 |
+
tf.logging.info(" Num split examples = %d", len(eval_features))
|
1391 |
+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
1392 |
+
|
1393 |
+
all_results = []
|
1394 |
+
|
1395 |
+
predict_input_fn = input_fn_builder(
|
1396 |
+
input_file=eval_writer.filename,
|
1397 |
+
seq_length=FLAGS.max_seq_length,
|
1398 |
+
is_training=False,
|
1399 |
+
drop_remainder=False,
|
1400 |
+
)
|
1401 |
+
|
1402 |
+
# If running eval on the TPU, you will need to specify the number of
|
1403 |
+
# steps.
|
1404 |
+
all_results = []
|
1405 |
+
for result in estimator.predict(predict_input_fn, yield_single_examples=True):
|
1406 |
+
if len(all_results) % 1000 == 0:
|
1407 |
+
tf.logging.info("Processing example: %d" % (len(all_results)))
|
1408 |
+
unique_id = int(result["unique_ids"])
|
1409 |
+
start_logits = [float(x) for x in result["start_logits"].flat]
|
1410 |
+
end_logits = [float(x) for x in result["end_logits"].flat]
|
1411 |
+
all_results.append(
|
1412 |
+
RawResult(
|
1413 |
+
unique_id=unique_id,
|
1414 |
+
start_logits=start_logits,
|
1415 |
+
end_logits=end_logits,
|
1416 |
+
)
|
1417 |
+
)
|
1418 |
+
|
1419 |
+
output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
|
1420 |
+
output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
|
1421 |
+
output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")
|
1422 |
+
|
1423 |
+
write_predictions(
|
1424 |
+
eval_examples,
|
1425 |
+
eval_features,
|
1426 |
+
all_results,
|
1427 |
+
FLAGS.n_best_size,
|
1428 |
+
FLAGS.max_answer_length,
|
1429 |
+
FLAGS.do_lower_case,
|
1430 |
+
output_prediction_file,
|
1431 |
+
output_nbest_file,
|
1432 |
+
output_null_log_odds_file,
|
1433 |
+
)
|
1434 |
+
|
1435 |
+
|
1436 |
+
if __name__ == "__main__":
|
1437 |
+
flags.mark_flag_as_required("vocab_file")
|
1438 |
+
flags.mark_flag_as_required("bert_config_file")
|
1439 |
+
flags.mark_flag_as_required("output_dir")
|
1440 |
+
tf.app.run()
|
arabert/arabert/sample_text.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Text should be one-sentence-per-line, with empty lines between documents.
|
2 |
+
This sample text is was randomly selected from the pretriaining corpus for araBERT with Farasa Tokenization.
|
3 |
+
|
4 |
+
• « أدعو ال+ جه +ات ال+ مختص +ة في ال+ دول +ة إلى إجراء دراس +ة مقارن +ة بين مستوي +ات ال+ طلب +ة في زمني و+ في هذا ال+ وقت » .
|
5 |
+
أريد طرح جانب واحد بسيط في هذه ال+ مساح +ة بناء على حوار بين +ي و+ بين أحد أقربائي عن كيفي +ة تأثير هذه ال+ حقب +ة ال+ رقمي +ة على ال+ حيا +ة ال+ جامعي +ة ، و+ مما لا شك في +ه أن ال+ تقني +ة اقتحم +ت حقل ال+ تعليم +كما فعل +ت مع غير +ه و+ غير +ت +ه ب+ شكل جذري .
|
6 |
+
ال+ يوم : ال+ هواتف ال+ ذكي +ة منتشر +ة و+ ال+ معلم +ون متوافر +ون على مدار ال+ ساع +ة ، و+ لن نبالغ إذا قل +نا إن هاتف أو كمبيوتر ال+ مكتب في زمن +نا هذا قد لا يكون ضروري +ا .
|
7 |
+
ال+ يوم : ال+ واتساب وسيل +ة ال+ تنسيق و+ ال+ تواصل بين ال+ طلاب ، و+ من خلال +ه تنشأ مجموع +ات تناقش ال+ مشروع +ات ال+ جامعي +ة ، بل إن ال+ واتساب ألغى ال+ حاج +ة إلى ال+ ماسح ال+ ضوئي في حال +ات كثير +ة ، و+ أصبح تصوير ال+ أوراق و+ ال+ مستند +ات يتم عن طريق +ه ب+ استخدام كاميرا ال+ هاتف ال+ محمول .
|
8 |
+
ال+ يوم : كل شيء يتم تصوير +ه و+ رصد +ه و+ بث +ه على وسائل ال+ تواصل كاف +ة ، و+ إن انتشر تفتح ال+ سلط +ات تحقيق +ا و+ يتحول أي موضوع ، سواء كان تافه +ا أم كبير +ا ، إلى و+ صم +ة عار في ال+ حال +ة ال+ أولى ، و+ قضي +ة في ال+ حال +ة ال+ ثاني +ة .
|
9 |
+
ال+ يوم : ينشغل ال+ طلاب ب+ وسائل ال+ تواصل ال+ اجتماعي داخل ال+ فصل بين مدرس +ين متساهل +ين و+ متشدد +ين ، و+ هناك طلاب يستخدم +ون ال+ يوتيوب وسيل +ة ترفيهي +ة داخل ال+ فصل على حساب ال+ درس و+ ب+ وجود ال+ معلم .
|
10 |
+
+كن +نا نمزح مع بعض +نا كثير +ا مزاح +ا خفيف +ا و+ ثقيل +ا ، و+ كان +ت تحدث معارك أحيان +ا ، لكن كان +ت فضيل +ة ال+ ستر منتشر +ة بين +نا ، و+ لم يكن أحد يشي ب+ ال+ آخر ، إلا نادر +ا ، و+ إن حدث ذلك ف+ لا دليل علي +ه ، إلا من صاحب ال+ وشاي +ة و+ كلم +ت +ه قد تصدق ، أو يتم تجاهل +ها .
|
11 |
+
لم يكن استخدام ال+ هاتف ال+ محمول مسموح +ا ب+ +ه داخل ال+ فصول ال+ دراسي +ة إلا في حال +ات ال+ طوارئ ، ال+ معلم +ون ال+ أجانب كان +وا يتساهلون مع ذلك بينما ال+ عرب متشدد +ون ، كان مسموح +ا ل+ +نا ب+ فتح كتاب تعليمي عن ال+ ماد +ة نفس +ها و+ ال+ مطالع +ة في +ه ب+ حضور ال+ مدرس .
|
12 |
+
|
13 |
+
أزم +ات ال+ أندي +ة - ال+ إمار +ات ال+ يومكثر ال+ حديث في ال+ آون +ة ال+ أخير +ة عن ال+ مشكل +ات التي تعاني +ها أندي +ت +نا ، و+ ما تواجه +ه من معوق +ات و+ تحدي +ات فرض +ت علي +ها ب+ سبب عوامل و+ تراكم +ات أسلوب ال+ عمل ال+ إداري ، الذي تنتهج +ه ال+ أغلبي +ة من +ها .
|
14 |
+
ال+ أزم +ات التي تظهر بين فتر +ة و+ أخرى مرد +ها غياب ال+ منهجي +ة و+ سوء ال+ تخطيط و+ ال+ صرف ال+ عشوائي .
|
15 |
+
أما ال+ أهلي ف+ رغم ما مر ب+ +ه خلال هذا ال+ موسم و+ ما واجه +ه من تحدي +ات ، إلا أن +ه أنعش آمال +ه من جديد و+ بقي +ت ل+ +ه خطو +ة من أجل مرافق +ة ال+ فرق ال+ متأهل +ة ل+ ال+ دور ال+ تالي .
|
16 |
+
ف+ هو واقع تعيش +ه هذه ال+ أندي +ة و+ ال+ جميع يدرك تداعي +ات +ه ال+ سلبي +ة علي +ها ، التي تنعكس مباشر +ة على أدائ +ها ال+ مؤسسي و+ مخرج +ات +ه ، و+ هو أمر يخالف ال+ طموح .
|
17 |
+
لكن ال+ سؤال الذي يتردد دائم +ا من ال+ متسبب في هذا ؟ و+ ل+ أجل ذلك عاد ال+ عين ب+ مكاسب عد +ة لعل أهم +ها أن +ه استطاع أن يغسل أحزان +ه ال+ محلي +ة في ال+ بطول +ة ال+ آسيوي +ة ، و+ يحيى أمل +ه في ال+ خروج ب+ مكسب آسيوي ينسي +ه خسار +ة ال+ دوري و+ ال+ كأس ، بل قد يعطي +ه مساح +ة أكبر من ال+ تركيز ل+ ال+ منافس +ة على ال+ أبطال و+ إعاد +ة ذكري +ات 2003 .
|
18 |
+
و+ ��عل ال+ أزم +ات التي تظهر بين فتر +ة و+ أخرى مرد +ها غياب ال+ منهجي +ة و+ سوء ال+ تخطيط ، ب+ ال+ إضاف +ة إلى ال+ صرف ال+ عشوائي الذي كبد ميزاني +ات +ها ال+ كثير ، و+ وضع +ها في خان +ة حرج +ة دفع +ها أحيان +ا ل+ إطلاق صرخ +ات ال+ استغاث +ة ل+ نجد +ت +ها و+ إخراج +ها من تلك ال+ دوام +ات التي تقع في +ها .
|
19 |
+
و+ لماذا يستمر هذا ال+ وضع في أغلب ال+ أندي +ة دون حراك نحو ال+ تغيير و+ ال+ تطوير و+ خلع
|
20 |
+
|
21 |
+
This sample text is was randomly selected from the pretriaining corpus for araBERT WITHOUT Farasa Tokenization.
|
22 |
+
|
23 |
+
• " أدعو الجهات المختصة في الدولة إلى إجراء دراسة مقارنة بين مستويات الطلبة في زمني وفي هذا الوقت ".
|
24 |
+
أريد طرح جانب واحد بسيط في هذه المساحة بناء على حوار بيني وبين أحد أقربائي عن كيفية تأثير هذه الحقبة الرقمية على الحياة الجامعية ، ومما لا شك فيه أن التقنية اقتحمت حقل التعليم كما فعلت مع غيره وغيرته بشكل جذري.
|
25 |
+
اليوم : الهواتف الذكية منتشرة والمعلمون متوافرون على مدار الساعة ، ولن نبالغ إذا قلنا إن هاتف أو كمبيوتر المكتب في زمننا هذا قد لا يكون ضروريا.
|
26 |
+
اليوم : الواتساب وسيلة التنسيق والتواصل بين الطلاب ، ومن خلاله تنشأ مجموعات تناقش المشروعات الجامعية ، بل إن الواتساب ألغى الحاجة إلى الماسح الضوئي في حالات كثيرة ، وأصبح تصوير الأوراق والمستندات يتم عن طريقه باستخدام كاميرا الهاتف المحمول.
|
27 |
+
اليوم : كل شيء يتم تصويره ورصده وبثه على وسائل التواصل كافة ، وإن انتشر تفتح السلطات تحقيقا ويتحول أي موضوع ، سواء كان تافها أم كبيرا ، إلى وصمة عار في الحالة الأولى ، وقضية في الحالة الثانية.
|
28 |
+
اليوم : ينشغل الطلاب بوسائل التواصل الاجتماعي داخل الفصل بين مدرسين متساهلين ومتشددين ، وهناك طلاب يستخدمون اليوتيوب وسيلة ترفيهية داخل الفصل على حساب الدرس وبوجود المعلم.
|
29 |
+
كنا نمزح مع بعضنا كثيرا مزاحا خفيفا وثقيلا ، وكانت تحدث معارك أحيانا ، لكن كانت فضيلة الستر منتشرة بيننا ، ولم يكن أحد يشي بالآخر ، إلا نادرا ، وإن حدث ذلك فلا دليل عليه ، إلا من صاحب الوشاية وكلمته قد تصدق ، أو يتم تجاهلها.
|
30 |
+
لم يكن استخدام الهاتف المحمول مسموحا به داخل الفصول الدراسية إلا في حالات الطوارئ ، المعلمون الأجانب كانوا يتساهلون مع ذلك بينما العرب متشددون ، كان مسموحا لنا بفتح كتاب تعليمي عن المادة نفسها والمطالعة فيه بحضور المدرس.
|
31 |
+
|
32 |
+
أزمات الأندية - الإمارات اليومكثر الحديث في الآونة الأخيرة عن المشكلات التي تعانيها أنديتنا ، وما تواجهه من معوقات وتحديات فرضت عليها بسبب عوامل وتراكمات أسلوب العمل الإداري ، الذي تنتهجه الأغلبية منها.
|
33 |
+
الأزمات التي تظهر بين فترة وأخرى مردها غياب المنهجية وسوء التخطيط والصرف العشوائي.
|
34 |
+
أما الأهلي فرغم ما مر به خلال هذا الموسم وما واجهه من تحديات ، إلا أنه أنعش آماله من جديد وبقيت له خطوة من أجل مرافقة الفرق المتأهلة للدور التالي.
|
35 |
+
فهو واقع تعيشه هذه الأندية والجميع يدرك تداعياته السلبية عليها ، التي تنعكس مباشرة على أدائها المؤسسي ومخرجاته ، وهو أمر يخالف الطموح.
|
36 |
+
لكن السؤال الذي يتردد دائما من المتسبب في هذا ؟ ولأجل ذلك عاد العين بمكاسب عدة لعل أهمها أنه استطاع أن يغسل أحزانه المحلية في البطولة الآسيوية ، ويحيى أمله في الخروج بمكسب آسيوي ينسيه خسارة الدوري والكأس ، بل قد يعطيه مساحة أكبر من التركيز للمنافسة على الأبطال وإعادة ذكريات 2003.
|
37 |
+
ولعل الأزمات التي تظهر بين فترة وأخرى مرده�� غياب المنهجية وسوء التخطيط ، بالإضافة إلى الصرف العشوائي الذي كبد ميزانياتها الكثير ، ووضعها في خانة حرجة دفعها أحيانا لإطلاق صرخات الاستغاثة لنجدتها وإخراجها من تلك الدوامات التي تقع فيها.
|
38 |
+
ولماذا يستمر هذا الوضع في أغلب الأندية دون حراك نحو التغيير والتطوير وخلع الجلباب الإداري القديم؟
|
arabert/arabert/tokenization.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes."""
|
16 |
+
|
17 |
+
from __future__ import absolute_import
|
18 |
+
from __future__ import division
|
19 |
+
from __future__ import print_function
|
20 |
+
|
21 |
+
import collections
|
22 |
+
import re
|
23 |
+
import unicodedata
|
24 |
+
import six
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
|
28 |
+
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
29 |
+
"""Checks whether the casing config is consistent with the checkpoint name."""
|
30 |
+
|
31 |
+
# The casing has to be passed in by the user and there is no explicit check
|
32 |
+
# as to whether it matches the checkpoint. The casing information probably
|
33 |
+
# should have been stored in the bert_config.json file, but it's not, so
|
34 |
+
# we have to heuristically detect it to validate.
|
35 |
+
|
36 |
+
if not init_checkpoint:
|
37 |
+
return
|
38 |
+
|
39 |
+
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
|
40 |
+
if m is None:
|
41 |
+
return
|
42 |
+
|
43 |
+
model_name = m.group(1)
|
44 |
+
|
45 |
+
lower_models = [
|
46 |
+
"uncased_L-24_H-1024_A-16",
|
47 |
+
"uncased_L-12_H-768_A-12",
|
48 |
+
"multilingual_L-12_H-768_A-12",
|
49 |
+
"chinese_L-12_H-768_A-12",
|
50 |
+
]
|
51 |
+
|
52 |
+
cased_models = [
|
53 |
+
"cased_L-12_H-768_A-12",
|
54 |
+
"cased_L-24_H-1024_A-16",
|
55 |
+
"multi_cased_L-12_H-768_A-12",
|
56 |
+
]
|
57 |
+
|
58 |
+
is_bad_config = False
|
59 |
+
if model_name in lower_models and not do_lower_case:
|
60 |
+
is_bad_config = True
|
61 |
+
actual_flag = "False"
|
62 |
+
case_name = "lowercased"
|
63 |
+
opposite_flag = "True"
|
64 |
+
|
65 |
+
if model_name in cased_models and do_lower_case:
|
66 |
+
is_bad_config = True
|
67 |
+
actual_flag = "True"
|
68 |
+
case_name = "cased"
|
69 |
+
opposite_flag = "False"
|
70 |
+
|
71 |
+
if is_bad_config:
|
72 |
+
raise ValueError(
|
73 |
+
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
|
74 |
+
"However, `%s` seems to be a %s model, so you "
|
75 |
+
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
76 |
+
"how the model was pre-training. If this error is wrong, please "
|
77 |
+
"just comment out this check."
|
78 |
+
% (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def convert_to_unicode(text):
|
83 |
+
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
84 |
+
if six.PY3:
|
85 |
+
if isinstance(text, str):
|
86 |
+
return text
|
87 |
+
elif isinstance(text, bytes):
|
88 |
+
return text.decode("utf-8", "ignore")
|
89 |
+
else:
|
90 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
91 |
+
elif six.PY2:
|
92 |
+
if isinstance(text, str):
|
93 |
+
return text.decode("utf-8", "ignore")
|
94 |
+
elif isinstance(text, unicode):
|
95 |
+
return text
|
96 |
+
else:
|
97 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
98 |
+
else:
|
99 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
100 |
+
|
101 |
+
|
102 |
+
def printable_text(text):
|
103 |
+
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
104 |
+
|
105 |
+
# These functions want `str` for both Python2 and Python3, but in one case
|
106 |
+
# it's a Unicode string and in the other it's a byte string.
|
107 |
+
if six.PY3:
|
108 |
+
if isinstance(text, str):
|
109 |
+
return text
|
110 |
+
elif isinstance(text, bytes):
|
111 |
+
return text.decode("utf-8", "ignore")
|
112 |
+
else:
|
113 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
114 |
+
elif six.PY2:
|
115 |
+
if isinstance(text, str):
|
116 |
+
return text
|
117 |
+
elif isinstance(text, unicode):
|
118 |
+
return text.encode("utf-8")
|
119 |
+
else:
|
120 |
+
raise ValueError("Unsupported string type: %s" % (type(text)))
|
121 |
+
else:
|
122 |
+
raise ValueError("Not running on Python2 or Python 3?")
|
123 |
+
|
124 |
+
|
125 |
+
def load_vocab(vocab_file):
|
126 |
+
"""Loads a vocabulary file into a dictionary."""
|
127 |
+
vocab = collections.OrderedDict()
|
128 |
+
index = 0
|
129 |
+
with tf.gfile.GFile(vocab_file, "r") as reader:
|
130 |
+
while True:
|
131 |
+
token = convert_to_unicode(reader.readline())
|
132 |
+
if not token:
|
133 |
+
break
|
134 |
+
token = token.strip()
|
135 |
+
vocab[token] = index
|
136 |
+
index += 1
|
137 |
+
return vocab
|
138 |
+
|
139 |
+
|
140 |
+
def convert_by_vocab(vocab, items):
|
141 |
+
"""Converts a sequence of [tokens|ids] using the vocab."""
|
142 |
+
output = []
|
143 |
+
for item in items:
|
144 |
+
output.append(vocab[item])
|
145 |
+
return output
|
146 |
+
|
147 |
+
|
148 |
+
def convert_tokens_to_ids(vocab, tokens):
|
149 |
+
return convert_by_vocab(vocab, tokens)
|
150 |
+
|
151 |
+
|
152 |
+
def convert_ids_to_tokens(inv_vocab, ids):
|
153 |
+
return convert_by_vocab(inv_vocab, ids)
|
154 |
+
|
155 |
+
|
156 |
+
def whitespace_tokenize(text):
|
157 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
158 |
+
text = text.strip()
|
159 |
+
if not text:
|
160 |
+
return []
|
161 |
+
tokens = text.split()
|
162 |
+
return tokens
|
163 |
+
|
164 |
+
|
165 |
+
class FullTokenizer(object):
|
166 |
+
"""Runs end-to-end tokenziation."""
|
167 |
+
|
168 |
+
def __init__(self, vocab_file, do_lower_case=True):
|
169 |
+
self.vocab = load_vocab(vocab_file)
|
170 |
+
self.inv_vocab = {v: k for k, v in self.vocab.items()}
|
171 |
+
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
172 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
173 |
+
|
174 |
+
def tokenize(self, text):
|
175 |
+
split_tokens = []
|
176 |
+
for token in self.basic_tokenizer.tokenize(text):
|
177 |
+
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
178 |
+
split_tokens.append(sub_token)
|
179 |
+
|
180 |
+
return split_tokens
|
181 |
+
|
182 |
+
def convert_tokens_to_ids(self, tokens):
|
183 |
+
return convert_by_vocab(self.vocab, tokens)
|
184 |
+
|
185 |
+
def convert_ids_to_tokens(self, ids):
|
186 |
+
return convert_by_vocab(self.inv_vocab, ids)
|
187 |
+
|
188 |
+
|
189 |
+
class BasicTokenizer(object):
|
190 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
191 |
+
|
192 |
+
def __init__(self, do_lower_case=True):
|
193 |
+
"""Constructs a BasicTokenizer.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
do_lower_case: Whether to lower case the input.
|
197 |
+
"""
|
198 |
+
self.do_lower_case = do_lower_case
|
199 |
+
|
200 |
+
def tokenize(self, text):
|
201 |
+
"""Tokenizes a piece of text."""
|
202 |
+
text = convert_to_unicode(text)
|
203 |
+
text = self._clean_text(text)
|
204 |
+
|
205 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
206 |
+
# models. This is also applied to the English models now, but it doesn't
|
207 |
+
# matter since the English models were not trained on any Chinese data
|
208 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
209 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
210 |
+
# words in the English Wikipedia.).
|
211 |
+
text = self._tokenize_chinese_chars(text)
|
212 |
+
|
213 |
+
orig_tokens = whitespace_tokenize(text)
|
214 |
+
split_tokens = []
|
215 |
+
for token in orig_tokens:
|
216 |
+
if self.do_lower_case:
|
217 |
+
token = token.lower()
|
218 |
+
token = self._run_strip_accents(token)
|
219 |
+
split_tokens.extend(self._run_split_on_punc(token))
|
220 |
+
|
221 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
222 |
+
return output_tokens
|
223 |
+
|
224 |
+
def _run_strip_accents(self, text):
|
225 |
+
"""Strips accents from a piece of text."""
|
226 |
+
text = unicodedata.normalize("NFD", text)
|
227 |
+
output = []
|
228 |
+
for char in text:
|
229 |
+
cat = unicodedata.category(char)
|
230 |
+
if cat == "Mn":
|
231 |
+
continue
|
232 |
+
output.append(char)
|
233 |
+
return "".join(output)
|
234 |
+
|
235 |
+
def _run_split_on_punc(self, text):
|
236 |
+
"""Splits punctuation on a piece of text."""
|
237 |
+
chars = list(text)
|
238 |
+
i = 0
|
239 |
+
start_new_word = True
|
240 |
+
output = []
|
241 |
+
while i < len(chars):
|
242 |
+
char = chars[i]
|
243 |
+
if _is_punctuation(char):
|
244 |
+
output.append([char])
|
245 |
+
start_new_word = True
|
246 |
+
else:
|
247 |
+
if start_new_word:
|
248 |
+
output.append([])
|
249 |
+
start_new_word = False
|
250 |
+
output[-1].append(char)
|
251 |
+
i += 1
|
252 |
+
|
253 |
+
return ["".join(x) for x in output]
|
254 |
+
|
255 |
+
def _tokenize_chinese_chars(self, text):
|
256 |
+
"""Adds whitespace around any CJK character."""
|
257 |
+
output = []
|
258 |
+
for char in text:
|
259 |
+
cp = ord(char)
|
260 |
+
if self._is_chinese_char(cp):
|
261 |
+
output.append(" ")
|
262 |
+
output.append(char)
|
263 |
+
output.append(" ")
|
264 |
+
else:
|
265 |
+
output.append(char)
|
266 |
+
return "".join(output)
|
267 |
+
|
268 |
+
def _is_chinese_char(self, cp):
|
269 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
270 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
271 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
272 |
+
#
|
273 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
274 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
275 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
276 |
+
# space-separated words, so they are not treated specially and handled
|
277 |
+
# like the all of the other languages.
|
278 |
+
if (
|
279 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
280 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
281 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
282 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
283 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
284 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
285 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
286 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
287 |
+
): #
|
288 |
+
return True
|
289 |
+
|
290 |
+
return False
|
291 |
+
|
292 |
+
def _clean_text(self, text):
|
293 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
294 |
+
output = []
|
295 |
+
for char in text:
|
296 |
+
cp = ord(char)
|
297 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
298 |
+
continue
|
299 |
+
if _is_whitespace(char):
|
300 |
+
output.append(" ")
|
301 |
+
else:
|
302 |
+
output.append(char)
|
303 |
+
return "".join(output)
|
304 |
+
|
305 |
+
|
306 |
+
class WordpieceTokenizer(object):
|
307 |
+
"""Runs WordPiece tokenziation."""
|
308 |
+
|
309 |
+
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
|
310 |
+
self.vocab = vocab
|
311 |
+
self.unk_token = unk_token
|
312 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
313 |
+
|
314 |
+
def tokenize(self, text):
|
315 |
+
"""Tokenizes a piece of text into its word pieces.
|
316 |
+
|
317 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
318 |
+
using the given vocabulary.
|
319 |
+
|
320 |
+
For example:
|
321 |
+
input = "unaffable"
|
322 |
+
output = ["un", "##aff", "##able"]
|
323 |
+
|
324 |
+
Args:
|
325 |
+
text: A single token or whitespace separated tokens. This should have
|
326 |
+
already been passed through `BasicTokenizer.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
A list of wordpiece tokens.
|
330 |
+
"""
|
331 |
+
|
332 |
+
text = convert_to_unicode(text)
|
333 |
+
|
334 |
+
output_tokens = []
|
335 |
+
for token in whitespace_tokenize(text):
|
336 |
+
chars = list(token)
|
337 |
+
if len(chars) > self.max_input_chars_per_word:
|
338 |
+
output_tokens.append(self.unk_token)
|
339 |
+
continue
|
340 |
+
|
341 |
+
is_bad = False
|
342 |
+
start = 0
|
343 |
+
sub_tokens = []
|
344 |
+
while start < len(chars):
|
345 |
+
end = len(chars)
|
346 |
+
cur_substr = None
|
347 |
+
while start < end:
|
348 |
+
substr = "".join(chars[start:end])
|
349 |
+
if start > 0:
|
350 |
+
substr = "##" + substr
|
351 |
+
if substr in self.vocab:
|
352 |
+
cur_substr = substr
|
353 |
+
break
|
354 |
+
end -= 1
|
355 |
+
if cur_substr is None:
|
356 |
+
is_bad = True
|
357 |
+
break
|
358 |
+
sub_tokens.append(cur_substr)
|
359 |
+
start = end
|
360 |
+
|
361 |
+
if is_bad:
|
362 |
+
output_tokens.append(self.unk_token)
|
363 |
+
else:
|
364 |
+
output_tokens.extend(sub_tokens)
|
365 |
+
return output_tokens
|
366 |
+
|
367 |
+
|
368 |
+
def _is_whitespace(char):
|
369 |
+
"""Checks whether `chars` is a whitespace character."""
|
370 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
371 |
+
# as whitespace since they are generally considered as such.
|
372 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
373 |
+
return True
|
374 |
+
cat = unicodedata.category(char)
|
375 |
+
if cat == "Zs":
|
376 |
+
return True
|
377 |
+
return False
|
378 |
+
|
379 |
+
|
380 |
+
def _is_control(char):
|
381 |
+
"""Checks whether `chars` is a control character."""
|
382 |
+
# These are technically control characters but we count them as whitespace
|
383 |
+
# characters.
|
384 |
+
if char == "\t" or char == "\n" or char == "\r":
|
385 |
+
return False
|
386 |
+
cat = unicodedata.category(char)
|
387 |
+
if cat in ("Cc", "Cf"):
|
388 |
+
return True
|
389 |
+
return False
|
390 |
+
|
391 |
+
|
392 |
+
def _is_punctuation(char):
|
393 |
+
"""Checks whether `chars` is a punctuation character."""
|
394 |
+
cp = ord(char)
|
395 |
+
# We treat all non-letter/number ASCII as punctuation.
|
396 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
397 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
398 |
+
# consistency.
|
399 |
+
if (
|
400 |
+
cp == 91 or cp == 93 or cp == 43
|
401 |
+
): # [ and ] are not punctuation since they are used in [xx] and the +
|
402 |
+
return False
|
403 |
+
|
404 |
+
if (
|
405 |
+
(cp >= 33 and cp <= 47)
|
406 |
+
or (cp >= 58 and cp <= 64)
|
407 |
+
or (cp >= 91 and cp <= 96)
|
408 |
+
or (cp >= 123 and cp <= 126)
|
409 |
+
):
|
410 |
+
return True
|
411 |
+
cat = unicodedata.category(char)
|
412 |
+
if cat.startswith("P"):
|
413 |
+
return True
|
414 |
+
return False
|
arabert/arabert_logo.png
ADDED
![]() |
arabert/araelectra/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.vscode/
|
3 |
+
data/
|
4 |
+
*.bat
|
arabert/araelectra/LICENSE
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
==========================================
|
2 |
+
SOFTWARE LICENSE AGREEMENT - AraELECTRA
|
3 |
+
==========================================
|
4 |
+
|
5 |
+
* NAME: AraELECTRA: Pre-Training Text Discriminatorsfor Arabic Language Understanding
|
6 |
+
|
7 |
+
* ACKNOWLEDGMENTS
|
8 |
+
|
9 |
+
This [software] was generated by [American
|
10 |
+
University of Beirut] (“Owners”). The statements
|
11 |
+
made herein are solely the responsibility of the author[s].
|
12 |
+
|
13 |
+
The following software programs and programs have been used in the
|
14 |
+
generation of [AraELECTRA]:
|
15 |
+
|
16 |
+
+ ELECTRA
|
17 |
+
- Kevin Clark and Minh-Thang Luong and Quoc V. Le and Christopher D. Manning.
|
18 |
+
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
|
19 |
+
https://openreview.net/pdf?id=r1xMH1BtvB, 2020.
|
20 |
+
- License and link : https://github.com/google-research/electra
|
21 |
+
|
22 |
+
+ PyArabic
|
23 |
+
- T. Zerrouki, Pyarabic, An Arabic language library for Python,
|
24 |
+
https://pypi.python.org/pypi/pyarabic/, 2010
|
25 |
+
- License and link: https://github.com/linuxscout/pyarabic/
|
26 |
+
|
27 |
+
* LICENSE
|
28 |
+
|
29 |
+
This software and database is being provided to you, the LICENSEE,
|
30 |
+
by the Owners under the following license. By obtaining, using and/or
|
31 |
+
copying this software and database, you agree that you have read,
|
32 |
+
understood, and will comply with these terms and conditions. You
|
33 |
+
further agree that you have read and you will abide by the license
|
34 |
+
agreements provided in the above links under “acknowledgements”:
|
35 |
+
Permission to use, copy, modify and distribute this software and
|
36 |
+
database and its documentation for any purpose and without fee or
|
37 |
+
royalty is hereby granted, provided that you agree to comply with the
|
38 |
+
following copyright notice and statements, including the disclaimer,
|
39 |
+
and that the same appear on ALL copies of the software, database and
|
40 |
+
documentation, including modifications that you make for internal use
|
41 |
+
or for distribution. [AraELECTRA] Copyright 2020 by [American University
|
42 |
+
of Beirut]. All rights reserved. If you remix, transform, or build
|
43 |
+
upon the material, you must distribute your contributions under the
|
44 |
+
same license as this one. You may not apply legal terms or technological
|
45 |
+
measures that legally restrict others from doing anything this license
|
46 |
+
permits. THIS SOFTWARE IS PROVIDED "AS IS" AND THE OWNERS MAKE NO
|
47 |
+
REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE,
|
48 |
+
BUT NOT LIMITATION, THE OWNERS MAKE NO REPRESENTATIONS OR WARRANTIES OF
|
49 |
+
MERCHANT-ABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF
|
50 |
+
THE LICENSED SOFTWARE, DATABASE OR DOCUMENTATION WILL NOT INFRINGE ANY THIRD
|
51 |
+
PARTY PATENTS, COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS. The name of the
|
52 |
+
Owners may not be used in advertising or publicity pertaining to
|
53 |
+
distribution of the software and/or database. Title to copyright in
|
54 |
+
this software, database and any associated documentation shall at all
|
55 |
+
times remain with the Owners and LICENSEE agrees to preserve same.
|
56 |
+
|
57 |
+
The use of AraELECTRA should be cited as follows:
|
58 |
+
|
59 |
+
@inproceedings{antoun-etal-2021-araelectra,
|
60 |
+
title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
|
61 |
+
author = "Antoun, Wissam and
|
62 |
+
Baly, Fady and
|
63 |
+
Hajj, Hazem",
|
64 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
65 |
+
month = apr,
|
66 |
+
year = "2021",
|
67 |
+
address = "Kyiv, Ukraine (Virtual)",
|
68 |
+
publisher = "Association for Computational Linguistics",
|
69 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
|
70 |
+
pages = "191--195",
|
71 |
+
}
|
72 |
+
|
73 |
+
[AraELECTRA] Copyright 2020 by [American University of Beirut].
|
74 |
+
All rights reserved.
|
75 |
+
==========================================
|
76 |
+
|
arabert/araelectra/README.md
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ELECTRA
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
**ELECTRA** is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using relatively little compute. ELECTRA models are trained to distinguish "real" input tokens vs "fake" input tokens generated by another neural network, similar to the discriminator of a [GAN](https://arxiv.org/pdf/1406.2661.pdf). AraELECTRA achieves state-of-the-art results on Arabic QA dataset.
|
6 |
+
|
7 |
+
For a detailed description, please refer to the AraELECTRA paper [AraELECTRA: Pre-Training Text Discriminatorsfor Arabic Language Understanding](https://arxiv.org/abs/2012.15516).
|
8 |
+
|
9 |
+
This repository contains code to pre-train ELECTRA. It also supports fine-tuning ELECTRA on downstream tasks including classification tasks (e.g,. [GLUE](https://gluebenchmark.com/)), QA tasks (e.g., [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)), and sequence tagging tasks (e.g., [text chunking](https://www.clips.uantwerpen.be/conll2000/chunking/)).
|
10 |
+
|
11 |
+
|
12 |
+
## Released Models
|
13 |
+
|
14 |
+
We are releasing two pre-trained models:
|
15 |
+
|
16 |
+
| Model | Layers | Hidden Size | Attention Heads | Params | HuggingFace Model Name |
|
17 |
+
| --- | --- | --- | --- | --- | --- |
|
18 |
+
| AraELECTRA-base-discriminator | 12 | 12 | 768 | 136M | [araelectra-base-discriminator](https://huggingface.co/aubmindlab/araelectra-base-discriminator) |
|
19 |
+
| AraELECTRA-base-generator | 12 |4 | 256 | 60M | [araelectra-base-generator](https://huggingface.co/aubmindlab/araelectra-base-generator)
|
20 |
+
|
21 |
+
## Results
|
22 |
+
|
23 |
+
Model | TyDiQA (EM - F1 ) | ARCD (EM - F1 ) |
|
24 |
+
|:----|:----:|:----:|
|
25 |
+
AraBERTv0.1| 68.51 - 82.86 | 31.62 - 67.45 |
|
26 |
+
AraBERTv1| 61.11 - 79.36 | 31.7 - 67.8 |
|
27 |
+
AraBERTv0.2-base| 73.07 - 85.41| 32.76 - 66.53|
|
28 |
+
AraBERTv2-base| 61.67 - 81.66| 31.34 - 67.23 |
|
29 |
+
AraBERTv0.2-large| 73.72 - 86.03| 36.89 - **71.32** |
|
30 |
+
AraBERTv2-large| 64.49 - 82.51 | 34.19 - 68.12 |
|
31 |
+
ArabicBERT-base| 67.42 - 81.24| 30.48 - 62.24 |
|
32 |
+
ArabicBERT-large| 70.03 - 84.12| 33.33 - 67.27 |
|
33 |
+
Arabic-ALBERT-base| 67.10 - 80.98| 30.91 - 61.33 |
|
34 |
+
Arabic-ALBERT-large| 68.07 - 81.59| 34.19 - 65.41 |
|
35 |
+
Arabic-ALBERT-xlarge| 71.12 - 84.59| **37.75** - 68.03 |
|
36 |
+
AraELECTRA| **74.91 - 86.68**| 37.03 - 71.22 |
|
37 |
+
|
38 |
+
## Requirements
|
39 |
+
* Python 3
|
40 |
+
* [TensorFlow](https://www.tensorflow.org/) 1.15 (although we hope to support TensorFlow 2.0 at a future date)
|
41 |
+
* [NumPy](https://numpy.org/)
|
42 |
+
* [scikit-learn](https://scikit-learn.org/stable/) and [SciPy](https://www.scipy.org/) (for computing some evaluation metrics).
|
43 |
+
|
44 |
+
## Pre-training
|
45 |
+
Use `build_pretraining_dataset.py` or `build_arabert_pretraining_data.py` to create a pre-training dataset from a dump of raw text. It has the following arguments:
|
46 |
+
|
47 |
+
* `--corpus-dir`: A directory containing raw text files to turn into ELECTRA examples. A text file can contain multiple documents with empty lines separating them.
|
48 |
+
* `--vocab-file`: File defining the wordpiece vocabulary.
|
49 |
+
* `--output-dir`: Where to write out ELECTRA examples.
|
50 |
+
* `--max-seq-length`: The number of tokens per example (128 by default).
|
51 |
+
* `--num-processes`: If >1 parallelize across multiple processes (1 by default).
|
52 |
+
* `--blanks-separate-docs`: Whether blank lines indicate document boundaries (True by default).
|
53 |
+
* `--do-lower-case/--no-lower-case`: Whether to lower case the input text (True by default).
|
54 |
+
|
55 |
+
Use `run_pretraining.py` to pre-train an ELECTRA model. It has the following arguments:
|
56 |
+
|
57 |
+
* `--data-dir`: a directory where pre-training data, model weights, etc. are stored. By default, the training loads examples from `<data-dir>/pretrain_tfrecords` and a vocabulary from `<data-dir>/vocab.txt`.
|
58 |
+
* `--model-name`: a name for the model being trained. Model weights will be saved in `<data-dir>/models/<model-name>` by default.
|
59 |
+
* `--hparams` (optional): a JSON dict or path to a JSON file containing model hyperparameters, data paths, etc. See `configure_pretraining.py` for the supported hyperparameters.
|
60 |
+
|
61 |
+
If training is halted, re-running the `run_pretraining.py` with the same arguments will continue the training where it left off.
|
62 |
+
|
63 |
+
You can continue pre-training from the released ELECTRA checkpoints by
|
64 |
+
1. Setting the model-name to point to a downloaded model (e.g., `--model-name electra_small` if you downloaded weights to `$DATA_DIR/electra_small`).
|
65 |
+
2. Setting `num_train_steps` by (for example) adding `"num_train_steps": 4010000` to the `--hparams`. This will continue training the small model for 10000 more steps (it has already been trained for 4e6 steps).
|
66 |
+
3. Increase the learning rate to account for the linear learning rate decay. For example, to start with a learning rate of 2e-4 you should set the `learning_rate` hparam to 2e-4 * (4e6 + 10000) / 10000.
|
67 |
+
4. For ELECTRA-Small, you also need to specifiy `"generator_hidden_size": 1.0` in the `hparams` because we did not use a small generator for that model.
|
68 |
+
|
69 |
+
#### Evaluating the pre-trained model.
|
70 |
+
|
71 |
+
To evaluate the model on a downstream task, see the below finetuning instructions. To evaluate the generator/discriminator on the openwebtext data run `python3 run_pretraining.py --data-dir $DATA_DIR --model-name electra_small_owt --hparams '{"do_train": false, "do_eval": true}'`. This will print out eval metrics such as the accuracy of the generator and discriminator, and also writing the metrics out to `data-dir/model-name/results`.
|
72 |
+
|
73 |
+
## Fine-tuning
|
74 |
+
|
75 |
+
Use `run_finetuning.py` to fine-tune and evaluate an ELECTRA model on a downstream NLP task. It expects three arguments:
|
76 |
+
|
77 |
+
* `--data-dir`: a directory where data, model weights, etc. are stored. By default, the script loads finetuning data from `<data-dir>/finetuning_data/<task-name>` and a vocabulary from `<data-dir>/vocab.txt`.
|
78 |
+
* `--model-name`: a name of the pre-trained model: the pre-trained weights should exist in `data-dir/models/model-name`.
|
79 |
+
* `--hparams`: a JSON dict containing model hyperparameters, data paths, etc. (e.g., `--hparams '{"task_names": ["rte"], "model_size": "base", "learning_rate": 1e-4, ...}'`). See `configure_pretraining.py` for the supported hyperparameters. Instead of a dict, this can also be a path to a `.json` file containing the hyperparameters. You must specify the `"task_names"` and `"model_size"` (see examples below).
|
80 |
+
|
81 |
+
Eval metrics will be saved in `data-dir/model-name/results` and model weights will be saved in `data-dir/model-name/finetuning_models` by default. Evaluation is done on the dev set by default. To customize the training, add `--hparams '{"hparam1": value1, "hparam2": value2, ...}'` to the run command. Some particularly useful options:
|
82 |
+
|
83 |
+
* `"debug": true` fine-tunes a tiny ELECTRA model for a few steps.
|
84 |
+
* `"task_names": ["task_name"]`: specifies the tasks to train on. A list because the codebase nominally supports multi-task learning, (although be warned this has not been thoroughly tested).
|
85 |
+
* `"model_size": one of "small", "base", or "large"`: determines the size of the model; you must set this to the same size as the pre-trained model.
|
86 |
+
* `"do_train" and "do_eval"`: train and/or evaluate a model (both are set to true by default). For using `"do_eval": true` with `"do_train": false`, you need to specify the `init_checkpoint`, e.g., `python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["mnli"], "do_train": false, "do_eval": true, "init_checkpoint": "<data-dir>/models/electra_base/finetuning_models/mnli_model_1"}'`
|
87 |
+
* `"num_trials": n`: If >1, does multiple fine-tuning/evaluation runs with different random seeds.
|
88 |
+
* `"learning_rate": lr, "train_batch_size": n`, etc. can be used to change training hyperparameters.
|
89 |
+
* `"model_hparam_overrides": {"hidden_size": n, "num_hidden_layers": m}`, etc. can be used to changed the hyperparameters for the underlying transformer (the `"model_size"` flag sets the default values).
|
90 |
+
|
91 |
+
### Setup
|
92 |
+
Get a pre-trained ELECTRA model either by training your own (see pre-training instructions above), or downloading the release ELECTRA weights and unziping them under `$DATA_DIR/models` (e.g., you should have a directory`$DATA_DIR/models/electra_large` if you are using the large model).
|
93 |
+
|
94 |
+
|
95 |
+
### Finetune ELECTRA on question answering
|
96 |
+
|
97 |
+
The code supports [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) 1.1 and 2.0, as well as datasets in [the 2019 MRQA shared task](https://github.com/mrqa/MRQA-Shared-Task-2019)
|
98 |
+
|
99 |
+
* **ARCD**: Download the train/dev datasets from `https://github.com/husseinmozannar/SOQAL` move them under `$DATA_DIR/finetuning_data/squadv1/(train|dev).json`
|
100 |
+
|
101 |
+
Then run (for example)
|
102 |
+
```
|
103 |
+
python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["squad"]}'
|
104 |
+
```
|
105 |
+
|
106 |
+
This repository uses the official evaluation code released by the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) authors
|
107 |
+
|
108 |
+
or you can use the `transformers` library as shown in the notebook `ARCD_pytorch.ipynb` of `Tydiqa_ar_pytorch.ipynb` from the examples folder
|
109 |
+
|
110 |
+
### Finetune ELECTRA on sequence tagging
|
111 |
+
|
112 |
+
Download the CoNLL-2000 text chunking dataset from [here](https://www.clips.uantwerpen.be/conll2000/chunking/) and put it under `$DATA_DIR/finetuning_data/chunk/(train|dev).txt`. Then run
|
113 |
+
```
|
114 |
+
python3 run_finetuning.py --data-dir $DATA_DIR --model-name electra_base --hparams '{"model_size": "base", "task_names": ["chunk"]}'
|
115 |
+
```
|
116 |
+
|
117 |
+
### Adding a new task
|
118 |
+
The easiest way to run on a new task is to implement a new `finetune.task.Task`, add it to `finetune.task_builder.py`, and then use `run_finetuning.py` as normal. For classification/qa/sequence tagging, you can inherit from a `finetune.classification.classification_tasks.ClassificationTask`, `finetune.qa.qa_tasks.QATask`, or `finetune.tagging.tagging_tasks.TaggingTask`.
|
119 |
+
For preprocessing data, we use the same tokenizer as [BERT](https://github.com/google-research/bert).
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
## Citation
|
125 |
+
|
126 |
+
## If you used this model please cite us as:
|
127 |
+
```
|
128 |
+
@inproceedings{antoun-etal-2021-araelectra,
|
129 |
+
title = "{A}ra{ELECTRA}: Pre-Training Text Discriminators for {A}rabic Language Understanding",
|
130 |
+
author = "Antoun, Wissam and
|
131 |
+
Baly, Fady and
|
132 |
+
Hajj, Hazem",
|
133 |
+
booktitle = "Proceedings of the Sixth Arabic Natural Language Processing Workshop",
|
134 |
+
month = apr,
|
135 |
+
year = "2021",
|
136 |
+
address = "Kyiv, Ukraine (Virtual)",
|
137 |
+
publisher = "Association for Computational Linguistics",
|
138 |
+
url = "https://www.aclweb.org/anthology/2021.wanlp-1.20",
|
139 |
+
pages = "191--195",
|
140 |
+
}
|
141 |
+
```
|
142 |
+
|
143 |
+
|
144 |
+
|
arabert/araelectra/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# coding=utf-8
|
arabert/araelectra/build_openwebtext_pretraining_dataset.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Preprocessess the Open WebText corpus for ELECTRA pre-training."""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import multiprocessing
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import tarfile
|
23 |
+
import time
|
24 |
+
import tensorflow as tf
|
25 |
+
|
26 |
+
import build_pretraining_dataset
|
27 |
+
from util import utils
|
28 |
+
|
29 |
+
|
30 |
+
def write_examples(job_id, args):
|
31 |
+
"""A single process creating and writing out pre-processed examples."""
|
32 |
+
job_tmp_dir = os.path.join(args.data_dir, "tmp", "job_" + str(job_id))
|
33 |
+
owt_dir = os.path.join(args.data_dir, "openwebtext")
|
34 |
+
|
35 |
+
def log(*args):
|
36 |
+
msg = " ".join(map(str, args))
|
37 |
+
print("Job {}:".format(job_id), msg)
|
38 |
+
|
39 |
+
log("Creating example writer")
|
40 |
+
example_writer = build_pretraining_dataset.ExampleWriter(
|
41 |
+
job_id=job_id,
|
42 |
+
vocab_file=os.path.join(args.data_dir, "vocab.txt"),
|
43 |
+
output_dir=os.path.join(args.data_dir, "pretrain_tfrecords"),
|
44 |
+
max_seq_length=args.max_seq_length,
|
45 |
+
num_jobs=args.num_processes,
|
46 |
+
blanks_separate_docs=False,
|
47 |
+
do_lower_case=args.do_lower_case
|
48 |
+
)
|
49 |
+
log("Writing tf examples")
|
50 |
+
fnames = sorted(tf.io.gfile.listdir(owt_dir))
|
51 |
+
fnames = [f for (i, f) in enumerate(fnames)
|
52 |
+
if i % args.num_processes == job_id]
|
53 |
+
random.shuffle(fnames)
|
54 |
+
start_time = time.time()
|
55 |
+
for file_no, fname in enumerate(fnames):
|
56 |
+
if file_no > 0 and file_no % 10 == 0:
|
57 |
+
elapsed = time.time() - start_time
|
58 |
+
log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, "
|
59 |
+
"{:} examples written".format(
|
60 |
+
file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed),
|
61 |
+
int((len(fnames) - file_no) / (file_no / elapsed)),
|
62 |
+
example_writer.n_written))
|
63 |
+
utils.rmkdir(job_tmp_dir)
|
64 |
+
with tarfile.open(os.path.join(owt_dir, fname)) as f:
|
65 |
+
f.extractall(job_tmp_dir)
|
66 |
+
extracted_files = tf.io.gfile.listdir(job_tmp_dir)
|
67 |
+
random.shuffle(extracted_files)
|
68 |
+
for txt_fname in extracted_files:
|
69 |
+
example_writer.write_examples(os.path.join(job_tmp_dir, txt_fname))
|
70 |
+
example_writer.finish()
|
71 |
+
log("Done!")
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
76 |
+
parser.add_argument("--data-dir", required=True,
|
77 |
+
help="Location of data (vocab file, corpus, etc).")
|
78 |
+
parser.add_argument("--max-seq-length", default=128, type=int,
|
79 |
+
help="Number of tokens per example.")
|
80 |
+
parser.add_argument("--num-processes", default=1, type=int,
|
81 |
+
help="Parallelize across multiple processes.")
|
82 |
+
parser.add_argument("--do-lower-case", dest='do_lower_case',
|
83 |
+
action='store_true', help="Lower case input text.")
|
84 |
+
parser.add_argument("--no-lower-case", dest='do_lower_case',
|
85 |
+
action='store_false', help="Don't lower case input text.")
|
86 |
+
parser.set_defaults(do_lower_case=True)
|
87 |
+
args = parser.parse_args()
|
88 |
+
|
89 |
+
utils.rmkdir(os.path.join(args.data_dir, "pretrain_tfrecords"))
|
90 |
+
if args.num_processes == 1:
|
91 |
+
write_examples(0, args)
|
92 |
+
else:
|
93 |
+
jobs = []
|
94 |
+
for i in range(args.num_processes):
|
95 |
+
job = multiprocessing.Process(target=write_examples, args=(i, args))
|
96 |
+
jobs.append(job)
|
97 |
+
job.start()
|
98 |
+
for job in jobs:
|
99 |
+
job.join()
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
main()
|
arabert/araelectra/build_pretraining_dataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Writes out text data as tfrecords that ELECTRA can be pre-trained on."""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import multiprocessing
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
import time
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
from model import tokenization
|
26 |
+
from util import utils
|
27 |
+
|
28 |
+
|
29 |
+
def create_int_feature(values):
|
30 |
+
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
31 |
+
return feature
|
32 |
+
|
33 |
+
|
34 |
+
class ExampleBuilder(object):
|
35 |
+
"""Given a stream of input text, creates pretraining examples."""
|
36 |
+
|
37 |
+
def __init__(self, tokenizer, max_length):
|
38 |
+
self._tokenizer = tokenizer
|
39 |
+
self._current_sentences = []
|
40 |
+
self._current_length = 0
|
41 |
+
self._max_length = max_length
|
42 |
+
self._target_length = max_length
|
43 |
+
|
44 |
+
def add_line(self, line):
|
45 |
+
"""Adds a line of text to the current example being built."""
|
46 |
+
line = line.strip().replace("\n", " ")
|
47 |
+
if (not line) and self._current_length != 0: # empty lines separate docs
|
48 |
+
return self._create_example()
|
49 |
+
bert_tokens = self._tokenizer.tokenize(line)
|
50 |
+
bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens)
|
51 |
+
self._current_sentences.append(bert_tokids)
|
52 |
+
self._current_length += len(bert_tokids)
|
53 |
+
if self._current_length >= self._target_length:
|
54 |
+
return self._create_example()
|
55 |
+
return None
|
56 |
+
|
57 |
+
def _create_example(self):
|
58 |
+
"""Creates a pre-training example from the current list of sentences."""
|
59 |
+
# small chance to only have one segment as in classification tasks
|
60 |
+
if random.random() < 0.1:
|
61 |
+
first_segment_target_length = 100000
|
62 |
+
else:
|
63 |
+
# -3 due to not yet having [CLS]/[SEP] tokens in the input text
|
64 |
+
first_segment_target_length = (self._target_length - 3) // 2
|
65 |
+
|
66 |
+
first_segment = []
|
67 |
+
second_segment = []
|
68 |
+
for sentence in self._current_sentences:
|
69 |
+
# the sentence goes to the first segment if (1) the first segment is
|
70 |
+
# empty, (2) the sentence doesn't put the first segment over length or
|
71 |
+
# (3) 50% of the time when it does put the first segment over length
|
72 |
+
if (len(first_segment) == 0 or
|
73 |
+
len(first_segment) + len(sentence) < first_segment_target_length or
|
74 |
+
(len(second_segment) == 0 and
|
75 |
+
len(first_segment) < first_segment_target_length and
|
76 |
+
random.random() < 0.5)):
|
77 |
+
first_segment += sentence
|
78 |
+
else:
|
79 |
+
second_segment += sentence
|
80 |
+
|
81 |
+
# trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens
|
82 |
+
first_segment = first_segment[:self._max_length - 2]
|
83 |
+
second_segment = second_segment[:max(0, self._max_length -
|
84 |
+
len(first_segment) - 3)]
|
85 |
+
|
86 |
+
# prepare to start building the next example
|
87 |
+
self._current_sentences = []
|
88 |
+
self._current_length = 0
|
89 |
+
# small chance for random-length instead of max_length-length example
|
90 |
+
if random.random() < 0.05:
|
91 |
+
self._target_length = random.randint(5, self._max_length)
|
92 |
+
else:
|
93 |
+
self._target_length = self._max_length
|
94 |
+
|
95 |
+
return self._make_tf_example(first_segment, second_segment)
|
96 |
+
|
97 |
+
def _make_tf_example(self, first_segment, second_segment):
|
98 |
+
"""Converts two "segments" of text into a tf.train.Example."""
|
99 |
+
vocab = self._tokenizer.vocab
|
100 |
+
input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]]
|
101 |
+
segment_ids = [0] * len(input_ids)
|
102 |
+
if second_segment:
|
103 |
+
input_ids += second_segment + [vocab["[SEP]"]]
|
104 |
+
segment_ids += [1] * (len(second_segment) + 1)
|
105 |
+
input_mask = [1] * len(input_ids)
|
106 |
+
input_ids += [0] * (self._max_length - len(input_ids))
|
107 |
+
input_mask += [0] * (self._max_length - len(input_mask))
|
108 |
+
segment_ids += [0] * (self._max_length - len(segment_ids))
|
109 |
+
tf_example = tf.train.Example(features=tf.train.Features(feature={
|
110 |
+
"input_ids": create_int_feature(input_ids),
|
111 |
+
"input_mask": create_int_feature(input_mask),
|
112 |
+
"segment_ids": create_int_feature(segment_ids)
|
113 |
+
}))
|
114 |
+
return tf_example
|
115 |
+
|
116 |
+
|
117 |
+
class ExampleWriter(object):
|
118 |
+
"""Writes pre-training examples to disk."""
|
119 |
+
|
120 |
+
def __init__(self, job_id, vocab_file, output_dir, max_seq_length,
|
121 |
+
num_jobs, blanks_separate_docs, do_lower_case,
|
122 |
+
num_out_files=1000):
|
123 |
+
self._blanks_separate_docs = blanks_separate_docs
|
124 |
+
tokenizer = tokenization.FullTokenizer(
|
125 |
+
vocab_file=vocab_file,
|
126 |
+
do_lower_case=do_lower_case)
|
127 |
+
self._example_builder = ExampleBuilder(tokenizer, max_seq_length)
|
128 |
+
self._writers = []
|
129 |
+
for i in range(num_out_files):
|
130 |
+
if i % num_jobs == job_id:
|
131 |
+
output_fname = os.path.join(
|
132 |
+
output_dir, "pretrain_data.tfrecord-{:}-of-{:}".format(
|
133 |
+
i, num_out_files))
|
134 |
+
self._writers.append(tf.io.TFRecordWriter(output_fname))
|
135 |
+
self.n_written = 0
|
136 |
+
|
137 |
+
def write_examples(self, input_file):
|
138 |
+
"""Writes out examples from the provided input file."""
|
139 |
+
with tf.io.gfile.GFile(input_file) as f:
|
140 |
+
for line in f:
|
141 |
+
line = line.strip()
|
142 |
+
if line or self._blanks_separate_docs:
|
143 |
+
example = self._example_builder.add_line(line)
|
144 |
+
if example:
|
145 |
+
self._writers[self.n_written % len(self._writers)].write(
|
146 |
+
example.SerializeToString())
|
147 |
+
self.n_written += 1
|
148 |
+
example = self._example_builder.add_line("")
|
149 |
+
if example:
|
150 |
+
self._writers[self.n_written % len(self._writers)].write(
|
151 |
+
example.SerializeToString())
|
152 |
+
self.n_written += 1
|
153 |
+
|
154 |
+
def finish(self):
|
155 |
+
for writer in self._writers:
|
156 |
+
writer.close()
|
157 |
+
|
158 |
+
|
159 |
+
def write_examples(job_id, args):
|
160 |
+
"""A single process creating and writing out pre-processed examples."""
|
161 |
+
|
162 |
+
def log(*args):
|
163 |
+
msg = " ".join(map(str, args))
|
164 |
+
print("Job {}:".format(job_id), msg)
|
165 |
+
|
166 |
+
log("Creating example writer")
|
167 |
+
example_writer = ExampleWriter(
|
168 |
+
job_id=job_id,
|
169 |
+
vocab_file=args.vocab_file,
|
170 |
+
output_dir=args.output_dir,
|
171 |
+
max_seq_length=args.max_seq_length,
|
172 |
+
num_jobs=args.num_processes,
|
173 |
+
blanks_separate_docs=args.blanks_separate_docs,
|
174 |
+
do_lower_case=args.do_lower_case
|
175 |
+
)
|
176 |
+
log("Writing tf examples")
|
177 |
+
fnames = sorted(tf.io.gfile.listdir(args.corpus_dir))
|
178 |
+
fnames = [f for (i, f) in enumerate(fnames)
|
179 |
+
if i % args.num_processes == job_id]
|
180 |
+
random.shuffle(fnames)
|
181 |
+
start_time = time.time()
|
182 |
+
for file_no, fname in enumerate(fnames):
|
183 |
+
if file_no > 0:
|
184 |
+
elapsed = time.time() - start_time
|
185 |
+
log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, "
|
186 |
+
"{:} examples written".format(
|
187 |
+
file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed),
|
188 |
+
int((len(fnames) - file_no) / (file_no / elapsed)),
|
189 |
+
example_writer.n_written))
|
190 |
+
example_writer.write_examples(os.path.join(args.corpus_dir, fname))
|
191 |
+
example_writer.finish()
|
192 |
+
log("Done!")
|
193 |
+
|
194 |
+
|
195 |
+
def main():
|
196 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
197 |
+
parser.add_argument("--corpus-dir", required=True,
|
198 |
+
help="Location of pre-training text files.")
|
199 |
+
parser.add_argument("--vocab-file", required=True,
|
200 |
+
help="Location of vocabulary file.")
|
201 |
+
parser.add_argument("--output-dir", required=True,
|
202 |
+
help="Where to write out the tfrecords.")
|
203 |
+
parser.add_argument("--max-seq-length", default=128, type=int,
|
204 |
+
help="Number of tokens per example.")
|
205 |
+
parser.add_argument("--num-processes", default=1, type=int,
|
206 |
+
help="Parallelize across multiple processes.")
|
207 |
+
parser.add_argument("--blanks-separate-docs", default=True, type=bool,
|
208 |
+
help="Whether blank lines indicate document boundaries.")
|
209 |
+
parser.add_argument("--do-lower-case", dest='do_lower_case',
|
210 |
+
action='store_true', help="Lower case input text.")
|
211 |
+
parser.add_argument("--no-lower-case", dest='do_lower_case',
|
212 |
+
action='store_false', help="Don't lower case input text.")
|
213 |
+
parser.set_defaults(do_lower_case=True)
|
214 |
+
args = parser.parse_args()
|
215 |
+
|
216 |
+
utils.rmkdir(args.output_dir)
|
217 |
+
if args.num_processes == 1:
|
218 |
+
write_examples(0, args)
|
219 |
+
else:
|
220 |
+
jobs = []
|
221 |
+
for i in range(args.num_processes):
|
222 |
+
job = multiprocessing.Process(target=write_examples, args=(i, args))
|
223 |
+
jobs.append(job)
|
224 |
+
job.start()
|
225 |
+
for job in jobs:
|
226 |
+
job.join()
|
227 |
+
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
main()
|
arabert/araelectra/build_pretraining_dataset_single_file.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import tensorflow as tf
|
6 |
+
|
7 |
+
import build_pretraining_dataset
|
8 |
+
from model import tokenization
|
9 |
+
|
10 |
+
class ExampleWriter(object):
|
11 |
+
"""Writes pre-training examples to disk."""
|
12 |
+
|
13 |
+
def __init__(self, input_fname, vocab_file, output_dir, max_seq_length,
|
14 |
+
blanks_separate_docs, do_lower_case):
|
15 |
+
self._blanks_separate_docs = blanks_separate_docs
|
16 |
+
tokenizer = tokenization.FullTokenizer(
|
17 |
+
vocab_file=vocab_file,
|
18 |
+
do_lower_case=do_lower_case)
|
19 |
+
self._example_builder = build_pretraining_dataset.ExampleBuilder(tokenizer, max_seq_length)
|
20 |
+
output_fname = os.path.join(output_dir, "{}.tfrecord".format(input_fname.split("/")[-1]))
|
21 |
+
self._writer = tf.io.TFRecordWriter(output_fname)
|
22 |
+
self.n_written = 0
|
23 |
+
|
24 |
+
def write_examples(self, input_file):
|
25 |
+
"""Writes out examples from the provided input file."""
|
26 |
+
with tf.io.gfile.GFile(input_file) as f:
|
27 |
+
for line in f:
|
28 |
+
line = line.strip()
|
29 |
+
if line or self._blanks_separate_docs:
|
30 |
+
example = self._example_builder.add_line(line)
|
31 |
+
if example:
|
32 |
+
self._writer.write(example.SerializeToString())
|
33 |
+
self.n_written += 1
|
34 |
+
example = self._example_builder.add_line("")
|
35 |
+
if example:
|
36 |
+
self._writer.write(example.SerializeToString())
|
37 |
+
self.n_written += 1
|
38 |
+
|
39 |
+
def finish(self):
|
40 |
+
self._writer.close()
|
41 |
+
|
42 |
+
def write_examples(args):
|
43 |
+
"""A single process creating and writing out pre-processed examples."""
|
44 |
+
|
45 |
+
def log(*args):
|
46 |
+
msg = " ".join(map(str, args))
|
47 |
+
print(msg)
|
48 |
+
|
49 |
+
log("Creating example writer")
|
50 |
+
example_writer = ExampleWriter(
|
51 |
+
input_fname=args.input_file,
|
52 |
+
vocab_file=args.vocab_file,
|
53 |
+
output_dir=args.output_dir,
|
54 |
+
max_seq_length=args.max_seq_length,
|
55 |
+
blanks_separate_docs=args.blanks_separate_docs,
|
56 |
+
do_lower_case=args.do_lower_case
|
57 |
+
)
|
58 |
+
log("Writing tf example")
|
59 |
+
|
60 |
+
example_writer.write_examples(args.input_file)
|
61 |
+
example_writer.finish()
|
62 |
+
log("Done!")
|
63 |
+
return
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
68 |
+
parser.add_argument("--input-file", required=True,
|
69 |
+
help="Location of pre-training text files.")
|
70 |
+
parser.add_argument("--vocab-file", required=True,
|
71 |
+
help="Location of vocabulary file.")
|
72 |
+
parser.add_argument("--output-dir", required=True,
|
73 |
+
help="Where to write out the tfrecords.")
|
74 |
+
parser.add_argument("--max-seq-length", default=128, type=int,
|
75 |
+
help="Number of tokens per example.")
|
76 |
+
parser.add_argument("--blanks-separate-docs", default=True, type=bool,
|
77 |
+
help="Whether blank lines indicate document boundaries.")
|
78 |
+
parser.add_argument("--do-lower-case", dest='do_lower_case',
|
79 |
+
action='store_true', help="Lower case input text.")
|
80 |
+
parser.add_argument("--no-lower-case", dest='do_lower_case',
|
81 |
+
action='store_false', help="Don't lower case input text.")
|
82 |
+
parser.set_defaults(do_lower_case=True)
|
83 |
+
args = parser.parse_args()
|
84 |
+
|
85 |
+
write_examples(args)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
main()
|
arabert/araelectra/configure_finetuning.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Config controlling hyperparameters for fine-tuning ELECTRA."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
|
24 |
+
import tensorflow as tf
|
25 |
+
|
26 |
+
|
27 |
+
class FinetuningConfig(object):
|
28 |
+
"""Fine-tuning hyperparameters."""
|
29 |
+
|
30 |
+
def __init__(self, model_name, data_dir, **kwargs):
|
31 |
+
# general
|
32 |
+
self.model_name = model_name
|
33 |
+
self.debug = False # debug mode for quickly running things
|
34 |
+
self.log_examples = False # print out some train examples for debugging
|
35 |
+
self.num_trials = 1 # how many train+eval runs to perform
|
36 |
+
self.do_train = True # train a model
|
37 |
+
self.do_eval = True # evaluate the model
|
38 |
+
self.keep_all_models = True # if False, only keep the last trial's ckpt
|
39 |
+
|
40 |
+
# model
|
41 |
+
self.model_size = "base" # one of "small", "base", or "large"
|
42 |
+
self.task_names = ["chunk"] # which tasks to learn
|
43 |
+
# override the default transformer hparams for the provided model size; see
|
44 |
+
# modeling.BertConfig for the possible hparams and util.training_utils for
|
45 |
+
# the defaults
|
46 |
+
self.model_hparam_overrides = (
|
47 |
+
kwargs["model_hparam_overrides"]
|
48 |
+
if "model_hparam_overrides" in kwargs else {})
|
49 |
+
self.embedding_size = None # bert hidden size by default
|
50 |
+
self.vocab_size = 64000 # number of tokens in the vocabulary
|
51 |
+
self.do_lower_case = True
|
52 |
+
|
53 |
+
# training
|
54 |
+
self.learning_rate = 1e-4
|
55 |
+
self.weight_decay_rate = 0.01
|
56 |
+
self.layerwise_lr_decay = 0.8 # if > 0, the learning rate for a layer is
|
57 |
+
# lr * lr_decay^(depth - max_depth) i.e.,
|
58 |
+
# shallower layers have lower learning rates
|
59 |
+
self.num_train_epochs = 3.0 # passes over the dataset during training
|
60 |
+
self.warmup_proportion = 0.1 # how much of training to warm up the LR for
|
61 |
+
self.save_checkpoints_steps = 1000000
|
62 |
+
self.iterations_per_loop = 1000
|
63 |
+
self.use_tfrecords_if_existing = True # don't make tfrecords and write them
|
64 |
+
# to disc if existing ones are found
|
65 |
+
|
66 |
+
# writing model outputs to disc
|
67 |
+
self.write_test_outputs = False # whether to write test set outputs,
|
68 |
+
# currently supported for GLUE + SQuAD 2.0
|
69 |
+
self.n_writes_test = 5 # write test set predictions for the first n trials
|
70 |
+
|
71 |
+
# sizing
|
72 |
+
self.max_seq_length = 128
|
73 |
+
self.train_batch_size = 32
|
74 |
+
self.eval_batch_size = 32
|
75 |
+
self.predict_batch_size = 32
|
76 |
+
self.double_unordered = True # for tasks like paraphrase where sentence
|
77 |
+
# order doesn't matter, train the model on
|
78 |
+
# on both sentence orderings for each example
|
79 |
+
# for qa tasks
|
80 |
+
self.max_query_length = 64 # max tokens in q as opposed to context
|
81 |
+
self.doc_stride = 128 # stride when splitting doc into multiple examples
|
82 |
+
self.n_best_size = 20 # number of predictions per example to save
|
83 |
+
self.max_answer_length = 30 # filter out answers longer than this length
|
84 |
+
self.answerable_classifier = True # answerable classifier for SQuAD 2.0
|
85 |
+
self.answerable_uses_start_logits = True # more advanced answerable
|
86 |
+
# classifier using predicted start
|
87 |
+
self.answerable_weight = 0.5 # weight for answerability loss
|
88 |
+
self.joint_prediction = True # jointly predict the start and end positions
|
89 |
+
# of the answer span
|
90 |
+
self.beam_size = 20 # beam size when doing joint predictions
|
91 |
+
self.qa_na_threshold = -2.75 # threshold for "no answer" when writing SQuAD
|
92 |
+
# 2.0 test outputs
|
93 |
+
|
94 |
+
# TPU settings
|
95 |
+
self.use_tpu = False
|
96 |
+
self.num_tpu_cores = 1
|
97 |
+
self.tpu_job_name = None
|
98 |
+
self.tpu_name = None # cloud TPU to use for training
|
99 |
+
self.tpu_zone = None # GCE zone where the Cloud TPU is located in
|
100 |
+
self.gcp_project = None # project name for the Cloud TPU-enabled project
|
101 |
+
|
102 |
+
# default locations of data files
|
103 |
+
self.data_dir = data_dir
|
104 |
+
pretrained_model_dir = os.path.join(data_dir, "models", model_name)
|
105 |
+
self.raw_data_dir = os.path.join(data_dir, "finetuning_data", "{:}").format
|
106 |
+
self.vocab_file = os.path.join(pretrained_model_dir, "vocab.txt")
|
107 |
+
if not tf.io.gfile.exists(self.vocab_file):
|
108 |
+
self.vocab_file = os.path.join(self.data_dir, "vocab.txt")
|
109 |
+
task_names_str = ",".join(
|
110 |
+
kwargs["task_names"] if "task_names" in kwargs else self.task_names)
|
111 |
+
self.init_checkpoint = None if self.debug else pretrained_model_dir
|
112 |
+
self.model_dir = os.path.join(pretrained_model_dir, "finetuning_models",
|
113 |
+
task_names_str + "_model")
|
114 |
+
results_dir = os.path.join(pretrained_model_dir, "results")
|
115 |
+
self.results_txt = os.path.join(results_dir,
|
116 |
+
task_names_str + "_results.txt")
|
117 |
+
self.results_pkl = os.path.join(results_dir,
|
118 |
+
task_names_str + "_results.pkl")
|
119 |
+
qa_topdir = os.path.join(results_dir, task_names_str + "_qa")
|
120 |
+
self.qa_eval_file = os.path.join(qa_topdir, "{:}_eval.json").format
|
121 |
+
self.qa_preds_file = os.path.join(qa_topdir, "{:}_preds.json").format
|
122 |
+
self.qa_na_file = os.path.join(qa_topdir, "{:}_null_odds.json").format
|
123 |
+
self.preprocessed_data_dir = os.path.join(
|
124 |
+
pretrained_model_dir, "finetuning_tfrecords",
|
125 |
+
task_names_str + "_tfrecords" + ("-debug" if self.debug else ""))
|
126 |
+
self.test_predictions = os.path.join(
|
127 |
+
pretrained_model_dir, "test_predictions",
|
128 |
+
"{:}_{:}_{:}_predictions.pkl").format
|
129 |
+
|
130 |
+
# update defaults with passed-in hyperparameters
|
131 |
+
self.update(kwargs)
|
132 |
+
|
133 |
+
# default hyperparameters for single-task models
|
134 |
+
if len(self.task_names) == 1:
|
135 |
+
task_name = self.task_names[0]
|
136 |
+
if task_name == "rte" or task_name == "sts":
|
137 |
+
self.num_train_epochs = 10.0
|
138 |
+
elif "squad" in task_name or "qa" in task_name:
|
139 |
+
self.max_seq_length = 512
|
140 |
+
self.num_train_epochs = 2.0
|
141 |
+
self.write_distill_outputs = False
|
142 |
+
self.write_test_outputs = False
|
143 |
+
elif task_name == "chunk":
|
144 |
+
self.max_seq_length = 256
|
145 |
+
else:
|
146 |
+
self.num_train_epochs = 3.0
|
147 |
+
|
148 |
+
# default hyperparameters for different model sizes
|
149 |
+
if self.model_size == "large":
|
150 |
+
self.learning_rate = 5e-5
|
151 |
+
self.layerwise_lr_decay = 0.9
|
152 |
+
elif self.model_size == "small":
|
153 |
+
self.embedding_size = 128
|
154 |
+
|
155 |
+
# debug-mode settings
|
156 |
+
if self.debug:
|
157 |
+
self.save_checkpoints_steps = 1000000
|
158 |
+
self.use_tfrecords_if_existing = False
|
159 |
+
self.num_trials = 1
|
160 |
+
self.iterations_per_loop = 1
|
161 |
+
self.train_batch_size = 32
|
162 |
+
self.num_train_epochs = 3.0
|
163 |
+
self.log_examples = True
|
164 |
+
|
165 |
+
# passed-in-arguments override (for example) debug-mode defaults
|
166 |
+
self.update(kwargs)
|
167 |
+
|
168 |
+
def update(self, kwargs):
|
169 |
+
for k, v in kwargs.items():
|
170 |
+
if k not in self.__dict__:
|
171 |
+
raise ValueError("Unknown hparam " + k)
|
172 |
+
self.__dict__[k] = v
|
arabert/araelectra/configure_pretraining.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Config controlling hyperparameters for pre-training ELECTRA."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
|
24 |
+
|
25 |
+
class PretrainingConfig(object):
|
26 |
+
"""Defines pre-training hyperparameters."""
|
27 |
+
|
28 |
+
def __init__(self, model_name, data_dir, **kwargs):
|
29 |
+
self.model_name = model_name
|
30 |
+
self.debug = False # debug mode for quickly running things
|
31 |
+
self.do_train = True # pre-train ELECTRA
|
32 |
+
self.do_eval = False # evaluate generator/discriminator on unlabeled data
|
33 |
+
|
34 |
+
# loss functions
|
35 |
+
# train ELECTRA or Electric? if both are false, trains a masked LM like BERT
|
36 |
+
self.electra_objective = True
|
37 |
+
self.electric_objective = False
|
38 |
+
self.gen_weight = 1.0 # masked language modeling / generator loss
|
39 |
+
self.disc_weight = 50.0 # discriminator loss
|
40 |
+
self.mask_prob = 0.15 # percent of input tokens to mask out / replace
|
41 |
+
|
42 |
+
# optimization
|
43 |
+
self.learning_rate = 2e-4
|
44 |
+
self.lr_decay_power = 1.0 # linear weight decay by default
|
45 |
+
self.weight_decay_rate = 0.01
|
46 |
+
self.num_warmup_steps = 10000
|
47 |
+
|
48 |
+
# training settings
|
49 |
+
self.iterations_per_loop = 5000
|
50 |
+
self.save_checkpoints_steps = 25000
|
51 |
+
self.num_train_steps = 2000000
|
52 |
+
self.num_eval_steps = 10000
|
53 |
+
self.keep_checkpoint_max = 0 # maximum number of recent checkpoint files to keep;
|
54 |
+
# change to 0 or None to keep all checkpoints
|
55 |
+
|
56 |
+
# model settings
|
57 |
+
self.model_size = "base" # one of "small", "base", or "large"
|
58 |
+
# override the default transformer hparams for the provided model size; see
|
59 |
+
# modeling.BertConfig for the possible hparams and util.training_utils for
|
60 |
+
# the defaults
|
61 |
+
self.model_hparam_overrides = (
|
62 |
+
kwargs["model_hparam_overrides"]
|
63 |
+
if "model_hparam_overrides" in kwargs else {})
|
64 |
+
self.embedding_size = None # bert hidden size by default
|
65 |
+
self.vocab_size = 64000 # number of tokens in the vocabulary
|
66 |
+
self.do_lower_case = False # lowercase the input?
|
67 |
+
|
68 |
+
# generator settings
|
69 |
+
self.uniform_generator = False # generator is uniform at random
|
70 |
+
self.two_tower_generator = False # generator is a two-tower cloze model
|
71 |
+
self.untied_generator_embeddings = False # tie generator/discriminator
|
72 |
+
# token embeddings?
|
73 |
+
self.untied_generator = True # tie all generator/discriminator weights?
|
74 |
+
self.generator_layers = 1.0 # frac of discriminator layers for generator
|
75 |
+
self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen
|
76 |
+
self.disallow_correct = False # force the generator to sample incorrect
|
77 |
+
# tokens (so 15% of tokens are always
|
78 |
+
# fake)
|
79 |
+
self.temperature = 1.0 # temperature for sampling from generator
|
80 |
+
|
81 |
+
# batch sizes
|
82 |
+
self.max_seq_length = 512
|
83 |
+
self.train_batch_size = 256
|
84 |
+
self.eval_batch_size = 256
|
85 |
+
|
86 |
+
# TPU settings
|
87 |
+
self.use_tpu = True
|
88 |
+
self.num_tpu_cores = 8
|
89 |
+
self.tpu_job_name = None
|
90 |
+
self.tpu_name = "" # cloud TPU to use for training
|
91 |
+
self.tpu_zone = "" # GCE zone where the Cloud TPU is located in
|
92 |
+
self.gcp_project = "" # project name for the Cloud TPU-enabled project
|
93 |
+
|
94 |
+
# default locations of data files
|
95 |
+
self.pretrain_tfrecords = os.path.join(
|
96 |
+
data_dir, "pretraining_data/512/*")
|
97 |
+
self.vocab_file = os.path.join(data_dir, "bertvocab_final.txt")
|
98 |
+
self.model_dir = os.path.join(data_dir, "models", model_name)
|
99 |
+
results_dir = os.path.join(self.model_dir, "results")
|
100 |
+
self.results_txt = os.path.join(results_dir, "unsup_results.txt")
|
101 |
+
self.results_pkl = os.path.join(results_dir, "unsup_results.pkl")
|
102 |
+
|
103 |
+
# update defaults with passed-in hyperparameters
|
104 |
+
self.update(kwargs)
|
105 |
+
|
106 |
+
self.max_predictions_per_seq = int((self.mask_prob + 0.005) *
|
107 |
+
self.max_seq_length)
|
108 |
+
|
109 |
+
# debug-mode settings
|
110 |
+
if self.debug:
|
111 |
+
self.train_batch_size = 8
|
112 |
+
self.num_train_steps = 20
|
113 |
+
self.eval_batch_size = 4
|
114 |
+
self.iterations_per_loop = 1
|
115 |
+
self.num_eval_steps = 2
|
116 |
+
|
117 |
+
# defaults for different-sized model
|
118 |
+
if self.model_size == "small":
|
119 |
+
self.embedding_size = 128
|
120 |
+
# Here are the hyperparameters we used for larger models; see Table 6 in the
|
121 |
+
# paper for the full hyperparameters
|
122 |
+
else:
|
123 |
+
self.max_seq_length = 512
|
124 |
+
self.learning_rate = 2e-4
|
125 |
+
if self.model_size == "base":
|
126 |
+
self.embedding_size = 768
|
127 |
+
self.generator_hidden_size = 0.33333
|
128 |
+
self.train_batch_size = 256
|
129 |
+
else:
|
130 |
+
self.embedding_size = 1024
|
131 |
+
self.mask_prob = 0.25
|
132 |
+
self.train_batch_size = 2048
|
133 |
+
if self.electric_objective:
|
134 |
+
self.two_tower_generator = True # electric requires a two-tower generator
|
135 |
+
|
136 |
+
# passed-in-arguments override (for example) debug-mode defaults
|
137 |
+
self.update(kwargs)
|
138 |
+
|
139 |
+
def update(self, kwargs):
|
140 |
+
for k, v in kwargs.items():
|
141 |
+
if k not in self.__dict__:
|
142 |
+
raise ValueError("Unknown hparam " + k)
|
143 |
+
self.__dict__[k] = v
|
arabert/araelectra/finetune/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
arabert/araelectra/finetune/classification/classification_metrics.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluation metrics for classification tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import numpy as np
|
24 |
+
import scipy
|
25 |
+
import sklearn
|
26 |
+
|
27 |
+
from finetune import scorer
|
28 |
+
|
29 |
+
|
30 |
+
class SentenceLevelScorer(scorer.Scorer):
|
31 |
+
"""Abstract scorer for classification/regression tasks."""
|
32 |
+
|
33 |
+
__metaclass__ = abc.ABCMeta
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
super(SentenceLevelScorer, self).__init__()
|
37 |
+
self._total_loss = 0
|
38 |
+
self._true_labels = []
|
39 |
+
self._preds = []
|
40 |
+
|
41 |
+
def update(self, results):
|
42 |
+
super(SentenceLevelScorer, self).update(results)
|
43 |
+
self._total_loss += results['loss']
|
44 |
+
self._true_labels.append(results['label_ids'] if 'label_ids' in results
|
45 |
+
else results['targets'])
|
46 |
+
self._preds.append(results['predictions'])
|
47 |
+
|
48 |
+
def get_loss(self):
|
49 |
+
return self._total_loss / len(self._true_labels)
|
50 |
+
|
51 |
+
|
52 |
+
class AccuracyScorer(SentenceLevelScorer):
|
53 |
+
|
54 |
+
def _get_results(self):
|
55 |
+
correct, count = 0, 0
|
56 |
+
for y_true, pred in zip(self._true_labels, self._preds):
|
57 |
+
count += 1
|
58 |
+
correct += (1 if y_true == pred else 0)
|
59 |
+
return [
|
60 |
+
('accuracy', 100.0 * correct / count),
|
61 |
+
('loss', self.get_loss()),
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
class F1Scorer(SentenceLevelScorer):
|
66 |
+
"""Computes F1 for classification tasks."""
|
67 |
+
|
68 |
+
def __init__(self):
|
69 |
+
super(F1Scorer, self).__init__()
|
70 |
+
self._positive_label = 1
|
71 |
+
|
72 |
+
def _get_results(self):
|
73 |
+
n_correct, n_predicted, n_gold = 0, 0, 0
|
74 |
+
for y_true, pred in zip(self._true_labels, self._preds):
|
75 |
+
if pred == self._positive_label:
|
76 |
+
n_gold += 1
|
77 |
+
if pred == self._positive_label:
|
78 |
+
n_predicted += 1
|
79 |
+
if pred == y_true:
|
80 |
+
n_correct += 1
|
81 |
+
if n_correct == 0:
|
82 |
+
p, r, f1 = 0, 0, 0
|
83 |
+
else:
|
84 |
+
p = 100.0 * n_correct / n_predicted
|
85 |
+
r = 100.0 * n_correct / n_gold
|
86 |
+
f1 = 2 * p * r / (p + r)
|
87 |
+
return [
|
88 |
+
('precision', p),
|
89 |
+
('recall', r),
|
90 |
+
('f1', f1),
|
91 |
+
('loss', self.get_loss()),
|
92 |
+
]
|
93 |
+
|
94 |
+
|
95 |
+
class MCCScorer(SentenceLevelScorer):
|
96 |
+
|
97 |
+
def _get_results(self):
|
98 |
+
return [
|
99 |
+
('mcc', 100 * sklearn.metrics.matthews_corrcoef(
|
100 |
+
self._true_labels, self._preds)),
|
101 |
+
('loss', self.get_loss()),
|
102 |
+
]
|
103 |
+
|
104 |
+
|
105 |
+
class RegressionScorer(SentenceLevelScorer):
|
106 |
+
|
107 |
+
def _get_results(self):
|
108 |
+
preds = np.array(self._preds).flatten()
|
109 |
+
return [
|
110 |
+
('pearson', 100.0 * scipy.stats.pearsonr(
|
111 |
+
self._true_labels, preds)[0]),
|
112 |
+
('spearman', 100.0 * scipy.stats.spearmanr(
|
113 |
+
self._true_labels, preds)[0]),
|
114 |
+
('mse', np.mean(np.square(np.array(self._true_labels) - self._preds))),
|
115 |
+
('loss', self.get_loss()),
|
116 |
+
]
|
arabert/araelectra/finetune/classification/classification_tasks.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Text classification and regression tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import csv
|
24 |
+
import os
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
import configure_finetuning
|
28 |
+
from finetune import feature_spec
|
29 |
+
from finetune import task
|
30 |
+
from finetune.classification import classification_metrics
|
31 |
+
from model import tokenization
|
32 |
+
from util import utils
|
33 |
+
|
34 |
+
|
35 |
+
class InputExample(task.Example):
|
36 |
+
"""A single training/test example for simple sequence classification."""
|
37 |
+
|
38 |
+
def __init__(self, eid, task_name, text_a, text_b=None, label=None):
|
39 |
+
super(InputExample, self).__init__(task_name)
|
40 |
+
self.eid = eid
|
41 |
+
self.text_a = text_a
|
42 |
+
self.text_b = text_b
|
43 |
+
self.label = label
|
44 |
+
|
45 |
+
|
46 |
+
class SingleOutputTask(task.Task):
|
47 |
+
"""Task with a single prediction per example (e.g., text classification)."""
|
48 |
+
|
49 |
+
__metaclass__ = abc.ABCMeta
|
50 |
+
|
51 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
52 |
+
tokenizer):
|
53 |
+
super(SingleOutputTask, self).__init__(config, name)
|
54 |
+
self._tokenizer = tokenizer
|
55 |
+
|
56 |
+
def get_examples(self, split):
|
57 |
+
return self._create_examples(read_tsv(
|
58 |
+
os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"),
|
59 |
+
max_lines=100 if self.config.debug else None), split)
|
60 |
+
|
61 |
+
@abc.abstractmethod
|
62 |
+
def _create_examples(self, lines, split):
|
63 |
+
pass
|
64 |
+
|
65 |
+
def featurize(self, example: InputExample, is_training, log=False):
|
66 |
+
"""Turn an InputExample into a dict of features."""
|
67 |
+
tokens_a = self._tokenizer.tokenize(example.text_a)
|
68 |
+
tokens_b = None
|
69 |
+
if example.text_b:
|
70 |
+
tokens_b = self._tokenizer.tokenize(example.text_b)
|
71 |
+
|
72 |
+
if tokens_b:
|
73 |
+
# Modifies `tokens_a` and `tokens_b` in place so that the total
|
74 |
+
# length is less than the specified length.
|
75 |
+
# Account for [CLS], [SEP], [SEP] with "- 3"
|
76 |
+
_truncate_seq_pair(tokens_a, tokens_b, self.config.max_seq_length - 3)
|
77 |
+
else:
|
78 |
+
# Account for [CLS] and [SEP] with "- 2"
|
79 |
+
if len(tokens_a) > self.config.max_seq_length - 2:
|
80 |
+
tokens_a = tokens_a[0:(self.config.max_seq_length - 2)]
|
81 |
+
|
82 |
+
# The convention in BERT is:
|
83 |
+
# (a) For sequence pairs:
|
84 |
+
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
85 |
+
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
86 |
+
# (b) For single sequences:
|
87 |
+
# tokens: [CLS] the dog is hairy . [SEP]
|
88 |
+
# type_ids: 0 0 0 0 0 0 0
|
89 |
+
#
|
90 |
+
# Where "type_ids" are used to indicate whether this is the first
|
91 |
+
# sequence or the second sequence. The embedding vectors for `type=0` and
|
92 |
+
# `type=1` were learned during pre-training and are added to the wordpiece
|
93 |
+
# embedding vector (and position vector). This is not *strictly* necessary
|
94 |
+
# since the [SEP] token unambiguously separates the sequences, but it
|
95 |
+
# makes it easier for the model to learn the concept of sequences.
|
96 |
+
#
|
97 |
+
# For classification tasks, the first vector (corresponding to [CLS]) is
|
98 |
+
# used as the "sentence vector". Note that this only makes sense because
|
99 |
+
# the entire model is fine-tuned.
|
100 |
+
tokens = []
|
101 |
+
segment_ids = []
|
102 |
+
tokens.append("[CLS]")
|
103 |
+
segment_ids.append(0)
|
104 |
+
for token in tokens_a:
|
105 |
+
tokens.append(token)
|
106 |
+
segment_ids.append(0)
|
107 |
+
tokens.append("[SEP]")
|
108 |
+
segment_ids.append(0)
|
109 |
+
|
110 |
+
if tokens_b:
|
111 |
+
for token in tokens_b:
|
112 |
+
tokens.append(token)
|
113 |
+
segment_ids.append(1)
|
114 |
+
tokens.append("[SEP]")
|
115 |
+
segment_ids.append(1)
|
116 |
+
|
117 |
+
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
118 |
+
|
119 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
120 |
+
# tokens are attended to.
|
121 |
+
input_mask = [1] * len(input_ids)
|
122 |
+
|
123 |
+
# Zero-pad up to the sequence length.
|
124 |
+
while len(input_ids) < self.config.max_seq_length:
|
125 |
+
input_ids.append(0)
|
126 |
+
input_mask.append(0)
|
127 |
+
segment_ids.append(0)
|
128 |
+
|
129 |
+
assert len(input_ids) == self.config.max_seq_length
|
130 |
+
assert len(input_mask) == self.config.max_seq_length
|
131 |
+
assert len(segment_ids) == self.config.max_seq_length
|
132 |
+
|
133 |
+
if log:
|
134 |
+
utils.log(" Example {:}".format(example.eid))
|
135 |
+
utils.log(" tokens: {:}".format(" ".join(
|
136 |
+
[tokenization.printable_text(x) for x in tokens])))
|
137 |
+
utils.log(" input_ids: {:}".format(" ".join(map(str, input_ids))))
|
138 |
+
utils.log(" input_mask: {:}".format(" ".join(map(str, input_mask))))
|
139 |
+
utils.log(" segment_ids: {:}".format(" ".join(map(str, segment_ids))))
|
140 |
+
|
141 |
+
eid = example.eid
|
142 |
+
features = {
|
143 |
+
"input_ids": input_ids,
|
144 |
+
"input_mask": input_mask,
|
145 |
+
"segment_ids": segment_ids,
|
146 |
+
"task_id": self.config.task_names.index(self.name),
|
147 |
+
self.name + "_eid": eid,
|
148 |
+
}
|
149 |
+
self._add_features(features, example, log)
|
150 |
+
return features
|
151 |
+
|
152 |
+
def _load_glue(self, lines, split, text_a_loc, text_b_loc, label_loc,
|
153 |
+
skip_first_line=False, eid_offset=0, swap=False):
|
154 |
+
examples = []
|
155 |
+
for (i, line) in enumerate(lines):
|
156 |
+
try:
|
157 |
+
if i == 0 and skip_first_line:
|
158 |
+
continue
|
159 |
+
eid = i - (1 if skip_first_line else 0) + eid_offset
|
160 |
+
text_a = tokenization.convert_to_unicode(line[text_a_loc])
|
161 |
+
if text_b_loc is None:
|
162 |
+
text_b = None
|
163 |
+
else:
|
164 |
+
text_b = tokenization.convert_to_unicode(line[text_b_loc])
|
165 |
+
if "test" in split or "diagnostic" in split:
|
166 |
+
label = self._get_dummy_label()
|
167 |
+
else:
|
168 |
+
label = tokenization.convert_to_unicode(line[label_loc])
|
169 |
+
if swap:
|
170 |
+
text_a, text_b = text_b, text_a
|
171 |
+
examples.append(InputExample(eid=eid, task_name=self.name,
|
172 |
+
text_a=text_a, text_b=text_b, label=label))
|
173 |
+
except Exception as ex:
|
174 |
+
utils.log("Error constructing example from line", i,
|
175 |
+
"for task", self.name + ":", ex)
|
176 |
+
utils.log("Input causing the error:", line)
|
177 |
+
return examples
|
178 |
+
|
179 |
+
@abc.abstractmethod
|
180 |
+
def _get_dummy_label(self):
|
181 |
+
pass
|
182 |
+
|
183 |
+
@abc.abstractmethod
|
184 |
+
def _add_features(self, features, example, log):
|
185 |
+
pass
|
186 |
+
|
187 |
+
|
188 |
+
class RegressionTask(SingleOutputTask):
|
189 |
+
"""Task where the output is a real-valued score for the input text."""
|
190 |
+
|
191 |
+
__metaclass__ = abc.ABCMeta
|
192 |
+
|
193 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
194 |
+
tokenizer, min_value, max_value):
|
195 |
+
super(RegressionTask, self).__init__(config, name, tokenizer)
|
196 |
+
self._tokenizer = tokenizer
|
197 |
+
self._min_value = min_value
|
198 |
+
self._max_value = max_value
|
199 |
+
|
200 |
+
def _get_dummy_label(self):
|
201 |
+
return 0.0
|
202 |
+
|
203 |
+
def get_feature_specs(self):
|
204 |
+
feature_specs = [feature_spec.FeatureSpec(self.name + "_eid", []),
|
205 |
+
feature_spec.FeatureSpec(self.name + "_targets", [],
|
206 |
+
is_int_feature=False)]
|
207 |
+
return feature_specs
|
208 |
+
|
209 |
+
def _add_features(self, features, example, log):
|
210 |
+
label = float(example.label)
|
211 |
+
assert self._min_value <= label <= self._max_value
|
212 |
+
# simple normalization of the label
|
213 |
+
label = (label - self._min_value) / self._max_value
|
214 |
+
if log:
|
215 |
+
utils.log(" label: {:}".format(label))
|
216 |
+
features[example.task_name + "_targets"] = label
|
217 |
+
|
218 |
+
def get_prediction_module(self, bert_model, features, is_training,
|
219 |
+
percent_done):
|
220 |
+
reprs = bert_model.get_pooled_output()
|
221 |
+
if is_training:
|
222 |
+
reprs = tf.nn.dropout(reprs, keep_prob=0.9)
|
223 |
+
|
224 |
+
predictions = tf.layers.dense(reprs, 1)
|
225 |
+
predictions = tf.squeeze(predictions, -1)
|
226 |
+
|
227 |
+
targets = features[self.name + "_targets"]
|
228 |
+
losses = tf.square(predictions - targets)
|
229 |
+
outputs = dict(
|
230 |
+
loss=losses,
|
231 |
+
predictions=predictions,
|
232 |
+
targets=features[self.name + "_targets"],
|
233 |
+
eid=features[self.name + "_eid"]
|
234 |
+
)
|
235 |
+
return losses, outputs
|
236 |
+
|
237 |
+
def get_scorer(self):
|
238 |
+
return classification_metrics.RegressionScorer()
|
239 |
+
|
240 |
+
|
241 |
+
class ClassificationTask(SingleOutputTask):
|
242 |
+
"""Task where the output is a single categorical label for the input text."""
|
243 |
+
__metaclass__ = abc.ABCMeta
|
244 |
+
|
245 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
246 |
+
tokenizer, label_list):
|
247 |
+
super(ClassificationTask, self).__init__(config, name, tokenizer)
|
248 |
+
self._tokenizer = tokenizer
|
249 |
+
self._label_list = label_list
|
250 |
+
|
251 |
+
def _get_dummy_label(self):
|
252 |
+
return self._label_list[0]
|
253 |
+
|
254 |
+
def get_feature_specs(self):
|
255 |
+
return [feature_spec.FeatureSpec(self.name + "_eid", []),
|
256 |
+
feature_spec.FeatureSpec(self.name + "_label_ids", [])]
|
257 |
+
|
258 |
+
def _add_features(self, features, example, log):
|
259 |
+
label_map = {}
|
260 |
+
for (i, label) in enumerate(self._label_list):
|
261 |
+
label_map[label] = i
|
262 |
+
label_id = label_map[example.label]
|
263 |
+
if log:
|
264 |
+
utils.log(" label: {:} (id = {:})".format(example.label, label_id))
|
265 |
+
features[example.task_name + "_label_ids"] = label_id
|
266 |
+
|
267 |
+
def get_prediction_module(self, bert_model, features, is_training,
|
268 |
+
percent_done):
|
269 |
+
num_labels = len(self._label_list)
|
270 |
+
reprs = bert_model.get_pooled_output()
|
271 |
+
|
272 |
+
if is_training:
|
273 |
+
reprs = tf.nn.dropout(reprs, keep_prob=0.9)
|
274 |
+
|
275 |
+
logits = tf.layers.dense(reprs, num_labels)
|
276 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
277 |
+
|
278 |
+
label_ids = features[self.name + "_label_ids"]
|
279 |
+
labels = tf.one_hot(label_ids, depth=num_labels, dtype=tf.float32)
|
280 |
+
|
281 |
+
losses = -tf.reduce_sum(labels * log_probs, axis=-1)
|
282 |
+
|
283 |
+
outputs = dict(
|
284 |
+
loss=losses,
|
285 |
+
logits=logits,
|
286 |
+
predictions=tf.argmax(logits, axis=-1),
|
287 |
+
label_ids=label_ids,
|
288 |
+
eid=features[self.name + "_eid"],
|
289 |
+
)
|
290 |
+
return losses, outputs
|
291 |
+
|
292 |
+
def get_scorer(self):
|
293 |
+
return classification_metrics.AccuracyScorer()
|
294 |
+
|
295 |
+
|
296 |
+
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
297 |
+
"""Truncates a sequence pair in place to the maximum length."""
|
298 |
+
|
299 |
+
# This is a simple heuristic which will always truncate the longer sequence
|
300 |
+
# one token at a time. This makes more sense than truncating an equal percent
|
301 |
+
# of tokens from each, since if one sequence is very short then each token
|
302 |
+
# that's truncated likely contains more information than a longer sequence.
|
303 |
+
while True:
|
304 |
+
total_length = len(tokens_a) + len(tokens_b)
|
305 |
+
if total_length <= max_length:
|
306 |
+
break
|
307 |
+
if len(tokens_a) > len(tokens_b):
|
308 |
+
tokens_a.pop()
|
309 |
+
else:
|
310 |
+
tokens_b.pop()
|
311 |
+
|
312 |
+
|
313 |
+
def read_tsv(input_file, quotechar=None, max_lines=None):
|
314 |
+
"""Reads a tab separated value file."""
|
315 |
+
with tf.io.gfile.GFile(input_file, "r") as f:
|
316 |
+
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
317 |
+
lines = []
|
318 |
+
for i, line in enumerate(reader):
|
319 |
+
if max_lines and i >= max_lines:
|
320 |
+
break
|
321 |
+
lines.append(line)
|
322 |
+
return lines
|
323 |
+
|
324 |
+
|
325 |
+
class MNLI(ClassificationTask):
|
326 |
+
"""Multi-NLI."""
|
327 |
+
|
328 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
329 |
+
super(MNLI, self).__init__(config, "mnli", tokenizer,
|
330 |
+
["contradiction", "entailment", "neutral"])
|
331 |
+
|
332 |
+
def get_examples(self, split):
|
333 |
+
if split == "dev":
|
334 |
+
split += "_matched"
|
335 |
+
return self._create_examples(read_tsv(
|
336 |
+
os.path.join(self.config.raw_data_dir(self.name), split + ".tsv"),
|
337 |
+
max_lines=100 if self.config.debug else None), split)
|
338 |
+
|
339 |
+
def _create_examples(self, lines, split):
|
340 |
+
if split == "diagnostic":
|
341 |
+
return self._load_glue(lines, split, 1, 2, None, True)
|
342 |
+
else:
|
343 |
+
return self._load_glue(lines, split, 8, 9, -1, True)
|
344 |
+
|
345 |
+
def get_test_splits(self):
|
346 |
+
return ["test_matched", "test_mismatched", "diagnostic"]
|
347 |
+
|
348 |
+
|
349 |
+
class MRPC(ClassificationTask):
|
350 |
+
"""Microsoft Research Paraphrase Corpus."""
|
351 |
+
|
352 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
353 |
+
super(MRPC, self).__init__(config, "mrpc", tokenizer, ["0", "1"])
|
354 |
+
|
355 |
+
def _create_examples(self, lines, split):
|
356 |
+
examples = []
|
357 |
+
examples += self._load_glue(lines, split, 3, 4, 0, True)
|
358 |
+
if self.config.double_unordered and split == "train":
|
359 |
+
examples += self._load_glue(
|
360 |
+
lines, split, 3, 4, 0, True, len(examples), True)
|
361 |
+
return examples
|
362 |
+
|
363 |
+
|
364 |
+
class CoLA(ClassificationTask):
|
365 |
+
"""Corpus of Linguistic Acceptability."""
|
366 |
+
|
367 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
368 |
+
super(CoLA, self).__init__(config, "cola", tokenizer, ["0", "1"])
|
369 |
+
|
370 |
+
def _create_examples(self, lines, split):
|
371 |
+
return self._load_glue(lines, split, 1 if split == "test" else 3,
|
372 |
+
None, 1, split == "test")
|
373 |
+
|
374 |
+
def get_scorer(self):
|
375 |
+
return classification_metrics.MCCScorer()
|
376 |
+
|
377 |
+
|
378 |
+
class SST(ClassificationTask):
|
379 |
+
"""Stanford Sentiment Treebank."""
|
380 |
+
|
381 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
382 |
+
super(SST, self).__init__(config, "sst", tokenizer, ["0", "1"])
|
383 |
+
|
384 |
+
def _create_examples(self, lines, split):
|
385 |
+
if "test" in split:
|
386 |
+
return self._load_glue(lines, split, 1, None, None, True)
|
387 |
+
else:
|
388 |
+
return self._load_glue(lines, split, 0, None, 1, True)
|
389 |
+
|
390 |
+
|
391 |
+
class QQP(ClassificationTask):
|
392 |
+
"""Quora Question Pair."""
|
393 |
+
|
394 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
395 |
+
super(QQP, self).__init__(config, "qqp", tokenizer, ["0", "1"])
|
396 |
+
|
397 |
+
def _create_examples(self, lines, split):
|
398 |
+
return self._load_glue(lines, split, 1 if split == "test" else 3,
|
399 |
+
2 if split == "test" else 4, 5, True)
|
400 |
+
|
401 |
+
|
402 |
+
class RTE(ClassificationTask):
|
403 |
+
"""Recognizing Textual Entailment."""
|
404 |
+
|
405 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
406 |
+
super(RTE, self).__init__(config, "rte", tokenizer,
|
407 |
+
["entailment", "not_entailment"])
|
408 |
+
|
409 |
+
def _create_examples(self, lines, split):
|
410 |
+
return self._load_glue(lines, split, 1, 2, 3, True)
|
411 |
+
|
412 |
+
|
413 |
+
class QNLI(ClassificationTask):
|
414 |
+
"""Question NLI."""
|
415 |
+
|
416 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
417 |
+
super(QNLI, self).__init__(config, "qnli", tokenizer,
|
418 |
+
["entailment", "not_entailment"])
|
419 |
+
|
420 |
+
def _create_examples(self, lines, split):
|
421 |
+
return self._load_glue(lines, split, 1, 2, 3, True)
|
422 |
+
|
423 |
+
|
424 |
+
class STS(RegressionTask):
|
425 |
+
"""Semantic Textual Similarity."""
|
426 |
+
|
427 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
428 |
+
super(STS, self).__init__(config, "sts", tokenizer, 0.0, 5.0)
|
429 |
+
|
430 |
+
def _create_examples(self, lines, split):
|
431 |
+
examples = []
|
432 |
+
if split == "test":
|
433 |
+
examples += self._load_glue(lines, split, -2, -1, None, True)
|
434 |
+
else:
|
435 |
+
examples += self._load_glue(lines, split, -3, -2, -1, True)
|
436 |
+
if self.config.double_unordered and split == "train":
|
437 |
+
examples += self._load_glue(
|
438 |
+
lines, split, -3, -2, -1, True, len(examples), True)
|
439 |
+
return examples
|
arabert/araelectra/finetune/feature_spec.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Defines the inputs used when fine-tuning a model."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf
|
24 |
+
|
25 |
+
import configure_finetuning
|
26 |
+
|
27 |
+
|
28 |
+
def get_shared_feature_specs(config: configure_finetuning.FinetuningConfig):
|
29 |
+
"""Non-task-specific model inputs."""
|
30 |
+
return [
|
31 |
+
FeatureSpec("input_ids", [config.max_seq_length]),
|
32 |
+
FeatureSpec("input_mask", [config.max_seq_length]),
|
33 |
+
FeatureSpec("segment_ids", [config.max_seq_length]),
|
34 |
+
FeatureSpec("task_id", []),
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
class FeatureSpec(object):
|
39 |
+
"""Defines a feature passed as input to the model."""
|
40 |
+
|
41 |
+
def __init__(self, name, shape, default_value_fn=None, is_int_feature=True):
|
42 |
+
self.name = name
|
43 |
+
self.shape = shape
|
44 |
+
self.default_value_fn = default_value_fn
|
45 |
+
self.is_int_feature = is_int_feature
|
46 |
+
|
47 |
+
def get_parsing_spec(self):
|
48 |
+
return tf.io.FixedLenFeature(
|
49 |
+
self.shape, tf.int64 if self.is_int_feature else tf.float32)
|
50 |
+
|
51 |
+
def get_default_values(self):
|
52 |
+
if self.default_value_fn:
|
53 |
+
return self.default_value_fn(self.shape)
|
54 |
+
else:
|
55 |
+
return np.zeros(
|
56 |
+
self.shape, np.int64 if self.is_int_feature else np.float32)
|
arabert/araelectra/finetune/preprocessing.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Code for serializing raw fine-tuning data into tfrecords"""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import collections
|
23 |
+
import os
|
24 |
+
import random
|
25 |
+
import numpy as np
|
26 |
+
import tensorflow as tf
|
27 |
+
|
28 |
+
import configure_finetuning
|
29 |
+
from finetune import feature_spec
|
30 |
+
from util import utils
|
31 |
+
|
32 |
+
|
33 |
+
class Preprocessor(object):
|
34 |
+
"""Class for loading, preprocessing, and serializing fine-tuning datasets."""
|
35 |
+
|
36 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tasks):
|
37 |
+
self._config = config
|
38 |
+
self._tasks = tasks
|
39 |
+
self._name_to_task = {task.name: task for task in tasks}
|
40 |
+
|
41 |
+
self._feature_specs = feature_spec.get_shared_feature_specs(config)
|
42 |
+
for task in tasks:
|
43 |
+
self._feature_specs += task.get_feature_specs()
|
44 |
+
self._name_to_feature_config = {
|
45 |
+
spec.name: spec.get_parsing_spec()
|
46 |
+
for spec in self._feature_specs
|
47 |
+
}
|
48 |
+
assert len(self._name_to_feature_config) == len(self._feature_specs)
|
49 |
+
|
50 |
+
def prepare_train(self):
|
51 |
+
return self._serialize_dataset(self._tasks, True, "train")
|
52 |
+
|
53 |
+
def prepare_predict(self, tasks, split):
|
54 |
+
return self._serialize_dataset(tasks, False, split)
|
55 |
+
|
56 |
+
def _serialize_dataset(self, tasks, is_training, split):
|
57 |
+
"""Write out the dataset as tfrecords."""
|
58 |
+
dataset_name = "_".join(sorted([task.name for task in tasks]))
|
59 |
+
dataset_name += "_" + split
|
60 |
+
dataset_prefix = os.path.join(
|
61 |
+
self._config.preprocessed_data_dir, dataset_name)
|
62 |
+
tfrecords_path = dataset_prefix + ".tfrecord"
|
63 |
+
metadata_path = dataset_prefix + ".metadata"
|
64 |
+
batch_size = (self._config.train_batch_size if is_training else
|
65 |
+
self._config.eval_batch_size)
|
66 |
+
|
67 |
+
utils.log("Loading dataset", dataset_name)
|
68 |
+
n_examples = None
|
69 |
+
if (self._config.use_tfrecords_if_existing and
|
70 |
+
tf.io.gfile.exists(metadata_path)):
|
71 |
+
n_examples = utils.load_json(metadata_path)["n_examples"]
|
72 |
+
|
73 |
+
if n_examples is None:
|
74 |
+
utils.log("Existing tfrecords not found so creating")
|
75 |
+
examples = []
|
76 |
+
for task in tasks:
|
77 |
+
task_examples = task.get_examples(split)
|
78 |
+
examples += task_examples
|
79 |
+
if is_training:
|
80 |
+
random.shuffle(examples)
|
81 |
+
utils.mkdir(tfrecords_path.rsplit("/", 1)[0])
|
82 |
+
n_examples = self.serialize_examples(
|
83 |
+
examples, is_training, tfrecords_path, batch_size)
|
84 |
+
utils.write_json({"n_examples": n_examples}, metadata_path)
|
85 |
+
|
86 |
+
input_fn = self._input_fn_builder(tfrecords_path, is_training)
|
87 |
+
if is_training:
|
88 |
+
steps = int(n_examples // batch_size * self._config.num_train_epochs)
|
89 |
+
else:
|
90 |
+
steps = n_examples // batch_size
|
91 |
+
|
92 |
+
return input_fn, steps
|
93 |
+
|
94 |
+
def serialize_examples(self, examples, is_training, output_file, batch_size):
|
95 |
+
"""Convert a set of `InputExample`s to a TFRecord file."""
|
96 |
+
n_examples = 0
|
97 |
+
with tf.io.TFRecordWriter(output_file) as writer:
|
98 |
+
for (ex_index, example) in enumerate(examples):
|
99 |
+
if ex_index % 2000 == 0:
|
100 |
+
utils.log("Writing example {:} of {:}".format(
|
101 |
+
ex_index, len(examples)))
|
102 |
+
for tf_example in self._example_to_tf_example(
|
103 |
+
example, is_training,
|
104 |
+
log=self._config.log_examples and ex_index < 1):
|
105 |
+
writer.write(tf_example.SerializeToString())
|
106 |
+
n_examples += 1
|
107 |
+
# add padding so the dataset is a multiple of batch_size
|
108 |
+
while n_examples % batch_size != 0:
|
109 |
+
writer.write(self._make_tf_example(task_id=len(self._config.task_names))
|
110 |
+
.SerializeToString())
|
111 |
+
n_examples += 1
|
112 |
+
return n_examples
|
113 |
+
|
114 |
+
def _example_to_tf_example(self, example, is_training, log=False):
|
115 |
+
examples = self._name_to_task[example.task_name].featurize(
|
116 |
+
example, is_training, log)
|
117 |
+
if not isinstance(examples, list):
|
118 |
+
examples = [examples]
|
119 |
+
for example in examples:
|
120 |
+
yield self._make_tf_example(**example)
|
121 |
+
|
122 |
+
def _make_tf_example(self, **kwargs):
|
123 |
+
"""Make a tf.train.Example from the provided features."""
|
124 |
+
for k in kwargs:
|
125 |
+
if k not in self._name_to_feature_config:
|
126 |
+
raise ValueError("Unknown feature", k)
|
127 |
+
features = collections.OrderedDict()
|
128 |
+
for spec in self._feature_specs:
|
129 |
+
if spec.name in kwargs:
|
130 |
+
values = kwargs[spec.name]
|
131 |
+
else:
|
132 |
+
values = spec.get_default_values()
|
133 |
+
if (isinstance(values, int) or isinstance(values, bool) or
|
134 |
+
isinstance(values, float) or isinstance(values, np.float32) or
|
135 |
+
(isinstance(values, np.ndarray) and values.size == 1)):
|
136 |
+
values = [values]
|
137 |
+
if spec.is_int_feature:
|
138 |
+
feature = tf.train.Feature(int64_list=tf.train.Int64List(
|
139 |
+
value=list(values)))
|
140 |
+
else:
|
141 |
+
feature = tf.train.Feature(float_list=tf.train.FloatList(
|
142 |
+
value=list(values)))
|
143 |
+
features[spec.name] = feature
|
144 |
+
return tf.train.Example(features=tf.train.Features(feature=features))
|
145 |
+
|
146 |
+
def _input_fn_builder(self, input_file, is_training):
|
147 |
+
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
148 |
+
|
149 |
+
def input_fn(params):
|
150 |
+
"""The actual input function."""
|
151 |
+
d = tf.data.TFRecordDataset(input_file)
|
152 |
+
if is_training:
|
153 |
+
d = d.repeat()
|
154 |
+
d = d.shuffle(buffer_size=100)
|
155 |
+
return d.apply(
|
156 |
+
tf.data.experimental.map_and_batch(
|
157 |
+
self._decode_tfrecord,
|
158 |
+
batch_size=params["batch_size"],
|
159 |
+
drop_remainder=True))
|
160 |
+
|
161 |
+
return input_fn
|
162 |
+
|
163 |
+
def _decode_tfrecord(self, record):
|
164 |
+
"""Decodes a record to a TensorFlow example."""
|
165 |
+
example = tf.io.parse_single_example(record, self._name_to_feature_config)
|
166 |
+
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
167 |
+
# So cast all int64 to int32.
|
168 |
+
for name, tensor in example.items():
|
169 |
+
if tensor.dtype == tf.int64:
|
170 |
+
example[name] = tf.cast(tensor, tf.int32)
|
171 |
+
else:
|
172 |
+
example[name] = tensor
|
173 |
+
return example
|
arabert/araelectra/finetune/qa/mrqa_official_eval.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Official evaluation script for the MRQA Workshop Shared Task.
|
17 |
+
Adapted fromt the SQuAD v1.1 official evaluation script.
|
18 |
+
Modified slightly for the ELECTRA codebase.
|
19 |
+
"""
|
20 |
+
from __future__ import absolute_import
|
21 |
+
from __future__ import division
|
22 |
+
from __future__ import print_function
|
23 |
+
|
24 |
+
import os
|
25 |
+
import string
|
26 |
+
import re
|
27 |
+
import json
|
28 |
+
import tensorflow as tf
|
29 |
+
from collections import Counter
|
30 |
+
|
31 |
+
import configure_finetuning
|
32 |
+
|
33 |
+
|
34 |
+
def normalize_answer(s):
|
35 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
36 |
+
def remove_articles(text):
|
37 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
38 |
+
|
39 |
+
def white_space_fix(text):
|
40 |
+
return ' '.join(text.split())
|
41 |
+
|
42 |
+
def remove_punc(text):
|
43 |
+
exclude = set(string.punctuation)
|
44 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
45 |
+
|
46 |
+
def lower(text):
|
47 |
+
return text.lower()
|
48 |
+
|
49 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
50 |
+
|
51 |
+
|
52 |
+
def f1_score(prediction, ground_truth):
|
53 |
+
prediction_tokens = normalize_answer(prediction).split()
|
54 |
+
ground_truth_tokens = normalize_answer(ground_truth).split()
|
55 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
56 |
+
num_same = sum(common.values())
|
57 |
+
if num_same == 0:
|
58 |
+
return 0
|
59 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
60 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
61 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
62 |
+
return f1
|
63 |
+
|
64 |
+
|
65 |
+
def exact_match_score(prediction, ground_truth):
|
66 |
+
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
67 |
+
|
68 |
+
|
69 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
70 |
+
scores_for_ground_truths = []
|
71 |
+
for ground_truth in ground_truths:
|
72 |
+
score = metric_fn(prediction, ground_truth)
|
73 |
+
scores_for_ground_truths.append(score)
|
74 |
+
return max(scores_for_ground_truths)
|
75 |
+
|
76 |
+
|
77 |
+
def read_predictions(prediction_file):
|
78 |
+
with tf.io.gfile.GFile(prediction_file) as f:
|
79 |
+
predictions = json.load(f)
|
80 |
+
return predictions
|
81 |
+
|
82 |
+
|
83 |
+
def read_answers(gold_file):
|
84 |
+
answers = {}
|
85 |
+
with tf.io.gfile.GFile(gold_file, 'r') as f:
|
86 |
+
for i, line in enumerate(f):
|
87 |
+
example = json.loads(line)
|
88 |
+
if i == 0 and 'header' in example:
|
89 |
+
continue
|
90 |
+
for qa in example['qas']:
|
91 |
+
answers[qa['qid']] = qa['answers']
|
92 |
+
return answers
|
93 |
+
|
94 |
+
|
95 |
+
def evaluate(answers, predictions, skip_no_answer=False):
|
96 |
+
f1 = exact_match = total = 0
|
97 |
+
for qid, ground_truths in answers.items():
|
98 |
+
if qid not in predictions:
|
99 |
+
if not skip_no_answer:
|
100 |
+
message = 'Unanswered question %s will receive score 0.' % qid
|
101 |
+
print(message)
|
102 |
+
total += 1
|
103 |
+
continue
|
104 |
+
total += 1
|
105 |
+
prediction = predictions[qid]
|
106 |
+
exact_match += metric_max_over_ground_truths(
|
107 |
+
exact_match_score, prediction, ground_truths)
|
108 |
+
f1 += metric_max_over_ground_truths(
|
109 |
+
f1_score, prediction, ground_truths)
|
110 |
+
|
111 |
+
exact_match = 100.0 * exact_match / total
|
112 |
+
f1 = 100.0 * f1 / total
|
113 |
+
|
114 |
+
return {'exact_match': exact_match, 'f1': f1}
|
115 |
+
|
116 |
+
|
117 |
+
def main(config: configure_finetuning.FinetuningConfig, split, task_name):
|
118 |
+
answers = read_answers(os.path.join(config.raw_data_dir(task_name), split + ".jsonl"))
|
119 |
+
predictions = read_predictions(config.qa_preds_file(task_name))
|
120 |
+
return evaluate(answers, predictions, True)
|
arabert/araelectra/finetune/qa/qa_metrics.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Evaluation metrics for question-answering tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import collections
|
23 |
+
import numpy as np
|
24 |
+
import six
|
25 |
+
|
26 |
+
import configure_finetuning
|
27 |
+
from finetune import scorer
|
28 |
+
from finetune.qa import mrqa_official_eval
|
29 |
+
from finetune.qa import squad_official_eval
|
30 |
+
from finetune.qa import squad_official_eval_v1
|
31 |
+
from model import tokenization
|
32 |
+
from util import utils
|
33 |
+
|
34 |
+
|
35 |
+
RawResult = collections.namedtuple("RawResult", [
|
36 |
+
"unique_id", "start_logits", "end_logits", "answerable_logit",
|
37 |
+
"start_top_log_probs", "start_top_index", "end_top_log_probs",
|
38 |
+
"end_top_index"
|
39 |
+
])
|
40 |
+
|
41 |
+
|
42 |
+
class SpanBasedQAScorer(scorer.Scorer):
|
43 |
+
"""Runs evaluation for SQuAD 1.1, SQuAD 2.0, and MRQA tasks."""
|
44 |
+
|
45 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, task, split,
|
46 |
+
v2):
|
47 |
+
super(SpanBasedQAScorer, self).__init__()
|
48 |
+
self._config = config
|
49 |
+
self._task = task
|
50 |
+
self._name = task.name
|
51 |
+
self._split = split
|
52 |
+
self._v2 = v2
|
53 |
+
self._all_results = []
|
54 |
+
self._total_loss = 0
|
55 |
+
self._split = split
|
56 |
+
self._eval_examples = task.get_examples(split)
|
57 |
+
|
58 |
+
def update(self, results):
|
59 |
+
super(SpanBasedQAScorer, self).update(results)
|
60 |
+
self._all_results.append(
|
61 |
+
RawResult(
|
62 |
+
unique_id=results["eid"],
|
63 |
+
start_logits=results["start_logits"],
|
64 |
+
end_logits=results["end_logits"],
|
65 |
+
answerable_logit=results["answerable_logit"],
|
66 |
+
start_top_log_probs=results["start_top_log_probs"],
|
67 |
+
start_top_index=results["start_top_index"],
|
68 |
+
end_top_log_probs=results["end_top_log_probs"],
|
69 |
+
end_top_index=results["end_top_index"],
|
70 |
+
))
|
71 |
+
self._total_loss += results["loss"]
|
72 |
+
|
73 |
+
def get_loss(self):
|
74 |
+
return self._total_loss / len(self._all_results)
|
75 |
+
|
76 |
+
def _get_results(self):
|
77 |
+
self.write_predictions()
|
78 |
+
if self._name == "squad":
|
79 |
+
squad_official_eval.set_opts(self._config, self._split)
|
80 |
+
squad_official_eval.main()
|
81 |
+
return sorted(utils.load_json(
|
82 |
+
self._config.qa_eval_file(self._name)).items())
|
83 |
+
elif self._name == "squadv1":
|
84 |
+
return sorted(squad_official_eval_v1.main(
|
85 |
+
self._config, self._split).items())
|
86 |
+
else:
|
87 |
+
return sorted(mrqa_official_eval.main(
|
88 |
+
self._config, self._split, self._name).items())
|
89 |
+
|
90 |
+
def write_predictions(self):
|
91 |
+
"""Write final predictions to the json file."""
|
92 |
+
unique_id_to_result = {}
|
93 |
+
for result in self._all_results:
|
94 |
+
unique_id_to_result[result.unique_id] = result
|
95 |
+
|
96 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
97 |
+
"PrelimPrediction",
|
98 |
+
["feature_index", "start_index", "end_index", "start_logit",
|
99 |
+
"end_logit"])
|
100 |
+
|
101 |
+
all_predictions = collections.OrderedDict()
|
102 |
+
all_nbest_json = collections.OrderedDict()
|
103 |
+
scores_diff_json = collections.OrderedDict()
|
104 |
+
|
105 |
+
for example in self._eval_examples:
|
106 |
+
example_id = example.qas_id if "squad" in self._name else example.qid
|
107 |
+
features = self._task.featurize(example, False, for_eval=True)
|
108 |
+
|
109 |
+
prelim_predictions = []
|
110 |
+
# keep track of the minimum score of null start+end of position 0
|
111 |
+
score_null = 1000000 # large and positive
|
112 |
+
for (feature_index, feature) in enumerate(features):
|
113 |
+
result = unique_id_to_result[feature[self._name + "_eid"]]
|
114 |
+
if self._config.joint_prediction:
|
115 |
+
start_indexes = result.start_top_index
|
116 |
+
end_indexes = result.end_top_index
|
117 |
+
else:
|
118 |
+
start_indexes = _get_best_indexes(result.start_logits,
|
119 |
+
self._config.n_best_size)
|
120 |
+
end_indexes = _get_best_indexes(result.end_logits,
|
121 |
+
self._config.n_best_size)
|
122 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
123 |
+
if self._v2:
|
124 |
+
if self._config.answerable_classifier:
|
125 |
+
feature_null_score = result.answerable_logit
|
126 |
+
else:
|
127 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
128 |
+
if feature_null_score < score_null:
|
129 |
+
score_null = feature_null_score
|
130 |
+
for i, start_index in enumerate(start_indexes):
|
131 |
+
for j, end_index in enumerate(
|
132 |
+
end_indexes[i] if self._config.joint_prediction else end_indexes):
|
133 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
134 |
+
# that the start of the span is in the question. We throw out all
|
135 |
+
# invalid predictions.
|
136 |
+
if start_index >= len(feature[self._name + "_tokens"]):
|
137 |
+
continue
|
138 |
+
if end_index >= len(feature[self._name + "_tokens"]):
|
139 |
+
continue
|
140 |
+
if start_index == 0:
|
141 |
+
continue
|
142 |
+
if start_index not in feature[self._name + "_token_to_orig_map"]:
|
143 |
+
continue
|
144 |
+
if end_index not in feature[self._name + "_token_to_orig_map"]:
|
145 |
+
continue
|
146 |
+
if not feature[self._name + "_token_is_max_context"].get(
|
147 |
+
start_index, False):
|
148 |
+
continue
|
149 |
+
if end_index < start_index:
|
150 |
+
continue
|
151 |
+
length = end_index - start_index + 1
|
152 |
+
if length > self._config.max_answer_length:
|
153 |
+
continue
|
154 |
+
start_logit = (result.start_top_log_probs[i] if
|
155 |
+
self._config.joint_prediction else
|
156 |
+
result.start_logits[start_index])
|
157 |
+
end_logit = (result.end_top_log_probs[i, j] if
|
158 |
+
self._config.joint_prediction else
|
159 |
+
result.end_logits[end_index])
|
160 |
+
prelim_predictions.append(
|
161 |
+
_PrelimPrediction(
|
162 |
+
feature_index=feature_index,
|
163 |
+
start_index=start_index,
|
164 |
+
end_index=end_index,
|
165 |
+
start_logit=start_logit,
|
166 |
+
end_logit=end_logit))
|
167 |
+
|
168 |
+
if self._v2:
|
169 |
+
if len(prelim_predictions) == 0 and self._config.debug:
|
170 |
+
tokid = sorted(feature[self._name + "_token_to_orig_map"].keys())[0]
|
171 |
+
prelim_predictions.append(_PrelimPrediction(
|
172 |
+
feature_index=0,
|
173 |
+
start_index=tokid,
|
174 |
+
end_index=tokid + 1,
|
175 |
+
start_logit=1.0,
|
176 |
+
end_logit=1.0))
|
177 |
+
prelim_predictions = sorted(
|
178 |
+
prelim_predictions,
|
179 |
+
key=lambda x: (x.start_logit + x.end_logit),
|
180 |
+
reverse=True)
|
181 |
+
|
182 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
183 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"])
|
184 |
+
|
185 |
+
seen_predictions = {}
|
186 |
+
nbest = []
|
187 |
+
for pred in prelim_predictions:
|
188 |
+
if len(nbest) >= self._config.n_best_size:
|
189 |
+
break
|
190 |
+
feature = features[pred.feature_index]
|
191 |
+
tok_tokens = feature[self._name + "_tokens"][
|
192 |
+
pred.start_index:(pred.end_index + 1)]
|
193 |
+
orig_doc_start = feature[
|
194 |
+
self._name + "_token_to_orig_map"][pred.start_index]
|
195 |
+
orig_doc_end = feature[
|
196 |
+
self._name + "_token_to_orig_map"][pred.end_index]
|
197 |
+
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
198 |
+
tok_text = " ".join(tok_tokens)
|
199 |
+
|
200 |
+
# De-tokenize WordPieces that have been split off.
|
201 |
+
tok_text = tok_text.replace(" ##", "")
|
202 |
+
tok_text = tok_text.replace("##", "")
|
203 |
+
|
204 |
+
# Clean whitespace
|
205 |
+
tok_text = tok_text.strip()
|
206 |
+
tok_text = " ".join(tok_text.split())
|
207 |
+
orig_text = " ".join(orig_tokens)
|
208 |
+
|
209 |
+
final_text = get_final_text(self._config, tok_text, orig_text)
|
210 |
+
if final_text in seen_predictions:
|
211 |
+
continue
|
212 |
+
|
213 |
+
seen_predictions[final_text] = True
|
214 |
+
|
215 |
+
nbest.append(
|
216 |
+
_NbestPrediction(
|
217 |
+
text=final_text,
|
218 |
+
start_logit=pred.start_logit,
|
219 |
+
end_logit=pred.end_logit))
|
220 |
+
|
221 |
+
# In very rare edge cases we could have no valid predictions. So we
|
222 |
+
# just create a nonce prediction in this case to avoid failure.
|
223 |
+
if not nbest:
|
224 |
+
nbest.append(
|
225 |
+
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
226 |
+
|
227 |
+
assert len(nbest) >= 1
|
228 |
+
|
229 |
+
total_scores = []
|
230 |
+
best_non_null_entry = None
|
231 |
+
for entry in nbest:
|
232 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
233 |
+
if not best_non_null_entry:
|
234 |
+
if entry.text:
|
235 |
+
best_non_null_entry = entry
|
236 |
+
|
237 |
+
probs = _compute_softmax(total_scores)
|
238 |
+
|
239 |
+
nbest_json = []
|
240 |
+
for (i, entry) in enumerate(nbest):
|
241 |
+
output = collections.OrderedDict()
|
242 |
+
output["text"] = entry.text
|
243 |
+
output["probability"] = probs[i]
|
244 |
+
output["start_logit"] = entry.start_logit
|
245 |
+
output["end_logit"] = entry.end_logit
|
246 |
+
nbest_json.append(dict(output))
|
247 |
+
|
248 |
+
assert len(nbest_json) >= 1
|
249 |
+
|
250 |
+
if not self._v2:
|
251 |
+
all_predictions[example_id] = nbest_json[0]["text"]
|
252 |
+
else:
|
253 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
254 |
+
if self._config.answerable_classifier:
|
255 |
+
score_diff = score_null
|
256 |
+
else:
|
257 |
+
score_diff = score_null - best_non_null_entry.start_logit - (
|
258 |
+
best_non_null_entry.end_logit)
|
259 |
+
scores_diff_json[example_id] = score_diff
|
260 |
+
all_predictions[example_id] = best_non_null_entry.text
|
261 |
+
|
262 |
+
all_nbest_json[example_id] = nbest_json
|
263 |
+
|
264 |
+
utils.write_json(dict(all_predictions),
|
265 |
+
self._config.qa_preds_file(self._name))
|
266 |
+
if self._v2:
|
267 |
+
utils.write_json({
|
268 |
+
k: float(v) for k, v in six.iteritems(scores_diff_json)},
|
269 |
+
self._config.qa_na_file(self._name))
|
270 |
+
|
271 |
+
|
272 |
+
def _get_best_indexes(logits, n_best_size):
|
273 |
+
"""Get the n-best logits from a list."""
|
274 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
275 |
+
|
276 |
+
best_indexes = []
|
277 |
+
for i in range(len(index_and_score)):
|
278 |
+
if i >= n_best_size:
|
279 |
+
break
|
280 |
+
best_indexes.append(index_and_score[i][0])
|
281 |
+
return best_indexes
|
282 |
+
|
283 |
+
|
284 |
+
def _compute_softmax(scores):
|
285 |
+
"""Compute softmax probability over raw logits."""
|
286 |
+
if not scores:
|
287 |
+
return []
|
288 |
+
|
289 |
+
max_score = None
|
290 |
+
for score in scores:
|
291 |
+
if max_score is None or score > max_score:
|
292 |
+
max_score = score
|
293 |
+
|
294 |
+
exp_scores = []
|
295 |
+
total_sum = 0.0
|
296 |
+
for score in scores:
|
297 |
+
x = np.exp(score - max_score)
|
298 |
+
exp_scores.append(x)
|
299 |
+
total_sum += x
|
300 |
+
|
301 |
+
probs = []
|
302 |
+
for score in exp_scores:
|
303 |
+
probs.append(score / total_sum)
|
304 |
+
return probs
|
305 |
+
|
306 |
+
|
307 |
+
def get_final_text(config: configure_finetuning.FinetuningConfig, pred_text,
|
308 |
+
orig_text):
|
309 |
+
"""Project the tokenized prediction back to the original text."""
|
310 |
+
|
311 |
+
# When we created the data, we kept track of the alignment between original
|
312 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
313 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
314 |
+
# span that we predicted.
|
315 |
+
#
|
316 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
317 |
+
# our prediction.
|
318 |
+
#
|
319 |
+
# For example, let's say:
|
320 |
+
# pred_text = steve smith
|
321 |
+
# orig_text = Steve Smith's
|
322 |
+
#
|
323 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
324 |
+
#
|
325 |
+
# We don't want to return `pred_text` because it's already been normalized
|
326 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
327 |
+
# our tokenizer does additional normalization like stripping accent
|
328 |
+
# characters).
|
329 |
+
#
|
330 |
+
# What we really want to return is "Steve Smith".
|
331 |
+
#
|
332 |
+
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
333 |
+
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
334 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
335 |
+
|
336 |
+
def _strip_spaces(text):
|
337 |
+
ns_chars = []
|
338 |
+
ns_to_s_map = collections.OrderedDict()
|
339 |
+
for i, c in enumerate(text):
|
340 |
+
if c == " ":
|
341 |
+
continue
|
342 |
+
ns_to_s_map[len(ns_chars)] = i
|
343 |
+
ns_chars.append(c)
|
344 |
+
ns_text = "".join(ns_chars)
|
345 |
+
return ns_text, dict(ns_to_s_map)
|
346 |
+
|
347 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
348 |
+
# and `pred_text`, and check if they are the same length. If they are
|
349 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
350 |
+
# length, we assume the characters are one-to-one aligned.
|
351 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=config.do_lower_case)
|
352 |
+
|
353 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
354 |
+
|
355 |
+
start_position = tok_text.find(pred_text)
|
356 |
+
if start_position == -1:
|
357 |
+
if config.debug:
|
358 |
+
utils.log(
|
359 |
+
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
360 |
+
return orig_text
|
361 |
+
end_position = start_position + len(pred_text) - 1
|
362 |
+
|
363 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
364 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
365 |
+
|
366 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
367 |
+
if config.debug:
|
368 |
+
utils.log("Length not equal after stripping spaces: '%s' vs '%s'",
|
369 |
+
orig_ns_text, tok_ns_text)
|
370 |
+
return orig_text
|
371 |
+
|
372 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
373 |
+
# the character-to-character alignment.
|
374 |
+
tok_s_to_ns_map = {}
|
375 |
+
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
376 |
+
tok_s_to_ns_map[tok_index] = i
|
377 |
+
|
378 |
+
orig_start_position = None
|
379 |
+
if start_position in tok_s_to_ns_map:
|
380 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
381 |
+
if ns_start_position in orig_ns_to_s_map:
|
382 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
383 |
+
|
384 |
+
if orig_start_position is None:
|
385 |
+
if config.debug:
|
386 |
+
utils.log("Couldn't map start position")
|
387 |
+
return orig_text
|
388 |
+
|
389 |
+
orig_end_position = None
|
390 |
+
if end_position in tok_s_to_ns_map:
|
391 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
392 |
+
if ns_end_position in orig_ns_to_s_map:
|
393 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
394 |
+
|
395 |
+
if orig_end_position is None:
|
396 |
+
if config.debug:
|
397 |
+
utils.log("Couldn't map end position")
|
398 |
+
return orig_text
|
399 |
+
|
400 |
+
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
401 |
+
return output_text
|
arabert/araelectra/finetune/qa/qa_tasks.py
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Question answering tasks. SQuAD 1.1/2.0 and 2019 MRQA tasks are supported."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import collections
|
24 |
+
import json
|
25 |
+
import os
|
26 |
+
import six
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
import configure_finetuning
|
30 |
+
from finetune import feature_spec
|
31 |
+
from finetune import task
|
32 |
+
from finetune.qa import qa_metrics
|
33 |
+
from model import modeling
|
34 |
+
from model import tokenization
|
35 |
+
from util import utils
|
36 |
+
|
37 |
+
|
38 |
+
class QAExample(task.Example):
|
39 |
+
"""Question-answering example."""
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
task_name,
|
43 |
+
eid,
|
44 |
+
qas_id,
|
45 |
+
qid,
|
46 |
+
question_text,
|
47 |
+
doc_tokens,
|
48 |
+
orig_answer_text=None,
|
49 |
+
start_position=None,
|
50 |
+
end_position=None,
|
51 |
+
is_impossible=False):
|
52 |
+
super(QAExample, self).__init__(task_name)
|
53 |
+
self.eid = eid
|
54 |
+
self.qas_id = qas_id
|
55 |
+
self.qid = qid
|
56 |
+
self.question_text = question_text
|
57 |
+
self.doc_tokens = doc_tokens
|
58 |
+
self.orig_answer_text = orig_answer_text
|
59 |
+
self.start_position = start_position
|
60 |
+
self.end_position = end_position
|
61 |
+
self.is_impossible = is_impossible
|
62 |
+
|
63 |
+
def __str__(self):
|
64 |
+
return self.__repr__()
|
65 |
+
|
66 |
+
def __repr__(self):
|
67 |
+
s = ""
|
68 |
+
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
|
69 |
+
s += ", question_text: %s" % (
|
70 |
+
tokenization.printable_text(self.question_text))
|
71 |
+
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
|
72 |
+
if self.start_position:
|
73 |
+
s += ", start_position: %d" % self.start_position
|
74 |
+
if self.start_position:
|
75 |
+
s += ", end_position: %d" % self.end_position
|
76 |
+
if self.start_position:
|
77 |
+
s += ", is_impossible: %r" % self.is_impossible
|
78 |
+
return s
|
79 |
+
|
80 |
+
|
81 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
82 |
+
"""Check if this is the 'max context' doc span for the token."""
|
83 |
+
|
84 |
+
# Because of the sliding window approach taken to scoring documents, a single
|
85 |
+
# token can appear in multiple documents. E.g.
|
86 |
+
# Doc: the man went to the store and bought a gallon of milk
|
87 |
+
# Span A: the man went to the
|
88 |
+
# Span B: to the store and bought
|
89 |
+
# Span C: and bought a gallon of
|
90 |
+
# ...
|
91 |
+
#
|
92 |
+
# Now the word 'bought' will have two scores from spans B and C. We only
|
93 |
+
# want to consider the score with "maximum context", which we define as
|
94 |
+
# the *minimum* of its left and right context (the *sum* of left and
|
95 |
+
# right context will always be the same, of course).
|
96 |
+
#
|
97 |
+
# In the example the maximum context for 'bought' would be span C since
|
98 |
+
# it has 1 left context and 3 right context, while span B has 4 left context
|
99 |
+
# and 0 right context.
|
100 |
+
best_score = None
|
101 |
+
best_span_index = None
|
102 |
+
for (span_index, doc_span) in enumerate(doc_spans):
|
103 |
+
end = doc_span.start + doc_span.length - 1
|
104 |
+
if position < doc_span.start:
|
105 |
+
continue
|
106 |
+
if position > end:
|
107 |
+
continue
|
108 |
+
num_left_context = position - doc_span.start
|
109 |
+
num_right_context = end - position
|
110 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
111 |
+
if best_score is None or score > best_score:
|
112 |
+
best_score = score
|
113 |
+
best_span_index = span_index
|
114 |
+
|
115 |
+
return cur_span_index == best_span_index
|
116 |
+
|
117 |
+
|
118 |
+
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
119 |
+
orig_answer_text):
|
120 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
121 |
+
|
122 |
+
# The SQuAD annotations are character based. We first project them to
|
123 |
+
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
124 |
+
# often find a "better match". For example:
|
125 |
+
#
|
126 |
+
# Question: What year was John Smith born?
|
127 |
+
# Context: The leader was John Smith (1895-1943).
|
128 |
+
# Answer: 1895
|
129 |
+
#
|
130 |
+
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
131 |
+
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
132 |
+
# the exact answer, 1895.
|
133 |
+
#
|
134 |
+
# However, this is not always possible. Consider the following:
|
135 |
+
#
|
136 |
+
# Question: What country is the top exporter of electornics?
|
137 |
+
# Context: The Japanese electronics industry is the lagest in the world.
|
138 |
+
# Answer: Japan
|
139 |
+
#
|
140 |
+
# In this case, the annotator chose "Japan" as a character sub-span of
|
141 |
+
# the word "Japanese". Since our WordPiece tokenizer does not split
|
142 |
+
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
143 |
+
# in SQuAD, but does happen.
|
144 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
145 |
+
|
146 |
+
for new_start in range(input_start, input_end + 1):
|
147 |
+
for new_end in range(input_end, new_start - 1, -1):
|
148 |
+
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
149 |
+
if text_span == tok_answer_text:
|
150 |
+
return new_start, new_end
|
151 |
+
|
152 |
+
return input_start, input_end
|
153 |
+
|
154 |
+
|
155 |
+
def is_whitespace(c):
|
156 |
+
return c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F
|
157 |
+
|
158 |
+
|
159 |
+
class QATask(task.Task):
|
160 |
+
"""A span-based question answering tasks (e.g., SQuAD)."""
|
161 |
+
|
162 |
+
__metaclass__ = abc.ABCMeta
|
163 |
+
|
164 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
165 |
+
tokenizer, v2=False):
|
166 |
+
super(QATask, self).__init__(config, name)
|
167 |
+
self._tokenizer = tokenizer
|
168 |
+
self._examples = {}
|
169 |
+
self.v2 = v2
|
170 |
+
|
171 |
+
def _add_examples(self, examples, example_failures, paragraph, split):
|
172 |
+
paragraph_text = paragraph["context"]
|
173 |
+
doc_tokens = []
|
174 |
+
char_to_word_offset = []
|
175 |
+
prev_is_whitespace = True
|
176 |
+
for c in paragraph_text:
|
177 |
+
if is_whitespace(c):
|
178 |
+
prev_is_whitespace = True
|
179 |
+
else:
|
180 |
+
if prev_is_whitespace:
|
181 |
+
doc_tokens.append(c)
|
182 |
+
else:
|
183 |
+
doc_tokens[-1] += c
|
184 |
+
prev_is_whitespace = False
|
185 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
186 |
+
|
187 |
+
for qa in paragraph["qas"]:
|
188 |
+
qas_id = qa["id"] if "id" in qa else None
|
189 |
+
qid = qa["qid"] if "qid" in qa else None
|
190 |
+
question_text = qa["question"]
|
191 |
+
start_position = None
|
192 |
+
end_position = None
|
193 |
+
orig_answer_text = None
|
194 |
+
is_impossible = False
|
195 |
+
if split == "train":
|
196 |
+
if self.v2:
|
197 |
+
is_impossible = qa["is_impossible"]
|
198 |
+
if not is_impossible:
|
199 |
+
if "detected_answers" in qa: # MRQA format
|
200 |
+
answer = qa["detected_answers"][0]
|
201 |
+
answer_offset = answer["char_spans"][0][0]
|
202 |
+
else: # SQuAD format
|
203 |
+
answer = qa["answers"][0]
|
204 |
+
answer_offset = answer["answer_start"]
|
205 |
+
orig_answer_text = answer["text"]
|
206 |
+
answer_length = len(orig_answer_text)
|
207 |
+
start_position = char_to_word_offset[answer_offset]
|
208 |
+
if answer_offset + answer_length - 1 >= len(char_to_word_offset):
|
209 |
+
utils.log("End position is out of document!")
|
210 |
+
example_failures[0] += 1
|
211 |
+
continue
|
212 |
+
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
213 |
+
|
214 |
+
# Only add answers where the text can be exactly recovered from the
|
215 |
+
# document. If this CAN'T happen it's likely due to weird Unicode
|
216 |
+
# stuff so we will just skip the example.
|
217 |
+
#
|
218 |
+
# Note that this means for training mode, every example is NOT
|
219 |
+
# guaranteed to be preserved.
|
220 |
+
actual_text = " ".join(
|
221 |
+
doc_tokens[start_position:(end_position + 1)])
|
222 |
+
cleaned_answer_text = " ".join(
|
223 |
+
tokenization.whitespace_tokenize(orig_answer_text))
|
224 |
+
actual_text = actual_text.lower()
|
225 |
+
cleaned_answer_text = cleaned_answer_text.lower()
|
226 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
227 |
+
utils.log("Could not find answer: '{:}' in doc vs. "
|
228 |
+
"'{:}' in provided answer".format(
|
229 |
+
tokenization.printable_text(actual_text),
|
230 |
+
tokenization.printable_text(cleaned_answer_text)))
|
231 |
+
example_failures[0] += 1
|
232 |
+
continue
|
233 |
+
else:
|
234 |
+
start_position = -1
|
235 |
+
end_position = -1
|
236 |
+
orig_answer_text = ""
|
237 |
+
|
238 |
+
example = QAExample(
|
239 |
+
task_name=self.name,
|
240 |
+
eid=len(examples),
|
241 |
+
qas_id=qas_id,
|
242 |
+
qid=qid,
|
243 |
+
question_text=question_text,
|
244 |
+
doc_tokens=doc_tokens,
|
245 |
+
orig_answer_text=orig_answer_text,
|
246 |
+
start_position=start_position,
|
247 |
+
end_position=end_position,
|
248 |
+
is_impossible=is_impossible)
|
249 |
+
examples.append(example)
|
250 |
+
|
251 |
+
def get_feature_specs(self):
|
252 |
+
return [
|
253 |
+
feature_spec.FeatureSpec(self.name + "_eid", []),
|
254 |
+
feature_spec.FeatureSpec(self.name + "_start_positions", []),
|
255 |
+
feature_spec.FeatureSpec(self.name + "_end_positions", []),
|
256 |
+
feature_spec.FeatureSpec(self.name + "_is_impossible", []),
|
257 |
+
]
|
258 |
+
|
259 |
+
def featurize(self, example: QAExample, is_training, log=False,
|
260 |
+
for_eval=False):
|
261 |
+
all_features = []
|
262 |
+
query_tokens = self._tokenizer.tokenize(example.question_text)
|
263 |
+
|
264 |
+
if len(query_tokens) > self.config.max_query_length:
|
265 |
+
query_tokens = query_tokens[0:self.config.max_query_length]
|
266 |
+
|
267 |
+
tok_to_orig_index = []
|
268 |
+
orig_to_tok_index = []
|
269 |
+
all_doc_tokens = []
|
270 |
+
for (i, token) in enumerate(example.doc_tokens):
|
271 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
272 |
+
sub_tokens = self._tokenizer.tokenize(token)
|
273 |
+
for sub_token in sub_tokens:
|
274 |
+
tok_to_orig_index.append(i)
|
275 |
+
all_doc_tokens.append(sub_token)
|
276 |
+
|
277 |
+
tok_start_position = None
|
278 |
+
tok_end_position = None
|
279 |
+
if is_training and example.is_impossible:
|
280 |
+
tok_start_position = -1
|
281 |
+
tok_end_position = -1
|
282 |
+
if is_training and not example.is_impossible:
|
283 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
284 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
285 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
286 |
+
else:
|
287 |
+
tok_end_position = len(all_doc_tokens) - 1
|
288 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
289 |
+
all_doc_tokens, tok_start_position, tok_end_position, self._tokenizer,
|
290 |
+
example.orig_answer_text)
|
291 |
+
|
292 |
+
# The -3 accounts for [CLS], [SEP] and [SEP]
|
293 |
+
max_tokens_for_doc = self.config.max_seq_length - len(query_tokens) - 3
|
294 |
+
|
295 |
+
# We can have documents that are longer than the maximum sequence length.
|
296 |
+
# To deal with this we do a sliding window approach, where we take chunks
|
297 |
+
# of the up to our max length with a stride of `doc_stride`.
|
298 |
+
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
|
299 |
+
"DocSpan", ["start", "length"])
|
300 |
+
doc_spans = []
|
301 |
+
start_offset = 0
|
302 |
+
while start_offset < len(all_doc_tokens):
|
303 |
+
length = len(all_doc_tokens) - start_offset
|
304 |
+
if length > max_tokens_for_doc:
|
305 |
+
length = max_tokens_for_doc
|
306 |
+
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
307 |
+
if start_offset + length == len(all_doc_tokens):
|
308 |
+
break
|
309 |
+
start_offset += min(length, self.config.doc_stride)
|
310 |
+
|
311 |
+
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
312 |
+
tokens = []
|
313 |
+
token_to_orig_map = {}
|
314 |
+
token_is_max_context = {}
|
315 |
+
segment_ids = []
|
316 |
+
tokens.append("[CLS]")
|
317 |
+
segment_ids.append(0)
|
318 |
+
for token in query_tokens:
|
319 |
+
tokens.append(token)
|
320 |
+
segment_ids.append(0)
|
321 |
+
tokens.append("[SEP]")
|
322 |
+
segment_ids.append(0)
|
323 |
+
|
324 |
+
for i in range(doc_span.length):
|
325 |
+
split_token_index = doc_span.start + i
|
326 |
+
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
327 |
+
|
328 |
+
is_max_context = _check_is_max_context(doc_spans, doc_span_index,
|
329 |
+
split_token_index)
|
330 |
+
token_is_max_context[len(tokens)] = is_max_context
|
331 |
+
tokens.append(all_doc_tokens[split_token_index])
|
332 |
+
segment_ids.append(1)
|
333 |
+
tokens.append("[SEP]")
|
334 |
+
segment_ids.append(1)
|
335 |
+
|
336 |
+
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
337 |
+
|
338 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
339 |
+
# tokens are attended to.
|
340 |
+
input_mask = [1] * len(input_ids)
|
341 |
+
|
342 |
+
# Zero-pad up to the sequence length.
|
343 |
+
while len(input_ids) < self.config.max_seq_length:
|
344 |
+
input_ids.append(0)
|
345 |
+
input_mask.append(0)
|
346 |
+
segment_ids.append(0)
|
347 |
+
|
348 |
+
assert len(input_ids) == self.config.max_seq_length
|
349 |
+
assert len(input_mask) == self.config.max_seq_length
|
350 |
+
assert len(segment_ids) == self.config.max_seq_length
|
351 |
+
|
352 |
+
start_position = None
|
353 |
+
end_position = None
|
354 |
+
if is_training and not example.is_impossible:
|
355 |
+
# For training, if our document chunk does not contain an annotation
|
356 |
+
# we throw it out, since there is nothing to predict.
|
357 |
+
doc_start = doc_span.start
|
358 |
+
doc_end = doc_span.start + doc_span.length - 1
|
359 |
+
out_of_span = False
|
360 |
+
if not (tok_start_position >= doc_start and
|
361 |
+
tok_end_position <= doc_end):
|
362 |
+
out_of_span = True
|
363 |
+
if out_of_span:
|
364 |
+
start_position = 0
|
365 |
+
end_position = 0
|
366 |
+
else:
|
367 |
+
doc_offset = len(query_tokens) + 2
|
368 |
+
start_position = tok_start_position - doc_start + doc_offset
|
369 |
+
end_position = tok_end_position - doc_start + doc_offset
|
370 |
+
|
371 |
+
if is_training and example.is_impossible:
|
372 |
+
start_position = 0
|
373 |
+
end_position = 0
|
374 |
+
|
375 |
+
if log:
|
376 |
+
utils.log("*** Example ***")
|
377 |
+
utils.log("doc_span_index: %s" % doc_span_index)
|
378 |
+
utils.log("tokens: %s" % " ".join(
|
379 |
+
[tokenization.printable_text(x) for x in tokens]))
|
380 |
+
utils.log("token_to_orig_map: %s" % " ".join(
|
381 |
+
["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
|
382 |
+
utils.log("token_is_max_context: %s" % " ".join([
|
383 |
+
"%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
|
384 |
+
]))
|
385 |
+
utils.log("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
386 |
+
utils.log("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
387 |
+
utils.log("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
388 |
+
if is_training and example.is_impossible:
|
389 |
+
utils.log("impossible example")
|
390 |
+
if is_training and not example.is_impossible:
|
391 |
+
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
392 |
+
utils.log("start_position: %d" % start_position)
|
393 |
+
utils.log("end_position: %d" % end_position)
|
394 |
+
utils.log("answer: %s" % (tokenization.printable_text(answer_text)))
|
395 |
+
|
396 |
+
features = {
|
397 |
+
"task_id": self.config.task_names.index(self.name),
|
398 |
+
self.name + "_eid": (1000 * example.eid) + doc_span_index,
|
399 |
+
"input_ids": input_ids,
|
400 |
+
"input_mask": input_mask,
|
401 |
+
"segment_ids": segment_ids,
|
402 |
+
}
|
403 |
+
if for_eval:
|
404 |
+
features.update({
|
405 |
+
self.name + "_doc_span_index": doc_span_index,
|
406 |
+
self.name + "_tokens": tokens,
|
407 |
+
self.name + "_token_to_orig_map": token_to_orig_map,
|
408 |
+
self.name + "_token_is_max_context": token_is_max_context,
|
409 |
+
})
|
410 |
+
if is_training:
|
411 |
+
features.update({
|
412 |
+
self.name + "_start_positions": start_position,
|
413 |
+
self.name + "_end_positions": end_position,
|
414 |
+
self.name + "_is_impossible": example.is_impossible
|
415 |
+
})
|
416 |
+
all_features.append(features)
|
417 |
+
return all_features
|
418 |
+
|
419 |
+
def get_prediction_module(self, bert_model, features, is_training,
|
420 |
+
percent_done):
|
421 |
+
final_hidden = bert_model.get_sequence_output()
|
422 |
+
|
423 |
+
final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
|
424 |
+
batch_size = final_hidden_shape[0]
|
425 |
+
seq_length = final_hidden_shape[1]
|
426 |
+
|
427 |
+
answer_mask = tf.cast(features["input_mask"], tf.float32)
|
428 |
+
answer_mask *= tf.cast(features["segment_ids"], tf.float32)
|
429 |
+
answer_mask += tf.one_hot(0, seq_length)
|
430 |
+
|
431 |
+
start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
|
432 |
+
|
433 |
+
start_top_log_probs = tf.zeros([batch_size, self.config.beam_size])
|
434 |
+
start_top_index = tf.zeros([batch_size, self.config.beam_size], tf.int32)
|
435 |
+
end_top_log_probs = tf.zeros([batch_size, self.config.beam_size,
|
436 |
+
self.config.beam_size])
|
437 |
+
end_top_index = tf.zeros([batch_size, self.config.beam_size,
|
438 |
+
self.config.beam_size], tf.int32)
|
439 |
+
if self.config.joint_prediction:
|
440 |
+
start_logits += 1000.0 * (answer_mask - 1)
|
441 |
+
start_log_probs = tf.nn.log_softmax(start_logits)
|
442 |
+
start_top_log_probs, start_top_index = tf.nn.top_k(
|
443 |
+
start_log_probs, k=self.config.beam_size)
|
444 |
+
|
445 |
+
if not is_training:
|
446 |
+
# batch, beam, length, hidden
|
447 |
+
end_features = tf.tile(tf.expand_dims(final_hidden, 1),
|
448 |
+
[1, self.config.beam_size, 1, 1])
|
449 |
+
# batch, beam, length
|
450 |
+
start_index = tf.one_hot(start_top_index,
|
451 |
+
depth=seq_length, axis=-1, dtype=tf.float32)
|
452 |
+
# batch, beam, hidden
|
453 |
+
start_features = tf.reduce_sum(
|
454 |
+
tf.expand_dims(final_hidden, 1) *
|
455 |
+
tf.expand_dims(start_index, -1), axis=-2)
|
456 |
+
# batch, beam, length, hidden
|
457 |
+
start_features = tf.tile(tf.expand_dims(start_features, 2),
|
458 |
+
[1, 1, seq_length, 1])
|
459 |
+
else:
|
460 |
+
start_index = tf.one_hot(
|
461 |
+
features[self.name + "_start_positions"], depth=seq_length,
|
462 |
+
axis=-1, dtype=tf.float32)
|
463 |
+
start_features = tf.reduce_sum(tf.expand_dims(start_index, -1) *
|
464 |
+
final_hidden, axis=1)
|
465 |
+
start_features = tf.tile(tf.expand_dims(start_features, 1),
|
466 |
+
[1, seq_length, 1])
|
467 |
+
end_features = final_hidden
|
468 |
+
|
469 |
+
final_repr = tf.concat([start_features, end_features], -1)
|
470 |
+
final_repr = tf.layers.dense(final_repr, 512, activation=modeling.gelu,
|
471 |
+
name="qa_hidden")
|
472 |
+
# batch, beam, length (batch, length when training)
|
473 |
+
end_logits = tf.squeeze(tf.layers.dense(final_repr, 1), -1,
|
474 |
+
name="qa_logits")
|
475 |
+
if is_training:
|
476 |
+
end_logits += 1000.0 * (answer_mask - 1)
|
477 |
+
else:
|
478 |
+
end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1)
|
479 |
+
|
480 |
+
if not is_training:
|
481 |
+
end_log_probs = tf.nn.log_softmax(end_logits)
|
482 |
+
end_top_log_probs, end_top_index = tf.nn.top_k(
|
483 |
+
end_log_probs, k=self.config.beam_size)
|
484 |
+
end_logits = tf.zeros([batch_size, seq_length])
|
485 |
+
else:
|
486 |
+
end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
|
487 |
+
start_logits += 1000.0 * (answer_mask - 1)
|
488 |
+
end_logits += 1000.0 * (answer_mask - 1)
|
489 |
+
|
490 |
+
def compute_loss(logits, positions):
|
491 |
+
one_hot_positions = tf.one_hot(
|
492 |
+
positions, depth=seq_length, dtype=tf.float32)
|
493 |
+
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
494 |
+
loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
|
495 |
+
return loss
|
496 |
+
|
497 |
+
start_positions = features[self.name + "_start_positions"]
|
498 |
+
end_positions = features[self.name + "_end_positions"]
|
499 |
+
|
500 |
+
start_loss = compute_loss(start_logits, start_positions)
|
501 |
+
end_loss = compute_loss(end_logits, end_positions)
|
502 |
+
|
503 |
+
losses = (start_loss + end_loss) / 2.0
|
504 |
+
|
505 |
+
answerable_logit = tf.zeros([batch_size])
|
506 |
+
if self.config.answerable_classifier:
|
507 |
+
final_repr = final_hidden[:, 0]
|
508 |
+
if self.config.answerable_uses_start_logits:
|
509 |
+
start_p = tf.nn.softmax(start_logits)
|
510 |
+
start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) *
|
511 |
+
final_hidden, axis=1)
|
512 |
+
final_repr = tf.concat([final_repr, start_feature], -1)
|
513 |
+
final_repr = tf.layers.dense(final_repr, 512,
|
514 |
+
activation=modeling.gelu)
|
515 |
+
answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1)
|
516 |
+
answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
517 |
+
labels=tf.cast(features[self.name + "_is_impossible"], tf.float32),
|
518 |
+
logits=answerable_logit)
|
519 |
+
losses += answerable_loss * self.config.answerable_weight
|
520 |
+
|
521 |
+
return losses, dict(
|
522 |
+
loss=losses,
|
523 |
+
start_logits=start_logits,
|
524 |
+
end_logits=end_logits,
|
525 |
+
answerable_logit=answerable_logit,
|
526 |
+
start_positions=features[self.name + "_start_positions"],
|
527 |
+
end_positions=features[self.name + "_end_positions"],
|
528 |
+
start_top_log_probs=start_top_log_probs,
|
529 |
+
start_top_index=start_top_index,
|
530 |
+
end_top_log_probs=end_top_log_probs,
|
531 |
+
end_top_index=end_top_index,
|
532 |
+
eid=features[self.name + "_eid"],
|
533 |
+
)
|
534 |
+
|
535 |
+
def get_scorer(self, split="dev"):
|
536 |
+
return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
|
537 |
+
|
538 |
+
|
539 |
+
class MRQATask(QATask):
|
540 |
+
"""Class for finetuning tasks from the 2019 MRQA shared task."""
|
541 |
+
|
542 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
543 |
+
tokenizer):
|
544 |
+
super(MRQATask, self).__init__(config, name, tokenizer)
|
545 |
+
|
546 |
+
def get_examples(self, split):
|
547 |
+
if split in self._examples:
|
548 |
+
utils.log("N EXAMPLES", split, len(self._examples[split]))
|
549 |
+
return self._examples[split]
|
550 |
+
|
551 |
+
examples = []
|
552 |
+
example_failures = [0]
|
553 |
+
with tf.io.gfile.GFile(os.path.join(
|
554 |
+
self.config.raw_data_dir(self.name), split + ".jsonl"), "r") as f:
|
555 |
+
for i, line in enumerate(f):
|
556 |
+
if self.config.debug and i > 10:
|
557 |
+
break
|
558 |
+
paragraph = json.loads(line.strip())
|
559 |
+
if "header" in paragraph:
|
560 |
+
continue
|
561 |
+
self._add_examples(examples, example_failures, paragraph, split)
|
562 |
+
self._examples[split] = examples
|
563 |
+
utils.log("{:} examples created, {:} failures".format(
|
564 |
+
len(examples), example_failures[0]))
|
565 |
+
return examples
|
566 |
+
|
567 |
+
def get_scorer(self, split="dev"):
|
568 |
+
return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
|
569 |
+
|
570 |
+
|
571 |
+
class SQuADTask(QATask):
|
572 |
+
"""Class for finetuning on SQuAD 2.0 or 1.1."""
|
573 |
+
|
574 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
575 |
+
tokenizer, v2=False):
|
576 |
+
super(SQuADTask, self).__init__(config, name, tokenizer, v2=v2)
|
577 |
+
|
578 |
+
def get_examples(self, split):
|
579 |
+
if split in self._examples:
|
580 |
+
return self._examples[split]
|
581 |
+
|
582 |
+
with tf.io.gfile.GFile(os.path.join(
|
583 |
+
self.config.raw_data_dir(self.name),
|
584 |
+
split + ("-debug" if self.config.debug else "") + ".json"), "r") as f:
|
585 |
+
input_data = json.load(f)["data"]
|
586 |
+
|
587 |
+
examples = []
|
588 |
+
example_failures = [0]
|
589 |
+
for entry in input_data:
|
590 |
+
for paragraph in entry["paragraphs"]:
|
591 |
+
self._add_examples(examples, example_failures, paragraph, split)
|
592 |
+
self._examples[split] = examples
|
593 |
+
utils.log("{:} examples created, {:} failures".format(
|
594 |
+
len(examples), example_failures[0]))
|
595 |
+
return examples
|
596 |
+
|
597 |
+
def get_scorer(self, split="dev"):
|
598 |
+
return qa_metrics.SpanBasedQAScorer(self.config, self, split, self.v2)
|
599 |
+
|
600 |
+
|
601 |
+
class SQuAD(SQuADTask):
|
602 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
603 |
+
super(SQuAD, self).__init__(config, "squad", tokenizer, v2=True)
|
604 |
+
|
605 |
+
|
606 |
+
class SQuADv1(SQuADTask):
|
607 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
608 |
+
super(SQuADv1, self).__init__(config, "squadv1", tokenizer)
|
609 |
+
|
610 |
+
|
611 |
+
class NewsQA(MRQATask):
|
612 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
613 |
+
super(NewsQA, self).__init__(config, "newsqa", tokenizer)
|
614 |
+
|
615 |
+
|
616 |
+
class NaturalQuestions(MRQATask):
|
617 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
618 |
+
super(NaturalQuestions, self).__init__(config, "naturalqs", tokenizer)
|
619 |
+
|
620 |
+
|
621 |
+
class SearchQA(MRQATask):
|
622 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
623 |
+
super(SearchQA, self).__init__(config, "searchqa", tokenizer)
|
624 |
+
|
625 |
+
|
626 |
+
class TriviaQA(MRQATask):
|
627 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, tokenizer):
|
628 |
+
super(TriviaQA, self).__init__(config, "triviaqa", tokenizer)
|
arabert/araelectra/finetune/qa/squad_official_eval.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Official evaluation script for SQuAD version 2.0.
|
17 |
+
|
18 |
+
In addition to basic functionality, we also compute additional statistics and
|
19 |
+
plot precision-recall curves if an additional na_prob.json file is provided.
|
20 |
+
This file is expected to map question ID's to the model's predicted probability
|
21 |
+
that a question is unanswerable.
|
22 |
+
|
23 |
+
Modified slightly for the ELECTRA codebase.
|
24 |
+
"""
|
25 |
+
from __future__ import absolute_import
|
26 |
+
from __future__ import division
|
27 |
+
from __future__ import print_function
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
import collections
|
31 |
+
import json
|
32 |
+
import numpy as np
|
33 |
+
import os
|
34 |
+
import re
|
35 |
+
import string
|
36 |
+
import sys
|
37 |
+
import tensorflow as tf
|
38 |
+
|
39 |
+
import configure_finetuning
|
40 |
+
|
41 |
+
OPTS = None
|
42 |
+
|
43 |
+
def parse_args():
|
44 |
+
parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
|
45 |
+
parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
|
46 |
+
parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
|
47 |
+
parser.add_argument('--out-file', '-o', metavar='eval.json',
|
48 |
+
help='Write accuracy metrics to file (default is stdout).')
|
49 |
+
parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
|
50 |
+
help='Model estimates of probability of no answer.')
|
51 |
+
parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
|
52 |
+
help='Predict "" if no-answer probability exceeds this (default = 1.0).')
|
53 |
+
parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
|
54 |
+
help='Save precision-recall curves to directory.')
|
55 |
+
parser.add_argument('--verbose', '-v', action='store_true')
|
56 |
+
if len(sys.argv) == 1:
|
57 |
+
parser.print_help()
|
58 |
+
sys.exit(1)
|
59 |
+
return parser.parse_args()
|
60 |
+
|
61 |
+
def set_opts(config: configure_finetuning.FinetuningConfig, split):
|
62 |
+
global OPTS
|
63 |
+
Options = collections.namedtuple("Options", [
|
64 |
+
"data_file", "pred_file", "out_file", "na_prob_file", "na_prob_thresh",
|
65 |
+
"out_image_dir", "verbose"])
|
66 |
+
OPTS = Options(
|
67 |
+
data_file=os.path.join(
|
68 |
+
config.raw_data_dir("squad"),
|
69 |
+
split + ("-debug" if config.debug else "") + ".json"),
|
70 |
+
pred_file=config.qa_preds_file("squad"),
|
71 |
+
out_file=config.qa_eval_file("squad"),
|
72 |
+
na_prob_file=config.qa_na_file("squad"),
|
73 |
+
na_prob_thresh=config.qa_na_threshold,
|
74 |
+
out_image_dir=None,
|
75 |
+
verbose=False
|
76 |
+
)
|
77 |
+
|
78 |
+
def make_qid_to_has_ans(dataset):
|
79 |
+
qid_to_has_ans = {}
|
80 |
+
for article in dataset:
|
81 |
+
for p in article['paragraphs']:
|
82 |
+
for qa in p['qas']:
|
83 |
+
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
84 |
+
return qid_to_has_ans
|
85 |
+
|
86 |
+
def normalize_answer(s):
|
87 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
88 |
+
def remove_articles(text):
|
89 |
+
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
90 |
+
return re.sub(regex, ' ', text)
|
91 |
+
def white_space_fix(text):
|
92 |
+
return ' '.join(text.split())
|
93 |
+
def remove_punc(text):
|
94 |
+
exclude = set(string.punctuation)
|
95 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
96 |
+
def lower(text):
|
97 |
+
return text.lower()
|
98 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
99 |
+
|
100 |
+
def get_tokens(s):
|
101 |
+
if not s: return []
|
102 |
+
return normalize_answer(s).split()
|
103 |
+
|
104 |
+
def compute_exact(a_gold, a_pred):
|
105 |
+
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
106 |
+
|
107 |
+
def compute_f1(a_gold, a_pred):
|
108 |
+
gold_toks = get_tokens(a_gold)
|
109 |
+
pred_toks = get_tokens(a_pred)
|
110 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
111 |
+
num_same = sum(common.values())
|
112 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
113 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
114 |
+
return int(gold_toks == pred_toks)
|
115 |
+
if num_same == 0:
|
116 |
+
return 0
|
117 |
+
precision = 1.0 * num_same / len(pred_toks)
|
118 |
+
recall = 1.0 * num_same / len(gold_toks)
|
119 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
120 |
+
return f1
|
121 |
+
|
122 |
+
def get_raw_scores(dataset, preds):
|
123 |
+
exact_scores = {}
|
124 |
+
f1_scores = {}
|
125 |
+
for article in dataset:
|
126 |
+
for p in article['paragraphs']:
|
127 |
+
for qa in p['qas']:
|
128 |
+
qid = qa['id']
|
129 |
+
gold_answers = [a['text'] for a in qa['answers']
|
130 |
+
if normalize_answer(a['text'])]
|
131 |
+
if not gold_answers:
|
132 |
+
# For unanswerable questions, only correct answer is empty string
|
133 |
+
gold_answers = ['']
|
134 |
+
if qid not in preds:
|
135 |
+
print('Missing prediction for %s' % qid)
|
136 |
+
continue
|
137 |
+
a_pred = preds[qid]
|
138 |
+
# Take max over all gold answers
|
139 |
+
exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
|
140 |
+
f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
|
141 |
+
return exact_scores, f1_scores
|
142 |
+
|
143 |
+
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
144 |
+
new_scores = {}
|
145 |
+
for qid, s in scores.items():
|
146 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
147 |
+
if pred_na:
|
148 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
149 |
+
else:
|
150 |
+
new_scores[qid] = s
|
151 |
+
return new_scores
|
152 |
+
|
153 |
+
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
154 |
+
if not qid_list:
|
155 |
+
total = len(exact_scores)
|
156 |
+
return collections.OrderedDict([
|
157 |
+
('exact', 100.0 * sum(exact_scores.values()) / total),
|
158 |
+
('f1', 100.0 * sum(f1_scores.values()) / total),
|
159 |
+
('total', total),
|
160 |
+
])
|
161 |
+
else:
|
162 |
+
total = len(qid_list)
|
163 |
+
return collections.OrderedDict([
|
164 |
+
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
165 |
+
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
166 |
+
('total', total),
|
167 |
+
])
|
168 |
+
|
169 |
+
def merge_eval(main_eval, new_eval, prefix):
|
170 |
+
for k in new_eval:
|
171 |
+
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
172 |
+
|
173 |
+
def plot_pr_curve(precisions, recalls, out_image, title):
|
174 |
+
plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
|
175 |
+
plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
|
176 |
+
plt.xlabel('Recall')
|
177 |
+
plt.ylabel('Precision')
|
178 |
+
plt.xlim([0.0, 1.05])
|
179 |
+
plt.ylim([0.0, 1.05])
|
180 |
+
plt.title(title)
|
181 |
+
plt.savefig(out_image)
|
182 |
+
plt.clf()
|
183 |
+
|
184 |
+
def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
|
185 |
+
out_image=None, title=None):
|
186 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
187 |
+
true_pos = 0.0
|
188 |
+
cur_p = 1.0
|
189 |
+
cur_r = 0.0
|
190 |
+
precisions = [1.0]
|
191 |
+
recalls = [0.0]
|
192 |
+
avg_prec = 0.0
|
193 |
+
for i, qid in enumerate(qid_list):
|
194 |
+
if qid_to_has_ans[qid]:
|
195 |
+
true_pos += scores[qid]
|
196 |
+
cur_p = true_pos / float(i+1)
|
197 |
+
cur_r = true_pos / float(num_true_pos)
|
198 |
+
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
|
199 |
+
# i.e., if we can put a threshold after this point
|
200 |
+
avg_prec += cur_p * (cur_r - recalls[-1])
|
201 |
+
precisions.append(cur_p)
|
202 |
+
recalls.append(cur_r)
|
203 |
+
if out_image:
|
204 |
+
plot_pr_curve(precisions, recalls, out_image, title)
|
205 |
+
return {'ap': 100.0 * avg_prec}
|
206 |
+
|
207 |
+
def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
|
208 |
+
qid_to_has_ans, out_image_dir):
|
209 |
+
if out_image_dir and not os.path.exists(out_image_dir):
|
210 |
+
os.makedirs(out_image_dir)
|
211 |
+
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
212 |
+
if num_true_pos == 0:
|
213 |
+
return
|
214 |
+
pr_exact = make_precision_recall_eval(
|
215 |
+
exact_raw, na_probs, num_true_pos, qid_to_has_ans,
|
216 |
+
out_image=os.path.join(out_image_dir, 'pr_exact.png'),
|
217 |
+
title='Precision-Recall curve for Exact Match score')
|
218 |
+
pr_f1 = make_precision_recall_eval(
|
219 |
+
f1_raw, na_probs, num_true_pos, qid_to_has_ans,
|
220 |
+
out_image=os.path.join(out_image_dir, 'pr_f1.png'),
|
221 |
+
title='Precision-Recall curve for F1 score')
|
222 |
+
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
223 |
+
pr_oracle = make_precision_recall_eval(
|
224 |
+
oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
|
225 |
+
out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
|
226 |
+
title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
|
227 |
+
merge_eval(main_eval, pr_exact, 'pr_exact')
|
228 |
+
merge_eval(main_eval, pr_f1, 'pr_f1')
|
229 |
+
merge_eval(main_eval, pr_oracle, 'pr_oracle')
|
230 |
+
|
231 |
+
def histogram_na_prob(na_probs, qid_list, image_dir, name):
|
232 |
+
if not qid_list:
|
233 |
+
return
|
234 |
+
x = [na_probs[k] for k in qid_list]
|
235 |
+
weights = np.ones_like(x) / float(len(x))
|
236 |
+
plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
|
237 |
+
plt.xlabel('Model probability of no-answer')
|
238 |
+
plt.ylabel('Proportion of dataset')
|
239 |
+
plt.title('Histogram of no-answer probability: %s' % name)
|
240 |
+
plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
|
241 |
+
plt.clf()
|
242 |
+
|
243 |
+
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
244 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
245 |
+
cur_score = num_no_ans
|
246 |
+
best_score = cur_score
|
247 |
+
best_thresh = 0.0
|
248 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
249 |
+
for i, qid in enumerate(qid_list):
|
250 |
+
if qid not in scores: continue
|
251 |
+
if qid_to_has_ans[qid]:
|
252 |
+
diff = scores[qid]
|
253 |
+
else:
|
254 |
+
if preds[qid]:
|
255 |
+
diff = -1
|
256 |
+
else:
|
257 |
+
diff = 0
|
258 |
+
cur_score += diff
|
259 |
+
if cur_score > best_score:
|
260 |
+
best_score = cur_score
|
261 |
+
best_thresh = na_probs[qid]
|
262 |
+
return 100.0 * best_score / len(scores), best_thresh
|
263 |
+
|
264 |
+
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
265 |
+
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
266 |
+
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
267 |
+
main_eval['best_exact'] = best_exact
|
268 |
+
main_eval['best_exact_thresh'] = exact_thresh
|
269 |
+
main_eval['best_f1'] = best_f1
|
270 |
+
main_eval['best_f1_thresh'] = f1_thresh
|
271 |
+
|
272 |
+
def main():
|
273 |
+
with tf.io.gfile.GFile(OPTS.data_file) as f:
|
274 |
+
dataset_json = json.load(f)
|
275 |
+
dataset = dataset_json['data']
|
276 |
+
with tf.io.gfile.GFile(OPTS.pred_file) as f:
|
277 |
+
preds = json.load(f)
|
278 |
+
if OPTS.na_prob_file:
|
279 |
+
with tf.io.gfile.GFile(OPTS.na_prob_file) as f:
|
280 |
+
na_probs = json.load(f)
|
281 |
+
else:
|
282 |
+
na_probs = {k: 0.0 for k in preds}
|
283 |
+
qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
|
284 |
+
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
285 |
+
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
286 |
+
exact_raw, f1_raw = get_raw_scores(dataset, preds)
|
287 |
+
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
|
288 |
+
OPTS.na_prob_thresh)
|
289 |
+
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
|
290 |
+
OPTS.na_prob_thresh)
|
291 |
+
out_eval = make_eval_dict(exact_thresh, f1_thresh)
|
292 |
+
if has_ans_qids:
|
293 |
+
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
294 |
+
merge_eval(out_eval, has_ans_eval, 'HasAns')
|
295 |
+
if no_ans_qids:
|
296 |
+
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
297 |
+
merge_eval(out_eval, no_ans_eval, 'NoAns')
|
298 |
+
if OPTS.na_prob_file:
|
299 |
+
find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
300 |
+
if OPTS.na_prob_file and OPTS.out_image_dir:
|
301 |
+
run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
|
302 |
+
qid_to_has_ans, OPTS.out_image_dir)
|
303 |
+
histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
|
304 |
+
histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
|
305 |
+
if OPTS.out_file:
|
306 |
+
with tf.io.gfile.GFile(OPTS.out_file, 'w') as f:
|
307 |
+
json.dump(out_eval, f)
|
308 |
+
else:
|
309 |
+
print(json.dumps(out_eval, indent=2))
|
310 |
+
|
311 |
+
if __name__ == '__main__':
|
312 |
+
OPTS = parse_args()
|
313 |
+
if OPTS.out_image_dir:
|
314 |
+
import matplotlib
|
315 |
+
matplotlib.use('Agg')
|
316 |
+
import matplotlib.pyplot as plt
|
317 |
+
main()
|
arabert/araelectra/finetune/qa/squad_official_eval_v1.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
Official evaluation script for v1.1 of the SQuAD dataset.
|
18 |
+
Modified slightly for the ELECTRA codebase.
|
19 |
+
"""
|
20 |
+
from __future__ import absolute_import
|
21 |
+
from __future__ import division
|
22 |
+
from __future__ import print_function
|
23 |
+
from collections import Counter
|
24 |
+
import string
|
25 |
+
import re
|
26 |
+
import json
|
27 |
+
import sys
|
28 |
+
import os
|
29 |
+
import collections
|
30 |
+
import tensorflow as tf
|
31 |
+
|
32 |
+
import configure_finetuning
|
33 |
+
|
34 |
+
|
35 |
+
def normalize_answer(s):
|
36 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
37 |
+
def remove_articles(text):
|
38 |
+
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
39 |
+
|
40 |
+
def white_space_fix(text):
|
41 |
+
return ' '.join(text.split())
|
42 |
+
|
43 |
+
def remove_punc(text):
|
44 |
+
exclude = set(string.punctuation)
|
45 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
46 |
+
|
47 |
+
def lower(text):
|
48 |
+
return text.lower()
|
49 |
+
|
50 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
51 |
+
|
52 |
+
|
53 |
+
def f1_score(prediction, ground_truth):
|
54 |
+
prediction_tokens = normalize_answer(prediction).split()
|
55 |
+
ground_truth_tokens = normalize_answer(ground_truth).split()
|
56 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
57 |
+
num_same = sum(common.values())
|
58 |
+
if num_same == 0:
|
59 |
+
return 0
|
60 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
61 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
62 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
63 |
+
return f1
|
64 |
+
|
65 |
+
|
66 |
+
def exact_match_score(prediction, ground_truth):
|
67 |
+
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
68 |
+
|
69 |
+
|
70 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
71 |
+
scores_for_ground_truths = []
|
72 |
+
for ground_truth in ground_truths:
|
73 |
+
score = metric_fn(prediction, ground_truth)
|
74 |
+
scores_for_ground_truths.append(score)
|
75 |
+
return max(scores_for_ground_truths)
|
76 |
+
|
77 |
+
|
78 |
+
def evaluate(dataset, predictions):
|
79 |
+
f1 = exact_match = total = 0
|
80 |
+
for article in dataset:
|
81 |
+
for paragraph in article['paragraphs']:
|
82 |
+
for qa in paragraph['qas']:
|
83 |
+
total += 1
|
84 |
+
if qa['id'] not in predictions:
|
85 |
+
message = 'Unanswered question ' + qa['id'] + \
|
86 |
+
' will receive score 0.'
|
87 |
+
print(message, file=sys.stderr)
|
88 |
+
continue
|
89 |
+
ground_truths = list(map(lambda x: x['text'], qa['answers']))
|
90 |
+
prediction = predictions[qa['id']]
|
91 |
+
exact_match += metric_max_over_ground_truths(
|
92 |
+
exact_match_score, prediction, ground_truths)
|
93 |
+
f1 += metric_max_over_ground_truths(
|
94 |
+
f1_score, prediction, ground_truths)
|
95 |
+
|
96 |
+
exact_match = 100.0 * exact_match / total
|
97 |
+
f1 = 100.0 * f1 / total
|
98 |
+
|
99 |
+
return {'exact_match': exact_match, 'f1': f1}
|
100 |
+
|
101 |
+
|
102 |
+
def main(config: configure_finetuning.FinetuningConfig, split):
|
103 |
+
expected_version = '1.1'
|
104 |
+
# parser = argparse.ArgumentParser(
|
105 |
+
# description='Evaluation for SQuAD ' + expected_version)
|
106 |
+
# parser.add_argument('dataset_file', help='Dataset file')
|
107 |
+
# parser.add_argument('prediction_file', help='Prediction File')
|
108 |
+
# args = parser.parse_args()
|
109 |
+
Args = collections.namedtuple("Args", [
|
110 |
+
"dataset_file", "prediction_file"
|
111 |
+
])
|
112 |
+
args = Args(dataset_file=os.path.join(
|
113 |
+
config.raw_data_dir("squadv1"),
|
114 |
+
split + ("-debug" if config.debug else "") + ".json"),
|
115 |
+
prediction_file=config.qa_preds_file("squadv1"))
|
116 |
+
with tf.io.gfile.GFile(args.dataset_file) as dataset_file:
|
117 |
+
dataset_json = json.load(dataset_file)
|
118 |
+
if dataset_json['version'] != expected_version:
|
119 |
+
print('Evaluation expects v-' + expected_version +
|
120 |
+
', but got dataset with v-' + dataset_json['version'],
|
121 |
+
file=sys.stderr)
|
122 |
+
dataset = dataset_json['data']
|
123 |
+
with tf.io.gfile.GFile(args.prediction_file) as prediction_file:
|
124 |
+
predictions = json.load(prediction_file)
|
125 |
+
return evaluate(dataset, predictions)
|
126 |
+
|
arabert/araelectra/finetune/scorer.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Base class for evaluation metrics."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
|
24 |
+
|
25 |
+
class Scorer(object):
|
26 |
+
"""Abstract base class for computing evaluation metrics."""
|
27 |
+
|
28 |
+
__metaclass__ = abc.ABCMeta
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
self._updated = False
|
32 |
+
self._cached_results = {}
|
33 |
+
|
34 |
+
@abc.abstractmethod
|
35 |
+
def update(self, results):
|
36 |
+
self._updated = True
|
37 |
+
|
38 |
+
@abc.abstractmethod
|
39 |
+
def get_loss(self):
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abc.abstractmethod
|
43 |
+
def _get_results(self):
|
44 |
+
return []
|
45 |
+
|
46 |
+
def get_results(self, prefix=""):
|
47 |
+
results = self._get_results() if self._updated else self._cached_results
|
48 |
+
self._cached_results = results
|
49 |
+
self._updated = False
|
50 |
+
return [(prefix + k, v) for k, v in results]
|
51 |
+
|
52 |
+
def results_str(self):
|
53 |
+
return " - ".join(["{:}: {:.2f}".format(k, v)
|
54 |
+
for k, v in self.get_results()])
|
arabert/araelectra/finetune/tagging/tagging_metrics.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Metrics for sequence tagging tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import six
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from finetune import scorer
|
28 |
+
from finetune.tagging import tagging_utils
|
29 |
+
|
30 |
+
|
31 |
+
class WordLevelScorer(scorer.Scorer):
|
32 |
+
"""Base class for tagging scorers."""
|
33 |
+
__metaclass__ = abc.ABCMeta
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
super(WordLevelScorer, self).__init__()
|
37 |
+
self._total_loss = 0
|
38 |
+
self._total_words = 0
|
39 |
+
self._labels = []
|
40 |
+
self._preds = []
|
41 |
+
|
42 |
+
def update(self, results):
|
43 |
+
super(WordLevelScorer, self).update(results)
|
44 |
+
self._total_loss += results['loss']
|
45 |
+
n_words = int(round(np.sum(results['labels_mask'])))
|
46 |
+
self._labels.append(results['labels'][:n_words])
|
47 |
+
self._preds.append(results['predictions'][:n_words])
|
48 |
+
self._total_loss += np.sum(results['loss'])
|
49 |
+
self._total_words += n_words
|
50 |
+
|
51 |
+
def get_loss(self):
|
52 |
+
return self._total_loss / max(1, self._total_words)
|
53 |
+
|
54 |
+
|
55 |
+
class AccuracyScorer(WordLevelScorer):
|
56 |
+
"""Computes accuracy scores."""
|
57 |
+
|
58 |
+
def __init__(self, auto_fail_label=None):
|
59 |
+
super(AccuracyScorer, self).__init__()
|
60 |
+
self._auto_fail_label = auto_fail_label
|
61 |
+
|
62 |
+
def _get_results(self):
|
63 |
+
correct, count = 0, 0
|
64 |
+
for labels, preds in zip(self._labels, self._preds):
|
65 |
+
for y_true, y_pred in zip(labels, preds):
|
66 |
+
count += 1
|
67 |
+
correct += (1 if y_pred == y_true and y_true != self._auto_fail_label
|
68 |
+
else 0)
|
69 |
+
return [
|
70 |
+
('accuracy', 100.0 * correct / count),
|
71 |
+
('loss', self.get_loss())
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
class F1Scorer(WordLevelScorer):
|
76 |
+
"""Computes F1 scores."""
|
77 |
+
|
78 |
+
__metaclass__ = abc.ABCMeta
|
79 |
+
|
80 |
+
def __init__(self):
|
81 |
+
super(F1Scorer, self).__init__()
|
82 |
+
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
|
83 |
+
|
84 |
+
def _get_results(self):
|
85 |
+
if self._n_correct == 0:
|
86 |
+
p, r, f1 = 0, 0, 0
|
87 |
+
else:
|
88 |
+
p = 100.0 * self._n_correct / self._n_predicted
|
89 |
+
r = 100.0 * self._n_correct / self._n_gold
|
90 |
+
f1 = 2 * p * r / (p + r)
|
91 |
+
return [
|
92 |
+
('precision', p),
|
93 |
+
('recall', r),
|
94 |
+
('f1', f1),
|
95 |
+
('loss', self.get_loss()),
|
96 |
+
]
|
97 |
+
|
98 |
+
|
99 |
+
class EntityLevelF1Scorer(F1Scorer):
|
100 |
+
"""Computes F1 score for entity-level tasks such as NER."""
|
101 |
+
|
102 |
+
def __init__(self, label_mapping):
|
103 |
+
super(EntityLevelF1Scorer, self).__init__()
|
104 |
+
self._inv_label_mapping = {v: k for k, v in six.iteritems(label_mapping)}
|
105 |
+
|
106 |
+
def _get_results(self):
|
107 |
+
self._n_correct, self._n_predicted, self._n_gold = 0, 0, 0
|
108 |
+
for labels, preds in zip(self._labels, self._preds):
|
109 |
+
sent_spans = set(tagging_utils.get_span_labels(
|
110 |
+
labels, self._inv_label_mapping))
|
111 |
+
span_preds = set(tagging_utils.get_span_labels(
|
112 |
+
preds, self._inv_label_mapping))
|
113 |
+
self._n_correct += len(sent_spans & span_preds)
|
114 |
+
self._n_gold += len(sent_spans)
|
115 |
+
self._n_predicted += len(span_preds)
|
116 |
+
return super(EntityLevelF1Scorer, self)._get_results()
|
arabert/araelectra/finetune/tagging/tagging_tasks.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Sequence tagging tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import collections
|
24 |
+
import os
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
import configure_finetuning
|
28 |
+
from finetune import feature_spec
|
29 |
+
from finetune import task
|
30 |
+
from finetune.tagging import tagging_metrics
|
31 |
+
from finetune.tagging import tagging_utils
|
32 |
+
from model import tokenization
|
33 |
+
from pretrain import pretrain_helpers
|
34 |
+
from util import utils
|
35 |
+
|
36 |
+
|
37 |
+
LABEL_ENCODING = "BIOES"
|
38 |
+
|
39 |
+
|
40 |
+
class TaggingExample(task.Example):
|
41 |
+
"""A single tagged input sequence."""
|
42 |
+
|
43 |
+
def __init__(self, eid, task_name, words, tags, is_token_level,
|
44 |
+
label_mapping):
|
45 |
+
super(TaggingExample, self).__init__(task_name)
|
46 |
+
self.eid = eid
|
47 |
+
self.words = words
|
48 |
+
if is_token_level:
|
49 |
+
labels = tags
|
50 |
+
else:
|
51 |
+
span_labels = tagging_utils.get_span_labels(tags)
|
52 |
+
labels = tagging_utils.get_tags(
|
53 |
+
span_labels, len(words), LABEL_ENCODING)
|
54 |
+
self.labels = [label_mapping[l] for l in labels]
|
55 |
+
|
56 |
+
|
57 |
+
class TaggingTask(task.Task):
|
58 |
+
"""Defines a sequence tagging task (e.g., part-of-speech tagging)."""
|
59 |
+
|
60 |
+
__metaclass__ = abc.ABCMeta
|
61 |
+
|
62 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name,
|
63 |
+
tokenizer, is_token_level):
|
64 |
+
super(TaggingTask, self).__init__(config, name)
|
65 |
+
self._tokenizer = tokenizer
|
66 |
+
self._label_mapping_path = os.path.join(
|
67 |
+
self.config.preprocessed_data_dir,
|
68 |
+
("debug_" if self.config.debug else "") + self.name +
|
69 |
+
"_label_mapping.pkl")
|
70 |
+
self._is_token_level = is_token_level
|
71 |
+
self._label_mapping = None
|
72 |
+
|
73 |
+
def get_examples(self, split):
|
74 |
+
sentences = self._get_labeled_sentences(split)
|
75 |
+
examples = []
|
76 |
+
label_mapping = self._get_label_mapping(split, sentences)
|
77 |
+
for i, (words, tags) in enumerate(sentences):
|
78 |
+
examples.append(TaggingExample(
|
79 |
+
i, self.name, words, tags, self._is_token_level, label_mapping
|
80 |
+
))
|
81 |
+
return examples
|
82 |
+
|
83 |
+
def _get_label_mapping(self, provided_split=None, provided_sentences=None):
|
84 |
+
if self._label_mapping is not None:
|
85 |
+
return self._label_mapping
|
86 |
+
if tf.io.gfile.exists(self._label_mapping_path):
|
87 |
+
self._label_mapping = utils.load_pickle(self._label_mapping_path)
|
88 |
+
return self._label_mapping
|
89 |
+
utils.log("Writing label mapping for task", self.name)
|
90 |
+
tag_counts = collections.Counter()
|
91 |
+
train_tags = set()
|
92 |
+
for split in ["train", "dev", "test"]:
|
93 |
+
if not tf.io.gfile.exists(os.path.join(
|
94 |
+
self.config.raw_data_dir(self.name), split + ".txt")):
|
95 |
+
continue
|
96 |
+
if split == provided_split:
|
97 |
+
split_sentences = provided_sentences
|
98 |
+
else:
|
99 |
+
split_sentences = self._get_labeled_sentences(split)
|
100 |
+
for _, tags in split_sentences:
|
101 |
+
if not self._is_token_level:
|
102 |
+
span_labels = tagging_utils.get_span_labels(tags)
|
103 |
+
tags = tagging_utils.get_tags(span_labels, len(tags), LABEL_ENCODING)
|
104 |
+
for tag in tags:
|
105 |
+
tag_counts[tag] += 1
|
106 |
+
if provided_split == "train":
|
107 |
+
train_tags.add(tag)
|
108 |
+
if self.name == "ccg":
|
109 |
+
infrequent_tags = []
|
110 |
+
for tag in tag_counts:
|
111 |
+
if tag not in train_tags:
|
112 |
+
infrequent_tags.append(tag)
|
113 |
+
label_mapping = {
|
114 |
+
label: i for i, label in enumerate(sorted(filter(
|
115 |
+
lambda t: t not in infrequent_tags, tag_counts.keys())))
|
116 |
+
}
|
117 |
+
n = len(label_mapping)
|
118 |
+
for tag in infrequent_tags:
|
119 |
+
label_mapping[tag] = n
|
120 |
+
else:
|
121 |
+
labels = sorted(tag_counts.keys())
|
122 |
+
label_mapping = {label: i for i, label in enumerate(labels)}
|
123 |
+
utils.write_pickle(label_mapping, self._label_mapping_path)
|
124 |
+
self._label_mapping = label_mapping
|
125 |
+
return label_mapping
|
126 |
+
|
127 |
+
def featurize(self, example: TaggingExample, is_training, log=False):
|
128 |
+
words_to_tokens = tokenize_and_align(self._tokenizer, example.words)
|
129 |
+
input_ids = []
|
130 |
+
tagged_positions = []
|
131 |
+
for word_tokens in words_to_tokens:
|
132 |
+
if len(words_to_tokens) + len(input_ids) + 1 > self.config.max_seq_length:
|
133 |
+
input_ids.append(self._tokenizer.vocab["[SEP]"])
|
134 |
+
break
|
135 |
+
if "[CLS]" not in word_tokens and "[SEP]" not in word_tokens:
|
136 |
+
tagged_positions.append(len(input_ids))
|
137 |
+
for token in word_tokens:
|
138 |
+
input_ids.append(self._tokenizer.vocab[token])
|
139 |
+
|
140 |
+
pad = lambda x: x + [0] * (self.config.max_seq_length - len(x))
|
141 |
+
labels = pad(example.labels[:self.config.max_seq_length])
|
142 |
+
labeled_positions = pad(tagged_positions)
|
143 |
+
labels_mask = pad([1.0] * len(tagged_positions))
|
144 |
+
segment_ids = pad([1] * len(input_ids))
|
145 |
+
input_mask = pad([1] * len(input_ids))
|
146 |
+
input_ids = pad(input_ids)
|
147 |
+
assert len(input_ids) == self.config.max_seq_length
|
148 |
+
assert len(input_mask) == self.config.max_seq_length
|
149 |
+
assert len(segment_ids) == self.config.max_seq_length
|
150 |
+
assert len(labels) == self.config.max_seq_length
|
151 |
+
assert len(labels_mask) == self.config.max_seq_length
|
152 |
+
|
153 |
+
return {
|
154 |
+
"input_ids": input_ids,
|
155 |
+
"input_mask": input_mask,
|
156 |
+
"segment_ids": segment_ids,
|
157 |
+
"task_id": self.config.task_names.index(self.name),
|
158 |
+
self.name + "_eid": example.eid,
|
159 |
+
self.name + "_labels": labels,
|
160 |
+
self.name + "_labels_mask": labels_mask,
|
161 |
+
self.name + "_labeled_positions": labeled_positions
|
162 |
+
}
|
163 |
+
|
164 |
+
def _get_labeled_sentences(self, split):
|
165 |
+
sentences = []
|
166 |
+
with tf.io.gfile.GFile(os.path.join(self.config.raw_data_dir(self.name),
|
167 |
+
split + ".txt"), "r") as f:
|
168 |
+
sentence = []
|
169 |
+
for line in f:
|
170 |
+
line = line.strip().split()
|
171 |
+
if not line:
|
172 |
+
if sentence:
|
173 |
+
words, tags = zip(*sentence)
|
174 |
+
sentences.append((words, tags))
|
175 |
+
sentence = []
|
176 |
+
if self.config.debug and len(sentences) > 100:
|
177 |
+
return sentences
|
178 |
+
continue
|
179 |
+
if line[0] == "-DOCSTART-":
|
180 |
+
continue
|
181 |
+
word, tag = line[0], line[-1]
|
182 |
+
sentence.append((word, tag))
|
183 |
+
return sentences
|
184 |
+
|
185 |
+
def get_scorer(self):
|
186 |
+
return tagging_metrics.AccuracyScorer() if self._is_token_level else \
|
187 |
+
tagging_metrics.EntityLevelF1Scorer(self._get_label_mapping())
|
188 |
+
|
189 |
+
def get_feature_specs(self):
|
190 |
+
return [
|
191 |
+
feature_spec.FeatureSpec(self.name + "_eid", []),
|
192 |
+
feature_spec.FeatureSpec(self.name + "_labels",
|
193 |
+
[self.config.max_seq_length]),
|
194 |
+
feature_spec.FeatureSpec(self.name + "_labels_mask",
|
195 |
+
[self.config.max_seq_length],
|
196 |
+
is_int_feature=False),
|
197 |
+
feature_spec.FeatureSpec(self.name + "_labeled_positions",
|
198 |
+
[self.config.max_seq_length]),
|
199 |
+
]
|
200 |
+
|
201 |
+
def get_prediction_module(
|
202 |
+
self, bert_model, features, is_training, percent_done):
|
203 |
+
n_classes = len(self._get_label_mapping())
|
204 |
+
reprs = bert_model.get_sequence_output()
|
205 |
+
reprs = pretrain_helpers.gather_positions(
|
206 |
+
reprs, features[self.name + "_labeled_positions"])
|
207 |
+
logits = tf.layers.dense(reprs, n_classes)
|
208 |
+
losses = tf.nn.softmax_cross_entropy_with_logits(
|
209 |
+
labels=tf.one_hot(features[self.name + "_labels"], n_classes),
|
210 |
+
logits=logits)
|
211 |
+
losses *= features[self.name + "_labels_mask"]
|
212 |
+
losses = tf.reduce_sum(losses, axis=-1)
|
213 |
+
return losses, dict(
|
214 |
+
loss=losses,
|
215 |
+
logits=logits,
|
216 |
+
predictions=tf.argmax(logits, axis=-1),
|
217 |
+
labels=features[self.name + "_labels"],
|
218 |
+
labels_mask=features[self.name + "_labels_mask"],
|
219 |
+
eid=features[self.name + "_eid"],
|
220 |
+
)
|
221 |
+
|
222 |
+
def _create_examples(self, lines, split):
|
223 |
+
pass
|
224 |
+
|
225 |
+
|
226 |
+
def tokenize_and_align(tokenizer, words, cased=False):
|
227 |
+
"""Splits up words into subword-level tokens."""
|
228 |
+
words = ["[CLS]"] + list(words) + ["[SEP]"]
|
229 |
+
basic_tokenizer = tokenizer.basic_tokenizer
|
230 |
+
tokenized_words = []
|
231 |
+
for word in words:
|
232 |
+
word = tokenization.convert_to_unicode(word)
|
233 |
+
word = basic_tokenizer._clean_text(word)
|
234 |
+
if word == "[CLS]" or word == "[SEP]":
|
235 |
+
word_toks = [word]
|
236 |
+
else:
|
237 |
+
if not cased:
|
238 |
+
word = word.lower()
|
239 |
+
word = basic_tokenizer._run_strip_accents(word)
|
240 |
+
word_toks = basic_tokenizer._run_split_on_punc(word)
|
241 |
+
tokenized_word = []
|
242 |
+
for word_tok in word_toks:
|
243 |
+
tokenized_word += tokenizer.wordpiece_tokenizer.tokenize(word_tok)
|
244 |
+
tokenized_words.append(tokenized_word)
|
245 |
+
assert len(tokenized_words) == len(words)
|
246 |
+
return tokenized_words
|
247 |
+
|
248 |
+
|
249 |
+
class Chunking(TaggingTask):
|
250 |
+
"""Text chunking."""
|
251 |
+
|
252 |
+
def __init__(self, config, tokenizer):
|
253 |
+
super(Chunking, self).__init__(config, "chunk", tokenizer, False)
|
arabert/araelectra/finetune/tagging/tagging_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Utilities for sequence tagging tasks."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
|
23 |
+
def get_span_labels(sentence_tags, inv_label_mapping=None):
|
24 |
+
"""Go from token-level labels to list of entities (start, end, class)."""
|
25 |
+
if inv_label_mapping:
|
26 |
+
sentence_tags = [inv_label_mapping[i] for i in sentence_tags]
|
27 |
+
span_labels = []
|
28 |
+
last = 'O'
|
29 |
+
start = -1
|
30 |
+
for i, tag in enumerate(sentence_tags):
|
31 |
+
pos, _ = (None, 'O') if tag == 'O' else tag.split('-')
|
32 |
+
if (pos == 'S' or pos == 'B' or tag == 'O') and last != 'O':
|
33 |
+
span_labels.append((start, i - 1, last.split('-')[-1]))
|
34 |
+
if pos == 'B' or pos == 'S' or last == 'O':
|
35 |
+
start = i
|
36 |
+
last = tag
|
37 |
+
if sentence_tags[-1] != 'O':
|
38 |
+
span_labels.append((start, len(sentence_tags) - 1,
|
39 |
+
sentence_tags[-1].split('-')[-1]))
|
40 |
+
return span_labels
|
41 |
+
|
42 |
+
|
43 |
+
def get_tags(span_labels, length, encoding):
|
44 |
+
"""Converts a list of entities to token-label labels based on the provided
|
45 |
+
encoding (e.g., BIOES).
|
46 |
+
"""
|
47 |
+
|
48 |
+
tags = ['O' for _ in range(length)]
|
49 |
+
for s, e, t in span_labels:
|
50 |
+
for i in range(s, e + 1):
|
51 |
+
tags[i] = 'I-' + t
|
52 |
+
if 'E' in encoding:
|
53 |
+
tags[e] = 'E-' + t
|
54 |
+
if 'B' in encoding:
|
55 |
+
tags[s] = 'B-' + t
|
56 |
+
if 'S' in encoding and s - e == 0:
|
57 |
+
tags[s] = 'S-' + t
|
58 |
+
return tags
|
arabert/araelectra/finetune/task.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Defines a supervised NLP task."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import abc
|
23 |
+
from typing import List, Tuple
|
24 |
+
|
25 |
+
import configure_finetuning
|
26 |
+
from finetune import feature_spec
|
27 |
+
from finetune import scorer
|
28 |
+
from model import modeling
|
29 |
+
|
30 |
+
|
31 |
+
class Example(object):
|
32 |
+
__metaclass__ = abc.ABCMeta
|
33 |
+
|
34 |
+
def __init__(self, task_name):
|
35 |
+
self.task_name = task_name
|
36 |
+
|
37 |
+
|
38 |
+
class Task(object):
|
39 |
+
"""Override this class to add a new fine-tuning task."""
|
40 |
+
|
41 |
+
__metaclass__ = abc.ABCMeta
|
42 |
+
|
43 |
+
def __init__(self, config: configure_finetuning.FinetuningConfig, name):
|
44 |
+
self.config = config
|
45 |
+
self.name = name
|
46 |
+
|
47 |
+
def get_test_splits(self):
|
48 |
+
return ["test"]
|
49 |
+
|
50 |
+
@abc.abstractmethod
|
51 |
+
def get_examples(self, split):
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abc.abstractmethod
|
55 |
+
def get_scorer(self) -> scorer.Scorer:
|
56 |
+
pass
|
57 |
+
|
58 |
+
@abc.abstractmethod
|
59 |
+
def get_feature_specs(self) -> List[feature_spec.FeatureSpec]:
|
60 |
+
pass
|
61 |
+
|
62 |
+
@abc.abstractmethod
|
63 |
+
def featurize(self, example: Example, is_training: bool,
|
64 |
+
log: bool=False):
|
65 |
+
pass
|
66 |
+
|
67 |
+
@abc.abstractmethod
|
68 |
+
def get_prediction_module(
|
69 |
+
self, bert_model: modeling.BertModel, features: dict, is_training: bool,
|
70 |
+
percent_done: float) -> Tuple:
|
71 |
+
pass
|
72 |
+
|
73 |
+
def __repr__(self):
|
74 |
+
return "Task(" + self.name + ")"
|
arabert/araelectra/finetune/task_builder.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Returns task instances given the task name."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import configure_finetuning
|
23 |
+
from finetune.classification import classification_tasks
|
24 |
+
from finetune.qa import qa_tasks
|
25 |
+
from finetune.tagging import tagging_tasks
|
26 |
+
from model import tokenization
|
27 |
+
|
28 |
+
|
29 |
+
def get_tasks(config: configure_finetuning.FinetuningConfig):
|
30 |
+
tokenizer = tokenization.FullTokenizer(vocab_file=config.vocab_file,
|
31 |
+
do_lower_case=config.do_lower_case)
|
32 |
+
return [get_task(config, task_name, tokenizer)
|
33 |
+
for task_name in config.task_names]
|
34 |
+
|
35 |
+
|
36 |
+
def get_task(config: configure_finetuning.FinetuningConfig, task_name,
|
37 |
+
tokenizer):
|
38 |
+
"""Get an instance of a task based on its name."""
|
39 |
+
if task_name == "cola":
|
40 |
+
return classification_tasks.CoLA(config, tokenizer)
|
41 |
+
elif task_name == "mrpc":
|
42 |
+
return classification_tasks.MRPC(config, tokenizer)
|
43 |
+
elif task_name == "mnli":
|
44 |
+
return classification_tasks.MNLI(config, tokenizer)
|
45 |
+
elif task_name == "sst":
|
46 |
+
return classification_tasks.SST(config, tokenizer)
|
47 |
+
elif task_name == "rte":
|
48 |
+
return classification_tasks.RTE(config, tokenizer)
|
49 |
+
elif task_name == "qnli":
|
50 |
+
return classification_tasks.QNLI(config, tokenizer)
|
51 |
+
elif task_name == "qqp":
|
52 |
+
return classification_tasks.QQP(config, tokenizer)
|
53 |
+
elif task_name == "sts":
|
54 |
+
return classification_tasks.STS(config, tokenizer)
|
55 |
+
elif task_name == "squad":
|
56 |
+
return qa_tasks.SQuAD(config, tokenizer)
|
57 |
+
elif task_name == "squadv1":
|
58 |
+
return qa_tasks.SQuADv1(config, tokenizer)
|
59 |
+
elif task_name == "newsqa":
|
60 |
+
return qa_tasks.NewsQA(config, tokenizer)
|
61 |
+
elif task_name == "naturalqs":
|
62 |
+
return qa_tasks.NaturalQuestions(config, tokenizer)
|
63 |
+
elif task_name == "triviaqa":
|
64 |
+
return qa_tasks.TriviaQA(config, tokenizer)
|
65 |
+
elif task_name == "searchqa":
|
66 |
+
return qa_tasks.SearchQA(config, tokenizer)
|
67 |
+
elif task_name == "chunk":
|
68 |
+
return tagging_tasks.Chunking(config, tokenizer)
|
69 |
+
else:
|
70 |
+
raise ValueError("Unknown task " + task_name)
|