vit-gpt2 / run_image_caption.py
ydshieh
clean repo
5cb8b84
raw history blame
No virus
34.5 kB
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for summarization.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import sys, os
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset, load_metric
from tqdm import tqdm
from PIL import Image
import jax
import jax.numpy as jnp
import optax
import transformers
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForSeq2SeqLM,
HfArgumentParser,
TrainingArguments,
is_tensorboard_available,
)
from transformers.file_utils import is_offline_mode
from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
logger = logging.getLogger(__name__)
try:
nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
if is_offline_mode():
raise LookupError(
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
)
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
summary_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
"during evaluation."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if self.val_max_target_length is None:
self.val_max_target_length = self.max_target_length
summarization_name_mapping = {
"amazon_reviews_multi": ("review_body", "review_title"),
"big_patent": ("description", "abstract"),
"cnn_dailymail": ("article", "highlights"),
"orange_sum": ("text", "summary"),
"pn_summary": ("article", "summary"),
"psc": ("extract_text", "summary_text"),
"samsum": ("dialogue", "summary"),
"thaisum": ("body", "summary"),
"xglue": ("news_body", "news_title"),
"xsum": ("document", "summary"),
"wiki_summary": ("article", "highlights"),
}
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
Shuffle batches if `shuffle` is `True`.
"""
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files this script will use the first column for the full texts and the second column for the
# summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
#
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir='/home/33611/caption/'
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
vit_name_path = 'google/vit-base-patch16-224-in21k'
gpt2_name_path = 'asi/gpt-fr-cased-small'
gpt2_config = GPT2Config.from_pretrained(gpt2_name_path)
gpt2_config.add_cross_attention = True
vit_gpt2_name_path = ''
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_name_path)
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_name_path)
if not vit_gpt2_name_path:
assert vit_name_path
assert gpt2_name_path
vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
vit_name_path, gpt2_name_path
)
else:
vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
vit_gpt2_name_path
)
model = vit_gpt2_model
model.config.is_encoder_decoder = True
model.config.decoder_start_token_id = gpt2_config.bos_token_id
model.config.bos_token_id = gpt2_config.bos_token_id
model.config.eos_token_id = gpt2_config.eos_token_id
model.config.pad_token_id = gpt2_config.pad_token_id
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = dataset["train"].column_names
elif training_args.do_eval:
column_names = dataset["validation"].column_names
elif training_args.do_predict:
column_names = dataset["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
image_file_column = 'image_file'
caption_column = 'fr'
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
# for that dynamically import the `shift_tokens_right` function from the model file
model_module = __import__(vit_gpt2_model.__module__, fromlist=["shift_tokens_right"])
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
# Setting padding="max_length" as we need fixed length inputs for jitted functions
def preprocess_function(examples):
_pixel_values = []
_captions = []
for y, z in zip(examples[image_file_column], examples[caption_column]):
with Image.open(y) as image:
try:
encoder_inputs = feature_extractor(images=image, return_tensors="np")
except:
continue
x = encoder_inputs.pixel_values
_pixel_values.append(x)
_captions.append(z + ' ' + tokenizer.eos_token)
pixel_values = np.concatenate(_pixel_values)
targets = _captions
# Add eos_token!!
#targets = [x + ' ' + tokenizer.eos_token for x in targets]
model_inputs = {}
model_inputs['pixel_values'] = pixel_values
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
)
model_inputs["labels"] = labels["input_ids"]
#print(labels["input_ids"])
#print(gpt2_config.pad_token_id)
#rint(gpt2_config.bos_token_id)
decoder_input_ids = shift_tokens_right_fn(
jnp.array(labels["input_ids"]), gpt2_config.pad_token_id, gpt2_config.bos_token_id
)
model_inputs["input_ids"] = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
model_inputs["attention_mask"] = labels["attention_mask"]
return model_inputs
if training_args.do_train:
if "train" not in dataset:
raise ValueError("--do_train requires a train dataset")
train_dataset = dataset["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = dataset["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in dataset:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = dataset["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
# Metric
metric = load_metric("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
def compute_metrics(preds, labels):
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
# Note that this mask is specifically adapted for FlaxBart.
# For FlaxT5, one should correct the layer norm parameter naming
# accordingly - see `run_t5_mlm_flax.py` e.g.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
layer_norm_params = [
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
]
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=vit_gpt2_model.__call__, params=vit_gpt2_model.params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum()
return loss
# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels, batch["attention_mask"], label_smoothing_factor)
# summarize metrics
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Define generation function
max_length = (
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
# output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
#encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
#output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
# encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
epochs.write(desc)
epochs.desc = desc
logger.info(desc)
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
fp.write(desc + '\n')
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_labels = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
eval_steps = len(eval_dataset) // eval_batch_size
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
labels = batch["labels"]
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
# generation
if data_args.predict_with_generate:
generated_ids = p_generate_step(state.params, batch)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# compute ROUGE metrics
rouge_desc = ""
if data_args.predict_with_generate:
rouge_metrics = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(rouge_metrics)
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
# Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
epochs.write(desc)
epochs.desc = desc
logger.info(desc)
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
fp.write(desc + '\n')
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# ======================== Prediction loop ==============================
if training_args.do_predict:
logger.info("*** Predict ***")
pred_metrics = []
pred_generations = []
pred_labels = []
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
pred_steps = len(predict_dataset) // eval_batch_size
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
# Model forward
batch = next(pred_loader)
labels = batch["labels"]
metrics = p_eval_step(state.params, batch)
pred_metrics.append(metrics)
# generation
if data_args.predict_with_generate:
generated_ids = p_generate_step(state.params, batch)
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
# normalize prediction metrics
pred_metrics = get_metrics(pred_metrics)
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
# compute ROUGE metrics
rouge_desc = ""
if data_args.predict_with_generate:
rouge_metrics = compute_metrics(pred_generations, pred_labels)
pred_metrics.update(rouge_metrics)
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
# Print metrics
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
epochs.write(desc)
epochs.desc = desc
logger.info(desc)
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
fp.write(desc + '\n')
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'),
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__":
main()