Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Script for training a masked language model on TPU.""" | |
| import argparse | |
| import logging | |
| import os | |
| import re | |
| import tensorflow as tf | |
| from transformers import ( | |
| AutoConfig, | |
| AutoTokenizer, | |
| DataCollatorForLanguageModeling, | |
| PushToHubCallback, | |
| TFAutoModelForMaskedLM, | |
| create_optimizer, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| AUTO = tf.data.AUTOTUNE | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train a masked language model on TPU.") | |
| parser.add_argument( | |
| "--pretrained_model_config", | |
| type=str, | |
| default="roberta-base", | |
| help="The model config to use. Note that we don't copy the model's weights, only the config!", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer", | |
| type=str, | |
| default="unigram-tokenizer-wikitext", | |
| help="The name of the tokenizer to load. We use the pretrained tokenizer to initialize the model's vocab size.", | |
| ) | |
| parser.add_argument( | |
| "--per_replica_batch_size", | |
| type=int, | |
| default=8, | |
| help="Batch size per TPU core.", | |
| ) | |
| parser.add_argument( | |
| "--no_tpu", | |
| action="store_true", | |
| help="If set, run on CPU and don't try to initialize a TPU. Useful for debugging on non-TPU instances.", | |
| ) | |
| parser.add_argument( | |
| "--tpu_name", | |
| type=str, | |
| help="Name of TPU resource to initialize. Should be blank on Colab, and 'local' on TPU VMs.", | |
| default="local", | |
| ) | |
| parser.add_argument( | |
| "--tpu_zone", | |
| type=str, | |
| help="Google cloud zone that TPU resource is located in. Only used for non-Colab TPU nodes.", | |
| ) | |
| parser.add_argument( | |
| "--gcp_project", type=str, help="Google cloud project name. Only used for non-Colab TPU nodes." | |
| ) | |
| parser.add_argument( | |
| "--bfloat16", | |
| action="store_true", | |
| help="Use mixed-precision bfloat16 for training. This is the recommended lower-precision format for TPU.", | |
| ) | |
| parser.add_argument( | |
| "--train_dataset", | |
| type=str, | |
| help="Path to training dataset to load. If the path begins with `gs://`" | |
| " then the dataset will be loaded from a Google Cloud Storage bucket.", | |
| ) | |
| parser.add_argument( | |
| "--shuffle_buffer_size", | |
| type=int, | |
| default=2**18, # Default corresponds to a 1GB buffer for seq_len 512 | |
| help="Size of the shuffle buffer (in samples)", | |
| ) | |
| parser.add_argument( | |
| "--eval_dataset", | |
| type=str, | |
| help="Path to evaluation dataset to load. If the path begins with `gs://`" | |
| " then the dataset will be loaded from a Google Cloud Storage bucket.", | |
| ) | |
| parser.add_argument( | |
| "--num_epochs", | |
| type=int, | |
| default=1, | |
| help="Number of epochs to train for.", | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=1e-4, | |
| help="Learning rate to use for training.", | |
| ) | |
| parser.add_argument( | |
| "--weight_decay_rate", | |
| type=float, | |
| default=1e-3, | |
| help="Weight decay rate to use for training.", | |
| ) | |
| parser.add_argument( | |
| "--max_length", | |
| type=int, | |
| default=512, | |
| help="Maximum length of tokenized sequences. Should match the setting used in prepare_tfrecord_shards.py", | |
| ) | |
| parser.add_argument( | |
| "--mlm_probability", | |
| type=float, | |
| default=0.15, | |
| help="Fraction of tokens to mask during training.", | |
| ) | |
| parser.add_argument("--output_dir", type=str, required=True, help="Path to save model checkpoints to.") | |
| parser.add_argument("--hub_model_id", type=str, help="Model ID to upload to on the Hugging Face Hub.") | |
| args = parser.parse_args() | |
| return args | |
| def initialize_tpu(args): | |
| try: | |
| if args.tpu_name: | |
| tpu = tf.distribute.cluster_resolver.TPUClusterResolver( | |
| args.tpu_name, zone=args.tpu_zone, project=args.gcp_project | |
| ) | |
| else: | |
| tpu = tf.distribute.cluster_resolver.TPUClusterResolver() | |
| except ValueError: | |
| raise RuntimeError( | |
| "Couldn't connect to TPU! Most likely you need to specify --tpu_name, --tpu_zone, or " | |
| "--gcp_project. When running on a TPU VM, use --tpu_name local." | |
| ) | |
| tf.config.experimental_connect_to_cluster(tpu) | |
| tf.tpu.experimental.initialize_tpu_system(tpu) | |
| return tpu | |
| def count_samples(file_list): | |
| num_samples = 0 | |
| for file in file_list: | |
| filename = file.split("/")[-1] | |
| sample_count = re.search(r"-\d+-(\d+)\.tfrecord", filename).group(1) | |
| sample_count = int(sample_count) | |
| num_samples += sample_count | |
| return num_samples | |
| def prepare_dataset(records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None): | |
| num_samples = count_samples(records) | |
| dataset = tf.data.Dataset.from_tensor_slices(records) | |
| if shuffle: | |
| dataset = dataset.shuffle(len(dataset)) | |
| dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=AUTO) | |
| # TF can't infer the total sample count because it doesn't read all the records yet, so we assert it here | |
| dataset = dataset.apply(tf.data.experimental.assert_cardinality(num_samples)) | |
| dataset = dataset.map(decode_fn, num_parallel_calls=AUTO) | |
| if shuffle: | |
| assert shuffle_buffer_size is not None | |
| dataset = dataset.shuffle(args.shuffle_buffer_size) | |
| dataset = dataset.batch(batch_size, drop_remainder=True) | |
| dataset = dataset.map(mask_fn, num_parallel_calls=AUTO) | |
| dataset = dataset.prefetch(AUTO) | |
| return dataset | |
| def main(args): | |
| if not args.no_tpu: | |
| tpu = initialize_tpu(args) | |
| strategy = tf.distribute.TPUStrategy(tpu) | |
| else: | |
| strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") | |
| if args.bfloat16: | |
| tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") | |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
| config = AutoConfig.from_pretrained(args.pretrained_model_config) | |
| config.vocab_size = tokenizer.vocab_size | |
| training_records = tf.io.gfile.glob(os.path.join(args.train_dataset, "*.tfrecord")) | |
| if not training_records: | |
| raise ValueError(f"No .tfrecord files found in {args.train_dataset}.") | |
| eval_records = tf.io.gfile.glob(os.path.join(args.eval_dataset, "*.tfrecord")) | |
| if not eval_records: | |
| raise ValueError(f"No .tfrecord files found in {args.eval_dataset}.") | |
| num_train_samples = count_samples(training_records) | |
| steps_per_epoch = num_train_samples // (args.per_replica_batch_size * strategy.num_replicas_in_sync) | |
| total_train_steps = steps_per_epoch * args.num_epochs | |
| with strategy.scope(): | |
| model = TFAutoModelForMaskedLM.from_config(config) | |
| model(model.dummy_inputs) # Pass some dummy inputs through the model to ensure all the weights are built | |
| optimizer, schedule = create_optimizer( | |
| num_train_steps=total_train_steps, | |
| num_warmup_steps=total_train_steps // 20, | |
| init_lr=args.learning_rate, | |
| weight_decay_rate=args.weight_decay_rate, | |
| ) | |
| # Transformers models compute the right loss for their task by default when labels are passed, and will | |
| # use this for training unless you specify your own loss function in compile(). | |
| model.compile(optimizer=optimizer, metrics=["accuracy"]) | |
| def decode_fn(example): | |
| features = { | |
| "input_ids": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)), | |
| "attention_mask": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)), | |
| } | |
| return tf.io.parse_single_example(example, features) | |
| # Many of the data collators in Transformers are TF-compilable when return_tensors == "tf", so we can | |
| # use their methods in our data pipeline. | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, mlm_probability=args.mlm_probability, mlm=True, return_tensors="tf" | |
| ) | |
| def mask_with_collator(batch): | |
| # TF really needs an isin() function | |
| special_tokens_mask = ( | |
| ~tf.cast(batch["attention_mask"], tf.bool) | |
| | (batch["input_ids"] == tokenizer.cls_token_id) | |
| | (batch["input_ids"] == tokenizer.sep_token_id) | |
| ) | |
| batch["input_ids"], batch["labels"] = data_collator.tf_mask_tokens( | |
| batch["input_ids"], | |
| vocab_size=len(tokenizer), | |
| mask_token_id=tokenizer.mask_token_id, | |
| special_tokens_mask=special_tokens_mask, | |
| ) | |
| return batch | |
| batch_size = args.per_replica_batch_size * strategy.num_replicas_in_sync | |
| train_dataset = prepare_dataset( | |
| training_records, | |
| decode_fn=decode_fn, | |
| mask_fn=mask_with_collator, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| shuffle_buffer_size=args.shuffle_buffer_size, | |
| ) | |
| eval_dataset = prepare_dataset( | |
| eval_records, | |
| decode_fn=decode_fn, | |
| mask_fn=mask_with_collator, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| ) | |
| callbacks = [] | |
| if args.hub_model_id: | |
| callbacks.append( | |
| PushToHubCallback(output_dir=args.output_dir, hub_model_id=args.hub_model_id, tokenizer=tokenizer) | |
| ) | |
| model.fit( | |
| train_dataset, | |
| validation_data=eval_dataset, | |
| epochs=args.num_epochs, | |
| callbacks=callbacks, | |
| ) | |
| model.save_pretrained(args.output_dir) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |