File size: 4,728 Bytes
d08dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# coding=utf-8
# Copyright 2018 The Google AI Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Tests for run_pretraining."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import random
import tempfile
from absl.testing import flagsaver
from albert import modeling
from albert import run_pretraining
import tensorflow.compat.v1 as tf

FLAGS = tf.app.flags.FLAGS


def _create_config_file(filename, max_seq_length, vocab_size):
  """Creates an AlbertConfig and saves it to file."""
  albert_config = modeling.AlbertConfig(
      vocab_size,
      embedding_size=5,
      hidden_size=14,
      num_hidden_layers=3,
      num_hidden_groups=1,
      num_attention_heads=2,
      intermediate_size=19,
      inner_group_num=1,
      down_scale_factor=1,
      hidden_act="gelu",
      hidden_dropout_prob=0,
      attention_probs_dropout_prob=0,
      max_position_embeddings=max_seq_length,
      type_vocab_size=2,
      initializer_range=0.02)
  with tf.gfile.Open(filename, "w") as outfile:
    outfile.write(albert_config.to_json_string())


def _create_record(max_predictions_per_seq, max_seq_length, vocab_size):
  """Returns a tf.train.Example containing random data."""
  example = tf.train.Example()
  example.features.feature["input_ids"].int64_list.value.extend(
      [random.randint(0, vocab_size - 1) for _ in range(max_seq_length)])
  example.features.feature["input_mask"].int64_list.value.extend(
      [random.randint(0, 1) for _ in range(max_seq_length)])
  example.features.feature["masked_lm_positions"].int64_list.value.extend([
      random.randint(0, max_seq_length - 1)
      for _ in range(max_predictions_per_seq)
  ])
  example.features.feature["masked_lm_ids"].int64_list.value.extend([
      random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq)
  ])
  example.features.feature["masked_lm_weights"].float_list.value.extend(
      [1. for _ in range(max_predictions_per_seq)])
  example.features.feature["segment_ids"].int64_list.value.extend(
      [0 for _ in range(max_seq_length)])
  example.features.feature["next_sentence_labels"].int64_list.value.append(
      random.randint(0, 1))
  return example


def _create_input_file(filename,
                       max_predictions_per_seq,
                       max_seq_length,
                       vocab_size,
                       size=1000):
  """Creates an input TFRecord file of specified size."""
  with tf.io.TFRecordWriter(filename) as writer:
    for _ in range(size):
      ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size)
      writer.write(ex.SerializeToString())


class RunPretrainingTest(tf.test.TestCase):

  def _verify_output_file(self, basename):
    self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename)))

  def _verify_checkpoint_files(self, name):
    self._verify_output_file(name + ".meta")
    self._verify_output_file(name + ".index")
    self._verify_output_file(name + ".data-00000-of-00001")

  @flagsaver.flagsaver
  def test_pretraining(self):
    # Set up required flags.
    vocab_size = 97
    FLAGS.max_predictions_per_seq = 7
    FLAGS.max_seq_length = 13
    FLAGS.output_dir = tempfile.mkdtemp("output_dir")
    FLAGS.albert_config_file = os.path.join(
        tempfile.mkdtemp("config_dir"), "albert_config.json")
    FLAGS.input_file = os.path.join(
        tempfile.mkdtemp("input_dir"), "input_data.tfrecord")
    FLAGS.do_train = True
    FLAGS.do_eval = True
    FLAGS.num_train_steps = 1
    FLAGS.save_checkpoints_steps = 1

    # Construct requisite input files.
    _create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length,
                        vocab_size)
    _create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq,
                       FLAGS.max_seq_length, vocab_size)

    # Run the pretraining.
    run_pretraining.main(None)

    # Verify output.
    self._verify_checkpoint_files("model.ckpt-best")
    self._verify_checkpoint_files("model.ckpt-1")
    self._verify_output_file("eval_results.txt")
    self._verify_output_file("checkpoint")


if __name__ == "__main__":
  tf.test.main()