|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Fine-tuning the library models for summarization. |
|
""" |
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import sys |
|
import time |
|
from dataclasses import asdict, dataclass, field |
|
from enum import Enum |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Callable, Optional |
|
|
|
import datasets |
|
import nltk |
|
import numpy as np |
|
from datasets import Dataset, load_dataset, load_metric |
|
from tqdm import tqdm |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import optax |
|
import transformers |
|
from filelock import FileLock |
|
from flax import jax_utils, traverse_util |
|
from flax.jax_utils import unreplicate |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key |
|
from huggingface_hub import Repository |
|
from transformers import ( |
|
CONFIG_MAPPING, |
|
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, |
|
AutoConfig, |
|
AutoTokenizer, |
|
FlaxAutoModelForSeq2SeqLM, |
|
HfArgumentParser, |
|
is_tensorboard_available, |
|
) |
|
from transformers.file_utils import get_full_repo_name, is_offline_mode |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
nltk.data.find("tokenizers/punkt") |
|
except (LookupError, OSError): |
|
if is_offline_mode(): |
|
raise LookupError( |
|
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" |
|
) |
|
with FileLock(".lock") as lock: |
|
nltk.download("punkt", quiet=True) |
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments: |
|
output_dir: str = field( |
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, |
|
) |
|
overwrite_output_dir: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Overwrite the content of the output directory. " |
|
"Use this to continue training if output_dir points to a checkpoint directory." |
|
) |
|
}, |
|
) |
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) |
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) |
|
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) |
|
per_device_train_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} |
|
) |
|
per_device_eval_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} |
|
) |
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) |
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) |
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) |
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) |
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) |
|
label_smoothing_factor: float = field( |
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} |
|
) |
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) |
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) |
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) |
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) |
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) |
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) |
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) |
|
push_to_hub: bool = field( |
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} |
|
) |
|
hub_model_id: str = field( |
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} |
|
) |
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) |
|
|
|
def __post_init__(self): |
|
if self.output_dir is not None: |
|
self.output_dir = os.path.expanduser(self.output_dir) |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates |
|
the token values by removing their value. |
|
""" |
|
d = asdict(self) |
|
for k, v in d.items(): |
|
if isinstance(v, Enum): |
|
d[k] = v.value |
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): |
|
d[k] = [x.value for x in v] |
|
if k.endswith("_token"): |
|
d[k] = f"<{k.upper()}>" |
|
return d |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
""" |
|
|
|
model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The model checkpoint for weights initialization." |
|
"Don't set if you want to train a model from scratch." |
|
}, |
|
) |
|
model_type: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
text_column: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, |
|
) |
|
summary_column: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, |
|
) |
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
|
) |
|
test_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input predict data file to do prediction on (a text file)."}, |
|
) |
|
max_source_length: Optional[int] = field( |
|
default=1024, |
|
metadata={ |
|
"help": "The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
}, |
|
) |
|
max_target_length: Optional[int] = field( |
|
default=128, |
|
metadata={ |
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
}, |
|
) |
|
val_max_target_length: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." |
|
"This argument is also used to override the `max_length` param of `model.generate`, which is used " |
|
"during evaluation." |
|
}, |
|
) |
|
max_train_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " |
|
"value if set." |
|
}, |
|
) |
|
max_eval_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
|
"value if set." |
|
}, |
|
) |
|
max_predict_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " |
|
"value if set." |
|
}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
source_prefix: Optional[str] = field( |
|
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} |
|
) |
|
predict_with_generate: bool = field( |
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} |
|
) |
|
num_beams: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, " |
|
"which is used during evaluation." |
|
}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." |
|
if self.val_max_target_length is None: |
|
self.val_max_target_length = self.max_target_length |
|
|
|
|
|
summarization_name_mapping = { |
|
"amazon_reviews_multi": ("review_body", "review_title"), |
|
"big_patent": ("description", "abstract"), |
|
"cnn_dailymail": ("article", "highlights"), |
|
"orange_sum": ("text", "summary"), |
|
"pn_summary": ("article", "summary"), |
|
"psc": ("extract_text", "summary_text"), |
|
"samsum": ("dialogue", "summary"), |
|
"thaisum": ("body", "summary"), |
|
"xglue": ("news_body", "news_title"), |
|
"xsum": ("document", "summary"), |
|
"wiki_summary": ("article", "highlights"), |
|
} |
|
|
|
|
|
class TrainState(train_state.TrainState): |
|
dropout_rng: jnp.ndarray |
|
|
|
def replicate(self): |
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) |
|
|
|
|
|
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): |
|
""" |
|
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
|
Shuffle batches if `shuffle` is `True`. |
|
""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
|
|
if shuffle: |
|
batch_idx = jax.random.permutation(rng, len(dataset)) |
|
else: |
|
batch_idx = jnp.arange(len(dataset)) |
|
|
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] |
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
|
batch = shard(batch) |
|
|
|
yield batch |
|
|
|
|
|
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
|
|
def create_learning_rate_fn( |
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float |
|
) -> Callable[[int], jnp.array]: |
|
"""Returns a linear warmup, linear_decay learning rate function.""" |
|
steps_per_epoch = train_ds_size // train_batch_size |
|
num_train_steps = steps_per_epoch * num_train_epochs |
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) |
|
decay_fn = optax.linear_schedule( |
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps |
|
) |
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) |
|
return schedule_fn |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty." |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
|
|
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
|
if jax.process_index() == 0: |
|
datasets.utils.logging.set_verbosity_warning() |
|
transformers.utils.logging.set_verbosity_info() |
|
else: |
|
datasets.utils.logging.set_verbosity_error() |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
if training_args.push_to_hub: |
|
if training_args.hub_model_id is None: |
|
repo_name = get_full_repo_name( |
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token |
|
) |
|
else: |
|
repo_name = training_args.hub_model_id |
|
repo = Repository(training_args.output_dir, clone_from=repo_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_args.dataset_name is not None: |
|
|
|
dataset = load_dataset( |
|
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False |
|
) |
|
else: |
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
extension = data_args.train_file.split(".")[-1] |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = data_args.validation_file.split(".")[-1] |
|
if data_args.test_file is not None: |
|
data_files["test"] = data_args.test_file |
|
extension = data_args.test_file.split(".")[-1] |
|
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) |
|
|
|
|
|
|
|
|
|
|
|
if model_args.config_name: |
|
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) |
|
elif model_args.model_name_or_path: |
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) |
|
else: |
|
config = CONFIG_MAPPING[model_args.model_type]() |
|
logger.warning("You are instantiating a new config instance from scratch.") |
|
|
|
if model_args.tokenizer_name: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
elif model_args.model_name_or_path: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer |
|
) |
|
else: |
|
raise ValueError( |
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script." |
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name." |
|
) |
|
|
|
if model_args.model_name_or_path: |
|
model = FlaxAutoModelForSeq2SeqLM.from_pretrained( |
|
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) |
|
) |
|
else: |
|
model = FlaxAutoModelForSeq2SeqLM.from_config( |
|
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) |
|
) |
|
|
|
if model.config.decoder_start_token_id is None: |
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") |
|
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
|
|
|
|
|
if training_args.do_train: |
|
column_names = dataset["train"].column_names |
|
elif training_args.do_eval: |
|
column_names = dataset["validation"].column_names |
|
elif training_args.do_predict: |
|
column_names = dataset["test"].column_names |
|
else: |
|
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") |
|
return |
|
|
|
|
|
dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) |
|
if data_args.text_column is None: |
|
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] |
|
else: |
|
text_column = data_args.text_column |
|
if text_column not in column_names: |
|
raise ValueError( |
|
f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" |
|
) |
|
if data_args.summary_column is None: |
|
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] |
|
else: |
|
summary_column = data_args.summary_column |
|
if summary_column not in column_names: |
|
raise ValueError( |
|
f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" |
|
) |
|
|
|
|
|
max_target_length = data_args.max_target_length |
|
|
|
|
|
|
|
|
|
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) |
|
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") |
|
|
|
|
|
def preprocess_function(examples): |
|
inputs = examples[text_column] |
|
targets = examples[summary_column] |
|
inputs = [prefix + inp for inp in inputs] |
|
model_inputs = tokenizer( |
|
inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np" |
|
) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer( |
|
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np" |
|
) |
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
decoder_input_ids = shift_tokens_right_fn( |
|
labels["input_ids"], config.pad_token_id, config.decoder_start_token_id |
|
) |
|
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) |
|
|
|
|
|
model_inputs["decoder_attention_mask"] = labels["attention_mask"] |
|
|
|
return model_inputs |
|
|
|
if training_args.do_train: |
|
if "train" not in dataset: |
|
raise ValueError("--do_train requires a train dataset") |
|
train_dataset = dataset["train"] |
|
if data_args.max_train_samples is not None: |
|
train_dataset = train_dataset.select(range(data_args.max_train_samples)) |
|
train_dataset = train_dataset.map( |
|
preprocess_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
desc="Running tokenizer on train dataset", |
|
) |
|
|
|
if training_args.do_eval: |
|
max_target_length = data_args.val_max_target_length |
|
if "validation" not in dataset: |
|
raise ValueError("--do_eval requires a validation dataset") |
|
eval_dataset = dataset["validation"] |
|
if data_args.max_eval_samples is not None: |
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) |
|
eval_dataset = eval_dataset.map( |
|
preprocess_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
desc="Running tokenizer on validation dataset", |
|
) |
|
|
|
if training_args.do_predict: |
|
max_target_length = data_args.val_max_target_length |
|
if "test" not in dataset: |
|
raise ValueError("--do_predict requires a test dataset") |
|
predict_dataset = dataset["test"] |
|
if data_args.max_predict_samples is not None: |
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) |
|
predict_dataset = predict_dataset.map( |
|
preprocess_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
desc="Running tokenizer on prediction dataset", |
|
) |
|
|
|
|
|
metric = load_metric("rouge") |
|
|
|
def postprocess_text(preds, labels): |
|
preds = [pred.strip() for pred in preds] |
|
labels = [label.strip() for label in labels] |
|
|
|
|
|
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] |
|
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] |
|
|
|
return preds, labels |
|
|
|
def compute_metrics(preds, labels): |
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
|
|
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) |
|
|
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) |
|
|
|
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} |
|
|
|
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] |
|
result["gen_len"] = np.mean(prediction_lens) |
|
result = {k: round(v, 4) for k, v in result.items()} |
|
return result |
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
if has_tensorboard and jax.process_index() == 0: |
|
try: |
|
from flax.metrics.tensorboard import SummaryWriter |
|
|
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) |
|
except ImportError as ie: |
|
has_tensorboard = False |
|
logger.warning( |
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" |
|
) |
|
else: |
|
logger.warning( |
|
"Unable to display metrics through TensorBoard because the package is not installed: " |
|
"Please run pip install tensorboard to enable." |
|
) |
|
|
|
|
|
rng = jax.random.PRNGKey(training_args.seed) |
|
rng, dropout_rng = jax.random.split(rng) |
|
|
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() |
|
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() |
|
steps_per_epoch = len(train_dataset) // train_batch_size |
|
total_train_steps = steps_per_epoch * num_epochs |
|
|
|
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn( |
|
len(train_dataset), |
|
train_batch_size, |
|
training_args.num_train_epochs, |
|
training_args.warmup_steps, |
|
training_args.learning_rate, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
layer_norm_params = [ |
|
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] |
|
] |
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
|
|
adamw = optax.adamw( |
|
learning_rate=linear_decay_lr_schedule_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=training_args.adam_epsilon, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
|
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) |
|
|
|
|
|
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): |
|
""" |
|
The label smoothing implementation is adapted from Flax's official example: |
|
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 |
|
""" |
|
vocab_size = logits.shape[-1] |
|
confidence = 1.0 - label_smoothing_factor |
|
low_confidence = (1.0 - confidence) / (vocab_size - 1) |
|
normalizing_constant = -( |
|
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) |
|
) |
|
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) |
|
|
|
loss = optax.softmax_cross_entropy(logits, soft_labels) |
|
loss = loss - normalizing_constant |
|
|
|
|
|
loss = loss * padding_mask |
|
loss = loss.sum() / padding_mask.sum() |
|
return loss |
|
|
|
|
|
def train_step(state, batch, label_smoothing_factor=0.0): |
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) |
|
|
|
def compute_loss(params): |
|
labels = batch.pop("labels") |
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] |
|
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) |
|
return loss |
|
|
|
grad_fn = jax.value_and_grad(compute_loss) |
|
loss, grad = grad_fn(state.params) |
|
grad = jax.lax.pmean(grad, "batch") |
|
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) |
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} |
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
|
|
return new_state, metrics |
|
|
|
|
|
def eval_step(params, batch, label_smoothing_factor=0.0): |
|
labels = batch.pop("labels") |
|
logits = model(**batch, params=params, train=False)[0] |
|
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) |
|
|
|
|
|
metrics = {"loss": loss} |
|
metrics = jax.lax.pmean(metrics, axis_name="batch") |
|
return metrics |
|
|
|
|
|
max_length = ( |
|
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length |
|
) |
|
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams |
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
|
|
|
def generate_step(params, batch): |
|
model.params = params |
|
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) |
|
return output_ids.sequences |
|
|
|
|
|
p_train_step = jax.pmap( |
|
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,) |
|
) |
|
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") |
|
p_generate_step = jax.pmap(generate_step, "batch") |
|
|
|
|
|
state = state.replicate() |
|
|
|
logger.info("***** Running training *****") |
|
logger.info(f" Num examples = {len(train_dataset)}") |
|
logger.info(f" Num Epochs = {num_epochs}") |
|
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") |
|
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") |
|
logger.info(f" Total optimization steps = {total_train_steps}") |
|
|
|
train_time = 0 |
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) |
|
for epoch in epochs: |
|
|
|
train_start = time.time() |
|
|
|
|
|
rng, input_rng = jax.random.split(rng) |
|
train_metrics = [] |
|
|
|
|
|
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) |
|
steps_per_epoch = len(train_dataset) // train_batch_size |
|
|
|
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): |
|
batch = next(train_loader) |
|
state, train_metric = p_train_step(state, batch) |
|
train_metrics.append(train_metric) |
|
|
|
train_time += time.time() - train_start |
|
|
|
train_metric = unreplicate(train_metric) |
|
|
|
epochs.write( |
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" |
|
) |
|
|
|
|
|
eval_metrics = [] |
|
eval_preds = [] |
|
eval_labels = [] |
|
|
|
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) |
|
eval_steps = len(eval_dataset) // eval_batch_size |
|
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): |
|
|
|
batch = next(eval_loader) |
|
labels = batch["labels"] |
|
|
|
metrics = p_eval_step(state.params, batch) |
|
eval_metrics.append(metrics) |
|
|
|
|
|
if data_args.predict_with_generate: |
|
generated_ids = p_generate_step(state.params, batch) |
|
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) |
|
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) |
|
|
|
|
|
eval_metrics = get_metrics(eval_metrics) |
|
eval_metrics = jax.tree_map(jnp.mean, eval_metrics) |
|
|
|
|
|
rouge_desc = "" |
|
if data_args.predict_with_generate: |
|
rouge_metrics = compute_metrics(eval_preds, eval_labels) |
|
eval_metrics.update(rouge_metrics) |
|
rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) |
|
|
|
|
|
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" |
|
epochs.write(desc) |
|
epochs.desc = desc |
|
|
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
cur_step = epoch * (len(train_dataset) // train_batch_size) |
|
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) |
|
|
|
|
|
if jax.process_index() == 0: |
|
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) |
|
model.save_pretrained(training_args.output_dir, params=params) |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
if training_args.push_to_hub: |
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) |
|
|
|
|
|
if training_args.do_predict: |
|
logger.info("*** Predict ***") |
|
|
|
pred_metrics = [] |
|
pred_generations = [] |
|
pred_labels = [] |
|
|
|
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) |
|
pred_steps = len(predict_dataset) // eval_batch_size |
|
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): |
|
|
|
batch = next(pred_loader) |
|
labels = batch["labels"] |
|
|
|
metrics = p_eval_step(state.params, batch) |
|
pred_metrics.append(metrics) |
|
|
|
|
|
if data_args.predict_with_generate: |
|
generated_ids = p_generate_step(state.params, batch) |
|
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) |
|
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) |
|
|
|
|
|
pred_metrics = get_metrics(pred_metrics) |
|
pred_metrics = jax.tree_map(jnp.mean, pred_metrics) |
|
|
|
|
|
rouge_desc = "" |
|
if data_args.predict_with_generate: |
|
rouge_metrics = compute_metrics(pred_generations, pred_labels) |
|
pred_metrics.update(rouge_metrics) |
|
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()]) |
|
|
|
|
|
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" |
|
logger.info(desc) |
|
|
|
|
|
if jax.process_index() == 0: |
|
rouge_metrics = {f"test_{metric_name}": value for metric_name, value in rouge_metrics.items()} |
|
path = os.path.join(training_args.output_dir, "test_results.json") |
|
with open(path, "w") as f: |
|
json.dump(rouge_metrics, f, indent=4, sort_keys=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|