aapot commited on
Commit
d4f665d
1 Parent(s): 83a33f1

Add training files

Browse files
config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "activation_function": "gelu",
4
+ "add_bias_logits": false,
5
+ "add_final_layer_norm": false,
6
+ "architectures": [
7
+ "BartForConditionalGeneration"
8
+ ],
9
+ "attention_dropout": 0.1,
10
+ "bos_token_id": 0,
11
+ "classif_dropout": 0.1,
12
+ "classifier_dropout": 0.0,
13
+ "d_model": 512,
14
+ "decoder_attention_heads": 8,
15
+ "decoder_ffn_dim": 2048,
16
+ "decoder_layerdrop": 0.0,
17
+ "decoder_layers": 6,
18
+ "decoder_start_token_id": 2,
19
+ "dropout": 0.1,
20
+ "early_stopping": true,
21
+ "encoder_attention_heads": 8,
22
+ "encoder_ffn_dim": 2048,
23
+ "encoder_layerdrop": 0.0,
24
+ "encoder_layers": 6,
25
+ "eos_token_id": 2,
26
+ "forced_bos_token_id": 0,
27
+ "forced_eos_token_id": 2,
28
+ "gradient_checkpointing": false,
29
+ "id2label": {
30
+ "0": "LABEL_0",
31
+ "1": "LABEL_1",
32
+ "2": "LABEL_2"
33
+ },
34
+ "init_std": 0.02,
35
+ "is_encoder_decoder": true,
36
+ "label2id": {
37
+ "LABEL_0": 0,
38
+ "LABEL_1": 1,
39
+ "LABEL_2": 2
40
+ },
41
+ "max_position_embeddings": 1024,
42
+ "model_type": "bart",
43
+ "no_repeat_ngram_size": 3,
44
+ "normalize_before": false,
45
+ "normalize_embedding": true,
46
+ "num_beams": 4,
47
+ "num_hidden_layers": 6,
48
+ "pad_token_id": 1,
49
+ "scale_embedding": false,
50
+ "task_specific_params": {
51
+ "summarization": {
52
+ "length_penalty": 1.0,
53
+ "max_length": 128,
54
+ "min_length": 12,
55
+ "num_beams": 4
56
+ },
57
+ "summarization_cnn": {
58
+ "length_penalty": 2.0,
59
+ "max_length": 142,
60
+ "min_length": 56,
61
+ "num_beams": 4
62
+ },
63
+ "summarization_xsum": {
64
+ "length_penalty": 1.0,
65
+ "max_length": 62,
66
+ "min_length": 11,
67
+ "num_beams": 6
68
+ }
69
+ },
70
+ "torch_dtype": "float32",
71
+ "transformers_version": "4.26.0.dev0",
72
+ "use_cache": true,
73
+ "vocab_size": 50265
74
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
run_bart_dlm_flax.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for denoising language modeling on a text file or a dataset.
18
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
19
+ https://huggingface.co/models?filter=bart
20
+ """
21
+ # You can also adapt this script on your own denoising language modeling task. Pointers for this are left as comments.
22
+
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+ from pathlib import Path
33
+ from typing import Dict, List, Optional
34
+
35
+ import nltk
36
+ import numpy as np
37
+ from datasets import load_dataset, load_from_disk
38
+ from tqdm import tqdm
39
+
40
+ import flax
41
+ import jax
42
+ import jax.numpy as jnp
43
+ import optax
44
+ from flax import jax_utils, traverse_util
45
+ from flax.jax_utils import pad_shard_unpad
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository, create_repo
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoTokenizer,
53
+ BartConfig,
54
+ BatchEncoding,
55
+ FlaxBartForConditionalGeneration,
56
+ HfArgumentParser,
57
+ PreTrainedTokenizerBase,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.models.bart.modeling_flax_bart import shift_tokens_right
62
+ from transformers.utils import get_full_repo_name, send_example_telemetry
63
+
64
+
65
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
66
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
67
+
68
+
69
+ @dataclass
70
+ class TrainingArguments:
71
+ output_dir: str = field(
72
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
73
+ )
74
+ overwrite_output_dir: bool = field(
75
+ default=False,
76
+ metadata={
77
+ "help": (
78
+ "Overwrite the content of the output directory. "
79
+ "Use this to continue training if output_dir points to a checkpoint directory."
80
+ )
81
+ },
82
+ )
83
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
84
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
85
+ per_device_train_batch_size: int = field(
86
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
87
+ )
88
+ per_device_eval_batch_size: int = field(
89
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
90
+ )
91
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
92
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
93
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
94
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
95
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
96
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
97
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
98
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
99
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
100
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
101
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
102
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
103
+ push_to_hub: bool = field(
104
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
105
+ )
106
+ hub_model_id: str = field(
107
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
108
+ )
109
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
110
+
111
+ def __post_init__(self):
112
+ if self.output_dir is not None:
113
+ self.output_dir = os.path.expanduser(self.output_dir)
114
+
115
+ def to_dict(self):
116
+ """
117
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
118
+ the token values by removing their value.
119
+ """
120
+ d = asdict(self)
121
+ for k, v in d.items():
122
+ if isinstance(v, Enum):
123
+ d[k] = v.value
124
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
125
+ d[k] = [x.value for x in v]
126
+ if k.endswith("_token"):
127
+ d[k] = f"<{k.upper()}>"
128
+ return d
129
+
130
+
131
+ @dataclass
132
+ class ModelArguments:
133
+ """
134
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
135
+ """
136
+
137
+ model_name_or_path: Optional[str] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": (
141
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
142
+ )
143
+ },
144
+ )
145
+ model_type: Optional[str] = field(
146
+ default=None,
147
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
148
+ )
149
+ config_name: Optional[str] = field(
150
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
151
+ )
152
+ tokenizer_name: Optional[str] = field(
153
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
154
+ )
155
+ cache_dir: Optional[str] = field(
156
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
157
+ )
158
+ use_fast_tokenizer: bool = field(
159
+ default=True,
160
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
161
+ )
162
+ dtype: Optional[str] = field(
163
+ default="float32",
164
+ metadata={
165
+ "help": (
166
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
167
+ " `[float32, float16, bfloat16]`."
168
+ )
169
+ },
170
+ )
171
+ use_auth_token: bool = field(
172
+ default=False,
173
+ metadata={
174
+ "help": (
175
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
176
+ "with private models)."
177
+ )
178
+ },
179
+ )
180
+
181
+
182
+ @dataclass
183
+ class DataTrainingArguments:
184
+ """
185
+ Arguments pertaining to what data we are going to input our model for training and eval.
186
+ """
187
+
188
+ dataset_name: Optional[str] = field(
189
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
190
+ )
191
+ dataset_config_name: Optional[str] = field(
192
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
193
+ )
194
+ dataset_filepath: Optional[str] = field(
195
+ default=None, metadata={"help": "Filepath to locally saved HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
196
+ )
197
+ tokenized_dataset_filepath: Optional[str] = field(
198
+ default=None, metadata={"help": "Filepath to locally saved pre-tokenized HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
199
+ )
200
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
201
+ validation_file: Optional[str] = field(
202
+ default=None,
203
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
204
+ )
205
+ train_ref_file: Optional[str] = field(
206
+ default=None,
207
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
208
+ )
209
+ validation_ref_file: Optional[str] = field(
210
+ default=None,
211
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
212
+ )
213
+ overwrite_cache: bool = field(
214
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
215
+ )
216
+ validation_split_percentage: Optional[int] = field(
217
+ default=5,
218
+ metadata={
219
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
220
+ },
221
+ )
222
+ max_seq_length: Optional[int] = field(
223
+ default=None,
224
+ metadata={
225
+ "help": (
226
+ "The maximum total input sequence length after tokenization and masking. Sequences longer than this"
227
+ " will be truncated. Default to the max input length of the model."
228
+ )
229
+ },
230
+ )
231
+ preprocessing_num_workers: Optional[int] = field(
232
+ default=None,
233
+ metadata={"help": "The number of processes to use for the preprocessing."},
234
+ )
235
+ mlm_probability: float = field(
236
+ default=0.3, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
237
+ )
238
+ permute_sentence_ratio: float = field(
239
+ default=1.0, metadata={"help": "Ratio of sentences to be permuted in each document"}
240
+ )
241
+ poisson_lambda: float = field(
242
+ default=3.5, metadata={"help": "Mean of Poisson distribution used to generate span-lengths to be masked"}
243
+ )
244
+
245
+ def __post_init__(self):
246
+ if self.dataset_name is None and self.dataset_filepath is None and self.tokenized_dataset_filepath is None and self.train_file is None and self.validation_file is None:
247
+ raise ValueError("Need either a dataset name or a training/validation file.")
248
+ else:
249
+ if self.train_file is not None:
250
+ extension = self.train_file.split(".")[-1]
251
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
252
+ if self.validation_file is not None:
253
+ extension = self.validation_file.split(".")[-1]
254
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
255
+
256
+
257
+ @flax.struct.dataclass
258
+ class FlaxDataCollatorForBartDenoisingLM:
259
+ """
260
+ Data collator used for BART denoising language modeling. The code is largely copied from
261
+ `<https://github.com/morganmcg1/rotobart/blob/main/data_collator.py#L223>`__.
262
+ For more information on how BART denoising language modeling works, one can take a look
263
+ at the `official paper <https://arxiv.org/pdf/1910.13461.pdf>`__
264
+ or the `official code for preprocessing <https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/denoising_dataset.py>`__ .
265
+ Args:
266
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
267
+ The tokenizer used for encoding the data
268
+ mask_ratio (:obj:`float`):
269
+ The probability with which to (randomly) mask tokens in the input
270
+ poisson_lambda (:obj:`float`):
271
+ Mean parameter of Poisson distribution used to generate span-lengths to be masked
272
+ permute_sentence_ratio (:obj:`float`):
273
+ Ratio of sentences to be permuted in each document
274
+ decoder_start_token_id: (:obj:`int):
275
+ The decoder start token id of the model
276
+ """
277
+
278
+ tokenizer: PreTrainedTokenizerBase
279
+ decoder_start_token_id: int
280
+ mask_ratio: float = 0.3
281
+ poisson_lambda: float = 3.0
282
+ permute_sentence_ratio: float = 1.0
283
+
284
+ def __post_init__(self):
285
+ if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None:
286
+ raise ValueError(
287
+ "This tokenizer does not have a mask token or eos token token which is necessary for denoising"
288
+ " language modeling. "
289
+ )
290
+
291
+ def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
292
+ # convert list to dict and tensorize input
293
+ batch = BatchEncoding(
294
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
295
+ )
296
+ batch["labels"] = batch["input_ids"].copy()
297
+ batch["decoder_input_ids"] = shift_tokens_right(
298
+ batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
299
+ )
300
+ # permuting sentences
301
+ do_permute = False
302
+ if self.permute_sentence_ratio > 0.0:
303
+ batch["input_ids"] = self.permute_sentences(batch["input_ids"])
304
+ do_permute = True
305
+
306
+ # masking span of tokens (text infilling in the paper)
307
+ if self.mask_ratio:
308
+ batch["input_ids"], batch["labels"] = self.span_mask_tokens(
309
+ batch["input_ids"], batch["labels"], do_permute
310
+ )
311
+
312
+ # ignore pad tokens
313
+ batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).astype(int)
314
+ batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).astype(int)
315
+ return batch
316
+
317
+ def permute_sentences(self, input_ids):
318
+ """
319
+ Shuffle sentences in each document.
320
+ """
321
+ results = input_ids.copy()
322
+
323
+ # find end locations of sentences
324
+ end_sentence_mask = input_ids == self.tokenizer.pad_token_id
325
+ sentence_ends = np.argwhere(end_sentence_mask)
326
+ sentence_ends[:, 1] += 1
327
+ example_has_multiple_sentences, num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)
328
+ num_sentences_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_sentences)}
329
+
330
+ num_to_permute = np.ceil(num_sentences * self.permute_sentence_ratio).astype(int)
331
+ num_to_permute_map = {
332
+ sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_to_permute)
333
+ }
334
+
335
+ sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
336
+ sentence_ends_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, sentence_ends)}
337
+
338
+ for i in range(input_ids.shape[0]):
339
+ if i not in example_has_multiple_sentences:
340
+ continue
341
+ substitutions = np.random.permutation(num_sentences_map[i])[: num_to_permute_map[i]]
342
+ ordering = np.arange(0, num_sentences_map[i])
343
+ ordering[substitutions] = substitutions[np.random.permutation(num_to_permute_map[i])]
344
+
345
+ # write shuffled sentences into results
346
+ index = 0
347
+ for j in ordering:
348
+ sentence = input_ids[i, (sentence_ends_map[i][j - 1] if j > 0 else 0) : sentence_ends_map[i][j]]
349
+ results[i, index : index + sentence.shape[0]] = sentence
350
+ index += sentence.shape[0]
351
+ return results
352
+
353
+ def span_mask_tokens(self, input_ids, labels, do_permute):
354
+ """
355
+ Sampling text spans with span lengths drawn from a Poisson distribution and masking them.
356
+ """
357
+ special_tokens_mask_labels = [
358
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
359
+ ]
360
+ special_tokens_mask_inputs = [
361
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in input_ids.tolist()
362
+ ]
363
+ special_tokens_mask_labels = np.array(special_tokens_mask_labels, dtype=bool)
364
+ special_tokens_mask_inputs = np.array(special_tokens_mask_inputs, dtype=bool)
365
+
366
+ # determine how many tokens we need to mask in total
367
+ is_token_mask = ~(input_ids == self.tokenizer.pad_token_id) & ~special_tokens_mask_inputs
368
+ num_tokens_to_mask = int(math.ceil(is_token_mask.astype(float).sum() * self.mask_ratio))
369
+ if num_tokens_to_mask == 0:
370
+ return input_ids, labels
371
+
372
+ # generate a sufficient number of span lengths
373
+ span_lengths = np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))
374
+ while np.cumsum(span_lengths, 0)[-1] < num_tokens_to_mask:
375
+ span_lengths = np.concatenate(
376
+ [span_lengths, np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))]
377
+ )
378
+
379
+ # remove all spans of length 0
380
+ # note that BART inserts additional mask tokens where length == 0,
381
+ # which we do not implement for now as it adds additional complexity
382
+ span_lengths = span_lengths[span_lengths > 0]
383
+
384
+ # trim to about num_tokens_to_mask tokens
385
+ cutoff_idx = np.argmin(np.abs(np.cumsum(span_lengths, 0) - num_tokens_to_mask)) + 1
386
+ span_lengths = span_lengths[:cutoff_idx]
387
+
388
+ # randomly choose starting positions for masking
389
+ token_indices = np.argwhere(is_token_mask == 1)
390
+ span_starts = np.random.permutation(token_indices.shape[0])[: span_lengths.shape[0]]
391
+ # prepare mask
392
+ masked_indices = np.array(token_indices[span_starts])
393
+ mask = np.full_like(input_ids, fill_value=False)
394
+
395
+ # mask starting positions
396
+ for mi in masked_indices:
397
+ mask[tuple(mi)] = True
398
+ span_lengths -= 1
399
+
400
+ # fill up spans
401
+ max_index = input_ids.shape[1] - 1
402
+ remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
403
+ while np.any(remaining):
404
+ masked_indices[remaining, 1] += 1
405
+ for mi in masked_indices:
406
+ mask[tuple(mi)] = True
407
+ span_lengths -= 1
408
+ remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
409
+
410
+ # place the mask tokens
411
+ mask[np.where(special_tokens_mask_inputs)] = False
412
+ input_ids[np.where(mask)] = self.tokenizer.mask_token_id
413
+ if not do_permute:
414
+ labels[np.where(mask == 0)] = -100
415
+ else:
416
+ labels[np.where(special_tokens_mask_labels)] = -100
417
+
418
+ # remove mask tokens that are not starts of spans
419
+ to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
420
+ new_input_ids = np.full_like(input_ids, fill_value=self.tokenizer.pad_token_id)
421
+ for i, example in enumerate(input_ids):
422
+ new_example = example[~to_remove[i]]
423
+ new_input_ids[i, : new_example.shape[0]] = new_example
424
+
425
+ return new_input_ids, labels
426
+
427
+
428
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
429
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
430
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
431
+ num_samples = len(samples_idx)
432
+ if drop_last:
433
+ samples_to_remove = num_samples % batch_size
434
+ if samples_to_remove != 0:
435
+ samples_idx = samples_idx[:-samples_to_remove]
436
+ sections_split = num_samples // batch_size
437
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
438
+ else:
439
+ sections_split = math.ceil(num_samples / batch_size)
440
+ samples_idx = np.array_split(samples_idx, sections_split)
441
+ return samples_idx
442
+
443
+
444
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
445
+ summary_writer.scalar("train_time", train_time, step)
446
+
447
+ train_metrics = get_metrics(train_metrics)
448
+ for key, vals in train_metrics.items():
449
+ tag = f"train_{key}"
450
+ for i, val in enumerate(vals):
451
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
452
+
453
+
454
+ def write_eval_metric(summary_writer, eval_metrics, step):
455
+ for metric_name, value in eval_metrics.items():
456
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
457
+
458
+
459
+ def main():
460
+ # See all possible arguments in src/transformers/training_args.py
461
+ # or by passing the --help flag to this script.
462
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
463
+
464
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
465
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
466
+ # If we pass only one argument to the script and it's the path to a json file,
467
+ # let's parse it to get our arguments.
468
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
469
+ else:
470
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
471
+
472
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
473
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
474
+ send_example_telemetry("run_bart_dlm", model_args, data_args, framework="flax")
475
+
476
+ if (
477
+ os.path.exists(training_args.output_dir)
478
+ and os.listdir(training_args.output_dir)
479
+ and training_args.do_train
480
+ and not training_args.overwrite_output_dir
481
+ ):
482
+ raise ValueError(
483
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
484
+ "Use --overwrite_output_dir to overcome."
485
+ )
486
+
487
+ # Setup logging
488
+ logging.basicConfig(
489
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
490
+ level=logging.INFO,
491
+ datefmt="[%X]",
492
+ )
493
+
494
+ # Log on each process the small summary:
495
+ logger = logging.getLogger(__name__)
496
+
497
+ # Set the verbosity to info of the Transformers logger (on main process only):
498
+ logger.info(f"Training/evaluation parameters {training_args}")
499
+
500
+ # Set seed before initializing model.
501
+ set_seed(training_args.seed)
502
+
503
+ # Handle the repository creation
504
+ if training_args.push_to_hub:
505
+ if training_args.hub_model_id is None:
506
+ repo_name = get_full_repo_name(
507
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
508
+ )
509
+ else:
510
+ repo_name = training_args.hub_model_id
511
+ create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
512
+ repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
513
+
514
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
515
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
516
+ # (the dataset will be downloaded automatically from the datasets Hub).
517
+ #
518
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
519
+ # 'text' is found. You can easily tweak this behavior (see below).
520
+ if not data_args.tokenized_dataset_filepath:
521
+ if data_args.dataset_name is not None:
522
+ # Downloading and loading a dataset from the hub.
523
+ datasets = load_dataset(
524
+ data_args.dataset_name,
525
+ data_args.dataset_config_name,
526
+ cache_dir=model_args.cache_dir,
527
+ use_auth_token=True if model_args.use_auth_token else None,
528
+ )
529
+
530
+ if "validation" not in datasets.keys():
531
+ datasets["validation"] = load_dataset(
532
+ data_args.dataset_name,
533
+ data_args.dataset_config_name,
534
+ split=f"train[:{data_args.validation_split_percentage}%]",
535
+ cache_dir=model_args.cache_dir,
536
+ use_auth_token=True if model_args.use_auth_token else None,
537
+ )
538
+ datasets["train"] = load_dataset(
539
+ data_args.dataset_name,
540
+ data_args.dataset_config_name,
541
+ split=f"train[{data_args.validation_split_percentage}%:]",
542
+ cache_dir=model_args.cache_dir,
543
+ use_auth_token=True if model_args.use_auth_token else None,
544
+ )
545
+ elif data_args.dataset_filepath is not None:
546
+ # Loading a dataset from the local dataset
547
+ datasets = load_from_disk(
548
+ data_args.dataset_filepath
549
+ )
550
+ if "validation" not in datasets.keys():
551
+ datasets = datasets.train_test_split(
552
+ test_size=data_args.validation_split_percentage/100, shuffle=True, seed=training_args.seed)
553
+ datasets["validation"] = datasets["test"]
554
+ keys_to_remove = set(datasets.keys()) - \
555
+ set(["train", "validation"])
556
+ for key in keys_to_remove:
557
+ del datasets[key]
558
+ else:
559
+ data_files = {}
560
+ if data_args.train_file is not None:
561
+ data_files["train"] = data_args.train_file
562
+ if data_args.validation_file is not None:
563
+ data_files["validation"] = data_args.validation_file
564
+ extension = data_args.train_file.split(".")[-1]
565
+ if extension == "txt":
566
+ extension = "text"
567
+ datasets = load_dataset(
568
+ extension,
569
+ data_files=data_files,
570
+ cache_dir=model_args.cache_dir,
571
+ use_auth_token=True if model_args.use_auth_token else None,
572
+ )
573
+
574
+ if "validation" not in datasets.keys():
575
+ datasets["validation"] = load_dataset(
576
+ extension,
577
+ data_files=data_files,
578
+ split=f"train[:{data_args.validation_split_percentage}%]",
579
+ cache_dir=model_args.cache_dir,
580
+ use_auth_token=True if model_args.use_auth_token else None,
581
+ )
582
+ datasets["train"] = load_dataset(
583
+ extension,
584
+ data_files=data_files,
585
+ split=f"train[{data_args.validation_split_percentage}%:]",
586
+ cache_dir=model_args.cache_dir,
587
+ use_auth_token=True if model_args.use_auth_token else None,
588
+ )
589
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
590
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
591
+ print(datasets)
592
+
593
+ # Load pretrained model and tokenizer
594
+
595
+ if model_args.tokenizer_name:
596
+ tokenizer = AutoTokenizer.from_pretrained(
597
+ model_args.tokenizer_name,
598
+ cache_dir=model_args.cache_dir,
599
+ use_fast=model_args.use_fast_tokenizer,
600
+ use_auth_token=True if model_args.use_auth_token else None,
601
+ )
602
+ elif model_args.model_name_or_path:
603
+ tokenizer = AutoTokenizer.from_pretrained(
604
+ model_args.model_name_or_path,
605
+ cache_dir=model_args.cache_dir,
606
+ use_fast=model_args.use_fast_tokenizer,
607
+ use_auth_token=True if model_args.use_auth_token else None,
608
+ )
609
+ else:
610
+ raise ValueError(
611
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
612
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
613
+ )
614
+
615
+ if model_args.config_name:
616
+ config = BartConfig.from_pretrained(
617
+ model_args.config_name,
618
+ cache_dir=model_args.cache_dir,
619
+ vocab_size=len(tokenizer),
620
+ use_auth_token=True if model_args.use_auth_token else None,
621
+ )
622
+ elif model_args.model_name_or_path:
623
+ config = BartConfig.from_pretrained(
624
+ model_args.model_name_or_path,
625
+ cache_dir=model_args.cache_dir,
626
+ use_auth_token=True if model_args.use_auth_token else None,
627
+ )
628
+ else:
629
+ config = CONFIG_MAPPING[model_args.model_type]()
630
+ logger.warning("You are instantiating a new config instance from scratch.")
631
+
632
+ if not data_args.tokenized_dataset_filepath:
633
+ # Preprocessing the datasets.
634
+ # First we tokenize all the texts.
635
+ if training_args.do_train:
636
+ column_names = datasets["train"].column_names
637
+ else:
638
+ column_names = datasets["validation"].column_names
639
+ text_column_name = "text" if "text" in column_names else column_names[0]
640
+
641
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
642
+
643
+ # Use Punkt Sentence Tokenizer to divide a document into a list of sentences
644
+ nltk.download("punkt")
645
+ sentence_tokenizer = nltk.data.load("tokenizers/punkt/finnish.pickle")
646
+
647
+ def sentence_split_function(example):
648
+ sents = sentence_tokenizer.tokenize(example["text"])
649
+ # use pad token as end of sentence indicator
650
+ new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(sents) + tokenizer.eos_token
651
+ return {"text": new_text}
652
+
653
+ split_datasets = datasets.map(
654
+ sentence_split_function,
655
+ batched=False,
656
+ num_proc=data_args.preprocessing_num_workers,
657
+ remove_columns=column_names,
658
+ load_from_cache_file=not data_args.overwrite_cache,
659
+ )
660
+
661
+ # Tokenize every text, then concatenate them together before splitting them in smaller parts.
662
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
663
+ def tokenize_function(examples):
664
+ return tokenizer(examples[text_column_name], add_special_tokens=False, return_attention_mask=False)
665
+
666
+ tokenized_datasets = split_datasets.map(
667
+ tokenize_function,
668
+ batched=True,
669
+ num_proc=data_args.preprocessing_num_workers,
670
+ remove_columns=text_column_name,
671
+ load_from_cache_file=not data_args.overwrite_cache,
672
+ )
673
+
674
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
675
+ # max_seq_length.
676
+ def group_texts(examples):
677
+ # Concatenate all texts.
678
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
679
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
680
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
681
+ # customize this part to your needs.
682
+ if total_length >= max_seq_length:
683
+ total_length = (total_length // max_seq_length) * max_seq_length
684
+ # Split by chunks of max_len.
685
+ result = {
686
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
687
+ for k, t in concatenated_examples.items()
688
+ }
689
+ return result
690
+
691
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
692
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
693
+ # might be slower to preprocess.
694
+ #
695
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
696
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
697
+ tokenized_datasets = tokenized_datasets.map(
698
+ group_texts,
699
+ batched=True,
700
+ num_proc=data_args.preprocessing_num_workers,
701
+ load_from_cache_file=not data_args.overwrite_cache,
702
+ )
703
+
704
+ tokenized_datasets.save_to_disk("/researchdisk/lm_training_dataset_tokenized")
705
+ else:
706
+ tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
707
+
708
+ # Enable tensorboard only on the master node
709
+ has_tensorboard = is_tensorboard_available()
710
+ if has_tensorboard and jax.process_index() == 0:
711
+ try:
712
+ from flax.metrics.tensorboard import SummaryWriter
713
+
714
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
715
+ except ImportError as ie:
716
+ has_tensorboard = False
717
+ logger.warning(
718
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
719
+ )
720
+ else:
721
+ logger.warning(
722
+ "Unable to display metrics through TensorBoard because the package is not installed: "
723
+ "Please run pip install tensorboard to enable."
724
+ )
725
+
726
+ # Initialize our training
727
+ rng = jax.random.PRNGKey(training_args.seed)
728
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
729
+
730
+ if model_args.model_name_or_path:
731
+ model = FlaxBartForConditionalGeneration.from_pretrained(
732
+ model_args.model_name_or_path,
733
+ config=config,
734
+ seed=training_args.seed,
735
+ dtype=getattr(jnp, model_args.dtype),
736
+ use_auth_token=True if model_args.use_auth_token else None,
737
+ )
738
+ else:
739
+ config.vocab_size = len(tokenizer)
740
+ model = FlaxBartForConditionalGeneration(
741
+ config,
742
+ seed=training_args.seed,
743
+ dtype=getattr(jnp, model_args.dtype),
744
+ )
745
+
746
+ # Data collator
747
+ # This one will take care of randomly masking the tokens and permuting the sentences.
748
+ data_collator = FlaxDataCollatorForBartDenoisingLM(
749
+ tokenizer=tokenizer,
750
+ decoder_start_token_id=model.config.decoder_start_token_id,
751
+ mask_ratio=data_args.mlm_probability,
752
+ poisson_lambda=data_args.poisson_lambda,
753
+ permute_sentence_ratio=data_args.permute_sentence_ratio,
754
+ )
755
+
756
+ # Store some constant
757
+ num_epochs = int(training_args.num_train_epochs)
758
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
759
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
760
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
761
+
762
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
763
+
764
+ # Create learning rate schedule
765
+ warmup_fn = optax.linear_schedule(
766
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
767
+ )
768
+ decay_fn = optax.linear_schedule(
769
+ init_value=training_args.learning_rate,
770
+ end_value=0,
771
+ transition_steps=num_train_steps - training_args.warmup_steps,
772
+ )
773
+ linear_decay_lr_schedule_fn = optax.join_schedules(
774
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
775
+ )
776
+
777
+ # We use Optax's "masking" functionality to not apply weight decay
778
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
779
+ # mask boolean with the same structure as the parameters.
780
+ # The mask is True for parameters that should be decayed.
781
+ def decay_mask_fn(params):
782
+ flat_params = traverse_util.flatten_dict(params)
783
+ # find out all LayerNorm parameters
784
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
785
+ layer_norm_named_params = set(
786
+ [
787
+ layer[-2:]
788
+ for layer_norm_name in layer_norm_candidates
789
+ for layer in flat_params.keys()
790
+ if layer_norm_name in "".join(layer).lower()
791
+ ]
792
+ )
793
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
794
+ return traverse_util.unflatten_dict(flat_mask)
795
+
796
+ # create adam optimizer
797
+ if training_args.adafactor:
798
+ # We use the default parameters here to initialize adafactor,
799
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
800
+ optimizer = optax.adafactor(
801
+ learning_rate=linear_decay_lr_schedule_fn,
802
+ )
803
+ else:
804
+ optimizer = optax.adamw(
805
+ learning_rate=linear_decay_lr_schedule_fn,
806
+ b1=training_args.adam_beta1,
807
+ b2=training_args.adam_beta2,
808
+ weight_decay=training_args.weight_decay,
809
+ mask=decay_mask_fn,
810
+ )
811
+
812
+ # Setup train state
813
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
814
+
815
+ # Define gradient update step fn
816
+ def train_step(state, batch, dropout_rng):
817
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
818
+
819
+ def loss_fn(params):
820
+ labels = batch.pop("labels")
821
+
822
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
823
+
824
+ # compute loss, ignore padded input tokens and special tokens
825
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
826
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
827
+
828
+ # take average
829
+ loss = loss.sum()
830
+ num_labels = label_mask.sum()
831
+
832
+ return loss, num_labels
833
+
834
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
835
+ (loss, num_labels), grad = grad_fn(state.params)
836
+ num_labels = jax.lax.psum(num_labels, "batch")
837
+
838
+ # true loss = total loss / total samples
839
+ loss = jax.lax.psum(loss, "batch")
840
+ loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
841
+
842
+ # true grad = total grad / total samples
843
+ grad = jax.lax.psum(grad, "batch")
844
+ grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
845
+ new_state = state.apply_gradients(grads=grad)
846
+
847
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
848
+ return new_state, metrics, new_dropout_rng
849
+
850
+ # Create parallel version of the train step
851
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
852
+
853
+ # Define eval fn
854
+ def eval_step(params, batch):
855
+ labels = batch.pop("labels")
856
+
857
+ logits = model(**batch, params=params, train=False)[0]
858
+
859
+ # compute loss, ignore padded input tokens and special tokens
860
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
861
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
862
+
863
+ # compute accuracy
864
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
865
+
866
+ # summarize metrics
867
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
868
+ metrics = jax.lax.psum(metrics, axis_name="batch")
869
+
870
+ return metrics
871
+
872
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
873
+
874
+ # Replicate the train state on each device
875
+ state = jax_utils.replicate(state)
876
+
877
+ train_time = 0
878
+ epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
879
+ for epoch in epochs:
880
+ # ======================== Training ================================
881
+ train_start = time.time()
882
+ train_metrics = []
883
+
884
+ # Create sampling rng
885
+ rng, input_rng = jax.random.split(rng)
886
+
887
+ # Generate an epoch by shuffling sampling indices from the train dataset
888
+ num_train_samples = len(tokenized_datasets["train"])
889
+ # Avoid using jax.numpy here in case of TPU training
890
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
891
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
892
+
893
+ # Gather the indexes for creating the batch and do a training step
894
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
895
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
896
+ model_inputs = data_collator(samples)
897
+
898
+ # Model forward
899
+ model_inputs = shard(model_inputs.data)
900
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
901
+ train_metrics.append(train_metric)
902
+
903
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
904
+
905
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
906
+ # Save metrics
907
+ train_metric = jax_utils.unreplicate(train_metric)
908
+ train_time += time.time() - train_start
909
+ if has_tensorboard and jax.process_index() == 0:
910
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
911
+
912
+ epochs.write(
913
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
914
+ f" {train_metric['learning_rate']})"
915
+ )
916
+
917
+ train_metrics = []
918
+
919
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
920
+ # ======================== Evaluating ==============================
921
+ num_eval_samples = len(tokenized_datasets["validation"])
922
+ # Avoid using jax.numpy here in case of TPU training
923
+ eval_samples_idx = np.arange(num_eval_samples)
924
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
925
+
926
+ eval_metrics = []
927
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
928
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
929
+ model_inputs = data_collator(samples)
930
+
931
+ # Model forward
932
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
933
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
934
+ )
935
+ eval_metrics.append(metrics)
936
+
937
+ # normalize eval metrics
938
+ eval_metrics = get_metrics(eval_metrics)
939
+ eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
940
+ eval_normalizer = eval_metrics.pop("normalizer")
941
+ eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
942
+
943
+ try:
944
+ perplexity = math.exp(eval_metrics["loss"])
945
+ except OverflowError:
946
+ perplexity = float("inf")
947
+ eval_metrics["perplexity"] = perplexity
948
+
949
+ # Update progress bar
950
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
951
+
952
+ # Save metrics
953
+ if has_tensorboard and jax.process_index() == 0:
954
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
955
+
956
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
957
+ # save checkpoint after each epoch and push checkpoint to the hub
958
+ if jax.process_index() == 0:
959
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
960
+ model.save_pretrained(training_args.output_dir, params=params)
961
+ tokenizer.save_pretrained(training_args.output_dir)
962
+ if training_args.push_to_hub:
963
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
964
+
965
+ # Eval after training
966
+ if training_args.do_eval:
967
+ num_eval_samples = len(tokenized_datasets["validation"])
968
+ # Avoid using jax.numpy here in case of TPU training
969
+ eval_samples_idx = np.arange(num_eval_samples)
970
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
971
+
972
+ eval_metrics = []
973
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
974
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
975
+ model_inputs = data_collator(samples)
976
+
977
+ # Model forward
978
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
979
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
980
+ )
981
+ eval_metrics.append(metrics)
982
+
983
+ # normalize eval metrics
984
+ eval_metrics = get_metrics(eval_metrics)
985
+ eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
986
+ eval_normalizer = eval_metrics.pop("normalizer")
987
+ eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
988
+
989
+ try:
990
+ perplexity = math.exp(eval_metrics["loss"])
991
+ except OverflowError:
992
+ perplexity = float("inf")
993
+ eval_metrics["perplexity"] = perplexity
994
+
995
+ if jax.process_index() == 0:
996
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
997
+ path = os.path.join(training_args.output_dir, "eval_results.json")
998
+ with open(path, "w") as f:
999
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
1000
+
1001
+
1002
+ if __name__ == "__main__":
1003
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
start_train.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_bart_dlm_flax.py \
2
+ --output_dir="./" \
3
+ --config_name="./" \
4
+ --tokenizer_name="./" \
5
+ --tokenized_dataset_filepath="/researchdisk/lm_training_dataset_tokenized" \
6
+ --preprocessing_num_workers="96" \
7
+ --max_seq_length="1024" \
8
+ --per_device_train_batch_size="16" \
9
+ --per_device_eval_batch_size="16" \
10
+ --learning_rate="4e-4" \
11
+ --weight_decay="0.01" \
12
+ --warmup_steps="10000" \
13
+ --overwrite_output_dir \
14
+ --num_train_epochs="5" \
15
+ --logging_steps="500" \
16
+ --save_steps="10000" \
17
+ --eval_steps="10000" \
18
+ --dtype="bfloat16" \
19
+ --use_auth_token \
20
+ --hub_model_id="Finnish-NLP/bart-small-finnish" \
21
+ --push_to_hub
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<s>",
4
+ "cls_token": "<s>",
5
+ "eos_token": "</s>",
6
+ "errors": "replace",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "name_or_path": "./",
10
+ "pad_token": "<pad>",
11
+ "sep_token": "</s>",
12
+ "special_tokens_map_file": null,
13
+ "tokenizer_class": "BartTokenizer",
14
+ "trim_offsets": true,
15
+ "unk_token": "<unk>"
16
+ }
train_tokenizer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+ from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
3
+
4
+ # load dataset
5
+ dataset = load_from_disk("/researchdisk/lm_training_dataset_full")["train"]
6
+
7
+ # Instantiate tokenizer
8
+ tokenizer = ByteLevelBPETokenizer()
9
+
10
+ def batch_iterator(batch_size=5000):
11
+ for i in range(0, len(dataset), batch_size):
12
+ yield dataset[i: i + batch_size]["text"]
13
+
14
+ # Customized training
15
+ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
16
+ "<s>",
17
+ "<pad>",
18
+ "</s>",
19
+ "<unk>",
20
+ "<mask>",
21
+ ])
22
+
23
+ # Save files to disk
24
+ tokenizer.save("./tokenizer.json")
vocab.json ADDED
The diff for this file is too large to render. See raw diff