|
import collections |
|
from transformers import GPT2TokenizerFast |
|
import tensorflow as tf |
|
|
|
import sys |
|
sys.path.append("..") |
|
from arabert.preprocess import preprocess |
|
|
|
flags = tf.flags |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string( |
|
"input_file", None, "Input raw text file (or comma-separated list of files)." |
|
) |
|
|
|
flags.DEFINE_string( |
|
"output_file", None, "Output TF example file (or comma-separated list of files)." |
|
) |
|
|
|
flags.DEFINE_string( |
|
"tokenizer_dir", None, "The directory of a pretrained GPT2TokenizerFast" |
|
) |
|
|
|
flags.DEFINE_integer( |
|
"max_len", 1024, "The vocabulary file that the BERT model was trained on." |
|
) |
|
|
|
flags.DEFINE_integer( |
|
"num_examples_print", 0, "Number of examples to print" |
|
) |
|
|
|
|
|
def create_int_feature(values): |
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) |
|
return feature |
|
|
|
|
|
def main(_): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
logger = tf.get_logger() |
|
logger.propagate = False |
|
|
|
input_files = [] |
|
for input_pattern in FLAGS.input_file.split(","): |
|
input_files.extend(tf.gfile.Glob(input_pattern)) |
|
|
|
tf.logging.info("*** Reading from input files ***") |
|
for input_file in input_files: |
|
tf.logging.info(" %s", input_file) |
|
|
|
gpt2_tok = GPT2TokenizerFast.from_pretrained(FLAGS.tokenizer_dir) |
|
|
|
writer = tf.python_io.TFRecordWriter(FLAGS.output_file + ".tfrecord") |
|
|
|
eos_id = gpt2_tok.eos_token_id |
|
all_examples = [] |
|
for input_file in input_files: |
|
queue = [] |
|
example = [] |
|
with tf.gfile.GFile(input_file, "r") as reader: |
|
for line in reader.readlines(): |
|
if line == "\n": |
|
queue.append(eos_id) |
|
else: |
|
line = line.replace("\n", " ") |
|
line = preprocess(line,model='gpt2-base-arabic') |
|
line = line.strip() |
|
enc_line = gpt2_tok.encode(line) |
|
queue.extend(enc_line) |
|
if len(queue) > FLAGS.max_len +1: |
|
example = [queue.pop(0) for _ in range(FLAGS.max_len +1)] |
|
assert len(example) == FLAGS.max_len +1 |
|
all_examples.append(example) |
|
|
|
|
|
for i, ex in enumerate(all_examples): |
|
features = collections.OrderedDict() |
|
features["input_ids"] = create_int_feature(ex) |
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) |
|
writer.write(tf_example.SerializeToString()) |
|
if i < FLAGS.num_examples_print: |
|
tf.logging.info("*** Example ***") |
|
tf.logging.info("Length: %d" % len(ex)) |
|
tf.logging.info("Tokens: %s" % gpt2_tok.decode(ex)) |
|
tf.logging.info("ids: %s" % " ".join([str(x) for x in ex])) |
|
|
|
tf.logging.info("Wrote %d total instances", len(all_examples)) |
|
|
|
|
|
if __name__ == "__main__": |
|
flags.mark_flag_as_required("input_file") |
|
flags.mark_flag_as_required("output_file") |
|
flags.mark_flag_as_required("tokenizer_dir") |
|
tf.app.run() |
|
|