# coding=utf-8 # Copyright 2018 The Google AI Team Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python2, python3 """Tests for run_pretraining.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import random import tempfile from absl.testing import flagsaver from albert import modeling from albert import run_pretraining import tensorflow.compat.v1 as tf FLAGS = tf.app.flags.FLAGS def _create_config_file(filename, max_seq_length, vocab_size): """Creates an AlbertConfig and saves it to file.""" albert_config = modeling.AlbertConfig( vocab_size, embedding_size=5, hidden_size=14, num_hidden_layers=3, num_hidden_groups=1, num_attention_heads=2, intermediate_size=19, inner_group_num=1, down_scale_factor=1, hidden_act="gelu", hidden_dropout_prob=0, attention_probs_dropout_prob=0, max_position_embeddings=max_seq_length, type_vocab_size=2, initializer_range=0.02) with tf.gfile.Open(filename, "w") as outfile: outfile.write(albert_config.to_json_string()) def _create_record(max_predictions_per_seq, max_seq_length, vocab_size): """Returns a tf.train.Example containing random data.""" example = tf.train.Example() example.features.feature["input_ids"].int64_list.value.extend( [random.randint(0, vocab_size - 1) for _ in range(max_seq_length)]) example.features.feature["input_mask"].int64_list.value.extend( [random.randint(0, 1) for _ in range(max_seq_length)]) example.features.feature["masked_lm_positions"].int64_list.value.extend([ random.randint(0, max_seq_length - 1) for _ in range(max_predictions_per_seq) ]) example.features.feature["masked_lm_ids"].int64_list.value.extend([ random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq) ]) example.features.feature["masked_lm_weights"].float_list.value.extend( [1. for _ in range(max_predictions_per_seq)]) example.features.feature["segment_ids"].int64_list.value.extend( [0 for _ in range(max_seq_length)]) example.features.feature["next_sentence_labels"].int64_list.value.append( random.randint(0, 1)) return example def _create_input_file(filename, max_predictions_per_seq, max_seq_length, vocab_size, size=1000): """Creates an input TFRecord file of specified size.""" with tf.io.TFRecordWriter(filename) as writer: for _ in range(size): ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size) writer.write(ex.SerializeToString()) class RunPretrainingTest(tf.test.TestCase): def _verify_output_file(self, basename): self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename))) def _verify_checkpoint_files(self, name): self._verify_output_file(name + ".meta") self._verify_output_file(name + ".index") self._verify_output_file(name + ".data-00000-of-00001") @flagsaver.flagsaver def test_pretraining(self): # Set up required flags. vocab_size = 97 FLAGS.max_predictions_per_seq = 7 FLAGS.max_seq_length = 13 FLAGS.output_dir = tempfile.mkdtemp("output_dir") FLAGS.albert_config_file = os.path.join( tempfile.mkdtemp("config_dir"), "albert_config.json") FLAGS.input_file = os.path.join( tempfile.mkdtemp("input_dir"), "input_data.tfrecord") FLAGS.do_train = True FLAGS.do_eval = True FLAGS.num_train_steps = 1 FLAGS.save_checkpoints_steps = 1 # Construct requisite input files. _create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length, vocab_size) _create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq, FLAGS.max_seq_length, vocab_size) # Run the pretraining. run_pretraining.main(None) # Verify output. self._verify_checkpoint_files("model.ckpt-best") self._verify_checkpoint_files("model.ckpt-1") self._verify_output_file("eval_results.txt") self._verify_output_file("checkpoint") if __name__ == "__main__": tf.test.main()