Spaces:
Runtime error
Runtime error
File size: 42,837 Bytes
455a40f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 |
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pretraining the library models for denoising language modeling on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be pretrained by this script:
https://huggingface.co/models?filter=bart
"""
# You can also adapt this script on your own denoising language modeling task. Pointers for this are left as comments.
import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
BartConfig,
BatchEncoding,
FlaxBartForConditionalGeneration,
HfArgumentParser,
PreTrainedTokenizerBase,
is_tensorboard_available,
set_seed,
)
from transformers.models.bart.modeling_flax_bart import shift_tokens_right
from transformers.utils import get_full_repo_name, send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization and masking. Sequences longer than this"
" will be truncated. Default to the max input length of the model."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.3, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
)
permute_sentence_ratio: float = field(
default=1.0, metadata={"help": "Ratio of sentences to be permuted in each document"}
)
poisson_lambda: float = field(
default=3.0, metadata={"help": "Mean of Poisson distribution used to generate span-lengths to be masked"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
@flax.struct.dataclass
class FlaxDataCollatorForBartDenoisingLM:
"""
Data collator used for BART denoising language modeling. The code is largely copied from
`<https://github.com/morganmcg1/rotobart/blob/main/data_collator.py#L223>`__.
For more information on how BART denoising language modeling works, one can take a look
at the `official paper <https://arxiv.org/pdf/1910.13461.pdf>`__
or the `official code for preprocessing <https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/denoising_dataset.py>`__ .
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data
mask_ratio (:obj:`float`):
The probability with which to (randomly) mask tokens in the input
poisson_lambda (:obj:`float`):
Mean parameter of Poisson distribution used to generate span-lengths to be masked
permute_sentence_ratio (:obj:`float`):
Ratio of sentences to be permuted in each document
decoder_start_token_id: (:obj:`int):
The decoder start token id of the model
"""
tokenizer: PreTrainedTokenizerBase
decoder_start_token_id: int
mask_ratio: float = 0.3
poisson_lambda: float = 3.0
permute_sentence_ratio: float = 1.0
def __post_init__(self):
if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None:
raise ValueError(
"This tokenizer does not have a mask token or eos token token which is necessary for denoising"
" language modeling. "
)
def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
# convert list to dict and tensorize input
batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
)
batch["labels"] = batch["input_ids"].copy()
batch["decoder_input_ids"] = shift_tokens_right(
batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
)
# permuting sentences
do_permute = False
if self.permute_sentence_ratio > 0.0:
batch["input_ids"] = self.permute_sentences(batch["input_ids"])
do_permute = True
# masking span of tokens (text infilling in the paper)
if self.mask_ratio:
batch["input_ids"], batch["labels"] = self.span_mask_tokens(
batch["input_ids"], batch["labels"], do_permute
)
# ignore pad tokens
batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).astype(int)
batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).astype(int)
return batch
def permute_sentences(self, input_ids):
"""
Shuffle sentences in each document.
"""
results = input_ids.copy()
# find end locations of sentences
end_sentence_mask = input_ids == self.tokenizer.pad_token_id
sentence_ends = np.argwhere(end_sentence_mask)
sentence_ends[:, 1] += 1
example_has_multiple_sentences, num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)
num_sentences_map = dict(zip(example_has_multiple_sentences, num_sentences))
num_to_permute = np.ceil(num_sentences * self.permute_sentence_ratio).astype(int)
num_to_permute_map = dict(zip(example_has_multiple_sentences, num_to_permute))
sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
sentence_ends_map = dict(zip(example_has_multiple_sentences, sentence_ends))
for i in range(input_ids.shape[0]):
if i not in example_has_multiple_sentences:
continue
substitutions = np.random.permutation(num_sentences_map[i])[: num_to_permute_map[i]]
ordering = np.arange(0, num_sentences_map[i])
ordering[substitutions] = substitutions[np.random.permutation(num_to_permute_map[i])]
# write shuffled sentences into results
index = 0
for j in ordering:
sentence = input_ids[i, (sentence_ends_map[i][j - 1] if j > 0 else 0) : sentence_ends_map[i][j]]
results[i, index : index + sentence.shape[0]] = sentence
index += sentence.shape[0]
return results
def span_mask_tokens(self, input_ids, labels, do_permute):
"""
Sampling text spans with span lengths drawn from a Poisson distribution and masking them.
"""
special_tokens_mask_labels = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask_inputs = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in input_ids.tolist()
]
special_tokens_mask_labels = np.array(special_tokens_mask_labels, dtype=bool)
special_tokens_mask_inputs = np.array(special_tokens_mask_inputs, dtype=bool)
# determine how many tokens we need to mask in total
is_token_mask = ~(input_ids == self.tokenizer.pad_token_id) & ~special_tokens_mask_inputs
num_tokens_to_mask = int(math.ceil(is_token_mask.astype(float).sum() * self.mask_ratio))
if num_tokens_to_mask == 0:
return input_ids, labels
# generate a sufficient number of span lengths
span_lengths = np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))
while np.cumsum(span_lengths, 0)[-1] < num_tokens_to_mask:
span_lengths = np.concatenate(
[span_lengths, np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))]
)
# remove all spans of length 0
# note that BART inserts additional mask tokens where length == 0,
# which we do not implement for now as it adds additional complexity
span_lengths = span_lengths[span_lengths > 0]
# trim to about num_tokens_to_mask tokens
cutoff_idx = np.argmin(np.abs(np.cumsum(span_lengths, 0) - num_tokens_to_mask)) + 1
span_lengths = span_lengths[:cutoff_idx]
# randomly choose starting positions for masking
token_indices = np.argwhere(is_token_mask == 1)
span_starts = np.random.permutation(token_indices.shape[0])[: span_lengths.shape[0]]
# prepare mask
masked_indices = np.array(token_indices[span_starts])
mask = np.full_like(input_ids, fill_value=False)
# mask starting positions
for mi in masked_indices:
mask[tuple(mi)] = True
span_lengths -= 1
# fill up spans
max_index = input_ids.shape[1] - 1
remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
while np.any(remaining):
masked_indices[remaining, 1] += 1
for mi in masked_indices:
mask[tuple(mi)] = True
span_lengths -= 1
remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
# place the mask tokens
mask[np.where(special_tokens_mask_inputs)] = False
input_ids[np.where(mask)] = self.tokenizer.mask_token_id
if not do_permute:
labels[np.where(mask == 0)] = -100
else:
labels[np.where(special_tokens_mask_labels)] = -100
# remove mask tokens that are not starts of spans
to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
new_input_ids = np.full_like(input_ids, fill_value=self.tokenizer.pad_token_id)
for i, example in enumerate(input_ids):
new_example = example[~to_remove[i]]
new_input_ids[i, : new_example.shape[0]] = new_example
return new_input_ids, labels
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_bart_dlm", model_args, data_args, framework="flax")
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level=logging.INFO,
datefmt="[%X]",
)
# Log on each process the small summary:
logger = logging.getLogger(__name__)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
if "validation" not in datasets.keys():
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
use_auth_token=True if model_args.use_auth_token else None,
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if model_args.config_name:
config = BartConfig.from_pretrained(
model_args.config_name,
cache_dir=model_args.cache_dir,
vocab_size=len(tokenizer),
use_auth_token=True if model_args.use_auth_token else None,
)
elif model_args.model_name_or_path:
config = BartConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Use Punkt Sentence Tokenizer to divide a document into a list of sentences
nltk.download("punkt")
sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
def sentence_split_function(example):
sents = sentence_tokenizer.tokenize(example["text"])
# use pad token as end of sentence indicator
new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(sents) + tokenizer.eos_token
return {"text": new_text}
split_datasets = datasets.map(
sentence_split_function,
batched=False,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Tokenize every text, then concatenate them together before splitting them in smaller parts.
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], add_special_tokens=False, return_attention_mask=False)
tokenized_datasets = split_datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=text_column_name,
load_from_cache_file=not data_args.overwrite_cache,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
if model_args.model_name_or_path:
model = FlaxBartForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)
else:
config.vocab_size = len(tokenizer)
model = FlaxBartForConditionalGeneration(
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
)
# Data collator
# This one will take care of randomly masking the tokens and permuting the sentences.
data_collator = FlaxDataCollatorForBartDenoisingLM(
tokenizer=tokenizer,
decoder_start_token_id=model.config.decoder_start_token_id,
mask_ratio=data_args.mlm_probability,
poisson_lambda=data_args.poisson_lambda,
permute_sentence_ratio=data_args.permute_sentence_ratio,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params.keys()
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
# compute loss, ignore padded input tokens and special tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average
loss = loss.sum()
num_labels = label_mask.sum()
return loss, num_labels
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics, new_dropout_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss, ignore padded input tokens and special tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
# summarize metrics
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_time = 0
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
# Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)
cur_step = epoch * (num_train_samples // train_batch_size) + step
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()
|