bertin-base-gaussian / run_mlm_flax_stream.py.orig
versae's picture
Dump 187k steps, acc 0.6574
9d93315
raw
history blame
30.8 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
import logging
import json
import os
import shutil
import sys
import time
from collections import defaultdict
from dataclasses import dataclass, field
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import joblib
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
import optax
from flax import jax_utils, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
TrainingArguments,
is_tensorboard_available,
set_seed,
)
if datasets.__version__ <= "1.8.0":
raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)
text_column_name: str = field(
default="text", metadata={"help": "The name of the column to retrieve the training text."}
)
shuffle_buffer_size: int = field(
default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
)
num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
@flax.struct.dataclass
class FlaxDataCollatorForLanguageModeling:
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input.
.. note::
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
"""
tokenizer: PreTrainedTokenizerBase
mlm_probability: float = 0.15
def __post_init__(self):
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
# Handle dict or lists with proper padding and conversion to tensor.
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
return batch
def mask_tokens(
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.copy()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = np.full(labels.shape, self.mlm_probability)
special_tokens_mask = special_tokens_mask.astype("bool")
probability_matrix[special_tokens_mask] = 0.0
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
indices_random &= masked_indices & ~indices_replaced
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
@dataclass
class SamplingArguments:
"""
Arguments pertaining to how to perform sampling of the dataset.
"""
perplexity_model: Optional[str] = field(
default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
)
sampling_method: Optional[str] = field(
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
)
sampling_factor: Optional[float] = field(
default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
)
boundaries: Optional[str] = field(
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
)
def __post_init__(self):
self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
batch_idx = np.split(samples_idx, sections_split)
return batch_idx
def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
"""
The training iterator is advanced so that after groupifying the samples,
`num_samples` of length `max_seq_length` are returned.
"""
num_total_tokens = max_seq_length * num_samples
samples = defaultdict(list)
i = 0
while i < num_total_tokens:
tokenized_samples = next(train_iterator)
i += len(tokenized_samples["input_ids"])
# concatenate tokenized samples to list
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
# Concatenated tokens are split to lists of length `max_seq_length`.
# Note that remainedr of % max_seq_length are thrown away.
def group_texts(examples):
result = {
k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
for k, t in examples.items()
}
return result
grouped_samples = group_texts(samples)
return grouped_samples
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def save_checkpoint_files(state, data_collator, training_args, save_dir):
unreplicated_state = jax_utils.unreplicate(state)
with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
f.write(to_bytes(unreplicated_state.opt_state))
joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": unreplicated_state.step.item()}, f)
def rotate_checkpoints(path, max_checkpoints=5):
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
if len(paths) > max_checkpoints:
for path_to_delete in paths[max_checkpoints:]:
try:
shutil.rmtree(path_to_delete)
except OSError:
os.remove(path_to_delete)
if __name__ == "__main__":
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level="INFO",
datefmt="[%X]",
)
# Log on each process the small summary:
logger = logging.getLogger(__name__)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
filepaths = {}
if data_args.train_file:
filepaths["train"] = data_args.train_file
if data_args.validation_file:
filepaths["validation"] = data_args.validation_file
try:
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
streaming=True,
split="train",
sampling_method=sampling_args.sampling_method,
sampling_factor=sampling_args.sampling_factor,
boundaries=sampling_args.boundaries,
perplexity_model=sampling_args.perplexity_model,
seed=training_args.seed,
data_files=filepaths,
)
except Exception as exc:
logger.warning(
f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
)
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
streaming=True,
split="train",
)
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(
examples[data_args.text_column_name],
return_special_tokens_mask=True
)
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
)
shuffle_seed = training_args.seed
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
# Enable Weight&Biases
import wandb
wandb.init(
entity='wandb',
project='hf-flax-bertin-roberta-es',
sync_tensorboard=True,
)
wandb.config.update(training_args)
wandb.config.update(model_args)
wandb.config.update(data_args)
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
if model_args.model_name_or_path:
model = FlaxAutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
else:
model = FlaxAutoModelForMaskedLM.from_config(
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
# define number steps per stream epoch
num_train_steps = data_args.num_train_steps
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBERT-like models.
# For other models, one should correct the layer norm parameter naming
# accordingly.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average
loss = loss.sum() / label_mask.sum()
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
return new_state, metrics, new_dropout_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
# summarize metrics
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_time = 0
train_start = time.time()
train_metrics = []
eval_metrics = []
training_iter = iter(tokenized_datasets)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
for step in range(num_train_steps):
# ======================== Training ================================
try:
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
except StopIteration:
# Once the end of the dataset stream is reached, the training iterator
# is reinitialized and reshuffled and a new eval dataset is randomely chosen.
shuffle_seed += 1
tokenized_datasets.set_epoch(shuffle_seed)
training_iter = iter(tokenized_datasets)
eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
# process input samples
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)
if step % training_args.logging_steps == 0 and step > 0:
steps.write(
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, step)
train_metrics = []
# ======================== Evaluating ==============================
if step % training_args.eval_steps == 0 and step > 0:
eval_samples_idx = jnp.arange(data_args.num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
# process input samples
batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
# Model forward
model_inputs = shard(model_inputs.data)
metrics = p_eval_step(state.params, model_inputs)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, step)
eval_metrics = []
# save checkpoint after eval_steps
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
logger.info(f"Saving checkpoint at {step + 1} steps")
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {step + 1}",
)
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
checkpoints_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(checkpoints_dir, params=params,)
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
rotate_checkpoints(
Path(training_args.output_dir) / "checkpoints",
max_checkpoints=training_args.save_total_limit
)
# update tqdm bar
steps.update(1)