Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2021 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import warnings | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Optional, Tuple | |
| import datasets | |
| import evaluate | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| from datasets import load_dataset | |
| from flax import struct, traverse_util | |
| from flax.jax_utils import pad_shard_unpad, replicate, unreplicate | |
| from flax.training import train_state | |
| from flax.training.common_utils import get_metrics, onehot, shard | |
| from huggingface_hub import Repository, create_repo | |
| from tqdm import tqdm | |
| import transformers | |
| from transformers import ( | |
| AutoConfig, | |
| AutoTokenizer, | |
| FlaxAutoModelForSequenceClassification, | |
| HfArgumentParser, | |
| PretrainedConfig, | |
| TrainingArguments, | |
| is_tensorboard_available, | |
| ) | |
| from transformers.utils import check_min_version, send_example_telemetry | |
| logger = logging.getLogger(__name__) | |
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | |
| check_min_version("4.34.0.dev0") | |
| Array = Any | |
| Dataset = datasets.arrow_dataset.Dataset | |
| PRNGKey = Any | |
| task_to_keys = { | |
| "cola": ("sentence", None), | |
| "mnli": ("premise", "hypothesis"), | |
| "mrpc": ("sentence1", "sentence2"), | |
| "qnli": ("question", "sentence"), | |
| "qqp": ("question1", "question2"), | |
| "rte": ("sentence1", "sentence2"), | |
| "sst2": ("sentence", None), | |
| "stsb": ("sentence1", "sentence2"), | |
| "wnli": ("sentence1", "sentence2"), | |
| } | |
| class ModelArguments: | |
| """ | |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
| """ | |
| model_name_or_path: str = field( | |
| metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
| ) | |
| config_name: Optional[str] = field( | |
| default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | |
| ) | |
| tokenizer_name: Optional[str] = field( | |
| default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | |
| ) | |
| use_slow_tokenizer: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."}, | |
| ) | |
| cache_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
| ) | |
| model_revision: str = field( | |
| default="main", | |
| metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, | |
| ) | |
| token: str = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " | |
| "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." | |
| ) | |
| }, | |
| ) | |
| use_auth_token: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." | |
| }, | |
| ) | |
| trust_remote_code: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": ( | |
| "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" | |
| "should only be set to `True` for repositories you trust and in which you have read the code, as it will" | |
| "execute code present on the Hub on your local machine." | |
| ) | |
| }, | |
| ) | |
| class DataTrainingArguments: | |
| """ | |
| Arguments pertaining to what data we are going to input our model for training and eval. | |
| """ | |
| task_name: Optional[str] = field( | |
| default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"} | |
| ) | |
| dataset_config_name: Optional[str] = field( | |
| default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
| ) | |
| train_file: Optional[str] = field( | |
| default=None, metadata={"help": "The input training data file (a csv or JSON file)."} | |
| ) | |
| validation_file: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, | |
| ) | |
| test_file: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, | |
| ) | |
| text_column_name: Optional[str] = field( | |
| default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} | |
| ) | |
| label_column_name: Optional[str] = field( | |
| default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} | |
| ) | |
| overwrite_cache: bool = field( | |
| default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
| ) | |
| preprocessing_num_workers: Optional[int] = field( | |
| default=None, | |
| metadata={"help": "The number of processes to use for the preprocessing."}, | |
| ) | |
| max_seq_length: int = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "The maximum total input sequence length after tokenization. If set, sequences longer " | |
| "than this will be truncated, sequences shorter will be padded." | |
| ) | |
| }, | |
| ) | |
| max_train_samples: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "For debugging purposes or quicker training, truncate the number of training examples to this " | |
| "value if set." | |
| ) | |
| }, | |
| ) | |
| max_eval_samples: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
| "value if set." | |
| ) | |
| }, | |
| ) | |
| max_predict_samples: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "For debugging purposes or quicker training, truncate the number of prediction examples to this " | |
| "value if set." | |
| ) | |
| }, | |
| ) | |
| def __post_init__(self): | |
| if self.task_name is None and self.train_file is None and self.validation_file is None: | |
| raise ValueError("Need either a dataset name or a training/validation file.") | |
| else: | |
| if self.train_file is not None: | |
| extension = self.train_file.split(".")[-1] | |
| assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." | |
| if self.validation_file is not None: | |
| extension = self.validation_file.split(".")[-1] | |
| assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." | |
| self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name | |
| def create_train_state( | |
| model: FlaxAutoModelForSequenceClassification, | |
| learning_rate_fn: Callable[[int], float], | |
| is_regression: bool, | |
| num_labels: int, | |
| weight_decay: float, | |
| ) -> train_state.TrainState: | |
| """Create initial training state.""" | |
| class TrainState(train_state.TrainState): | |
| """Train state with an Optax optimizer. | |
| The two functions below differ depending on whether the task is classification | |
| or regression. | |
| Args: | |
| logits_fn: Applied to last layer to obtain the logits. | |
| loss_fn: Function to compute the loss. | |
| """ | |
| logits_fn: Callable = struct.field(pytree_node=False) | |
| loss_fn: Callable = struct.field(pytree_node=False) | |
| # We use Optax's "masking" functionality to not apply weight decay | |
| # to bias and LayerNorm scale parameters. decay_mask_fn returns a | |
| # mask boolean with the same structure as the parameters. | |
| # The mask is True for parameters that should be decayed. | |
| def decay_mask_fn(params): | |
| flat_params = traverse_util.flatten_dict(params) | |
| # find out all LayerNorm parameters | |
| layer_norm_candidates = ["layernorm", "layer_norm", "ln"] | |
| layer_norm_named_params = { | |
| layer[-2:] | |
| for layer_norm_name in layer_norm_candidates | |
| for layer in flat_params.keys() | |
| if layer_norm_name in "".join(layer).lower() | |
| } | |
| flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} | |
| return traverse_util.unflatten_dict(flat_mask) | |
| tx = optax.adamw( | |
| learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn | |
| ) | |
| if is_regression: | |
| def mse_loss(logits, labels): | |
| return jnp.mean((logits[..., 0] - labels) ** 2) | |
| return TrainState.create( | |
| apply_fn=model.__call__, | |
| params=model.params, | |
| tx=tx, | |
| logits_fn=lambda logits: logits[..., 0], | |
| loss_fn=mse_loss, | |
| ) | |
| else: # Classification. | |
| def cross_entropy_loss(logits, labels): | |
| xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) | |
| return jnp.mean(xentropy) | |
| return TrainState.create( | |
| apply_fn=model.__call__, | |
| params=model.params, | |
| tx=tx, | |
| logits_fn=lambda logits: logits.argmax(-1), | |
| loss_fn=cross_entropy_loss, | |
| ) | |
| def create_learning_rate_fn( | |
| train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float | |
| ) -> Callable[[int], jnp.array]: | |
| """Returns a linear warmup, linear_decay learning rate function.""" | |
| steps_per_epoch = train_ds_size // train_batch_size | |
| num_train_steps = steps_per_epoch * num_train_epochs | |
| warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) | |
| decay_fn = optax.linear_schedule( | |
| init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps | |
| ) | |
| schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) | |
| return schedule_fn | |
| def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): | |
| """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices.""" | |
| steps_per_epoch = len(dataset) // batch_size | |
| perms = jax.random.permutation(rng, len(dataset)) | |
| perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch. | |
| perms = perms.reshape((steps_per_epoch, batch_size)) | |
| for perm in perms: | |
| batch = dataset[perm] | |
| batch = {k: np.array(v) for k, v in batch.items()} | |
| batch = shard(batch) | |
| yield batch | |
| def glue_eval_data_collator(dataset: Dataset, batch_size: int): | |
| """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop.""" | |
| batch_idx = np.arange(len(dataset)) | |
| steps_per_epoch = math.ceil(len(dataset) / batch_size) | |
| batch_idx = np.array_split(batch_idx, steps_per_epoch) | |
| for idx in batch_idx: | |
| batch = dataset[idx] | |
| batch = {k: np.array(v) for k, v in batch.items()} | |
| yield batch | |
| def main(): | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) | |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | |
| # If we pass only one argument to the script and it's the path to a json file, | |
| # let's parse it to get our arguments. | |
| model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | |
| else: | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| if model_args.use_auth_token is not None: | |
| warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning) | |
| if model_args.token is not None: | |
| raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") | |
| model_args.token = model_args.use_auth_token | |
| # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The | |
| # information sent is the one passed as arguments along with your Python/PyTorch versions. | |
| send_example_telemetry("run_glue", model_args, data_args, framework="flax") | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| # Setup logging, we only want one process per machine to log things on the screen. | |
| logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) | |
| if jax.process_index() == 0: | |
| datasets.utils.logging.set_verbosity_warning() | |
| transformers.utils.logging.set_verbosity_info() | |
| else: | |
| datasets.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| # Handle the repository creation | |
| if training_args.push_to_hub: | |
| # Retrieve of infer repo_name | |
| repo_name = training_args.hub_model_id | |
| if repo_name is None: | |
| repo_name = Path(training_args.output_dir).absolute().name | |
| # Create repo and retrieve repo_id | |
| repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id | |
| # Clone repo locally | |
| repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token) | |
| # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) | |
| # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). | |
| # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the | |
| # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named | |
| # label if at least two columns are provided. | |
| # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this | |
| # single column. You can easily tweak this behavior (see below) | |
| # In distributed training, the load_dataset function guarantee that only one local process can concurrently | |
| # download the dataset. | |
| if data_args.task_name is not None: | |
| # Downloading and loading a dataset from the hub. | |
| raw_datasets = load_dataset( | |
| "glue", | |
| data_args.task_name, | |
| token=model_args.token, | |
| ) | |
| else: | |
| # Loading the dataset from local csv or json file. | |
| data_files = {} | |
| if data_args.train_file is not None: | |
| data_files["train"] = data_args.train_file | |
| if data_args.validation_file is not None: | |
| data_files["validation"] = data_args.validation_file | |
| extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] | |
| raw_datasets = load_dataset( | |
| extension, | |
| data_files=data_files, | |
| token=model_args.token, | |
| ) | |
| # See more about loading any type of standard or custom dataset at | |
| # https://huggingface.co/docs/datasets/loading_datasets.html. | |
| # Labels | |
| if data_args.task_name is not None: | |
| is_regression = data_args.task_name == "stsb" | |
| if not is_regression: | |
| label_list = raw_datasets["train"].features["label"].names | |
| num_labels = len(label_list) | |
| else: | |
| num_labels = 1 | |
| else: | |
| # Trying to have good defaults here, don't hesitate to tweak to your needs. | |
| is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] | |
| if is_regression: | |
| num_labels = 1 | |
| else: | |
| # A useful fast method: | |
| # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique | |
| label_list = raw_datasets["train"].unique("label") | |
| label_list.sort() # Let's sort it for determinism | |
| num_labels = len(label_list) | |
| # Load pretrained model and tokenizer | |
| config = AutoConfig.from_pretrained( | |
| model_args.model_name_or_path, | |
| num_labels=num_labels, | |
| finetuning_task=data_args.task_name, | |
| token=model_args.token, | |
| trust_remote_code=model_args.trust_remote_code, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name_or_path, | |
| use_fast=not model_args.use_slow_tokenizer, | |
| token=model_args.token, | |
| trust_remote_code=model_args.trust_remote_code, | |
| ) | |
| model = FlaxAutoModelForSequenceClassification.from_pretrained( | |
| model_args.model_name_or_path, | |
| config=config, | |
| token=model_args.token, | |
| trust_remote_code=model_args.trust_remote_code, | |
| ) | |
| # Preprocessing the datasets | |
| if data_args.task_name is not None: | |
| sentence1_key, sentence2_key = task_to_keys[data_args.task_name] | |
| else: | |
| # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. | |
| non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] | |
| if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: | |
| sentence1_key, sentence2_key = "sentence1", "sentence2" | |
| else: | |
| if len(non_label_column_names) >= 2: | |
| sentence1_key, sentence2_key = non_label_column_names[:2] | |
| else: | |
| sentence1_key, sentence2_key = non_label_column_names[0], None | |
| # Some models have set the order of the labels to use, so let's make sure we do use it. | |
| label_to_id = None | |
| if ( | |
| model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id | |
| and data_args.task_name is not None | |
| and not is_regression | |
| ): | |
| # Some have all caps in their config, some don't. | |
| label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} | |
| if sorted(label_name_to_id.keys()) == sorted(label_list): | |
| logger.info( | |
| f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " | |
| "Using it!" | |
| ) | |
| label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} | |
| else: | |
| logger.warning( | |
| "Your model seems to have been trained with labels, but they don't match the dataset: ", | |
| f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." | |
| "\nIgnoring the model labels as a result.", | |
| ) | |
| elif data_args.task_name is None: | |
| label_to_id = {v: i for i, v in enumerate(label_list)} | |
| def preprocess_function(examples): | |
| # Tokenize the texts | |
| texts = ( | |
| (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) | |
| ) | |
| result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True) | |
| if "label" in examples: | |
| if label_to_id is not None: | |
| # Map labels to IDs (not necessary for GLUE tasks) | |
| result["labels"] = [label_to_id[l] for l in examples["label"]] | |
| else: | |
| # In all cases, rename the column to labels because the model will expect that. | |
| result["labels"] = examples["label"] | |
| return result | |
| processed_datasets = raw_datasets.map( | |
| preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names | |
| ) | |
| train_dataset = processed_datasets["train"] | |
| eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] | |
| # Log a few random samples from the training set: | |
| for index in random.sample(range(len(train_dataset)), 3): | |
| logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") | |
| # Define a summary writer | |
| has_tensorboard = is_tensorboard_available() | |
| if has_tensorboard and jax.process_index() == 0: | |
| try: | |
| from flax.metrics.tensorboard import SummaryWriter | |
| summary_writer = SummaryWriter(training_args.output_dir) | |
| summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) | |
| except ImportError as ie: | |
| has_tensorboard = False | |
| logger.warning( | |
| f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" | |
| ) | |
| else: | |
| logger.warning( | |
| "Unable to display metrics through TensorBoard because the package is not installed: " | |
| "Please run pip install tensorboard to enable." | |
| ) | |
| def write_train_metric(summary_writer, train_metrics, train_time, step): | |
| summary_writer.scalar("train_time", train_time, step) | |
| train_metrics = get_metrics(train_metrics) | |
| for key, vals in train_metrics.items(): | |
| tag = f"train_{key}" | |
| for i, val in enumerate(vals): | |
| summary_writer.scalar(tag, val, step - len(vals) + i + 1) | |
| def write_eval_metric(summary_writer, eval_metrics, step): | |
| for metric_name, value in eval_metrics.items(): | |
| summary_writer.scalar(f"eval_{metric_name}", value, step) | |
| num_epochs = int(training_args.num_train_epochs) | |
| rng = jax.random.PRNGKey(training_args.seed) | |
| dropout_rngs = jax.random.split(rng, jax.local_device_count()) | |
| train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count() | |
| per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) | |
| eval_batch_size = per_device_eval_batch_size * jax.device_count() | |
| learning_rate_fn = create_learning_rate_fn( | |
| len(train_dataset), | |
| train_batch_size, | |
| training_args.num_train_epochs, | |
| training_args.warmup_steps, | |
| training_args.learning_rate, | |
| ) | |
| state = create_train_state( | |
| model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay | |
| ) | |
| # define step functions | |
| def train_step( | |
| state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey | |
| ) -> Tuple[train_state.TrainState, float]: | |
| """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" | |
| dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) | |
| targets = batch.pop("labels") | |
| def loss_fn(params): | |
| logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] | |
| loss = state.loss_fn(logits, targets) | |
| return loss | |
| grad_fn = jax.value_and_grad(loss_fn) | |
| loss, grad = grad_fn(state.params) | |
| grad = jax.lax.pmean(grad, "batch") | |
| new_state = state.apply_gradients(grads=grad) | |
| metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") | |
| return new_state, metrics, new_dropout_rng | |
| p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) | |
| def eval_step(state, batch): | |
| logits = state.apply_fn(**batch, params=state.params, train=False)[0] | |
| return state.logits_fn(logits) | |
| p_eval_step = jax.pmap(eval_step, axis_name="batch") | |
| if data_args.task_name is not None: | |
| metric = evaluate.load("glue", data_args.task_name) | |
| else: | |
| metric = evaluate.load("accuracy") | |
| logger.info(f"===== Starting training ({num_epochs} epochs) =====") | |
| train_time = 0 | |
| # make sure weights are replicated on each device | |
| state = replicate(state) | |
| steps_per_epoch = len(train_dataset) // train_batch_size | |
| total_steps = steps_per_epoch * num_epochs | |
| epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0) | |
| for epoch in epochs: | |
| train_start = time.time() | |
| train_metrics = [] | |
| # Create sampling rng | |
| rng, input_rng = jax.random.split(rng) | |
| # train | |
| train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size) | |
| for step, batch in enumerate( | |
| tqdm( | |
| train_loader, | |
| total=steps_per_epoch, | |
| desc="Training...", | |
| position=1, | |
| ), | |
| ): | |
| state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) | |
| train_metrics.append(train_metric) | |
| cur_step = (epoch * steps_per_epoch) + (step + 1) | |
| if cur_step % training_args.logging_steps == 0 and cur_step > 0: | |
| # Save metrics | |
| train_metric = unreplicate(train_metric) | |
| train_time += time.time() - train_start | |
| if has_tensorboard and jax.process_index() == 0: | |
| write_train_metric(summary_writer, train_metrics, train_time, cur_step) | |
| epochs.write( | |
| f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" | |
| f" {train_metric['learning_rate']})" | |
| ) | |
| train_metrics = [] | |
| if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0: | |
| # evaluate | |
| eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size) | |
| for batch in tqdm( | |
| eval_loader, | |
| total=math.ceil(len(eval_dataset) / eval_batch_size), | |
| desc="Evaluating ...", | |
| position=2, | |
| ): | |
| labels = batch.pop("labels") | |
| predictions = pad_shard_unpad(p_eval_step)( | |
| state, batch, min_device_batch=per_device_eval_batch_size | |
| ) | |
| metric.add_batch(predictions=np.array(predictions), references=labels) | |
| eval_metric = metric.compute() | |
| logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})") | |
| if has_tensorboard and jax.process_index() == 0: | |
| write_eval_metric(summary_writer, eval_metric, cur_step) | |
| if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): | |
| # save checkpoint after each epoch and push checkpoint to the hub | |
| if jax.process_index() == 0: | |
| params = jax.device_get(unreplicate(state.params)) | |
| model.save_pretrained(training_args.output_dir, params=params) | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| if training_args.push_to_hub: | |
| repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) | |
| epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" | |
| # save the eval metrics in json | |
| if jax.process_index() == 0: | |
| eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} | |
| path = os.path.join(training_args.output_dir, "eval_results.json") | |
| with open(path, "w") as f: | |
| json.dump(eval_metric, f, indent=4, sort_keys=True) | |
| if __name__ == "__main__": | |
| main() | |