Spaces:
No application file
No application file
File size: 4,728 Bytes
d08dd00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Tests for run_pretraining."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import tempfile
from absl.testing import flagsaver
from albert import modeling
from albert import run_pretraining
import tensorflow.compat.v1 as tf
FLAGS = tf.app.flags.FLAGS
def _create_config_file(filename, max_seq_length, vocab_size):
"""Creates an AlbertConfig and saves it to file."""
albert_config = modeling.AlbertConfig(
vocab_size,
embedding_size=5,
hidden_size=14,
num_hidden_layers=3,
num_hidden_groups=1,
num_attention_heads=2,
intermediate_size=19,
inner_group_num=1,
down_scale_factor=1,
hidden_act="gelu",
hidden_dropout_prob=0,
attention_probs_dropout_prob=0,
max_position_embeddings=max_seq_length,
type_vocab_size=2,
initializer_range=0.02)
with tf.gfile.Open(filename, "w") as outfile:
outfile.write(albert_config.to_json_string())
def _create_record(max_predictions_per_seq, max_seq_length, vocab_size):
"""Returns a tf.train.Example containing random data."""
example = tf.train.Example()
example.features.feature["input_ids"].int64_list.value.extend(
[random.randint(0, vocab_size - 1) for _ in range(max_seq_length)])
example.features.feature["input_mask"].int64_list.value.extend(
[random.randint(0, 1) for _ in range(max_seq_length)])
example.features.feature["masked_lm_positions"].int64_list.value.extend([
random.randint(0, max_seq_length - 1)
for _ in range(max_predictions_per_seq)
])
example.features.feature["masked_lm_ids"].int64_list.value.extend([
random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq)
])
example.features.feature["masked_lm_weights"].float_list.value.extend(
[1. for _ in range(max_predictions_per_seq)])
example.features.feature["segment_ids"].int64_list.value.extend(
[0 for _ in range(max_seq_length)])
example.features.feature["next_sentence_labels"].int64_list.value.append(
random.randint(0, 1))
return example
def _create_input_file(filename,
max_predictions_per_seq,
max_seq_length,
vocab_size,
size=1000):
"""Creates an input TFRecord file of specified size."""
with tf.io.TFRecordWriter(filename) as writer:
for _ in range(size):
ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size)
writer.write(ex.SerializeToString())
class RunPretrainingTest(tf.test.TestCase):
def _verify_output_file(self, basename):
self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename)))
def _verify_checkpoint_files(self, name):
self._verify_output_file(name + ".meta")
self._verify_output_file(name + ".index")
self._verify_output_file(name + ".data-00000-of-00001")
@flagsaver.flagsaver
def test_pretraining(self):
# Set up required flags.
vocab_size = 97
FLAGS.max_predictions_per_seq = 7
FLAGS.max_seq_length = 13
FLAGS.output_dir = tempfile.mkdtemp("output_dir")
FLAGS.albert_config_file = os.path.join(
tempfile.mkdtemp("config_dir"), "albert_config.json")
FLAGS.input_file = os.path.join(
tempfile.mkdtemp("input_dir"), "input_data.tfrecord")
FLAGS.do_train = True
FLAGS.do_eval = True
FLAGS.num_train_steps = 1
FLAGS.save_checkpoints_steps = 1
# Construct requisite input files.
_create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length,
vocab_size)
_create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq,
FLAGS.max_seq_length, vocab_size)
# Run the pretraining.
run_pretraining.main(None)
# Verify output.
self._verify_checkpoint_files("model.ckpt-best")
self._verify_checkpoint_files("model.ckpt-1")
self._verify_output_file("eval_results.txt")
self._verify_output_file("checkpoint")
if __name__ == "__main__":
tf.test.main()
|