Spaces:
No application file
No application file
# coding=utf-8 | |
# Copyright 2018 The Google AI Team Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Exports a minimal TF-Hub module for ALBERT models.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
from absl import app | |
from absl import flags | |
from albert import modeling | |
import tensorflow.compat.v1 as tf | |
import tensorflow_hub as hub | |
flags.DEFINE_string( | |
"albert_directory", None, | |
"The config json file corresponding to the pre-trained ALBERT model. " | |
"This specifies the model architecture.") | |
flags.DEFINE_string( | |
"checkpoint_name", "model.ckpt-best", | |
"Name of the checkpoint under albert_directory to be exported.") | |
flags.DEFINE_bool( | |
"do_lower_case", True, | |
"Whether to lower case the input text. Should be True for uncased " | |
"models and False for cased models.") | |
flags.DEFINE_bool( | |
"use_einsum", True, | |
"Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must " | |
"be set to False for TFLite compatibility.") | |
flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.") | |
FLAGS = flags.FLAGS | |
def gather_indexes(sequence_tensor, positions): | |
"""Gathers the vectors at the specific positions over a minibatch.""" | |
sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) | |
batch_size = sequence_shape[0] | |
seq_length = sequence_shape[1] | |
width = sequence_shape[2] | |
flat_offsets = tf.reshape( | |
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) | |
flat_positions = tf.reshape(positions + flat_offsets, [-1]) | |
flat_sequence_tensor = tf.reshape(sequence_tensor, | |
[batch_size * seq_length, width]) | |
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) | |
return output_tensor | |
def get_mlm_logits(model, albert_config, mlm_positions): | |
"""From run_pretraining.py.""" | |
input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions) | |
with tf.variable_scope("cls/predictions"): | |
# We apply one more non-linear transformation before the output layer. | |
# This matrix is not used after pre-training. | |
with tf.variable_scope("transform"): | |
input_tensor = tf.layers.dense( | |
input_tensor, | |
units=albert_config.embedding_size, | |
activation=modeling.get_activation(albert_config.hidden_act), | |
kernel_initializer=modeling.create_initializer( | |
albert_config.initializer_range)) | |
input_tensor = modeling.layer_norm(input_tensor) | |
# The output weights are the same as the input embeddings, but there is | |
# an output-only bias for each token. | |
output_bias = tf.get_variable( | |
"output_bias", | |
shape=[albert_config.vocab_size], | |
initializer=tf.zeros_initializer()) | |
logits = tf.matmul( | |
input_tensor, model.get_embedding_table(), transpose_b=True) | |
logits = tf.nn.bias_add(logits, output_bias) | |
return logits | |
def module_fn(is_training): | |
"""Module function.""" | |
input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") | |
input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") | |
segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") | |
mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") | |
albert_config_path = os.path.join( | |
FLAGS.albert_directory, "albert_config.json") | |
albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) | |
model = modeling.AlbertModel( | |
config=albert_config, | |
is_training=is_training, | |
input_ids=input_ids, | |
input_mask=input_mask, | |
token_type_ids=segment_ids, | |
use_one_hot_embeddings=False, | |
use_einsum=FLAGS.use_einsum) | |
mlm_logits = get_mlm_logits(model, albert_config, mlm_positions) | |
vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model") | |
vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab") | |
config_file = tf.constant( | |
value=albert_config_path, dtype=tf.string, name="config_file") | |
vocab_model = tf.constant( | |
value=vocab_model_path, dtype=tf.string, name="vocab_model") | |
# This is only for visualization purpose. | |
vocab_file = tf.constant( | |
value=vocab_file_path, dtype=tf.string, name="vocab_file") | |
# By adding `config_file, vocab_model and vocab_file` | |
# to the ASSET_FILEPATHS collection, TF-Hub will | |
# rewrite this tensor so that this asset is portable. | |
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file) | |
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model) | |
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file) | |
hub.add_signature( | |
name="tokens", | |
inputs=dict( | |
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), | |
outputs=dict( | |
sequence_output=model.get_sequence_output(), | |
pooled_output=model.get_pooled_output())) | |
hub.add_signature( | |
name="mlm", | |
inputs=dict( | |
input_ids=input_ids, | |
input_mask=input_mask, | |
segment_ids=segment_ids, | |
mlm_positions=mlm_positions), | |
outputs=dict( | |
sequence_output=model.get_sequence_output(), | |
pooled_output=model.get_pooled_output(), | |
mlm_logits=mlm_logits)) | |
hub.add_signature( | |
name="tokenization_info", | |
inputs={}, | |
outputs=dict( | |
vocab_file=vocab_model, | |
do_lower_case=tf.constant(FLAGS.do_lower_case))) | |
def main(_): | |
tags_and_args = [] | |
for is_training in (True, False): | |
tags = set() | |
if is_training: | |
tags.add("train") | |
tags_and_args.append((tags, dict(is_training=is_training))) | |
spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args) | |
checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) | |
tf.logging.info("Using checkpoint {}".format(checkpoint_path)) | |
spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path) | |
if __name__ == "__main__": | |
flags.mark_flag_as_required("albert_directory") | |
flags.mark_flag_as_required("export_path") | |
app.run(main) | |