Spaces:
Sleeping
Sleeping
Pradeep Kumar
commited on
Commit
•
c130734
1
Parent(s):
b64b72d
Upload 10 files
Browse files- export_tfhub.py +219 -0
- export_tfhub_lib.py +493 -0
- export_tfhub_lib_test.py +1080 -0
- squad_evaluate_v1_1.py +106 -0
- squad_evaluate_v2_0.py +249 -0
- tf1_bert_checkpoint_converter_lib.py +201 -0
- tf2_albert_encoder_checkpoint_converter.py +170 -0
- tf2_bert_encoder_checkpoint_converter.py +160 -0
- tokenization_test.py +156 -0
export_tfhub.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
|
16 |
+
|
17 |
+
This tool creates preprocessor and encoder SavedModels suitable for uploading
|
18 |
+
to https://tfhub.dev that implement the preprocessor and encoder APIs defined
|
19 |
+
at https://www.tensorflow.org/hub/common_saved_model_apis/text.
|
20 |
+
|
21 |
+
For a full usage guide, see
|
22 |
+
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md
|
23 |
+
|
24 |
+
Minimal usage examples:
|
25 |
+
|
26 |
+
1) Exporting an Encoder from checkpoint and config.
|
27 |
+
|
28 |
+
```
|
29 |
+
export_tfhub \
|
30 |
+
--encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
|
31 |
+
--model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
|
32 |
+
--vocab_file=${BERT_DIR:?}/vocab.txt \
|
33 |
+
--export_type=model \
|
34 |
+
--export_path=/tmp/bert_model
|
35 |
+
```
|
36 |
+
|
37 |
+
An --encoder_config_file can specify encoder types other than BERT.
|
38 |
+
For BERT, a --bert_config_file in the legacy JSON format can be passed instead.
|
39 |
+
|
40 |
+
Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
|
41 |
+
from the vocab_file path) capture how BertTokenizer was used in pre-training.
|
42 |
+
Use flag --sp_model_file instead if SentencepieceTokenizer was used.
|
43 |
+
|
44 |
+
Changing --export_type to model_with_mlm additionally creates an `.mlm`
|
45 |
+
subobject on the exported SavedModel that can be called to produce
|
46 |
+
the logits of the Masked Language Model task from pretraining.
|
47 |
+
The help string for flag --model_checkpoint_path explains the checkpoint
|
48 |
+
formats required for each --export_type.
|
49 |
+
|
50 |
+
|
51 |
+
2) Exporting a preprocessor SavedModel
|
52 |
+
|
53 |
+
```
|
54 |
+
export_tfhub \
|
55 |
+
--vocab_file ${BERT_DIR:?}/vocab.txt \
|
56 |
+
--export_type preprocessing --export_path /tmp/bert_preprocessing
|
57 |
+
```
|
58 |
+
|
59 |
+
Be sure to use flag values that match the encoder and how it has been
|
60 |
+
pre-trained (see above for --vocab_file vs --sp_model_file).
|
61 |
+
|
62 |
+
If your encoder has been trained with text preprocessing for which tfhub.dev
|
63 |
+
already has SavedModel, you could guide your users to reuse that one instead
|
64 |
+
of exporting and publishing your own.
|
65 |
+
|
66 |
+
TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
|
67 |
+
`--experimental_disable_assert_in_preprocessing`.
|
68 |
+
"""
|
69 |
+
|
70 |
+
from absl import app
|
71 |
+
from absl import flags
|
72 |
+
import gin
|
73 |
+
|
74 |
+
from official.legacy.bert import configs
|
75 |
+
from official.modeling import hyperparams
|
76 |
+
from official.nlp.configs import encoders
|
77 |
+
from official.nlp.tools import export_tfhub_lib
|
78 |
+
|
79 |
+
FLAGS = flags.FLAGS
|
80 |
+
|
81 |
+
flags.DEFINE_enum(
|
82 |
+
"export_type", "model",
|
83 |
+
["model", "model_with_mlm", "preprocessing"],
|
84 |
+
"The overall type of SavedModel to export. Flags "
|
85 |
+
"--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
|
86 |
+
"control which particular encoder model and preprocessing are exported.")
|
87 |
+
flags.DEFINE_string(
|
88 |
+
"export_path", None,
|
89 |
+
"Directory to which the SavedModel is written.")
|
90 |
+
flags.DEFINE_string(
|
91 |
+
"encoder_config_file", None,
|
92 |
+
"A yaml file representing `encoders.EncoderConfig` to define the encoder "
|
93 |
+
"(BERT or other). "
|
94 |
+
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
95 |
+
"Needed for --export_type model and model_with_mlm.")
|
96 |
+
flags.DEFINE_string(
|
97 |
+
"bert_config_file", None,
|
98 |
+
"A JSON file with a legacy BERT configuration to define the BERT encoder. "
|
99 |
+
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
100 |
+
"Needed for --export_type model and model_with_mlm.")
|
101 |
+
flags.DEFINE_bool(
|
102 |
+
"copy_pooler_dense_to_encoder", False,
|
103 |
+
"When the model is trained using `BertPretrainerV2`, the pool layer "
|
104 |
+
"of next sentence prediction task exists in `ClassificationHead` passed "
|
105 |
+
"to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
|
106 |
+
"to the encoder that is exported by this tool (as in classic BERT). "
|
107 |
+
"Using `BertPretrainerV2` and leaving this False exports an untrained "
|
108 |
+
"(randomly initialized) pooling layer, which some authors recommend for "
|
109 |
+
"subsequent fine-tuning,")
|
110 |
+
flags.DEFINE_string(
|
111 |
+
"model_checkpoint_path", None,
|
112 |
+
"File path to a pre-trained model checkpoint. "
|
113 |
+
"For --export_type model, this has to be an object-based (TF2) checkpoint "
|
114 |
+
"that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
|
115 |
+
"for the `encoder` defined by the config file."
|
116 |
+
"(Legacy checkpoints with `model=` instead of `encoder=` are also "
|
117 |
+
"supported for now.) "
|
118 |
+
"For --export_type model_with_mlm, it must be restorable to "
|
119 |
+
"`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
|
120 |
+
"(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
|
121 |
+
"accepted.)")
|
122 |
+
flags.DEFINE_string(
|
123 |
+
"vocab_file", None,
|
124 |
+
"For encoders trained on BertTokenzier input: "
|
125 |
+
"the vocabulary file that the encoder model was trained with. "
|
126 |
+
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
127 |
+
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
128 |
+
flags.DEFINE_string(
|
129 |
+
"sp_model_file", None,
|
130 |
+
"For encoders trained on SentencepieceTokenzier input: "
|
131 |
+
"the SentencePiece .model file that the encoder model was trained with. "
|
132 |
+
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
133 |
+
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
134 |
+
flags.DEFINE_bool(
|
135 |
+
"do_lower_case", None,
|
136 |
+
"Whether to lowercase before tokenization. "
|
137 |
+
"If left as None, and --vocab_file is set, do_lower_case will be enabled "
|
138 |
+
"if 'uncased' appears in the name of --vocab_file. "
|
139 |
+
"If left as None, and --sp_model_file set, do_lower_case defaults to true. "
|
140 |
+
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
141 |
+
flags.DEFINE_integer(
|
142 |
+
"default_seq_length", 128,
|
143 |
+
"The sequence length of preprocessing results from "
|
144 |
+
"top-level preprocess method. This is also the default "
|
145 |
+
"sequence length for the bert_pack_inputs subobject."
|
146 |
+
"Needed for --export_type preprocessing.")
|
147 |
+
flags.DEFINE_bool(
|
148 |
+
"tokenize_with_offsets", False, # TODO(b/181866850)
|
149 |
+
"Whether to export a .tokenize_with_offsets subobject for "
|
150 |
+
"--export_type preprocessing.")
|
151 |
+
flags.DEFINE_multi_string(
|
152 |
+
"gin_file", default=None,
|
153 |
+
help="List of paths to the config files.")
|
154 |
+
flags.DEFINE_multi_string(
|
155 |
+
"gin_params", default=None,
|
156 |
+
help="List of Gin bindings.")
|
157 |
+
flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use.
|
158 |
+
"experimental_disable_assert_in_preprocessing", False,
|
159 |
+
"Export a preprocessing model without tf.Assert ops. "
|
160 |
+
"Usually, that would be a bad idea, except TF2.4 has an issue with "
|
161 |
+
"Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
|
162 |
+
"and omitting the Assert ops lets SavedModels avoid the issue.")
|
163 |
+
|
164 |
+
|
165 |
+
def main(argv):
|
166 |
+
if len(argv) > 1:
|
167 |
+
raise app.UsageError("Too many command-line arguments.")
|
168 |
+
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
169 |
+
|
170 |
+
if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
|
171 |
+
raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
|
172 |
+
"can be specified, but got %s and %s." %
|
173 |
+
(FLAGS.vocab_file, FLAGS.sp_model_file))
|
174 |
+
do_lower_case = export_tfhub_lib.get_do_lower_case(
|
175 |
+
FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)
|
176 |
+
|
177 |
+
if FLAGS.export_type in ("model", "model_with_mlm"):
|
178 |
+
if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
|
179 |
+
raise ValueError("Exactly one of `bert_config_file` and "
|
180 |
+
"`encoder_config_file` can be specified, but got "
|
181 |
+
"%s and %s." %
|
182 |
+
(FLAGS.bert_config_file, FLAGS.encoder_config_file))
|
183 |
+
if FLAGS.bert_config_file:
|
184 |
+
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
185 |
+
encoder_config = None
|
186 |
+
else:
|
187 |
+
bert_config = None
|
188 |
+
encoder_config = encoders.EncoderConfig()
|
189 |
+
encoder_config = hyperparams.override_params_dict(
|
190 |
+
encoder_config, FLAGS.encoder_config_file, is_strict=True)
|
191 |
+
export_tfhub_lib.export_model(
|
192 |
+
FLAGS.export_path,
|
193 |
+
bert_config=bert_config,
|
194 |
+
encoder_config=encoder_config,
|
195 |
+
model_checkpoint_path=FLAGS.model_checkpoint_path,
|
196 |
+
vocab_file=FLAGS.vocab_file,
|
197 |
+
sp_model_file=FLAGS.sp_model_file,
|
198 |
+
do_lower_case=do_lower_case,
|
199 |
+
with_mlm=FLAGS.export_type == "model_with_mlm",
|
200 |
+
copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)
|
201 |
+
|
202 |
+
elif FLAGS.export_type == "preprocessing":
|
203 |
+
export_tfhub_lib.export_preprocessing(
|
204 |
+
FLAGS.export_path,
|
205 |
+
vocab_file=FLAGS.vocab_file,
|
206 |
+
sp_model_file=FLAGS.sp_model_file,
|
207 |
+
do_lower_case=do_lower_case,
|
208 |
+
default_seq_length=FLAGS.default_seq_length,
|
209 |
+
tokenize_with_offsets=FLAGS.tokenize_with_offsets,
|
210 |
+
experimental_disable_assert=
|
211 |
+
FLAGS.experimental_disable_assert_in_preprocessing)
|
212 |
+
|
213 |
+
else:
|
214 |
+
raise app.UsageError(
|
215 |
+
"Unknown value '%s' for flag --export_type" % FLAGS.export_type)
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
app.run(main)
|
export_tfhub_lib.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Library of components of export_tfhub.py. See docstring there for more."""
|
16 |
+
|
17 |
+
import contextlib
|
18 |
+
import hashlib
|
19 |
+
import os
|
20 |
+
import tempfile
|
21 |
+
|
22 |
+
from typing import Optional, Text, Tuple
|
23 |
+
|
24 |
+
# Import libraries
|
25 |
+
from absl import logging
|
26 |
+
import tensorflow as tf, tf_keras
|
27 |
+
# pylint: disable=g-direct-tensorflow-import TODO(b/175369555): Remove these.
|
28 |
+
from tensorflow.core.protobuf import saved_model_pb2
|
29 |
+
from tensorflow.python.ops import control_flow_assert
|
30 |
+
# pylint: enable=g-direct-tensorflow-import
|
31 |
+
from official.legacy.bert import configs
|
32 |
+
from official.modeling import tf_utils
|
33 |
+
from official.nlp.configs import encoders
|
34 |
+
from official.nlp.modeling import layers
|
35 |
+
from official.nlp.modeling import models
|
36 |
+
from official.nlp.modeling import networks
|
37 |
+
|
38 |
+
|
39 |
+
def get_bert_encoder(bert_config):
|
40 |
+
"""Returns a BertEncoder with dict outputs."""
|
41 |
+
bert_encoder = networks.BertEncoder(
|
42 |
+
vocab_size=bert_config.vocab_size,
|
43 |
+
hidden_size=bert_config.hidden_size,
|
44 |
+
num_layers=bert_config.num_hidden_layers,
|
45 |
+
num_attention_heads=bert_config.num_attention_heads,
|
46 |
+
intermediate_size=bert_config.intermediate_size,
|
47 |
+
activation=tf_utils.get_activation(bert_config.hidden_act),
|
48 |
+
dropout_rate=bert_config.hidden_dropout_prob,
|
49 |
+
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
|
50 |
+
max_sequence_length=bert_config.max_position_embeddings,
|
51 |
+
type_vocab_size=bert_config.type_vocab_size,
|
52 |
+
initializer=tf_keras.initializers.TruncatedNormal(
|
53 |
+
stddev=bert_config.initializer_range),
|
54 |
+
embedding_width=bert_config.embedding_size,
|
55 |
+
dict_outputs=True)
|
56 |
+
|
57 |
+
return bert_encoder
|
58 |
+
|
59 |
+
|
60 |
+
def get_do_lower_case(do_lower_case, vocab_file=None, sp_model_file=None):
|
61 |
+
"""Returns do_lower_case, replacing None by a guess from vocab file name."""
|
62 |
+
if do_lower_case is not None:
|
63 |
+
return do_lower_case
|
64 |
+
elif vocab_file:
|
65 |
+
do_lower_case = "uncased" in vocab_file
|
66 |
+
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
|
67 |
+
do_lower_case, vocab_file)
|
68 |
+
return do_lower_case
|
69 |
+
elif sp_model_file:
|
70 |
+
do_lower_case = True # All public ALBERTs (as of Oct 2020) do it.
|
71 |
+
logging.info("Defaulting to do_lower_case=%s for Sentencepiece tokenizer",
|
72 |
+
do_lower_case)
|
73 |
+
return do_lower_case
|
74 |
+
else:
|
75 |
+
raise ValueError("Must set vocab_file or sp_model_file.")
|
76 |
+
|
77 |
+
|
78 |
+
def _create_model(
|
79 |
+
*,
|
80 |
+
bert_config: Optional[configs.BertConfig] = None,
|
81 |
+
encoder_config: Optional[encoders.EncoderConfig] = None,
|
82 |
+
with_mlm: bool,
|
83 |
+
) -> Tuple[tf_keras.Model, tf_keras.Model]:
|
84 |
+
"""Creates the model to export and the model to restore the checkpoint.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
bert_config: A legacy `BertConfig` to create a `BertEncoder` object. Exactly
|
88 |
+
one of encoder_config and bert_config must be set.
|
89 |
+
encoder_config: An `EncoderConfig` to create an encoder of the configured
|
90 |
+
type (`BertEncoder` or other).
|
91 |
+
with_mlm: A bool to control the second component of the result. If True,
|
92 |
+
will create a `BertPretrainerV2` object; otherwise, will create a
|
93 |
+
`BertEncoder` object.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
A Tuple of (1) a Keras model that will be exported, (2) a `BertPretrainerV2`
|
97 |
+
object or `BertEncoder` object depending on the value of `with_mlm`
|
98 |
+
argument, which contains the first model and will be used for restoring
|
99 |
+
weights from the checkpoint.
|
100 |
+
"""
|
101 |
+
if (bert_config is not None) == (encoder_config is not None):
|
102 |
+
raise ValueError("Exactly one of `bert_config` and `encoder_config` "
|
103 |
+
"can be specified, but got %s and %s" %
|
104 |
+
(bert_config, encoder_config))
|
105 |
+
|
106 |
+
if bert_config is not None:
|
107 |
+
encoder = get_bert_encoder(bert_config)
|
108 |
+
else:
|
109 |
+
encoder = encoders.build_encoder(encoder_config)
|
110 |
+
|
111 |
+
# Convert from list of named inputs to dict of inputs keyed by name.
|
112 |
+
# Only the latter accepts a dict of inputs after restoring from SavedModel.
|
113 |
+
if isinstance(encoder.inputs, list) or isinstance(encoder.inputs, tuple):
|
114 |
+
encoder_inputs_dict = {x.name: x for x in encoder.inputs}
|
115 |
+
else:
|
116 |
+
# encoder.inputs by default is dict for BertEncoderV2.
|
117 |
+
encoder_inputs_dict = encoder.inputs
|
118 |
+
encoder_output_dict = encoder(encoder_inputs_dict)
|
119 |
+
# For interchangeability with other text representations,
|
120 |
+
# add "default" as an alias for BERT's whole-input reptesentations.
|
121 |
+
encoder_output_dict["default"] = encoder_output_dict["pooled_output"]
|
122 |
+
core_model = tf_keras.Model(
|
123 |
+
inputs=encoder_inputs_dict, outputs=encoder_output_dict)
|
124 |
+
|
125 |
+
if with_mlm:
|
126 |
+
if bert_config is not None:
|
127 |
+
hidden_act = bert_config.hidden_act
|
128 |
+
else:
|
129 |
+
assert encoder_config is not None
|
130 |
+
hidden_act = encoder_config.get().hidden_activation
|
131 |
+
|
132 |
+
pretrainer = models.BertPretrainerV2(
|
133 |
+
encoder_network=encoder,
|
134 |
+
mlm_activation=tf_utils.get_activation(hidden_act))
|
135 |
+
|
136 |
+
if isinstance(pretrainer.inputs, dict):
|
137 |
+
pretrainer_inputs_dict = pretrainer.inputs
|
138 |
+
else:
|
139 |
+
pretrainer_inputs_dict = {x.name: x for x in pretrainer.inputs}
|
140 |
+
pretrainer_output_dict = pretrainer(pretrainer_inputs_dict)
|
141 |
+
mlm_model = tf_keras.Model(
|
142 |
+
inputs=pretrainer_inputs_dict, outputs=pretrainer_output_dict)
|
143 |
+
# Set `_auto_track_sub_layers` to False, so that the additional weights
|
144 |
+
# from `mlm` sub-object will not be included in the core model.
|
145 |
+
# TODO(b/169210253): Use a public API when available.
|
146 |
+
core_model._auto_track_sub_layers = False # pylint: disable=protected-access
|
147 |
+
core_model.mlm = mlm_model
|
148 |
+
return core_model, pretrainer
|
149 |
+
else:
|
150 |
+
return core_model, encoder
|
151 |
+
|
152 |
+
|
153 |
+
def export_model(export_path: Text,
|
154 |
+
*,
|
155 |
+
bert_config: Optional[configs.BertConfig] = None,
|
156 |
+
encoder_config: Optional[encoders.EncoderConfig] = None,
|
157 |
+
model_checkpoint_path: Text,
|
158 |
+
with_mlm: bool,
|
159 |
+
copy_pooler_dense_to_encoder: bool = False,
|
160 |
+
vocab_file: Optional[Text] = None,
|
161 |
+
sp_model_file: Optional[Text] = None,
|
162 |
+
do_lower_case: Optional[bool] = None) -> None:
|
163 |
+
"""Exports an Encoder as SavedModel after restoring pre-trained weights.
|
164 |
+
|
165 |
+
The exported SavedModel implements a superset of the Encoder API for
|
166 |
+
Text embeddings with Transformer Encoders described at
|
167 |
+
https://www.tensorflow.org/hub/common_saved_model_apis/text.
|
168 |
+
|
169 |
+
In particular, the exported SavedModel can be used in the following way:
|
170 |
+
|
171 |
+
```
|
172 |
+
# Calls default interface (encoder only).
|
173 |
+
|
174 |
+
encoder = hub.load(...)
|
175 |
+
encoder_inputs = dict(
|
176 |
+
input_word_ids=..., # Shape [batch, seq_length], dtype=int32
|
177 |
+
input_mask=..., # Shape [batch, seq_length], dtype=int32
|
178 |
+
input_type_ids=..., # Shape [batch, seq_length], dtype=int32
|
179 |
+
)
|
180 |
+
encoder_outputs = encoder(encoder_inputs)
|
181 |
+
assert encoder_outputs.keys() == {
|
182 |
+
"pooled_output", # Shape [batch_size, width], dtype=float32
|
183 |
+
"default", # Alias for "pooled_output" (aligns with other models).
|
184 |
+
"sequence_output" # Shape [batch_size, seq_length, width], dtype=float32
|
185 |
+
"encoder_outputs", # List of Tensors with outputs of all transformer layers.
|
186 |
+
}
|
187 |
+
```
|
188 |
+
|
189 |
+
If `with_mlm` is True, the exported SavedModel can also be called in the
|
190 |
+
following way:
|
191 |
+
|
192 |
+
```
|
193 |
+
# Calls expanded interface that includes logits of the Masked Language Model.
|
194 |
+
mlm_inputs = dict(
|
195 |
+
input_word_ids=..., # Shape [batch, seq_length], dtype=int32
|
196 |
+
input_mask=..., # Shape [batch, seq_length], dtype=int32
|
197 |
+
input_type_ids=..., # Shape [batch, seq_length], dtype=int32
|
198 |
+
masked_lm_positions=..., # Shape [batch, num_predictions], dtype=int32
|
199 |
+
)
|
200 |
+
mlm_outputs = encoder.mlm(mlm_inputs)
|
201 |
+
assert mlm_outputs.keys() == {
|
202 |
+
"pooled_output", # Shape [batch, width], dtype=float32
|
203 |
+
"sequence_output", # Shape [batch, seq_length, width], dtype=float32
|
204 |
+
"encoder_outputs", # List of Tensors with outputs of all transformer layers.
|
205 |
+
"mlm_logits" # Shape [batch, num_predictions, vocab_size], dtype=float32
|
206 |
+
}
|
207 |
+
```
|
208 |
+
|
209 |
+
Args:
|
210 |
+
export_path: The SavedModel output directory.
|
211 |
+
bert_config: An optional `configs.BertConfig` object. Note: exactly one of
|
212 |
+
`bert_config` and following `encoder_config` must be specified.
|
213 |
+
encoder_config: An optional `encoders.EncoderConfig` object.
|
214 |
+
model_checkpoint_path: The path to the checkpoint.
|
215 |
+
with_mlm: Whether to export the additional mlm sub-object.
|
216 |
+
copy_pooler_dense_to_encoder: Whether to copy the pooler's dense layer used
|
217 |
+
in the next sentence prediction task to the encoder.
|
218 |
+
vocab_file: The path to the wordpiece vocab file, or None.
|
219 |
+
sp_model_file: The path to the sentencepiece model file, or None. Exactly
|
220 |
+
one of vocab_file and sp_model_file must be set.
|
221 |
+
do_lower_case: Whether to lower-case text before tokenization.
|
222 |
+
"""
|
223 |
+
if with_mlm:
|
224 |
+
core_model, pretrainer = _create_model(
|
225 |
+
bert_config=bert_config,
|
226 |
+
encoder_config=encoder_config,
|
227 |
+
with_mlm=with_mlm)
|
228 |
+
encoder = pretrainer.encoder_network
|
229 |
+
# It supports both the new pretrainer checkpoint produced by TF-NLP and
|
230 |
+
# the checkpoint converted from TF1 (original BERT, SmallBERTs).
|
231 |
+
checkpoint_items = pretrainer.checkpoint_items
|
232 |
+
checkpoint = tf.train.Checkpoint(**checkpoint_items)
|
233 |
+
else:
|
234 |
+
core_model, encoder = _create_model(
|
235 |
+
bert_config=bert_config,
|
236 |
+
encoder_config=encoder_config,
|
237 |
+
with_mlm=with_mlm)
|
238 |
+
checkpoint = tf.train.Checkpoint(
|
239 |
+
model=encoder, # Legacy checkpoints.
|
240 |
+
encoder=encoder)
|
241 |
+
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
|
242 |
+
|
243 |
+
if copy_pooler_dense_to_encoder:
|
244 |
+
logging.info("Copy pooler's dense layer to the encoder.")
|
245 |
+
pooler_checkpoint = tf.train.Checkpoint(
|
246 |
+
**{"next_sentence.pooler_dense": encoder.pooler_layer})
|
247 |
+
pooler_checkpoint.restore(
|
248 |
+
model_checkpoint_path).assert_existing_objects_matched()
|
249 |
+
|
250 |
+
# Before SavedModels for preprocessing appeared in Oct 2020, the encoders
|
251 |
+
# provided this information to let users do preprocessing themselves.
|
252 |
+
# We keep doing that for now. It helps users to upgrade incrementally.
|
253 |
+
# Moreover, it offers an escape hatch for advanced users who want the
|
254 |
+
# full vocab, not the high-level operations from the preprocessing model.
|
255 |
+
if vocab_file:
|
256 |
+
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
|
257 |
+
if do_lower_case is None:
|
258 |
+
raise ValueError("Must pass do_lower_case if passing vocab_file.")
|
259 |
+
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
|
260 |
+
elif sp_model_file:
|
261 |
+
# This was used by ALBERT, with implied values of do_lower_case=True
|
262 |
+
# and strip_diacritics=True.
|
263 |
+
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
|
264 |
+
else:
|
265 |
+
raise ValueError("Must set vocab_file or sp_model_file")
|
266 |
+
core_model.save(export_path, include_optimizer=False, save_format="tf")
|
267 |
+
|
268 |
+
|
269 |
+
class BertPackInputsSavedModelWrapper(tf.train.Checkpoint):
|
270 |
+
"""Wraps a BertPackInputs layer for export to SavedModel.
|
271 |
+
|
272 |
+
The wrapper object is suitable for use with `tf.saved_model.save()` and
|
273 |
+
`.load()`. The wrapper object is callable with inputs and outputs like the
|
274 |
+
BertPackInputs layer, but differs from saving an unwrapped Keras object:
|
275 |
+
|
276 |
+
- The inputs can be a list of 1 or 2 RaggedTensors of dtype int32 and
|
277 |
+
ragged rank 1 or 2. (In Keras, saving to a tf.function in a SavedModel
|
278 |
+
would fix the number of RaggedTensors and their ragged rank.)
|
279 |
+
- The call accepts an optional keyword argument `seq_length=` to override
|
280 |
+
the layer's .seq_length hyperparameter. (In Keras, a hyperparameter
|
281 |
+
could not be changed after saving to a tf.function in a SavedModel.)
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, bert_pack_inputs: layers.BertPackInputs):
|
285 |
+
super().__init__()
|
286 |
+
|
287 |
+
# Preserve the layer's configured seq_length as a default but make it
|
288 |
+
# overridable. Having this dynamically determined default argument
|
289 |
+
# requires self.__call__ to be defined in this indirect way.
|
290 |
+
default_seq_length = bert_pack_inputs.seq_length
|
291 |
+
|
292 |
+
@tf.function(autograph=False)
|
293 |
+
def call(inputs, seq_length=default_seq_length):
|
294 |
+
return layers.BertPackInputs.bert_pack_inputs(
|
295 |
+
inputs,
|
296 |
+
seq_length=seq_length,
|
297 |
+
start_of_sequence_id=bert_pack_inputs.start_of_sequence_id,
|
298 |
+
end_of_segment_id=bert_pack_inputs.end_of_segment_id,
|
299 |
+
padding_id=bert_pack_inputs.padding_id)
|
300 |
+
|
301 |
+
self.__call__ = call
|
302 |
+
|
303 |
+
for ragged_rank in range(1, 3):
|
304 |
+
for num_segments in range(1, 3):
|
305 |
+
_ = self.__call__.get_concrete_function([
|
306 |
+
tf.RaggedTensorSpec([None] * (ragged_rank + 1), dtype=tf.int32)
|
307 |
+
for _ in range(num_segments)
|
308 |
+
],
|
309 |
+
seq_length=tf.TensorSpec(
|
310 |
+
[], tf.int32))
|
311 |
+
|
312 |
+
|
313 |
+
def create_preprocessing(*,
|
314 |
+
vocab_file: Optional[str] = None,
|
315 |
+
sp_model_file: Optional[str] = None,
|
316 |
+
do_lower_case: bool,
|
317 |
+
tokenize_with_offsets: bool,
|
318 |
+
default_seq_length: int) -> tf_keras.Model:
|
319 |
+
"""Returns a preprocessing Model for given tokenization parameters.
|
320 |
+
|
321 |
+
This function builds a Keras Model with attached subobjects suitable for
|
322 |
+
saving to a SavedModel. The resulting SavedModel implements the Preprocessor
|
323 |
+
API for Text embeddings with Transformer Encoders described at
|
324 |
+
https://www.tensorflow.org/hub/common_saved_model_apis/text.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
vocab_file: The path to the wordpiece vocab file, or None.
|
328 |
+
sp_model_file: The path to the sentencepiece model file, or None. Exactly
|
329 |
+
one of vocab_file and sp_model_file must be set. This determines the type
|
330 |
+
of tokenzer that is used.
|
331 |
+
do_lower_case: Whether to do lower case.
|
332 |
+
tokenize_with_offsets: Whether to include the .tokenize_with_offsets
|
333 |
+
subobject.
|
334 |
+
default_seq_length: The sequence length of preprocessing results from root
|
335 |
+
callable. This is also the default sequence length for the
|
336 |
+
bert_pack_inputs subobject.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
A tf_keras.Model object with several attached subobjects, suitable for
|
340 |
+
saving as a preprocessing SavedModel.
|
341 |
+
"""
|
342 |
+
# Select tokenizer.
|
343 |
+
if bool(vocab_file) == bool(sp_model_file):
|
344 |
+
raise ValueError("Must set exactly one of vocab_file, sp_model_file")
|
345 |
+
if vocab_file:
|
346 |
+
tokenize = layers.BertTokenizer(
|
347 |
+
vocab_file=vocab_file,
|
348 |
+
lower_case=do_lower_case,
|
349 |
+
tokenize_with_offsets=tokenize_with_offsets)
|
350 |
+
else:
|
351 |
+
tokenize = layers.SentencepieceTokenizer(
|
352 |
+
model_file_path=sp_model_file,
|
353 |
+
lower_case=do_lower_case,
|
354 |
+
strip_diacritics=True, # Strip diacritics to follow ALBERT model.
|
355 |
+
tokenize_with_offsets=tokenize_with_offsets)
|
356 |
+
|
357 |
+
# The root object of the preprocessing model can be called to do
|
358 |
+
# one-shot preprocessing for users with single-sentence inputs.
|
359 |
+
sentences = tf_keras.layers.Input(shape=(), dtype=tf.string, name="sentences")
|
360 |
+
if tokenize_with_offsets:
|
361 |
+
tokens, start_offsets, limit_offsets = tokenize(sentences)
|
362 |
+
else:
|
363 |
+
tokens = tokenize(sentences)
|
364 |
+
pack = layers.BertPackInputs(
|
365 |
+
seq_length=default_seq_length,
|
366 |
+
special_tokens_dict=tokenize.get_special_tokens_dict())
|
367 |
+
model_inputs = pack(tokens)
|
368 |
+
preprocessing = tf_keras.Model(sentences, model_inputs)
|
369 |
+
|
370 |
+
# Individual steps of preprocessing are made available as named subobjects
|
371 |
+
# to enable more general preprocessing. For saving, they need to be Models
|
372 |
+
# in their own right.
|
373 |
+
preprocessing.tokenize = tf_keras.Model(sentences, tokens)
|
374 |
+
# Provide an equivalent to tokenize.get_special_tokens_dict().
|
375 |
+
preprocessing.tokenize.get_special_tokens_dict = tf.train.Checkpoint()
|
376 |
+
preprocessing.tokenize.get_special_tokens_dict.__call__ = tf.function(
|
377 |
+
lambda: tokenize.get_special_tokens_dict(), # pylint: disable=[unnecessary-lambda]
|
378 |
+
input_signature=[])
|
379 |
+
if tokenize_with_offsets:
|
380 |
+
preprocessing.tokenize_with_offsets = tf_keras.Model(
|
381 |
+
sentences, [tokens, start_offsets, limit_offsets])
|
382 |
+
preprocessing.tokenize_with_offsets.get_special_tokens_dict = (
|
383 |
+
preprocessing.tokenize.get_special_tokens_dict)
|
384 |
+
# Conceptually, this should be
|
385 |
+
# preprocessing.bert_pack_inputs = tf_keras.Model(tokens, model_inputs)
|
386 |
+
# but technicalities require us to use a wrapper (see comments there).
|
387 |
+
# In particular, seq_length can be overridden when calling this.
|
388 |
+
preprocessing.bert_pack_inputs = BertPackInputsSavedModelWrapper(pack)
|
389 |
+
|
390 |
+
return preprocessing
|
391 |
+
|
392 |
+
|
393 |
+
def _move_to_tmpdir(file_path: Optional[Text], tmpdir: Text) -> Optional[Text]:
|
394 |
+
"""Returns new path with same basename and hash of original path."""
|
395 |
+
if file_path is None:
|
396 |
+
return None
|
397 |
+
olddir, filename = os.path.split(file_path)
|
398 |
+
hasher = hashlib.sha1()
|
399 |
+
hasher.update(olddir.encode("utf-8"))
|
400 |
+
target_dir = os.path.join(tmpdir, hasher.hexdigest())
|
401 |
+
target_file = os.path.join(target_dir, filename)
|
402 |
+
tf.io.gfile.mkdir(target_dir)
|
403 |
+
tf.io.gfile.copy(file_path, target_file)
|
404 |
+
return target_file
|
405 |
+
|
406 |
+
|
407 |
+
def export_preprocessing(export_path: Text,
|
408 |
+
*,
|
409 |
+
vocab_file: Optional[Text] = None,
|
410 |
+
sp_model_file: Optional[Text] = None,
|
411 |
+
do_lower_case: bool,
|
412 |
+
tokenize_with_offsets: bool,
|
413 |
+
default_seq_length: int,
|
414 |
+
experimental_disable_assert: bool = False) -> None:
|
415 |
+
"""Exports preprocessing to a SavedModel for TF Hub."""
|
416 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
417 |
+
# TODO(b/175369555): Remove experimental_disable_assert and its use.
|
418 |
+
with _maybe_disable_assert(experimental_disable_assert):
|
419 |
+
preprocessing = create_preprocessing(
|
420 |
+
vocab_file=_move_to_tmpdir(vocab_file, tmpdir),
|
421 |
+
sp_model_file=_move_to_tmpdir(sp_model_file, tmpdir),
|
422 |
+
do_lower_case=do_lower_case,
|
423 |
+
tokenize_with_offsets=tokenize_with_offsets,
|
424 |
+
default_seq_length=default_seq_length)
|
425 |
+
preprocessing.save(export_path, include_optimizer=False, save_format="tf")
|
426 |
+
if experimental_disable_assert:
|
427 |
+
_check_no_assert(export_path)
|
428 |
+
# It helps the unit test to prevent stray copies of the vocab file.
|
429 |
+
if tf.io.gfile.exists(tmpdir):
|
430 |
+
raise IOError("Failed to clean up TemporaryDirectory")
|
431 |
+
|
432 |
+
|
433 |
+
# TODO(b/175369555): Remove all workarounds for this bug of TensorFlow 2.4
|
434 |
+
# when this bug is no longer a concern for publishing new models.
|
435 |
+
# TensorFlow 2.4 has a placement issue with Assert ops in tf.functions called
|
436 |
+
# from Dataset.map() on a TPU worker. They end up on the TPU coordinator,
|
437 |
+
# and invoking them from the TPU worker is either inefficient (when possible)
|
438 |
+
# or impossible (notably when using "headless" TPU workers on Cloud that do not
|
439 |
+
# have a channel to the coordinator). The bug has been fixed in time for TF 2.5.
|
440 |
+
# To work around this, the following code avoids Assert ops in the exported
|
441 |
+
# SavedModels. It monkey-patches calls to tf.Assert from inside TensorFlow and
|
442 |
+
# replaces them by a no-op while building the exported model. This is fragile,
|
443 |
+
# so _check_no_assert() validates the result. The resulting model should be fine
|
444 |
+
# to read on future versions of TF, even if this workaround at export time
|
445 |
+
# may break eventually. (Failing unit tests will tell.)
|
446 |
+
|
447 |
+
|
448 |
+
def _dont_assert(condition, data, summarize=None, name="Assert"):
|
449 |
+
"""The no-op version of tf.Assert installed by _maybe_disable_assert."""
|
450 |
+
del condition, data, summarize # Unused.
|
451 |
+
if tf.executing_eagerly():
|
452 |
+
return
|
453 |
+
with tf.name_scope(name):
|
454 |
+
return tf.no_op(name="dont_assert")
|
455 |
+
|
456 |
+
|
457 |
+
@contextlib.contextmanager
|
458 |
+
def _maybe_disable_assert(disable_assert):
|
459 |
+
"""Scoped monkey patch of control_flow_assert.Assert to a no-op."""
|
460 |
+
if not disable_assert:
|
461 |
+
yield
|
462 |
+
return
|
463 |
+
|
464 |
+
original_assert = control_flow_assert.Assert
|
465 |
+
control_flow_assert.Assert = _dont_assert
|
466 |
+
yield
|
467 |
+
control_flow_assert.Assert = original_assert
|
468 |
+
|
469 |
+
|
470 |
+
def _check_no_assert(saved_model_path):
|
471 |
+
"""Raises AssertionError if SavedModel contains Assert ops."""
|
472 |
+
saved_model_filename = os.path.join(saved_model_path, "saved_model.pb")
|
473 |
+
with tf.io.gfile.GFile(saved_model_filename, "rb") as f:
|
474 |
+
saved_model = saved_model_pb2.SavedModel.FromString(f.read())
|
475 |
+
|
476 |
+
assert_nodes = []
|
477 |
+
graph_def = saved_model.meta_graphs[0].graph_def
|
478 |
+
assert_nodes += [
|
479 |
+
"node '{}' in global graph".format(n.name)
|
480 |
+
for n in graph_def.node
|
481 |
+
if n.op == "Assert"
|
482 |
+
]
|
483 |
+
for fdef in graph_def.library.function:
|
484 |
+
assert_nodes += [
|
485 |
+
"node '{}' in function '{}'".format(n.name, fdef.signature.name)
|
486 |
+
for n in fdef.node_def
|
487 |
+
if n.op == "Assert"
|
488 |
+
]
|
489 |
+
if assert_nodes:
|
490 |
+
raise AssertionError(
|
491 |
+
"Internal tool error: "
|
492 |
+
"failed to suppress {} Assert ops in SavedModel:\n{}".format(
|
493 |
+
len(assert_nodes), "\n".join(assert_nodes[:10])))
|
export_tfhub_lib_test.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests export_tfhub_lib."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
from absl.testing import parameterized
|
21 |
+
import numpy as np
|
22 |
+
import tensorflow as tf, tf_keras
|
23 |
+
from tensorflow import estimator as tf_estimator
|
24 |
+
import tensorflow_hub as hub
|
25 |
+
import tensorflow_text as text
|
26 |
+
|
27 |
+
from sentencepiece import SentencePieceTrainer
|
28 |
+
from official.legacy.bert import configs
|
29 |
+
from official.modeling import tf_utils
|
30 |
+
from official.nlp.configs import encoders
|
31 |
+
from official.nlp.modeling import layers
|
32 |
+
from official.nlp.modeling import models
|
33 |
+
from official.nlp.tools import export_tfhub_lib
|
34 |
+
|
35 |
+
|
36 |
+
def _get_bert_config_or_encoder_config(use_bert_config,
|
37 |
+
hidden_size,
|
38 |
+
num_hidden_layers,
|
39 |
+
encoder_type="albert",
|
40 |
+
vocab_size=100):
|
41 |
+
"""Generates config args for export_tfhub_lib._create_model().
|
42 |
+
|
43 |
+
Args:
|
44 |
+
use_bert_config: bool. If True, returns legacy BertConfig.
|
45 |
+
hidden_size: int.
|
46 |
+
num_hidden_layers: int.
|
47 |
+
encoder_type: str. Can be ['albert', 'bert', 'bert_v2']. If use_bert_config
|
48 |
+
== True, then model_type is not used.
|
49 |
+
vocab_size: int.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
bert_config, encoder_config. Only one is not None. If
|
53 |
+
`use_bert_config` == True, the first config is valid. Otherwise
|
54 |
+
`bert_config` == None.
|
55 |
+
"""
|
56 |
+
if use_bert_config:
|
57 |
+
bert_config = configs.BertConfig(
|
58 |
+
vocab_size=vocab_size,
|
59 |
+
hidden_size=hidden_size,
|
60 |
+
intermediate_size=32,
|
61 |
+
max_position_embeddings=128,
|
62 |
+
num_attention_heads=2,
|
63 |
+
num_hidden_layers=num_hidden_layers)
|
64 |
+
encoder_config = None
|
65 |
+
else:
|
66 |
+
bert_config = None
|
67 |
+
if encoder_type == "albert":
|
68 |
+
encoder_config = encoders.EncoderConfig(
|
69 |
+
type="albert",
|
70 |
+
albert=encoders.AlbertEncoderConfig(
|
71 |
+
vocab_size=vocab_size,
|
72 |
+
embedding_width=16,
|
73 |
+
hidden_size=hidden_size,
|
74 |
+
intermediate_size=32,
|
75 |
+
max_position_embeddings=128,
|
76 |
+
num_attention_heads=2,
|
77 |
+
num_layers=num_hidden_layers,
|
78 |
+
dropout_rate=0.1))
|
79 |
+
else:
|
80 |
+
# encoder_type can be 'bert' or 'bert_v2'.
|
81 |
+
model_config = encoders.BertEncoderConfig(
|
82 |
+
vocab_size=vocab_size,
|
83 |
+
embedding_size=16,
|
84 |
+
hidden_size=hidden_size,
|
85 |
+
intermediate_size=32,
|
86 |
+
max_position_embeddings=128,
|
87 |
+
num_attention_heads=2,
|
88 |
+
num_layers=num_hidden_layers,
|
89 |
+
dropout_rate=0.1)
|
90 |
+
kwargs = {"type": encoder_type, encoder_type: model_config}
|
91 |
+
encoder_config = encoders.EncoderConfig(**kwargs)
|
92 |
+
|
93 |
+
return bert_config, encoder_config
|
94 |
+
|
95 |
+
|
96 |
+
def _get_vocab_or_sp_model_dummy(temp_dir, use_sp_model):
|
97 |
+
"""Returns tokenizer asset args for export_tfhub_lib.export_model()."""
|
98 |
+
dummy_file = os.path.join(temp_dir, "dummy_file.txt")
|
99 |
+
with tf.io.gfile.GFile(dummy_file, "w") as f:
|
100 |
+
f.write("dummy content")
|
101 |
+
if use_sp_model:
|
102 |
+
vocab_file, sp_model_file = None, dummy_file
|
103 |
+
else:
|
104 |
+
vocab_file, sp_model_file = dummy_file, None
|
105 |
+
return vocab_file, sp_model_file
|
106 |
+
|
107 |
+
|
108 |
+
def _read_asset(asset: tf.saved_model.Asset):
|
109 |
+
return tf.io.gfile.GFile(asset.asset_path.numpy()).read()
|
110 |
+
|
111 |
+
|
112 |
+
def _find_lambda_layers(layer):
|
113 |
+
"""Returns list of all Lambda layers in a Keras model."""
|
114 |
+
if isinstance(layer, tf_keras.layers.Lambda):
|
115 |
+
return [layer]
|
116 |
+
elif hasattr(layer, "layers"): # It's nested, like a Model.
|
117 |
+
result = []
|
118 |
+
for l in layer.layers:
|
119 |
+
result += _find_lambda_layers(l)
|
120 |
+
return result
|
121 |
+
else:
|
122 |
+
return []
|
123 |
+
|
124 |
+
|
125 |
+
class ExportModelTest(tf.test.TestCase, parameterized.TestCase):
|
126 |
+
"""Tests exporting a Transformer Encoder model as a SavedModel.
|
127 |
+
|
128 |
+
This covers export from an Encoder checkpoint to a SavedModel without
|
129 |
+
the .mlm subobject. This is no longer preferred, but still useful
|
130 |
+
for models like Electra that are trained without the MLM task.
|
131 |
+
|
132 |
+
The export code is generic. This test focuses on two main cases
|
133 |
+
(the most important ones in practice when this was written in 2020):
|
134 |
+
- BERT built from a legacy BertConfig, for use with BertTokenizer.
|
135 |
+
- ALBERT built from an EncoderConfig (as a representative of all other
|
136 |
+
choices beyond BERT, for use with SentencepieceTokenizer (the one
|
137 |
+
alternative to BertTokenizer).
|
138 |
+
"""
|
139 |
+
|
140 |
+
@parameterized.named_parameters(
|
141 |
+
("Bert_Legacy", True, None), ("Albert", False, "albert"),
|
142 |
+
("BertEncoder", False, "bert"), ("BertEncoderV2", False, "bert_v2"))
|
143 |
+
def test_export_model(self, use_bert, encoder_type):
|
144 |
+
# Create the encoder and export it.
|
145 |
+
hidden_size = 16
|
146 |
+
num_hidden_layers = 1
|
147 |
+
bert_config, encoder_config = _get_bert_config_or_encoder_config(
|
148 |
+
use_bert,
|
149 |
+
hidden_size=hidden_size,
|
150 |
+
num_hidden_layers=num_hidden_layers,
|
151 |
+
encoder_type=encoder_type)
|
152 |
+
bert_model, encoder = export_tfhub_lib._create_model(
|
153 |
+
bert_config=bert_config, encoder_config=encoder_config, with_mlm=False)
|
154 |
+
self.assertEmpty(
|
155 |
+
_find_lambda_layers(bert_model),
|
156 |
+
"Lambda layers are non-portable since they serialize Python bytecode.")
|
157 |
+
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
158 |
+
checkpoint = tf.train.Checkpoint(encoder=encoder)
|
159 |
+
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
160 |
+
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
|
161 |
+
|
162 |
+
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
|
163 |
+
self.get_temp_dir(), use_sp_model=not use_bert)
|
164 |
+
export_path = os.path.join(self.get_temp_dir(), "hub")
|
165 |
+
export_tfhub_lib.export_model(
|
166 |
+
export_path=export_path,
|
167 |
+
bert_config=bert_config,
|
168 |
+
encoder_config=encoder_config,
|
169 |
+
model_checkpoint_path=model_checkpoint_path,
|
170 |
+
with_mlm=False,
|
171 |
+
vocab_file=vocab_file,
|
172 |
+
sp_model_file=sp_model_file,
|
173 |
+
do_lower_case=True)
|
174 |
+
|
175 |
+
# Restore the exported model.
|
176 |
+
hub_layer = hub.KerasLayer(export_path, trainable=True)
|
177 |
+
|
178 |
+
# Check legacy tokenization data.
|
179 |
+
if use_bert:
|
180 |
+
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
|
181 |
+
self.assertEqual("dummy content",
|
182 |
+
_read_asset(hub_layer.resolved_object.vocab_file))
|
183 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "sp_model_file"))
|
184 |
+
else:
|
185 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "do_lower_case"))
|
186 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
|
187 |
+
self.assertEqual("dummy content",
|
188 |
+
_read_asset(hub_layer.resolved_object.sp_model_file))
|
189 |
+
|
190 |
+
# Check restored weights.
|
191 |
+
self.assertEqual(
|
192 |
+
len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
|
193 |
+
for source_weight, hub_weight in zip(bert_model.trainable_weights,
|
194 |
+
hub_layer.trainable_weights):
|
195 |
+
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
|
196 |
+
|
197 |
+
# Check computation.
|
198 |
+
seq_length = 10
|
199 |
+
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
|
200 |
+
input_dict = dict(
|
201 |
+
input_word_ids=dummy_ids,
|
202 |
+
input_mask=dummy_ids,
|
203 |
+
input_type_ids=dummy_ids)
|
204 |
+
hub_output = hub_layer(input_dict)
|
205 |
+
source_output = bert_model(input_dict)
|
206 |
+
encoder_output = encoder(input_dict)
|
207 |
+
self.assertEqual(hub_output["pooled_output"].shape, (2, hidden_size))
|
208 |
+
self.assertEqual(hub_output["sequence_output"].shape,
|
209 |
+
(2, seq_length, hidden_size))
|
210 |
+
self.assertLen(hub_output["encoder_outputs"], num_hidden_layers)
|
211 |
+
|
212 |
+
for key in ("pooled_output", "sequence_output", "encoder_outputs"):
|
213 |
+
self.assertAllClose(source_output[key], hub_output[key])
|
214 |
+
self.assertAllClose(source_output[key], encoder_output[key])
|
215 |
+
|
216 |
+
# The "default" output of BERT as a text representation is pooled_output.
|
217 |
+
self.assertAllClose(hub_output["pooled_output"], hub_output["default"])
|
218 |
+
|
219 |
+
# Test that training=True makes a difference (activates dropout).
|
220 |
+
def _dropout_mean_stddev(training, num_runs=20):
|
221 |
+
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
|
222 |
+
input_dict = dict(
|
223 |
+
input_word_ids=input_ids,
|
224 |
+
input_mask=np.ones_like(input_ids),
|
225 |
+
input_type_ids=np.zeros_like(input_ids))
|
226 |
+
outputs = np.concatenate([
|
227 |
+
hub_layer(input_dict, training=training)["pooled_output"]
|
228 |
+
for _ in range(num_runs)
|
229 |
+
])
|
230 |
+
return np.mean(np.std(outputs, axis=0))
|
231 |
+
|
232 |
+
self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
|
233 |
+
self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
|
234 |
+
|
235 |
+
# Test propagation of seq_length in shape inference.
|
236 |
+
input_word_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
237 |
+
input_mask = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
238 |
+
input_type_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
239 |
+
input_dict = dict(
|
240 |
+
input_word_ids=input_word_ids,
|
241 |
+
input_mask=input_mask,
|
242 |
+
input_type_ids=input_type_ids)
|
243 |
+
output_dict = hub_layer(input_dict)
|
244 |
+
pooled_output = output_dict["pooled_output"]
|
245 |
+
sequence_output = output_dict["sequence_output"]
|
246 |
+
encoder_outputs = output_dict["encoder_outputs"]
|
247 |
+
|
248 |
+
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
|
249 |
+
self.assertEqual(sequence_output.shape.as_list(),
|
250 |
+
[None, seq_length, hidden_size])
|
251 |
+
self.assertLen(encoder_outputs, num_hidden_layers)
|
252 |
+
|
253 |
+
|
254 |
+
class ExportModelWithMLMTest(tf.test.TestCase, parameterized.TestCase):
|
255 |
+
"""Tests exporting a Transformer Encoder model as a SavedModel.
|
256 |
+
|
257 |
+
This covers export from a Pretrainer checkpoint to a SavedModel including
|
258 |
+
the .mlm subobject, which is the preferred way since 2020.
|
259 |
+
|
260 |
+
The export code is generic. This test focuses on two main cases
|
261 |
+
(the most important ones in practice when this was written in 2020):
|
262 |
+
- BERT built from a legacy BertConfig, for use with BertTokenizer.
|
263 |
+
- ALBERT built from an EncoderConfig (as a representative of all other
|
264 |
+
choices beyond BERT, for use with SentencepieceTokenizer (the one
|
265 |
+
alternative to BertTokenizer).
|
266 |
+
"""
|
267 |
+
|
268 |
+
def test_copy_pooler_dense_to_encoder(self):
|
269 |
+
encoder_config = encoders.EncoderConfig(
|
270 |
+
type="bert",
|
271 |
+
bert=encoders.BertEncoderConfig(
|
272 |
+
hidden_size=24, intermediate_size=48, num_layers=2))
|
273 |
+
cls_heads = [
|
274 |
+
layers.ClassificationHead(
|
275 |
+
inner_dim=24, num_classes=2, name="next_sentence")
|
276 |
+
]
|
277 |
+
encoder = encoders.build_encoder(encoder_config)
|
278 |
+
pretrainer = models.BertPretrainerV2(
|
279 |
+
encoder_network=encoder,
|
280 |
+
classification_heads=cls_heads,
|
281 |
+
mlm_activation=tf_utils.get_activation(
|
282 |
+
encoder_config.get().hidden_activation))
|
283 |
+
# Makes sure the pretrainer variables are created.
|
284 |
+
_ = pretrainer(pretrainer.inputs)
|
285 |
+
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
|
286 |
+
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
287 |
+
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
288 |
+
|
289 |
+
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
|
290 |
+
self.get_temp_dir(), use_sp_model=True)
|
291 |
+
export_path = os.path.join(self.get_temp_dir(), "hub")
|
292 |
+
export_tfhub_lib.export_model(
|
293 |
+
export_path=export_path,
|
294 |
+
encoder_config=encoder_config,
|
295 |
+
model_checkpoint_path=tf.train.latest_checkpoint(model_checkpoint_dir),
|
296 |
+
with_mlm=True,
|
297 |
+
copy_pooler_dense_to_encoder=True,
|
298 |
+
vocab_file=vocab_file,
|
299 |
+
sp_model_file=sp_model_file,
|
300 |
+
do_lower_case=True)
|
301 |
+
# Restores a hub KerasLayer.
|
302 |
+
hub_layer = hub.KerasLayer(export_path, trainable=True)
|
303 |
+
dummy_ids = np.zeros((2, 10), dtype=np.int32)
|
304 |
+
input_dict = dict(
|
305 |
+
input_word_ids=dummy_ids,
|
306 |
+
input_mask=dummy_ids,
|
307 |
+
input_type_ids=dummy_ids)
|
308 |
+
hub_pooled_output = hub_layer(input_dict)["pooled_output"]
|
309 |
+
encoder_outputs = encoder(input_dict)
|
310 |
+
# Verify that hub_layer's pooled_output is the same as the output of next
|
311 |
+
# sentence prediction's dense layer.
|
312 |
+
pretrained_pooled_output = cls_heads[0].dense(
|
313 |
+
(encoder_outputs["sequence_output"][:, 0, :]))
|
314 |
+
self.assertAllClose(hub_pooled_output, pretrained_pooled_output)
|
315 |
+
# But the pooled_output between encoder and hub_layer are not the same.
|
316 |
+
encoder_pooled_output = encoder_outputs["pooled_output"]
|
317 |
+
self.assertNotAllClose(hub_pooled_output, encoder_pooled_output)
|
318 |
+
|
319 |
+
@parameterized.named_parameters(
|
320 |
+
("Bert", True),
|
321 |
+
("Albert", False),
|
322 |
+
)
|
323 |
+
def test_export_model_with_mlm(self, use_bert):
|
324 |
+
# Create the encoder and export it.
|
325 |
+
hidden_size = 16
|
326 |
+
num_hidden_layers = 2
|
327 |
+
bert_config, encoder_config = _get_bert_config_or_encoder_config(
|
328 |
+
use_bert, hidden_size, num_hidden_layers)
|
329 |
+
bert_model, pretrainer = export_tfhub_lib._create_model(
|
330 |
+
bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
|
331 |
+
self.assertEmpty(
|
332 |
+
_find_lambda_layers(bert_model),
|
333 |
+
"Lambda layers are non-portable since they serialize Python bytecode.")
|
334 |
+
bert_model_with_mlm = bert_model.mlm
|
335 |
+
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
336 |
+
|
337 |
+
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
|
338 |
+
|
339 |
+
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
340 |
+
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
|
341 |
+
|
342 |
+
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(
|
343 |
+
self.get_temp_dir(), use_sp_model=not use_bert)
|
344 |
+
export_path = os.path.join(self.get_temp_dir(), "hub")
|
345 |
+
export_tfhub_lib.export_model(
|
346 |
+
export_path=export_path,
|
347 |
+
bert_config=bert_config,
|
348 |
+
encoder_config=encoder_config,
|
349 |
+
model_checkpoint_path=model_checkpoint_path,
|
350 |
+
with_mlm=True,
|
351 |
+
vocab_file=vocab_file,
|
352 |
+
sp_model_file=sp_model_file,
|
353 |
+
do_lower_case=True)
|
354 |
+
|
355 |
+
# Restore the exported model.
|
356 |
+
hub_layer = hub.KerasLayer(export_path, trainable=True)
|
357 |
+
|
358 |
+
# Check legacy tokenization data.
|
359 |
+
if use_bert:
|
360 |
+
self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
|
361 |
+
self.assertEqual("dummy content",
|
362 |
+
_read_asset(hub_layer.resolved_object.vocab_file))
|
363 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "sp_model_file"))
|
364 |
+
else:
|
365 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "do_lower_case"))
|
366 |
+
self.assertFalse(hasattr(hub_layer.resolved_object, "vocab_file"))
|
367 |
+
self.assertEqual("dummy content",
|
368 |
+
_read_asset(hub_layer.resolved_object.sp_model_file))
|
369 |
+
|
370 |
+
# Check restored weights.
|
371 |
+
# Note that we set `_auto_track_sub_layers` to False when exporting the
|
372 |
+
# SavedModel, so hub_layer has the same number of weights as bert_model;
|
373 |
+
# otherwise, hub_layer will have extra weights from its `mlm` subobject.
|
374 |
+
self.assertEqual(
|
375 |
+
len(bert_model.trainable_weights), len(hub_layer.trainable_weights))
|
376 |
+
for source_weight, hub_weight in zip(bert_model.trainable_weights,
|
377 |
+
hub_layer.trainable_weights):
|
378 |
+
self.assertAllClose(source_weight, hub_weight)
|
379 |
+
|
380 |
+
# Check computation.
|
381 |
+
seq_length = 10
|
382 |
+
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
|
383 |
+
input_dict = dict(
|
384 |
+
input_word_ids=dummy_ids,
|
385 |
+
input_mask=dummy_ids,
|
386 |
+
input_type_ids=dummy_ids)
|
387 |
+
hub_outputs_dict = hub_layer(input_dict)
|
388 |
+
source_outputs_dict = bert_model(input_dict)
|
389 |
+
encoder_outputs_dict = pretrainer.encoder_network(
|
390 |
+
[dummy_ids, dummy_ids, dummy_ids])
|
391 |
+
self.assertEqual(hub_outputs_dict["pooled_output"].shape, (2, hidden_size))
|
392 |
+
self.assertEqual(hub_outputs_dict["sequence_output"].shape,
|
393 |
+
(2, seq_length, hidden_size))
|
394 |
+
for output_key in ("pooled_output", "sequence_output", "encoder_outputs"):
|
395 |
+
self.assertAllClose(source_outputs_dict[output_key],
|
396 |
+
hub_outputs_dict[output_key])
|
397 |
+
self.assertAllClose(source_outputs_dict[output_key],
|
398 |
+
encoder_outputs_dict[output_key])
|
399 |
+
|
400 |
+
# The "default" output of BERT as a text representation is pooled_output.
|
401 |
+
self.assertAllClose(hub_outputs_dict["pooled_output"],
|
402 |
+
hub_outputs_dict["default"])
|
403 |
+
|
404 |
+
# Test that training=True makes a difference (activates dropout).
|
405 |
+
def _dropout_mean_stddev(training, num_runs=20):
|
406 |
+
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
|
407 |
+
input_dict = dict(
|
408 |
+
input_word_ids=input_ids,
|
409 |
+
input_mask=np.ones_like(input_ids),
|
410 |
+
input_type_ids=np.zeros_like(input_ids))
|
411 |
+
outputs = np.concatenate([
|
412 |
+
hub_layer(input_dict, training=training)["pooled_output"]
|
413 |
+
for _ in range(num_runs)
|
414 |
+
])
|
415 |
+
return np.mean(np.std(outputs, axis=0))
|
416 |
+
|
417 |
+
self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
|
418 |
+
self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
|
419 |
+
|
420 |
+
# Checks sub-object `mlm`.
|
421 |
+
self.assertTrue(hasattr(hub_layer.resolved_object, "mlm"))
|
422 |
+
|
423 |
+
self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
|
424 |
+
len(bert_model_with_mlm.trainable_weights))
|
425 |
+
self.assertLen(hub_layer.resolved_object.mlm.trainable_variables,
|
426 |
+
len(pretrainer.trainable_weights))
|
427 |
+
for source_weight, hub_weight, pretrainer_weight in zip(
|
428 |
+
bert_model_with_mlm.trainable_weights,
|
429 |
+
hub_layer.resolved_object.mlm.trainable_variables,
|
430 |
+
pretrainer.trainable_weights):
|
431 |
+
self.assertAllClose(source_weight, hub_weight)
|
432 |
+
self.assertAllClose(source_weight, pretrainer_weight)
|
433 |
+
|
434 |
+
max_predictions_per_seq = 4
|
435 |
+
mlm_positions = np.zeros((2, max_predictions_per_seq), dtype=np.int32)
|
436 |
+
input_dict = dict(
|
437 |
+
input_word_ids=dummy_ids,
|
438 |
+
input_mask=dummy_ids,
|
439 |
+
input_type_ids=dummy_ids,
|
440 |
+
masked_lm_positions=mlm_positions)
|
441 |
+
hub_mlm_outputs_dict = hub_layer.resolved_object.mlm(input_dict)
|
442 |
+
source_mlm_outputs_dict = bert_model_with_mlm(input_dict)
|
443 |
+
for output_key in ("pooled_output", "sequence_output", "mlm_logits",
|
444 |
+
"encoder_outputs"):
|
445 |
+
self.assertAllClose(hub_mlm_outputs_dict[output_key],
|
446 |
+
source_mlm_outputs_dict[output_key])
|
447 |
+
|
448 |
+
pretrainer_mlm_logits_output = pretrainer(input_dict)["mlm_logits"]
|
449 |
+
self.assertAllClose(hub_mlm_outputs_dict["mlm_logits"],
|
450 |
+
pretrainer_mlm_logits_output)
|
451 |
+
|
452 |
+
# Test that training=True makes a difference (activates dropout).
|
453 |
+
def _dropout_mean_stddev_mlm(training, num_runs=20):
|
454 |
+
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
|
455 |
+
mlm_position_ids = np.array([[1, 2, 3, 4]], np.int32)
|
456 |
+
input_dict = dict(
|
457 |
+
input_word_ids=input_ids,
|
458 |
+
input_mask=np.ones_like(input_ids),
|
459 |
+
input_type_ids=np.zeros_like(input_ids),
|
460 |
+
masked_lm_positions=mlm_position_ids)
|
461 |
+
outputs = np.concatenate([
|
462 |
+
hub_layer.resolved_object.mlm(input_dict,
|
463 |
+
training=training)["pooled_output"]
|
464 |
+
for _ in range(num_runs)
|
465 |
+
])
|
466 |
+
return np.mean(np.std(outputs, axis=0))
|
467 |
+
|
468 |
+
self.assertLess(_dropout_mean_stddev_mlm(training=False), 1e-6)
|
469 |
+
self.assertGreater(_dropout_mean_stddev_mlm(training=True), 1e-3)
|
470 |
+
|
471 |
+
# Test propagation of seq_length in shape inference.
|
472 |
+
input_word_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
473 |
+
input_mask = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
474 |
+
input_type_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
|
475 |
+
input_dict = dict(
|
476 |
+
input_word_ids=input_word_ids,
|
477 |
+
input_mask=input_mask,
|
478 |
+
input_type_ids=input_type_ids)
|
479 |
+
hub_outputs_dict = hub_layer(input_dict)
|
480 |
+
self.assertEqual(hub_outputs_dict["pooled_output"].shape.as_list(),
|
481 |
+
[None, hidden_size])
|
482 |
+
self.assertEqual(hub_outputs_dict["sequence_output"].shape.as_list(),
|
483 |
+
[None, seq_length, hidden_size])
|
484 |
+
|
485 |
+
|
486 |
+
_STRING_NOT_TO_LEAK = "private_path_component_"
|
487 |
+
|
488 |
+
|
489 |
+
class ExportPreprocessingTest(tf.test.TestCase, parameterized.TestCase):
|
490 |
+
|
491 |
+
def _make_vocab_file(self, vocab, filename="vocab.txt", add_mask_token=False):
|
492 |
+
"""Creates wordpiece vocab file with given words plus special tokens.
|
493 |
+
|
494 |
+
The tokens of the resulting model are, in this order:
|
495 |
+
[PAD], [UNK], [CLS], [SEP], [MASK]*, ...vocab...
|
496 |
+
*=if requested by args.
|
497 |
+
|
498 |
+
This function also accepts wordpieces that start with the ## continuation
|
499 |
+
marker, but avoiding those makes this function interchangeable with
|
500 |
+
_make_sp_model_file(), up to the extra dimension returned by BertTokenizer.
|
501 |
+
|
502 |
+
Args:
|
503 |
+
vocab: a list of strings with the words or wordpieces to put into the
|
504 |
+
model's vocabulary. Do not include special tokens here.
|
505 |
+
filename: Optionally, a filename (relative to the temporary directory
|
506 |
+
created by this function).
|
507 |
+
add_mask_token: an optional bool, whether to include a [MASK] token.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
The absolute filename of the created vocab file.
|
511 |
+
"""
|
512 |
+
full_vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"
|
513 |
+
] + ["[MASK]"] * add_mask_token + vocab
|
514 |
+
path = os.path.join(
|
515 |
+
tempfile.mkdtemp(
|
516 |
+
dir=self.get_temp_dir(), # New subdir each time.
|
517 |
+
prefix=_STRING_NOT_TO_LEAK),
|
518 |
+
filename)
|
519 |
+
with tf.io.gfile.GFile(path, "w") as f:
|
520 |
+
f.write("\n".join(full_vocab + [""]))
|
521 |
+
return path
|
522 |
+
|
523 |
+
def _make_sp_model_file(self, vocab, prefix="spm", add_mask_token=False):
|
524 |
+
"""Creates Sentencepiece word model with given words plus special tokens.
|
525 |
+
|
526 |
+
The tokens of the resulting model are, in this order:
|
527 |
+
<pad>, <unk>, [CLS], [SEP], [MASK]*, ...vocab..., <s>, </s>
|
528 |
+
*=if requested by args.
|
529 |
+
|
530 |
+
The words in the input vocab are plain text, without the whitespace marker.
|
531 |
+
That makes this function interchangeable with _make_vocab_file().
|
532 |
+
|
533 |
+
Args:
|
534 |
+
vocab: a list of strings with the words to put into the model's
|
535 |
+
vocabulary. Do not include special tokens here.
|
536 |
+
prefix: an optional string, to change the filename prefix for the model
|
537 |
+
(relative to the temporary directory created by this function).
|
538 |
+
add_mask_token: an optional bool, whether to include a [MASK] token.
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
The absolute filename of the created Sentencepiece model file.
|
542 |
+
"""
|
543 |
+
model_prefix = os.path.join(
|
544 |
+
tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
|
545 |
+
prefix)
|
546 |
+
input_file = model_prefix + "_train_input.txt"
|
547 |
+
# Create input text for training the sp model from the tokens provided.
|
548 |
+
# Repeat tokens, the earlier the more, because they are sorted by frequency.
|
549 |
+
input_text = []
|
550 |
+
for i, token in enumerate(vocab):
|
551 |
+
input_text.append(" ".join([token] * (len(vocab) - i)))
|
552 |
+
with tf.io.gfile.GFile(input_file, "w") as f:
|
553 |
+
f.write("\n".join(input_text + [""]))
|
554 |
+
control_symbols = "[CLS],[SEP]"
|
555 |
+
full_vocab_size = len(vocab) + 6 # <pad>, <unk>, [CLS], [SEP], <s>, </s>.
|
556 |
+
if add_mask_token:
|
557 |
+
control_symbols += ",[MASK]"
|
558 |
+
full_vocab_size += 1
|
559 |
+
flags = dict(
|
560 |
+
model_prefix=model_prefix,
|
561 |
+
model_type="word",
|
562 |
+
input=input_file,
|
563 |
+
pad_id=0,
|
564 |
+
unk_id=1,
|
565 |
+
control_symbols=control_symbols,
|
566 |
+
vocab_size=full_vocab_size,
|
567 |
+
bos_id=full_vocab_size - 2,
|
568 |
+
eos_id=full_vocab_size - 1)
|
569 |
+
SentencePieceTrainer.Train(" ".join(
|
570 |
+
["--{}={}".format(k, v) for k, v in flags.items()]))
|
571 |
+
return model_prefix + ".model"
|
572 |
+
|
573 |
+
def _do_export(self,
|
574 |
+
vocab,
|
575 |
+
do_lower_case,
|
576 |
+
default_seq_length=128,
|
577 |
+
tokenize_with_offsets=True,
|
578 |
+
use_sp_model=False,
|
579 |
+
experimental_disable_assert=False,
|
580 |
+
add_mask_token=False):
|
581 |
+
"""Runs SavedModel export and returns the export_path."""
|
582 |
+
export_path = tempfile.mkdtemp(dir=self.get_temp_dir())
|
583 |
+
vocab_file = sp_model_file = None
|
584 |
+
if use_sp_model:
|
585 |
+
sp_model_file = self._make_sp_model_file(
|
586 |
+
vocab, add_mask_token=add_mask_token)
|
587 |
+
else:
|
588 |
+
vocab_file = self._make_vocab_file(vocab, add_mask_token=add_mask_token)
|
589 |
+
export_tfhub_lib.export_preprocessing(
|
590 |
+
export_path,
|
591 |
+
vocab_file=vocab_file,
|
592 |
+
sp_model_file=sp_model_file,
|
593 |
+
do_lower_case=do_lower_case,
|
594 |
+
tokenize_with_offsets=tokenize_with_offsets,
|
595 |
+
default_seq_length=default_seq_length,
|
596 |
+
experimental_disable_assert=experimental_disable_assert)
|
597 |
+
# Invalidate the original filename to verify loading from the SavedModel.
|
598 |
+
tf.io.gfile.remove(sp_model_file or vocab_file)
|
599 |
+
return export_path
|
600 |
+
|
601 |
+
def test_no_leaks(self):
|
602 |
+
"""Tests not leaking the path to the original vocab file."""
|
603 |
+
path = self._do_export(["d", "ef", "abc", "xy"],
|
604 |
+
do_lower_case=True,
|
605 |
+
use_sp_model=False)
|
606 |
+
with tf.io.gfile.GFile(os.path.join(path, "saved_model.pb"), "rb") as f:
|
607 |
+
self.assertFalse( # pylint: disable=g-generic-assert
|
608 |
+
_STRING_NOT_TO_LEAK.encode("ascii") in f.read())
|
609 |
+
|
610 |
+
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
|
611 |
+
def test_exported_callables(self, use_sp_model):
|
612 |
+
preprocess = tf.saved_model.load(
|
613 |
+
self._do_export(
|
614 |
+
["d", "ef", "abc", "xy"],
|
615 |
+
do_lower_case=True,
|
616 |
+
# TODO(b/181866850): drop this.
|
617 |
+
tokenize_with_offsets=not use_sp_model,
|
618 |
+
# TODO(b/175369555): drop this.
|
619 |
+
experimental_disable_assert=True,
|
620 |
+
use_sp_model=use_sp_model))
|
621 |
+
|
622 |
+
def fold_dim(rt):
|
623 |
+
"""Removes the word/subword distinction of BertTokenizer."""
|
624 |
+
return rt if use_sp_model else rt.merge_dims(1, 2)
|
625 |
+
|
626 |
+
# .tokenize()
|
627 |
+
inputs = tf.constant(["abc d ef", "ABC D EF d"])
|
628 |
+
token_ids = preprocess.tokenize(inputs)
|
629 |
+
self.assertAllEqual(
|
630 |
+
fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
|
631 |
+
|
632 |
+
special_tokens_dict = {
|
633 |
+
k: v.numpy().item() # Expecting eager Tensor, converting to Python.
|
634 |
+
for k, v in preprocess.tokenize.get_special_tokens_dict().items()
|
635 |
+
}
|
636 |
+
self.assertDictEqual(
|
637 |
+
special_tokens_dict,
|
638 |
+
dict(
|
639 |
+
padding_id=0,
|
640 |
+
start_of_sequence_id=2,
|
641 |
+
end_of_segment_id=3,
|
642 |
+
vocab_size=4 + 6 if use_sp_model else 4 + 4))
|
643 |
+
|
644 |
+
# .tokenize_with_offsets()
|
645 |
+
if use_sp_model:
|
646 |
+
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
|
647 |
+
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
|
648 |
+
else:
|
649 |
+
token_ids, start_offsets, limit_offsets = (
|
650 |
+
preprocess.tokenize_with_offsets(inputs))
|
651 |
+
self.assertAllEqual(
|
652 |
+
fold_dim(token_ids), tf.ragged.constant([[6, 4, 5], [6, 4, 5, 4]]))
|
653 |
+
self.assertAllEqual(
|
654 |
+
fold_dim(start_offsets), tf.ragged.constant([[0, 4, 6], [0, 4, 6,
|
655 |
+
9]]))
|
656 |
+
self.assertAllEqual(
|
657 |
+
fold_dim(limit_offsets), tf.ragged.constant([[3, 5, 8], [3, 5, 8,
|
658 |
+
10]]))
|
659 |
+
self.assertIs(preprocess.tokenize.get_special_tokens_dict,
|
660 |
+
preprocess.tokenize_with_offsets.get_special_tokens_dict)
|
661 |
+
|
662 |
+
# Root callable.
|
663 |
+
bert_inputs = preprocess(inputs)
|
664 |
+
self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 128])
|
665 |
+
self.assertAllEqual(
|
666 |
+
bert_inputs["input_word_ids"][:, :10],
|
667 |
+
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
|
668 |
+
[2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
|
669 |
+
self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 128])
|
670 |
+
self.assertAllEqual(
|
671 |
+
bert_inputs["input_mask"][:, :10],
|
672 |
+
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
|
673 |
+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
|
674 |
+
self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 128])
|
675 |
+
self.assertAllEqual(
|
676 |
+
bert_inputs["input_type_ids"][:, :10],
|
677 |
+
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
678 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
|
679 |
+
|
680 |
+
# .bert_pack_inputs()
|
681 |
+
inputs_2 = tf.constant(["d xy", "xy abc"])
|
682 |
+
token_ids_2 = preprocess.tokenize(inputs_2)
|
683 |
+
bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2],
|
684 |
+
seq_length=256)
|
685 |
+
self.assertAllEqual(bert_inputs["input_word_ids"].shape.as_list(), [2, 256])
|
686 |
+
self.assertAllEqual(
|
687 |
+
bert_inputs["input_word_ids"][:, :10],
|
688 |
+
tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
|
689 |
+
[2, 6, 4, 5, 4, 3, 7, 6, 3, 0]]))
|
690 |
+
self.assertAllEqual(bert_inputs["input_mask"].shape.as_list(), [2, 256])
|
691 |
+
self.assertAllEqual(
|
692 |
+
bert_inputs["input_mask"][:, :10],
|
693 |
+
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
|
694 |
+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]))
|
695 |
+
self.assertAllEqual(bert_inputs["input_type_ids"].shape.as_list(), [2, 256])
|
696 |
+
self.assertAllEqual(
|
697 |
+
bert_inputs["input_type_ids"][:, :10],
|
698 |
+
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
|
699 |
+
[0, 0, 0, 0, 0, 0, 1, 1, 1, 0]]))
|
700 |
+
|
701 |
+
# For BertTokenizer only: repeat relevant parts for do_lower_case=False,
|
702 |
+
# default_seq_length=10, experimental_disable_assert=False,
|
703 |
+
# tokenize_with_offsets=False, and without folding the word/subword dimension.
|
704 |
+
def test_cased_length10(self):
|
705 |
+
preprocess = tf.saved_model.load(
|
706 |
+
self._do_export(["d", "##ef", "abc", "ABC"],
|
707 |
+
do_lower_case=False,
|
708 |
+
default_seq_length=10,
|
709 |
+
tokenize_with_offsets=False,
|
710 |
+
use_sp_model=False,
|
711 |
+
experimental_disable_assert=False))
|
712 |
+
inputs = tf.constant(["abc def", "ABC DEF"])
|
713 |
+
token_ids = preprocess.tokenize(inputs)
|
714 |
+
self.assertAllEqual(token_ids,
|
715 |
+
tf.ragged.constant([[[6], [4, 5]], [[7], [1]]]))
|
716 |
+
|
717 |
+
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
|
718 |
+
|
719 |
+
bert_inputs = preprocess(inputs)
|
720 |
+
self.assertAllEqual(
|
721 |
+
bert_inputs["input_word_ids"],
|
722 |
+
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
|
723 |
+
[2, 7, 1, 3, 0, 0, 0, 0, 0, 0]]))
|
724 |
+
self.assertAllEqual(
|
725 |
+
bert_inputs["input_mask"],
|
726 |
+
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
|
727 |
+
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]))
|
728 |
+
self.assertAllEqual(
|
729 |
+
bert_inputs["input_type_ids"],
|
730 |
+
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
731 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
|
732 |
+
|
733 |
+
inputs_2 = tf.constant(["d ABC", "ABC abc"])
|
734 |
+
token_ids_2 = preprocess.tokenize(inputs_2)
|
735 |
+
bert_inputs = preprocess.bert_pack_inputs([token_ids, token_ids_2])
|
736 |
+
# Test default seq_length=10.
|
737 |
+
self.assertAllEqual(
|
738 |
+
bert_inputs["input_word_ids"],
|
739 |
+
tf.constant([[2, 6, 4, 5, 3, 4, 7, 3, 0, 0],
|
740 |
+
[2, 7, 1, 3, 7, 6, 3, 0, 0, 0]]))
|
741 |
+
self.assertAllEqual(
|
742 |
+
bert_inputs["input_mask"],
|
743 |
+
tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
|
744 |
+
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]))
|
745 |
+
self.assertAllEqual(
|
746 |
+
bert_inputs["input_type_ids"],
|
747 |
+
tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
|
748 |
+
[0, 0, 0, 0, 1, 1, 1, 0, 0, 0]]))
|
749 |
+
|
750 |
+
# XLA requires fixed shapes for tensors found in graph mode.
|
751 |
+
# Statically known shapes in Python are a particularly firm way to
|
752 |
+
# guarantee that, and they are generally more convenient to work with.
|
753 |
+
# We test that the exported SavedModel plays well with TF's shape
|
754 |
+
# inference when applied to fully or partially known input shapes.
|
755 |
+
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
|
756 |
+
def test_shapes(self, use_sp_model):
|
757 |
+
preprocess = tf.saved_model.load(
|
758 |
+
self._do_export(
|
759 |
+
["abc", "def"],
|
760 |
+
do_lower_case=True,
|
761 |
+
# TODO(b/181866850): drop this.
|
762 |
+
tokenize_with_offsets=not use_sp_model,
|
763 |
+
# TODO(b/175369555): drop this.
|
764 |
+
experimental_disable_assert=True,
|
765 |
+
use_sp_model=use_sp_model))
|
766 |
+
|
767 |
+
def expected_bert_input_shapes(batch_size, seq_length):
|
768 |
+
return dict(
|
769 |
+
input_word_ids=[batch_size, seq_length],
|
770 |
+
input_mask=[batch_size, seq_length],
|
771 |
+
input_type_ids=[batch_size, seq_length])
|
772 |
+
|
773 |
+
for batch_size in [7, None]:
|
774 |
+
if use_sp_model:
|
775 |
+
token_out_shape = [batch_size, None] # No word/subword distinction.
|
776 |
+
else:
|
777 |
+
token_out_shape = [batch_size, None, None]
|
778 |
+
self.assertEqual(
|
779 |
+
_result_shapes_in_tf_function(preprocess.tokenize,
|
780 |
+
tf.TensorSpec([batch_size], tf.string)),
|
781 |
+
token_out_shape, "with batch_size=%s" % batch_size)
|
782 |
+
# TODO(b/181866850): Enable tokenize_with_offsets when it works and test.
|
783 |
+
if use_sp_model:
|
784 |
+
self.assertFalse(hasattr(preprocess, "tokenize_with_offsets"))
|
785 |
+
else:
|
786 |
+
self.assertEqual(
|
787 |
+
_result_shapes_in_tf_function(
|
788 |
+
preprocess.tokenize_with_offsets,
|
789 |
+
tf.TensorSpec([batch_size], tf.string)), [token_out_shape] * 3,
|
790 |
+
"with batch_size=%s" % batch_size)
|
791 |
+
self.assertEqual(
|
792 |
+
_result_shapes_in_tf_function(
|
793 |
+
preprocess.bert_pack_inputs,
|
794 |
+
[tf.RaggedTensorSpec([batch_size, None, None], tf.int32)] * 2,
|
795 |
+
seq_length=256), expected_bert_input_shapes(batch_size, 256),
|
796 |
+
"with batch_size=%s" % batch_size)
|
797 |
+
self.assertEqual(
|
798 |
+
_result_shapes_in_tf_function(preprocess,
|
799 |
+
tf.TensorSpec([batch_size], tf.string)),
|
800 |
+
expected_bert_input_shapes(batch_size, 128),
|
801 |
+
"with batch_size=%s" % batch_size)
|
802 |
+
|
803 |
+
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
|
804 |
+
def test_reexport(self, use_sp_model):
|
805 |
+
"""Test that preprocess keeps working after another save/load cycle."""
|
806 |
+
path1 = self._do_export(
|
807 |
+
["d", "ef", "abc", "xy"],
|
808 |
+
do_lower_case=True,
|
809 |
+
default_seq_length=10,
|
810 |
+
tokenize_with_offsets=False,
|
811 |
+
experimental_disable_assert=True, # TODO(b/175369555): drop this.
|
812 |
+
use_sp_model=use_sp_model)
|
813 |
+
path2 = path1.rstrip("/") + ".2"
|
814 |
+
model1 = tf.saved_model.load(path1)
|
815 |
+
tf.saved_model.save(model1, path2)
|
816 |
+
# Delete the first SavedModel to test that the sceond one loads by itself.
|
817 |
+
# https://github.com/tensorflow/tensorflow/issues/46456 reports such a
|
818 |
+
# failure case for BertTokenizer.
|
819 |
+
tf.io.gfile.rmtree(path1)
|
820 |
+
model2 = tf.saved_model.load(path2)
|
821 |
+
|
822 |
+
inputs = tf.constant(["abc d ef", "ABC D EF d"])
|
823 |
+
bert_inputs = model2(inputs)
|
824 |
+
self.assertAllEqual(
|
825 |
+
bert_inputs["input_word_ids"],
|
826 |
+
tf.constant([[2, 6, 4, 5, 3, 0, 0, 0, 0, 0],
|
827 |
+
[2, 6, 4, 5, 4, 3, 0, 0, 0, 0]]))
|
828 |
+
self.assertAllEqual(
|
829 |
+
bert_inputs["input_mask"],
|
830 |
+
tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
|
831 |
+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]))
|
832 |
+
self.assertAllEqual(
|
833 |
+
bert_inputs["input_type_ids"],
|
834 |
+
tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
835 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))
|
836 |
+
|
837 |
+
@parameterized.named_parameters(("Bert", True), ("Albert", False))
|
838 |
+
def test_preprocessing_for_mlm(self, use_bert):
|
839 |
+
"""Combines both SavedModel types and TF.text helpers for MLM."""
|
840 |
+
# Create the preprocessing SavedModel with a [MASK] token.
|
841 |
+
non_special_tokens = [
|
842 |
+
"hello", "world", "nice", "movie", "great", "actors", "quick", "fox",
|
843 |
+
"lazy", "dog"
|
844 |
+
]
|
845 |
+
|
846 |
+
preprocess = tf.saved_model.load(
|
847 |
+
self._do_export(
|
848 |
+
non_special_tokens,
|
849 |
+
do_lower_case=True,
|
850 |
+
tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this.
|
851 |
+
experimental_disable_assert=True, # TODO(b/175369555): drop this.
|
852 |
+
add_mask_token=True,
|
853 |
+
use_sp_model=not use_bert))
|
854 |
+
vocab_size = len(non_special_tokens) + (5 if use_bert else 7)
|
855 |
+
|
856 |
+
# Create the encoder SavedModel with an .mlm subobject.
|
857 |
+
hidden_size = 16
|
858 |
+
num_hidden_layers = 2
|
859 |
+
bert_config, encoder_config = _get_bert_config_or_encoder_config(
|
860 |
+
use_bert_config=use_bert,
|
861 |
+
hidden_size=hidden_size,
|
862 |
+
num_hidden_layers=num_hidden_layers,
|
863 |
+
vocab_size=vocab_size)
|
864 |
+
_, pretrainer = export_tfhub_lib._create_model(
|
865 |
+
bert_config=bert_config, encoder_config=encoder_config, with_mlm=True)
|
866 |
+
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
|
867 |
+
checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
|
868 |
+
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
|
869 |
+
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
|
870 |
+
vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy( # Not used below.
|
871 |
+
self.get_temp_dir(), use_sp_model=not use_bert)
|
872 |
+
encoder_export_path = os.path.join(self.get_temp_dir(), "encoder_export")
|
873 |
+
export_tfhub_lib.export_model(
|
874 |
+
export_path=encoder_export_path,
|
875 |
+
bert_config=bert_config,
|
876 |
+
encoder_config=encoder_config,
|
877 |
+
model_checkpoint_path=model_checkpoint_path,
|
878 |
+
with_mlm=True,
|
879 |
+
vocab_file=vocab_file,
|
880 |
+
sp_model_file=sp_model_file,
|
881 |
+
do_lower_case=True)
|
882 |
+
encoder = tf.saved_model.load(encoder_export_path)
|
883 |
+
|
884 |
+
# Get special tokens from the vocab (and vocab size).
|
885 |
+
special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
|
886 |
+
self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
|
887 |
+
padding_id = int(special_tokens_dict["padding_id"])
|
888 |
+
self.assertEqual(padding_id, 0)
|
889 |
+
start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
|
890 |
+
self.assertEqual(start_of_sequence_id, 2)
|
891 |
+
end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
|
892 |
+
self.assertEqual(end_of_segment_id, 3)
|
893 |
+
mask_id = int(special_tokens_dict["mask_id"])
|
894 |
+
self.assertEqual(mask_id, 4)
|
895 |
+
|
896 |
+
# A batch of 3 segment pairs.
|
897 |
+
raw_segments = [
|
898 |
+
tf.constant(["hello", "nice movie", "quick fox"]),
|
899 |
+
tf.constant(["world", "great actors", "lazy dog"])
|
900 |
+
]
|
901 |
+
batch_size = 3
|
902 |
+
|
903 |
+
# Misc hyperparameters.
|
904 |
+
seq_length = 10
|
905 |
+
max_selections_per_seq = 2
|
906 |
+
|
907 |
+
# Tokenize inputs.
|
908 |
+
tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
|
909 |
+
# Trim inputs to eventually fit seq_lentgh.
|
910 |
+
num_special_tokens = len(raw_segments) + 1
|
911 |
+
trimmed_segments = text.WaterfallTrimmer(
|
912 |
+
seq_length - num_special_tokens).trim(tokenized_segments)
|
913 |
+
# Combine input segments into one input sequence.
|
914 |
+
input_ids, segment_ids = text.combine_segments(
|
915 |
+
trimmed_segments,
|
916 |
+
start_of_sequence_id=start_of_sequence_id,
|
917 |
+
end_of_segment_id=end_of_segment_id)
|
918 |
+
# Apply random masking controlled by policy objects.
|
919 |
+
(masked_input_ids, masked_lm_positions,
|
920 |
+
masked_ids) = text.mask_language_model(
|
921 |
+
input_ids=input_ids,
|
922 |
+
item_selector=text.RandomItemSelector(
|
923 |
+
max_selections_per_seq,
|
924 |
+
selection_rate=0.5, # Adjusted for the short test examples.
|
925 |
+
unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
|
926 |
+
mask_values_chooser=text.MaskValuesChooser(
|
927 |
+
vocab_size=vocab_size,
|
928 |
+
mask_token=mask_id,
|
929 |
+
# Always put [MASK] to have a predictable result.
|
930 |
+
mask_token_rate=1.0,
|
931 |
+
random_token_rate=0.0))
|
932 |
+
# Pad to fixed-length Transformer encoder inputs.
|
933 |
+
input_word_ids, _ = text.pad_model_inputs(
|
934 |
+
masked_input_ids, seq_length, pad_value=padding_id)
|
935 |
+
input_type_ids, input_mask = text.pad_model_inputs(
|
936 |
+
segment_ids, seq_length, pad_value=0)
|
937 |
+
masked_lm_positions, _ = text.pad_model_inputs(
|
938 |
+
masked_lm_positions, max_selections_per_seq, pad_value=0)
|
939 |
+
masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
|
940 |
+
num_predictions = int(tf.shape(masked_lm_positions)[1])
|
941 |
+
|
942 |
+
# Test transformer inputs.
|
943 |
+
self.assertEqual(num_predictions, max_selections_per_seq)
|
944 |
+
expected_word_ids = np.array([
|
945 |
+
# [CLS] hello [SEP] world [SEP]
|
946 |
+
[2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
|
947 |
+
# [CLS] nice movie [SEP] great actors [SEP]
|
948 |
+
[2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
|
949 |
+
# [CLS] brown fox [SEP] lazy dog [SEP]
|
950 |
+
[2, 11, 12, 3, 13, 14, 3, 0, 0, 0]
|
951 |
+
])
|
952 |
+
for i in range(batch_size):
|
953 |
+
for j in range(num_predictions):
|
954 |
+
k = int(masked_lm_positions[i, j])
|
955 |
+
if k != 0:
|
956 |
+
expected_word_ids[i, k] = 4 # [MASK]
|
957 |
+
self.assertAllEqual(input_word_ids, expected_word_ids)
|
958 |
+
|
959 |
+
# Call the MLM head of the Transformer encoder.
|
960 |
+
mlm_inputs = dict(
|
961 |
+
input_word_ids=input_word_ids,
|
962 |
+
input_mask=input_mask,
|
963 |
+
input_type_ids=input_type_ids,
|
964 |
+
masked_lm_positions=masked_lm_positions,
|
965 |
+
)
|
966 |
+
mlm_outputs = encoder.mlm(mlm_inputs)
|
967 |
+
self.assertEqual(mlm_outputs["pooled_output"].shape,
|
968 |
+
(batch_size, hidden_size))
|
969 |
+
self.assertEqual(mlm_outputs["sequence_output"].shape,
|
970 |
+
(batch_size, seq_length, hidden_size))
|
971 |
+
self.assertEqual(mlm_outputs["mlm_logits"].shape,
|
972 |
+
(batch_size, num_predictions, vocab_size))
|
973 |
+
self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)
|
974 |
+
|
975 |
+
# A real trainer would now compute the loss of mlm_logits
|
976 |
+
# trying to predict the masked_ids.
|
977 |
+
del masked_ids # Unused.
|
978 |
+
|
979 |
+
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
|
980 |
+
def test_special_tokens_in_estimator(self, use_sp_model):
|
981 |
+
"""Tests getting special tokens without an Eager init context."""
|
982 |
+
preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
|
983 |
+
do_lower_case=True,
|
984 |
+
use_sp_model=use_sp_model,
|
985 |
+
tokenize_with_offsets=False)
|
986 |
+
|
987 |
+
def _get_special_tokens_dict(obj):
|
988 |
+
"""Returns special tokens of restored tokenizer as Python values."""
|
989 |
+
if tf.executing_eagerly():
|
990 |
+
special_tokens_numpy = {
|
991 |
+
k: v.numpy() for k, v in obj.get_special_tokens_dict()
|
992 |
+
}
|
993 |
+
else:
|
994 |
+
with tf.Graph().as_default():
|
995 |
+
# This code expects `get_special_tokens_dict()` to be a tf.function
|
996 |
+
# with no dependencies (bound args) from the context it was loaded in,
|
997 |
+
# and boldly assumes that it can just be called in a dfferent context.
|
998 |
+
special_tokens_tensors = obj.get_special_tokens_dict()
|
999 |
+
with tf.compat.v1.Session() as sess:
|
1000 |
+
special_tokens_numpy = sess.run(special_tokens_tensors)
|
1001 |
+
return {
|
1002 |
+
k: v.item() # Numpy to Python.
|
1003 |
+
for k, v in special_tokens_numpy.items()
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
def input_fn():
|
1007 |
+
self.assertFalse(tf.executing_eagerly())
|
1008 |
+
# Build a preprocessing Model.
|
1009 |
+
sentences = tf_keras.layers.Input(shape=[], dtype=tf.string)
|
1010 |
+
preprocess = tf.saved_model.load(preprocess_export_path)
|
1011 |
+
tokenize = hub.KerasLayer(preprocess.tokenize)
|
1012 |
+
special_tokens_dict = _get_special_tokens_dict(tokenize.resolved_object)
|
1013 |
+
for k, v in special_tokens_dict.items():
|
1014 |
+
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
|
1015 |
+
tokens = tokenize(sentences)
|
1016 |
+
packed_inputs = layers.BertPackInputs(
|
1017 |
+
4, special_tokens_dict=special_tokens_dict)(
|
1018 |
+
tokens)
|
1019 |
+
preprocessing = tf_keras.Model(sentences, packed_inputs)
|
1020 |
+
# Map the dataset.
|
1021 |
+
ds = tf.data.Dataset.from_tensors(
|
1022 |
+
(tf.constant(["abc", "D EF"]), tf.constant([0, 1])))
|
1023 |
+
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
|
1024 |
+
return ds
|
1025 |
+
|
1026 |
+
def model_fn(features, labels, mode):
|
1027 |
+
del labels # Unused.
|
1028 |
+
return tf_estimator.EstimatorSpec(
|
1029 |
+
mode=mode, predictions=features["input_word_ids"])
|
1030 |
+
|
1031 |
+
estimator = tf_estimator.Estimator(model_fn=model_fn)
|
1032 |
+
outputs = list(estimator.predict(input_fn))
|
1033 |
+
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0], [2, 4, 5, 3]]))
|
1034 |
+
|
1035 |
+
# TODO(b/175369555): Remove that code and its test.
|
1036 |
+
@parameterized.named_parameters(("Bert", False), ("Sentencepiece", True))
|
1037 |
+
def test_check_no_assert(self, use_sp_model):
|
1038 |
+
"""Tests the self-check during export without assertions."""
|
1039 |
+
preprocess_export_path = self._do_export(["d", "ef", "abc", "xy"],
|
1040 |
+
do_lower_case=True,
|
1041 |
+
use_sp_model=use_sp_model,
|
1042 |
+
tokenize_with_offsets=False,
|
1043 |
+
experimental_disable_assert=False)
|
1044 |
+
with self.assertRaisesRegex(AssertionError,
|
1045 |
+
r"failed to suppress \d+ Assert ops"):
|
1046 |
+
export_tfhub_lib._check_no_assert(preprocess_export_path)
|
1047 |
+
|
1048 |
+
|
1049 |
+
def _result_shapes_in_tf_function(fn, *args, **kwargs):
|
1050 |
+
"""Returns shapes (as lists) observed on the result of `fn`.
|
1051 |
+
|
1052 |
+
Args:
|
1053 |
+
fn: A callable.
|
1054 |
+
*args: TensorSpecs for Tensor-valued arguments and actual values for
|
1055 |
+
Python-valued arguments to fn.
|
1056 |
+
**kwargs: Same for keyword arguments.
|
1057 |
+
|
1058 |
+
Returns:
|
1059 |
+
The nest of partial tensor shapes (as lists) that is statically known inside
|
1060 |
+
tf.function(fn)(*args, **kwargs) for the nest of its results.
|
1061 |
+
"""
|
1062 |
+
# Use a captured mutable container for a side outout from the wrapper.
|
1063 |
+
uninitialized = "uninitialized!"
|
1064 |
+
result_shapes_container = [uninitialized]
|
1065 |
+
assert result_shapes_container[0] is uninitialized
|
1066 |
+
|
1067 |
+
@tf.function
|
1068 |
+
def shape_reporting_wrapper(*args, **kwargs):
|
1069 |
+
result = fn(*args, **kwargs)
|
1070 |
+
result_shapes_container[0] = tf.nest.map_structure(
|
1071 |
+
lambda x: x.shape.as_list(), result)
|
1072 |
+
return result
|
1073 |
+
|
1074 |
+
shape_reporting_wrapper.get_concrete_function(*args, **kwargs)
|
1075 |
+
assert result_shapes_container[0] is not uninitialized
|
1076 |
+
return result_shapes_container[0]
|
1077 |
+
|
1078 |
+
|
1079 |
+
if __name__ == "__main__":
|
1080 |
+
tf.test.main()
|
squad_evaluate_v1_1.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Evaluation of SQuAD predictions (version 1.1).
|
16 |
+
|
17 |
+
The functions are copied from
|
18 |
+
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
|
19 |
+
|
20 |
+
The SQuAD dataset is described in this paper:
|
21 |
+
SQuAD: 100,000+ Questions for Machine Comprehension of Text
|
22 |
+
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
|
23 |
+
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
|
24 |
+
"""
|
25 |
+
|
26 |
+
import collections
|
27 |
+
import re
|
28 |
+
import string
|
29 |
+
|
30 |
+
# pylint: disable=g-bad-import-order
|
31 |
+
|
32 |
+
from absl import logging
|
33 |
+
# pylint: enable=g-bad-import-order
|
34 |
+
|
35 |
+
|
36 |
+
def _normalize_answer(s):
|
37 |
+
"""Lowers text and remove punctuation, articles and extra whitespace."""
|
38 |
+
|
39 |
+
def remove_articles(text):
|
40 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
41 |
+
|
42 |
+
def white_space_fix(text):
|
43 |
+
return " ".join(text.split())
|
44 |
+
|
45 |
+
def remove_punc(text):
|
46 |
+
exclude = set(string.punctuation)
|
47 |
+
return "".join(ch for ch in text if ch not in exclude)
|
48 |
+
|
49 |
+
def lower(text):
|
50 |
+
return text.lower()
|
51 |
+
|
52 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
53 |
+
|
54 |
+
|
55 |
+
def _f1_score(prediction, ground_truth):
|
56 |
+
"""Computes F1 score by comparing prediction to ground truth."""
|
57 |
+
prediction_tokens = _normalize_answer(prediction).split()
|
58 |
+
ground_truth_tokens = _normalize_answer(ground_truth).split()
|
59 |
+
prediction_counter = collections.Counter(prediction_tokens)
|
60 |
+
ground_truth_counter = collections.Counter(ground_truth_tokens)
|
61 |
+
common = prediction_counter & ground_truth_counter
|
62 |
+
num_same = sum(common.values())
|
63 |
+
if num_same == 0:
|
64 |
+
return 0
|
65 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
66 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
67 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
68 |
+
return f1
|
69 |
+
|
70 |
+
|
71 |
+
def _exact_match_score(prediction, ground_truth):
|
72 |
+
"""Checks if predicted answer exactly matches ground truth answer."""
|
73 |
+
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
|
74 |
+
|
75 |
+
|
76 |
+
def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
77 |
+
"""Computes the max over all metric scores."""
|
78 |
+
scores_for_ground_truths = []
|
79 |
+
for ground_truth in ground_truths:
|
80 |
+
score = metric_fn(prediction, ground_truth)
|
81 |
+
scores_for_ground_truths.append(score)
|
82 |
+
return max(scores_for_ground_truths)
|
83 |
+
|
84 |
+
|
85 |
+
def evaluate(dataset, predictions):
|
86 |
+
"""Evaluates predictions for a dataset."""
|
87 |
+
f1 = exact_match = total = 0
|
88 |
+
for article in dataset:
|
89 |
+
for paragraph in article["paragraphs"]:
|
90 |
+
for qa in paragraph["qas"]:
|
91 |
+
total += 1
|
92 |
+
if qa["id"] not in predictions:
|
93 |
+
message = "Unanswered question " + qa["id"] + " will receive score 0."
|
94 |
+
logging.error(message)
|
95 |
+
continue
|
96 |
+
ground_truths = [entry["text"] for entry in qa["answers"]]
|
97 |
+
prediction = predictions[qa["id"]]
|
98 |
+
exact_match += _metric_max_over_ground_truths(_exact_match_score,
|
99 |
+
prediction, ground_truths)
|
100 |
+
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
|
101 |
+
ground_truths)
|
102 |
+
|
103 |
+
exact_match = exact_match / total
|
104 |
+
f1 = f1 / total
|
105 |
+
|
106 |
+
return {"exact_match": exact_match, "final_f1": f1}
|
squad_evaluate_v2_0.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Evaluation script for SQuAD version 2.0.
|
16 |
+
|
17 |
+
The functions are copied and modified from
|
18 |
+
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
|
19 |
+
|
20 |
+
In addition to basic functionality, we also compute additional statistics and
|
21 |
+
plot precision-recall curves if an additional na_prob.json file is provided.
|
22 |
+
This file is expected to map question ID's to the model's predicted probability
|
23 |
+
that a question is unanswerable.
|
24 |
+
"""
|
25 |
+
|
26 |
+
import collections
|
27 |
+
import re
|
28 |
+
import string
|
29 |
+
|
30 |
+
from absl import logging
|
31 |
+
|
32 |
+
|
33 |
+
def _make_qid_to_has_ans(dataset):
|
34 |
+
qid_to_has_ans = {}
|
35 |
+
for article in dataset:
|
36 |
+
for p in article['paragraphs']:
|
37 |
+
for qa in p['qas']:
|
38 |
+
qid_to_has_ans[qa['id']] = bool(qa['answers'])
|
39 |
+
return qid_to_has_ans
|
40 |
+
|
41 |
+
|
42 |
+
def _normalize_answer(s):
|
43 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
44 |
+
def remove_articles(text):
|
45 |
+
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
46 |
+
return re.sub(regex, ' ', text)
|
47 |
+
def white_space_fix(text):
|
48 |
+
return ' '.join(text.split())
|
49 |
+
def remove_punc(text):
|
50 |
+
exclude = set(string.punctuation)
|
51 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
52 |
+
def lower(text):
|
53 |
+
return text.lower()
|
54 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
55 |
+
|
56 |
+
|
57 |
+
def _get_tokens(s):
|
58 |
+
if not s: return []
|
59 |
+
return _normalize_answer(s).split()
|
60 |
+
|
61 |
+
|
62 |
+
def _compute_exact(a_gold, a_pred):
|
63 |
+
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
|
64 |
+
|
65 |
+
|
66 |
+
def _compute_f1(a_gold, a_pred):
|
67 |
+
"""Compute F1-score."""
|
68 |
+
gold_toks = _get_tokens(a_gold)
|
69 |
+
pred_toks = _get_tokens(a_pred)
|
70 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
71 |
+
num_same = sum(common.values())
|
72 |
+
if not gold_toks or not pred_toks:
|
73 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
74 |
+
return int(gold_toks == pred_toks)
|
75 |
+
if num_same == 0:
|
76 |
+
return 0
|
77 |
+
precision = 1.0 * num_same / len(pred_toks)
|
78 |
+
recall = 1.0 * num_same / len(gold_toks)
|
79 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
80 |
+
return f1
|
81 |
+
|
82 |
+
|
83 |
+
def _get_raw_scores(dataset, predictions):
|
84 |
+
"""Compute raw scores."""
|
85 |
+
exact_scores = {}
|
86 |
+
f1_scores = {}
|
87 |
+
for article in dataset:
|
88 |
+
for p in article['paragraphs']:
|
89 |
+
for qa in p['qas']:
|
90 |
+
qid = qa['id']
|
91 |
+
gold_answers = [a['text'] for a in qa['answers']
|
92 |
+
if _normalize_answer(a['text'])]
|
93 |
+
if not gold_answers:
|
94 |
+
# For unanswerable questions, only correct answer is empty string
|
95 |
+
gold_answers = ['']
|
96 |
+
if qid not in predictions:
|
97 |
+
logging.error('Missing prediction for %s', qid)
|
98 |
+
continue
|
99 |
+
a_pred = predictions[qid]
|
100 |
+
# Take max over all gold answers
|
101 |
+
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
|
102 |
+
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
|
103 |
+
return exact_scores, f1_scores
|
104 |
+
|
105 |
+
|
106 |
+
def _apply_no_ans_threshold(
|
107 |
+
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
|
108 |
+
new_scores = {}
|
109 |
+
for qid, s in scores.items():
|
110 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
111 |
+
if pred_na:
|
112 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
113 |
+
else:
|
114 |
+
new_scores[qid] = s
|
115 |
+
return new_scores
|
116 |
+
|
117 |
+
|
118 |
+
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
119 |
+
"""Make evaluation result dictionary."""
|
120 |
+
if not qid_list:
|
121 |
+
total = len(exact_scores)
|
122 |
+
return collections.OrderedDict([
|
123 |
+
('exact', 100.0 * sum(exact_scores.values()) / total),
|
124 |
+
('f1', 100.0 * sum(f1_scores.values()) / total),
|
125 |
+
('total', total),
|
126 |
+
])
|
127 |
+
else:
|
128 |
+
total = len(qid_list)
|
129 |
+
return collections.OrderedDict([
|
130 |
+
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
131 |
+
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
132 |
+
('total', total),
|
133 |
+
])
|
134 |
+
|
135 |
+
|
136 |
+
def _merge_eval(main_eval, new_eval, prefix):
|
137 |
+
for k in new_eval:
|
138 |
+
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
139 |
+
|
140 |
+
|
141 |
+
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
|
142 |
+
"""Make evaluation dictionary containing average recision recall."""
|
143 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
144 |
+
true_pos = 0.0
|
145 |
+
cur_p = 1.0
|
146 |
+
cur_r = 0.0
|
147 |
+
precisions = [1.0]
|
148 |
+
recalls = [0.0]
|
149 |
+
avg_prec = 0.0
|
150 |
+
for i, qid in enumerate(qid_list):
|
151 |
+
if qid_to_has_ans[qid]:
|
152 |
+
true_pos += scores[qid]
|
153 |
+
cur_p = true_pos / float(i+1)
|
154 |
+
cur_r = true_pos / float(num_true_pos)
|
155 |
+
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
|
156 |
+
# i.e., if we can put a threshold after this point
|
157 |
+
avg_prec += cur_p * (cur_r - recalls[-1])
|
158 |
+
precisions.append(cur_p)
|
159 |
+
recalls.append(cur_r)
|
160 |
+
return {'ap': 100.0 * avg_prec}
|
161 |
+
|
162 |
+
|
163 |
+
def _run_precision_recall_analysis(
|
164 |
+
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
165 |
+
"""Run precision recall analysis and return result dictionary."""
|
166 |
+
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
|
167 |
+
if num_true_pos == 0:
|
168 |
+
return
|
169 |
+
pr_exact = _make_precision_recall_eval(
|
170 |
+
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
|
171 |
+
pr_f1 = _make_precision_recall_eval(
|
172 |
+
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
|
173 |
+
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
|
174 |
+
pr_oracle = _make_precision_recall_eval(
|
175 |
+
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
|
176 |
+
_merge_eval(main_eval, pr_exact, 'pr_exact')
|
177 |
+
_merge_eval(main_eval, pr_f1, 'pr_f1')
|
178 |
+
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
|
179 |
+
|
180 |
+
|
181 |
+
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
|
182 |
+
"""Find the best threshold for no answer probability."""
|
183 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
184 |
+
cur_score = num_no_ans
|
185 |
+
best_score = cur_score
|
186 |
+
best_thresh = 0.0
|
187 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
188 |
+
for qid in qid_list:
|
189 |
+
if qid not in scores: continue
|
190 |
+
if qid_to_has_ans[qid]:
|
191 |
+
diff = scores[qid]
|
192 |
+
else:
|
193 |
+
if predictions[qid]:
|
194 |
+
diff = -1
|
195 |
+
else:
|
196 |
+
diff = 0
|
197 |
+
cur_score += diff
|
198 |
+
if cur_score > best_score:
|
199 |
+
best_score = cur_score
|
200 |
+
best_thresh = na_probs[qid]
|
201 |
+
return 100.0 * best_score / len(scores), best_thresh
|
202 |
+
|
203 |
+
|
204 |
+
def _find_all_best_thresh(
|
205 |
+
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
206 |
+
best_exact, exact_thresh = _find_best_thresh(
|
207 |
+
predictions, exact_raw, na_probs, qid_to_has_ans)
|
208 |
+
best_f1, f1_thresh = _find_best_thresh(
|
209 |
+
predictions, f1_raw, na_probs, qid_to_has_ans)
|
210 |
+
main_eval['final_exact'] = best_exact
|
211 |
+
main_eval['final_exact_thresh'] = exact_thresh
|
212 |
+
main_eval['final_f1'] = best_f1
|
213 |
+
main_eval['final_f1_thresh'] = f1_thresh
|
214 |
+
|
215 |
+
|
216 |
+
def evaluate(dataset, predictions, na_probs=None):
|
217 |
+
"""Evaluate prediction results."""
|
218 |
+
new_orig_data = []
|
219 |
+
for article in dataset:
|
220 |
+
for p in article['paragraphs']:
|
221 |
+
for qa in p['qas']:
|
222 |
+
if qa['id'] in predictions:
|
223 |
+
new_para = {'qas': [qa]}
|
224 |
+
new_article = {'paragraphs': [new_para]}
|
225 |
+
new_orig_data.append(new_article)
|
226 |
+
dataset = new_orig_data
|
227 |
+
|
228 |
+
if na_probs is None:
|
229 |
+
na_probs = {k: 0.0 for k in predictions}
|
230 |
+
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
|
231 |
+
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
232 |
+
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
233 |
+
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
|
234 |
+
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
|
235 |
+
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
|
236 |
+
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
|
237 |
+
if has_ans_qids:
|
238 |
+
has_ans_eval = _make_eval_dict(
|
239 |
+
exact_thresh, f1_thresh, qid_list=has_ans_qids)
|
240 |
+
_merge_eval(out_eval, has_ans_eval, 'HasAns')
|
241 |
+
if no_ans_qids:
|
242 |
+
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
|
243 |
+
_merge_eval(out_eval, no_ans_eval, 'NoAns')
|
244 |
+
|
245 |
+
_find_all_best_thresh(
|
246 |
+
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
247 |
+
_run_precision_recall_analysis(
|
248 |
+
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
|
249 |
+
return out_eval
|
tf1_bert_checkpoint_converter_lib.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
r"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import tensorflow.compat.v1 as tf # TF 1.x
|
19 |
+
|
20 |
+
# Mapping between old <=> new names. The source pattern in original variable
|
21 |
+
# name will be replaced by destination pattern.
|
22 |
+
BERT_NAME_REPLACEMENTS = (
|
23 |
+
("bert", "bert_model"),
|
24 |
+
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
25 |
+
("embeddings/token_type_embeddings",
|
26 |
+
"embedding_postprocessor/type_embeddings"),
|
27 |
+
("embeddings/position_embeddings",
|
28 |
+
"embedding_postprocessor/position_embeddings"),
|
29 |
+
("embeddings/LayerNorm", "embedding_postprocessor/layer_norm"),
|
30 |
+
("attention/self", "self_attention"),
|
31 |
+
("attention/output/dense", "self_attention_output"),
|
32 |
+
("attention/output/LayerNorm", "self_attention_layer_norm"),
|
33 |
+
("intermediate/dense", "intermediate"),
|
34 |
+
("output/dense", "output"),
|
35 |
+
("output/LayerNorm", "output_layer_norm"),
|
36 |
+
("pooler/dense", "pooler_transform"),
|
37 |
+
)
|
38 |
+
|
39 |
+
BERT_V2_NAME_REPLACEMENTS = (
|
40 |
+
("bert/", ""),
|
41 |
+
("encoder", "transformer"),
|
42 |
+
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
43 |
+
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
|
44 |
+
("embeddings/position_embeddings", "position_embedding/embeddings"),
|
45 |
+
("embeddings/LayerNorm", "embeddings/layer_norm"),
|
46 |
+
("attention/self", "self_attention"),
|
47 |
+
("attention/output/dense", "self_attention/attention_output"),
|
48 |
+
("attention/output/LayerNorm", "self_attention_layer_norm"),
|
49 |
+
("intermediate/dense", "intermediate"),
|
50 |
+
("output/dense", "output"),
|
51 |
+
("output/LayerNorm", "output_layer_norm"),
|
52 |
+
("pooler/dense", "pooler_transform"),
|
53 |
+
("cls/predictions", "bert/cls/predictions"),
|
54 |
+
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
|
55 |
+
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
|
56 |
+
("cls/seq_relationship/output_weights",
|
57 |
+
"predictions/transform/logits/kernel"),
|
58 |
+
)
|
59 |
+
|
60 |
+
BERT_PERMUTATIONS = ()
|
61 |
+
|
62 |
+
BERT_V2_PERMUTATIONS = (("cls/seq_relationship/output_weights", (1, 0)),)
|
63 |
+
|
64 |
+
|
65 |
+
def _bert_name_replacement(var_name, name_replacements):
|
66 |
+
"""Gets the variable name replacement."""
|
67 |
+
for src_pattern, tgt_pattern in name_replacements:
|
68 |
+
if src_pattern in var_name:
|
69 |
+
old_var_name = var_name
|
70 |
+
var_name = var_name.replace(src_pattern, tgt_pattern)
|
71 |
+
tf.logging.info("Converted: %s --> %s", old_var_name, var_name)
|
72 |
+
return var_name
|
73 |
+
|
74 |
+
|
75 |
+
def _has_exclude_patterns(name, exclude_patterns):
|
76 |
+
"""Checks if a string contains substrings that match patterns to exclude."""
|
77 |
+
for p in exclude_patterns:
|
78 |
+
if p in name:
|
79 |
+
return True
|
80 |
+
return False
|
81 |
+
|
82 |
+
|
83 |
+
def _get_permutation(name, permutations):
|
84 |
+
"""Checks whether a variable requires transposition by pattern matching."""
|
85 |
+
for src_pattern, permutation in permutations:
|
86 |
+
if src_pattern in name:
|
87 |
+
tf.logging.info("Permuted: %s --> %s", name, permutation)
|
88 |
+
return permutation
|
89 |
+
|
90 |
+
return None
|
91 |
+
|
92 |
+
|
93 |
+
def _get_new_shape(name, shape, num_heads):
|
94 |
+
"""Checks whether a variable requires reshape by pattern matching."""
|
95 |
+
if "self_attention/attention_output/kernel" in name:
|
96 |
+
return tuple([num_heads, shape[0] // num_heads, shape[1]])
|
97 |
+
if "self_attention/attention_output/bias" in name:
|
98 |
+
return shape
|
99 |
+
|
100 |
+
patterns = [
|
101 |
+
"self_attention/query", "self_attention/value", "self_attention/key"
|
102 |
+
]
|
103 |
+
for pattern in patterns:
|
104 |
+
if pattern in name:
|
105 |
+
if "kernel" in name:
|
106 |
+
return tuple([shape[0], num_heads, shape[1] // num_heads])
|
107 |
+
if "bias" in name:
|
108 |
+
return tuple([num_heads, shape[0] // num_heads])
|
109 |
+
return None
|
110 |
+
|
111 |
+
|
112 |
+
def create_v2_checkpoint(model,
|
113 |
+
src_checkpoint,
|
114 |
+
output_path,
|
115 |
+
checkpoint_model_name="model"):
|
116 |
+
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
|
117 |
+
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
|
118 |
+
model.load_weights(src_checkpoint).assert_existing_objects_matched()
|
119 |
+
if hasattr(model, "checkpoint_items"):
|
120 |
+
checkpoint_items = model.checkpoint_items
|
121 |
+
else:
|
122 |
+
checkpoint_items = {}
|
123 |
+
|
124 |
+
checkpoint_items[checkpoint_model_name] = model
|
125 |
+
checkpoint = tf.train.Checkpoint(**checkpoint_items)
|
126 |
+
checkpoint.save(output_path)
|
127 |
+
|
128 |
+
|
129 |
+
def convert(checkpoint_from_path,
|
130 |
+
checkpoint_to_path,
|
131 |
+
num_heads,
|
132 |
+
name_replacements,
|
133 |
+
permutations,
|
134 |
+
exclude_patterns=None):
|
135 |
+
"""Migrates the names of variables within a checkpoint.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
checkpoint_from_path: Path to source checkpoint to be read in.
|
139 |
+
checkpoint_to_path: Path to checkpoint to be written out.
|
140 |
+
num_heads: The number of heads of the model.
|
141 |
+
name_replacements: A list of tuples of the form (match_str, replace_str)
|
142 |
+
describing variable names to adjust.
|
143 |
+
permutations: A list of tuples of the form (match_str, permutation)
|
144 |
+
describing permutations to apply to given variables. Note that match_str
|
145 |
+
should match the original variable name, not the replaced one.
|
146 |
+
exclude_patterns: A list of string patterns to exclude variables from
|
147 |
+
checkpoint conversion.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
A dictionary that maps the new variable names to the Variable objects.
|
151 |
+
A dictionary that maps the old variable names to the new variable names.
|
152 |
+
"""
|
153 |
+
with tf.Graph().as_default():
|
154 |
+
tf.logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
|
155 |
+
reader = tf.train.NewCheckpointReader(checkpoint_from_path)
|
156 |
+
name_shape_map = reader.get_variable_to_shape_map()
|
157 |
+
new_variable_map = {}
|
158 |
+
conversion_map = {}
|
159 |
+
for var_name in name_shape_map:
|
160 |
+
if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
|
161 |
+
continue
|
162 |
+
# Get the original tensor data.
|
163 |
+
tensor = reader.get_tensor(var_name)
|
164 |
+
|
165 |
+
# Look up the new variable name, if any.
|
166 |
+
new_var_name = _bert_name_replacement(var_name, name_replacements)
|
167 |
+
|
168 |
+
# See if we need to reshape the underlying tensor.
|
169 |
+
new_shape = None
|
170 |
+
if num_heads > 0:
|
171 |
+
new_shape = _get_new_shape(new_var_name, tensor.shape, num_heads)
|
172 |
+
if new_shape:
|
173 |
+
tf.logging.info("Veriable %s has a shape change from %s to %s",
|
174 |
+
var_name, tensor.shape, new_shape)
|
175 |
+
tensor = np.reshape(tensor, new_shape)
|
176 |
+
|
177 |
+
# See if we need to permute the underlying tensor.
|
178 |
+
permutation = _get_permutation(var_name, permutations)
|
179 |
+
if permutation:
|
180 |
+
tensor = np.transpose(tensor, permutation)
|
181 |
+
|
182 |
+
# Create a new variable with the possibly-reshaped or transposed tensor.
|
183 |
+
var = tf.Variable(tensor, name=var_name)
|
184 |
+
|
185 |
+
# Save the variable into the new variable map.
|
186 |
+
new_variable_map[new_var_name] = var
|
187 |
+
|
188 |
+
# Keep a list of converter variables for sanity checking.
|
189 |
+
if new_var_name != var_name:
|
190 |
+
conversion_map[var_name] = new_var_name
|
191 |
+
|
192 |
+
saver = tf.train.Saver(new_variable_map)
|
193 |
+
|
194 |
+
with tf.Session() as sess:
|
195 |
+
sess.run(tf.global_variables_initializer())
|
196 |
+
tf.logging.info("Writing checkpoint_to_path %s", checkpoint_to_path)
|
197 |
+
saver.save(sess, checkpoint_to_path, write_meta_graph=False)
|
198 |
+
|
199 |
+
tf.logging.info("Summary:")
|
200 |
+
tf.logging.info(" Converted %d variable name(s).", len(new_variable_map))
|
201 |
+
tf.logging.info(" Converted: %s", str(conversion_map))
|
tf2_albert_encoder_checkpoint_converter.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A converter from a tf1 ALBERT encoder checkpoint to a tf2 encoder checkpoint.
|
16 |
+
|
17 |
+
The conversion will yield an object-oriented checkpoint that can be used
|
18 |
+
to restore an AlbertEncoder object.
|
19 |
+
"""
|
20 |
+
import os
|
21 |
+
|
22 |
+
from absl import app
|
23 |
+
from absl import flags
|
24 |
+
|
25 |
+
import tensorflow as tf, tf_keras
|
26 |
+
from official.legacy.albert import configs
|
27 |
+
from official.modeling import tf_utils
|
28 |
+
from official.nlp.modeling import models
|
29 |
+
from official.nlp.modeling import networks
|
30 |
+
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
|
31 |
+
|
32 |
+
FLAGS = flags.FLAGS
|
33 |
+
|
34 |
+
flags.DEFINE_string("albert_config_file", None,
|
35 |
+
"Albert configuration file to define core bert layers.")
|
36 |
+
flags.DEFINE_string(
|
37 |
+
"checkpoint_to_convert", None,
|
38 |
+
"Initial checkpoint from a pretrained BERT model core (that is, only the "
|
39 |
+
"BertModel, with no task heads.)")
|
40 |
+
flags.DEFINE_string("converted_checkpoint_path", None,
|
41 |
+
"Name for the created object-based V2 checkpoint.")
|
42 |
+
flags.DEFINE_string("checkpoint_model_name", "encoder",
|
43 |
+
"The name of the model when saving the checkpoint, i.e., "
|
44 |
+
"the checkpoint will be saved using: "
|
45 |
+
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
|
46 |
+
flags.DEFINE_enum(
|
47 |
+
"converted_model", "encoder", ["encoder", "pretrainer"],
|
48 |
+
"Whether to convert the checkpoint to a `AlbertEncoder` model or a "
|
49 |
+
"`BertPretrainerV2` model (with mlm but without classification heads).")
|
50 |
+
|
51 |
+
|
52 |
+
ALBERT_NAME_REPLACEMENTS = (
|
53 |
+
("bert/encoder/", ""),
|
54 |
+
("bert/", ""),
|
55 |
+
("embeddings/word_embeddings", "word_embeddings/embeddings"),
|
56 |
+
("embeddings/position_embeddings", "position_embedding/embeddings"),
|
57 |
+
("embeddings/token_type_embeddings", "type_embeddings/embeddings"),
|
58 |
+
("embeddings/LayerNorm", "embeddings/layer_norm"),
|
59 |
+
("embedding_hidden_mapping_in", "embedding_projection"),
|
60 |
+
("group_0/inner_group_0/", ""),
|
61 |
+
("attention_1/self", "self_attention"),
|
62 |
+
("attention_1/output/dense", "self_attention/attention_output"),
|
63 |
+
("transformer/LayerNorm/", "transformer/self_attention_layer_norm/"),
|
64 |
+
("ffn_1/intermediate/dense", "intermediate"),
|
65 |
+
("ffn_1/intermediate/output/dense", "output"),
|
66 |
+
("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
|
67 |
+
("pooler/dense", "pooler_transform"),
|
68 |
+
("cls/predictions", "bert/cls/predictions"),
|
69 |
+
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
|
70 |
+
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
|
71 |
+
("cls/seq_relationship/output_weights",
|
72 |
+
"predictions/transform/logits/kernel"),
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _create_albert_model(cfg):
|
77 |
+
"""Creates an ALBERT keras core model from BERT configuration.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
cfg: A `AlbertConfig` to create the core model.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
A keras model.
|
84 |
+
"""
|
85 |
+
albert_encoder = networks.AlbertEncoder(
|
86 |
+
vocab_size=cfg.vocab_size,
|
87 |
+
hidden_size=cfg.hidden_size,
|
88 |
+
embedding_width=cfg.embedding_size,
|
89 |
+
num_layers=cfg.num_hidden_layers,
|
90 |
+
num_attention_heads=cfg.num_attention_heads,
|
91 |
+
intermediate_size=cfg.intermediate_size,
|
92 |
+
activation=tf_utils.get_activation(cfg.hidden_act),
|
93 |
+
dropout_rate=cfg.hidden_dropout_prob,
|
94 |
+
attention_dropout_rate=cfg.attention_probs_dropout_prob,
|
95 |
+
max_sequence_length=cfg.max_position_embeddings,
|
96 |
+
type_vocab_size=cfg.type_vocab_size,
|
97 |
+
initializer=tf_keras.initializers.TruncatedNormal(
|
98 |
+
stddev=cfg.initializer_range))
|
99 |
+
return albert_encoder
|
100 |
+
|
101 |
+
|
102 |
+
def _create_pretrainer_model(cfg):
|
103 |
+
"""Creates a pretrainer with AlbertEncoder from ALBERT configuration.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
cfg: A `BertConfig` to create the core model.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
A BertPretrainerV2 model.
|
110 |
+
"""
|
111 |
+
albert_encoder = _create_albert_model(cfg)
|
112 |
+
pretrainer = models.BertPretrainerV2(
|
113 |
+
encoder_network=albert_encoder,
|
114 |
+
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
|
115 |
+
mlm_initializer=tf_keras.initializers.TruncatedNormal(
|
116 |
+
stddev=cfg.initializer_range))
|
117 |
+
# Makes sure masked_lm layer's variables in pretrainer are created.
|
118 |
+
_ = pretrainer(pretrainer.inputs)
|
119 |
+
return pretrainer
|
120 |
+
|
121 |
+
|
122 |
+
def convert_checkpoint(bert_config, output_path, v1_checkpoint,
|
123 |
+
checkpoint_model_name,
|
124 |
+
converted_model="encoder"):
|
125 |
+
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
|
126 |
+
output_dir, _ = os.path.split(output_path)
|
127 |
+
|
128 |
+
# Create a temporary V1 name-converted checkpoint in the output directory.
|
129 |
+
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
|
130 |
+
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
|
131 |
+
tf1_bert_checkpoint_converter_lib.convert(
|
132 |
+
checkpoint_from_path=v1_checkpoint,
|
133 |
+
checkpoint_to_path=temporary_checkpoint,
|
134 |
+
num_heads=bert_config.num_attention_heads,
|
135 |
+
name_replacements=ALBERT_NAME_REPLACEMENTS,
|
136 |
+
permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
|
137 |
+
exclude_patterns=["adam", "Adam"])
|
138 |
+
|
139 |
+
# Create a V2 checkpoint from the temporary checkpoint.
|
140 |
+
if converted_model == "encoder":
|
141 |
+
model = _create_albert_model(bert_config)
|
142 |
+
elif converted_model == "pretrainer":
|
143 |
+
model = _create_pretrainer_model(bert_config)
|
144 |
+
else:
|
145 |
+
raise ValueError("Unsupported converted_model: %s" % converted_model)
|
146 |
+
|
147 |
+
tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
|
148 |
+
model, temporary_checkpoint, output_path, checkpoint_model_name)
|
149 |
+
|
150 |
+
# Clean up the temporary checkpoint, if it exists.
|
151 |
+
try:
|
152 |
+
tf.io.gfile.rmtree(temporary_checkpoint_dir)
|
153 |
+
except tf.errors.OpError:
|
154 |
+
# If it doesn't exist, we don't need to clean it up; continue.
|
155 |
+
pass
|
156 |
+
|
157 |
+
|
158 |
+
def main(_):
|
159 |
+
output_path = FLAGS.converted_checkpoint_path
|
160 |
+
v1_checkpoint = FLAGS.checkpoint_to_convert
|
161 |
+
checkpoint_model_name = FLAGS.checkpoint_model_name
|
162 |
+
converted_model = FLAGS.converted_model
|
163 |
+
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
|
164 |
+
convert_checkpoint(albert_config, output_path, v1_checkpoint,
|
165 |
+
checkpoint_model_name,
|
166 |
+
converted_model=converted_model)
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
app.run(main)
|
tf2_bert_encoder_checkpoint_converter.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
|
16 |
+
|
17 |
+
The conversion will yield an object-oriented checkpoint that can be used
|
18 |
+
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
|
19 |
+
FLAG below).
|
20 |
+
"""
|
21 |
+
|
22 |
+
import os
|
23 |
+
|
24 |
+
from absl import app
|
25 |
+
from absl import flags
|
26 |
+
|
27 |
+
import tensorflow as tf, tf_keras
|
28 |
+
from official.legacy.bert import configs
|
29 |
+
from official.modeling import tf_utils
|
30 |
+
from official.nlp.modeling import models
|
31 |
+
from official.nlp.modeling import networks
|
32 |
+
from official.nlp.tools import tf1_bert_checkpoint_converter_lib
|
33 |
+
|
34 |
+
FLAGS = flags.FLAGS
|
35 |
+
|
36 |
+
flags.DEFINE_string("bert_config_file", None,
|
37 |
+
"Bert configuration file to define core bert layers.")
|
38 |
+
flags.DEFINE_string(
|
39 |
+
"checkpoint_to_convert", None,
|
40 |
+
"Initial checkpoint from a pretrained BERT model core (that is, only the "
|
41 |
+
"BertModel, with no task heads.)")
|
42 |
+
flags.DEFINE_string("converted_checkpoint_path", None,
|
43 |
+
"Name for the created object-based V2 checkpoint.")
|
44 |
+
flags.DEFINE_string("checkpoint_model_name", "encoder",
|
45 |
+
"The name of the model when saving the checkpoint, i.e., "
|
46 |
+
"the checkpoint will be saved using: "
|
47 |
+
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
|
48 |
+
flags.DEFINE_enum(
|
49 |
+
"converted_model", "encoder", ["encoder", "pretrainer"],
|
50 |
+
"Whether to convert the checkpoint to a `BertEncoder` model or a "
|
51 |
+
"`BertPretrainerV2` model (with mlm but without classification heads).")
|
52 |
+
|
53 |
+
|
54 |
+
def _create_bert_model(cfg):
|
55 |
+
"""Creates a BERT keras core model from BERT configuration.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
cfg: A `BertConfig` to create the core model.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
A BertEncoder network.
|
62 |
+
"""
|
63 |
+
bert_encoder = networks.BertEncoder(
|
64 |
+
vocab_size=cfg.vocab_size,
|
65 |
+
hidden_size=cfg.hidden_size,
|
66 |
+
num_layers=cfg.num_hidden_layers,
|
67 |
+
num_attention_heads=cfg.num_attention_heads,
|
68 |
+
intermediate_size=cfg.intermediate_size,
|
69 |
+
activation=tf_utils.get_activation(cfg.hidden_act),
|
70 |
+
dropout_rate=cfg.hidden_dropout_prob,
|
71 |
+
attention_dropout_rate=cfg.attention_probs_dropout_prob,
|
72 |
+
max_sequence_length=cfg.max_position_embeddings,
|
73 |
+
type_vocab_size=cfg.type_vocab_size,
|
74 |
+
initializer=tf_keras.initializers.TruncatedNormal(
|
75 |
+
stddev=cfg.initializer_range),
|
76 |
+
embedding_width=cfg.embedding_size)
|
77 |
+
|
78 |
+
return bert_encoder
|
79 |
+
|
80 |
+
|
81 |
+
def _create_bert_pretrainer_model(cfg):
|
82 |
+
"""Creates a BERT keras core model from BERT configuration.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
cfg: A `BertConfig` to create the core model.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
A BertPretrainerV2 model.
|
89 |
+
"""
|
90 |
+
bert_encoder = _create_bert_model(cfg)
|
91 |
+
pretrainer = models.BertPretrainerV2(
|
92 |
+
encoder_network=bert_encoder,
|
93 |
+
mlm_activation=tf_utils.get_activation(cfg.hidden_act),
|
94 |
+
mlm_initializer=tf_keras.initializers.TruncatedNormal(
|
95 |
+
stddev=cfg.initializer_range))
|
96 |
+
# Makes sure the pretrainer variables are created.
|
97 |
+
_ = pretrainer(pretrainer.inputs)
|
98 |
+
return pretrainer
|
99 |
+
|
100 |
+
|
101 |
+
def convert_checkpoint(bert_config,
|
102 |
+
output_path,
|
103 |
+
v1_checkpoint,
|
104 |
+
checkpoint_model_name="model",
|
105 |
+
converted_model="encoder"):
|
106 |
+
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
|
107 |
+
output_dir, _ = os.path.split(output_path)
|
108 |
+
tf.io.gfile.makedirs(output_dir)
|
109 |
+
|
110 |
+
# Create a temporary V1 name-converted checkpoint in the output directory.
|
111 |
+
temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
|
112 |
+
temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
|
113 |
+
|
114 |
+
tf1_bert_checkpoint_converter_lib.convert(
|
115 |
+
checkpoint_from_path=v1_checkpoint,
|
116 |
+
checkpoint_to_path=temporary_checkpoint,
|
117 |
+
num_heads=bert_config.num_attention_heads,
|
118 |
+
name_replacements=(
|
119 |
+
tf1_bert_checkpoint_converter_lib.BERT_V2_NAME_REPLACEMENTS),
|
120 |
+
permutations=tf1_bert_checkpoint_converter_lib.BERT_V2_PERMUTATIONS,
|
121 |
+
exclude_patterns=["adam", "Adam"])
|
122 |
+
|
123 |
+
if converted_model == "encoder":
|
124 |
+
model = _create_bert_model(bert_config)
|
125 |
+
elif converted_model == "pretrainer":
|
126 |
+
model = _create_bert_pretrainer_model(bert_config)
|
127 |
+
else:
|
128 |
+
raise ValueError("Unsupported converted_model: %s" % converted_model)
|
129 |
+
|
130 |
+
# Create a V2 checkpoint from the temporary checkpoint.
|
131 |
+
tf1_bert_checkpoint_converter_lib.create_v2_checkpoint(
|
132 |
+
model, temporary_checkpoint, output_path, checkpoint_model_name)
|
133 |
+
|
134 |
+
# Clean up the temporary checkpoint, if it exists.
|
135 |
+
try:
|
136 |
+
tf.io.gfile.rmtree(temporary_checkpoint_dir)
|
137 |
+
except tf.errors.OpError:
|
138 |
+
# If it doesn't exist, we don't need to clean it up; continue.
|
139 |
+
pass
|
140 |
+
|
141 |
+
|
142 |
+
def main(argv):
|
143 |
+
if len(argv) > 1:
|
144 |
+
raise app.UsageError("Too many command-line arguments.")
|
145 |
+
|
146 |
+
output_path = FLAGS.converted_checkpoint_path
|
147 |
+
v1_checkpoint = FLAGS.checkpoint_to_convert
|
148 |
+
checkpoint_model_name = FLAGS.checkpoint_model_name
|
149 |
+
converted_model = FLAGS.converted_model
|
150 |
+
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
151 |
+
convert_checkpoint(
|
152 |
+
bert_config=bert_config,
|
153 |
+
output_path=output_path,
|
154 |
+
v1_checkpoint=v1_checkpoint,
|
155 |
+
checkpoint_model_name=checkpoint_model_name,
|
156 |
+
converted_model=converted_model)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
app.run(main)
|
tokenization_test.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import tempfile
|
17 |
+
|
18 |
+
import six
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.nlp.tools import tokenization
|
22 |
+
|
23 |
+
|
24 |
+
class TokenizationTest(tf.test.TestCase):
|
25 |
+
"""Tokenization test.
|
26 |
+
|
27 |
+
The implementation is forked from
|
28 |
+
https://github.com/google-research/bert/blob/master/tokenization_test.py."
|
29 |
+
"""
|
30 |
+
|
31 |
+
def test_full_tokenizer(self):
|
32 |
+
vocab_tokens = [
|
33 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
34 |
+
"##ing", ","
|
35 |
+
]
|
36 |
+
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
|
37 |
+
if six.PY2:
|
38 |
+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
39 |
+
else:
|
40 |
+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens
|
41 |
+
]).encode("utf-8"))
|
42 |
+
|
43 |
+
vocab_file = vocab_writer.name
|
44 |
+
|
45 |
+
tokenizer = tokenization.FullTokenizer(vocab_file)
|
46 |
+
os.unlink(vocab_file)
|
47 |
+
|
48 |
+
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
49 |
+
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
50 |
+
|
51 |
+
self.assertAllEqual(
|
52 |
+
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
53 |
+
|
54 |
+
def test_chinese(self):
|
55 |
+
tokenizer = tokenization.BasicTokenizer()
|
56 |
+
|
57 |
+
self.assertAllEqual(
|
58 |
+
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
59 |
+
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
60 |
+
|
61 |
+
def test_basic_tokenizer_lower(self):
|
62 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
|
63 |
+
|
64 |
+
self.assertAllEqual(
|
65 |
+
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
66 |
+
["hello", "!", "how", "are", "you", "?"])
|
67 |
+
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
68 |
+
|
69 |
+
def test_basic_tokenizer_no_lower(self):
|
70 |
+
tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
|
71 |
+
|
72 |
+
self.assertAllEqual(
|
73 |
+
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
74 |
+
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
75 |
+
|
76 |
+
def test_basic_tokenizer_no_split_on_punc(self):
|
77 |
+
tokenizer = tokenization.BasicTokenizer(
|
78 |
+
do_lower_case=True, split_on_punc=False)
|
79 |
+
|
80 |
+
self.assertAllEqual(
|
81 |
+
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
82 |
+
["hello!how", "are", "you?"])
|
83 |
+
|
84 |
+
def test_wordpiece_tokenizer(self):
|
85 |
+
vocab_tokens = [
|
86 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
87 |
+
"##ing", "##!", "!"
|
88 |
+
]
|
89 |
+
|
90 |
+
vocab = {}
|
91 |
+
for (i, token) in enumerate(vocab_tokens):
|
92 |
+
vocab[token] = i
|
93 |
+
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
|
94 |
+
|
95 |
+
self.assertAllEqual(tokenizer.tokenize(""), [])
|
96 |
+
|
97 |
+
self.assertAllEqual(
|
98 |
+
tokenizer.tokenize("unwanted running"),
|
99 |
+
["un", "##want", "##ed", "runn", "##ing"])
|
100 |
+
|
101 |
+
self.assertAllEqual(
|
102 |
+
tokenizer.tokenize("unwanted running !"),
|
103 |
+
["un", "##want", "##ed", "runn", "##ing", "!"])
|
104 |
+
|
105 |
+
self.assertAllEqual(
|
106 |
+
tokenizer.tokenize("unwanted running!"),
|
107 |
+
["un", "##want", "##ed", "runn", "##ing", "##!"])
|
108 |
+
|
109 |
+
self.assertAllEqual(
|
110 |
+
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
111 |
+
|
112 |
+
def test_convert_tokens_to_ids(self):
|
113 |
+
vocab_tokens = [
|
114 |
+
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
115 |
+
"##ing"
|
116 |
+
]
|
117 |
+
|
118 |
+
vocab = {}
|
119 |
+
for (i, token) in enumerate(vocab_tokens):
|
120 |
+
vocab[token] = i
|
121 |
+
|
122 |
+
self.assertAllEqual(
|
123 |
+
tokenization.convert_tokens_to_ids(
|
124 |
+
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
|
125 |
+
|
126 |
+
def test_is_whitespace(self):
|
127 |
+
self.assertTrue(tokenization._is_whitespace(u" "))
|
128 |
+
self.assertTrue(tokenization._is_whitespace(u"\t"))
|
129 |
+
self.assertTrue(tokenization._is_whitespace(u"\r"))
|
130 |
+
self.assertTrue(tokenization._is_whitespace(u"\n"))
|
131 |
+
self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
|
132 |
+
|
133 |
+
self.assertFalse(tokenization._is_whitespace(u"A"))
|
134 |
+
self.assertFalse(tokenization._is_whitespace(u"-"))
|
135 |
+
|
136 |
+
def test_is_control(self):
|
137 |
+
self.assertTrue(tokenization._is_control(u"\u0005"))
|
138 |
+
|
139 |
+
self.assertFalse(tokenization._is_control(u"A"))
|
140 |
+
self.assertFalse(tokenization._is_control(u" "))
|
141 |
+
self.assertFalse(tokenization._is_control(u"\t"))
|
142 |
+
self.assertFalse(tokenization._is_control(u"\r"))
|
143 |
+
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
|
144 |
+
|
145 |
+
def test_is_punctuation(self):
|
146 |
+
self.assertTrue(tokenization._is_punctuation(u"-"))
|
147 |
+
self.assertTrue(tokenization._is_punctuation(u"$"))
|
148 |
+
self.assertTrue(tokenization._is_punctuation(u"`"))
|
149 |
+
self.assertTrue(tokenization._is_punctuation(u"."))
|
150 |
+
|
151 |
+
self.assertFalse(tokenization._is_punctuation(u"A"))
|
152 |
+
self.assertFalse(tokenization._is_punctuation(u" "))
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
tf.test.main()
|