aapot commited on
Commit
0b21f29
1 Parent(s): b741880

Add tokenizer

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-large",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.13.0.dev0",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
run_mlm_flax.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+
29
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset, load_from_disk
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from huggingface_hub import Repository
45
+ from transformers import (
46
+ CONFIG_MAPPING,
47
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
48
+ AutoConfig,
49
+ AutoTokenizer,
50
+ FlaxAutoModelForMaskedLM,
51
+ HfArgumentParser,
52
+ PreTrainedTokenizerBase,
53
+ TensorType,
54
+ TrainingArguments,
55
+ is_tensorboard_available,
56
+ set_seed,
57
+ )
58
+ from transformers.file_utils import get_full_repo_name
59
+
60
+
61
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
62
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
63
+
64
+
65
+ @dataclass
66
+ class ModelArguments:
67
+ """
68
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
69
+ """
70
+
71
+ model_name_or_path: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": "The model checkpoint for weights initialization."
75
+ "Don't set if you want to train a model from scratch."
76
+ },
77
+ )
78
+ model_type: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
81
+ )
82
+ config_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
84
+ )
85
+ tokenizer_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ dtype: Optional[str] = field(
96
+ default="float32",
97
+ metadata={
98
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
99
+ },
100
+ )
101
+
102
+
103
+ @dataclass
104
+ class DataTrainingArguments:
105
+ """
106
+ Arguments pertaining to what data we are going to input our model for training and eval.
107
+ """
108
+
109
+ dataset_name: Optional[str] = field(
110
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
111
+ )
112
+ dataset_config_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
114
+ )
115
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
116
+ dataset_filepath: Optional[str] = field(
117
+ default=None, metadata={"help": "Filepath to locally saved HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
118
+ )
119
+ save_tokenized_dataset_filepath: Optional[str] = field(
120
+ default=None, metadata={"help": "Filepath for saving tokenized HF Dataset (with 'dataset.save_to_disk' method) to use for future trainings"}
121
+ )
122
+ tokenized_dataset_filepath: Optional[str] = field(
123
+ default=None, metadata={"help": "Filepath to locally saved tokenized HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
124
+ )
125
+ validation_file: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
128
+ )
129
+ train_ref_file: Optional[str] = field(
130
+ default=None,
131
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
132
+ )
133
+ validation_ref_file: Optional[str] = field(
134
+ default=None,
135
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
136
+ )
137
+ overwrite_cache: bool = field(
138
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
139
+ )
140
+ validation_split_percentage: Optional[int] = field(
141
+ default=5,
142
+ metadata={
143
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
144
+ },
145
+ )
146
+ max_seq_length: Optional[int] = field(
147
+ default=None,
148
+ metadata={
149
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
150
+ "than this will be truncated. Default to the max input length of the model."
151
+ },
152
+ )
153
+ preprocessing_num_workers: Optional[int] = field(
154
+ default=None,
155
+ metadata={"help": "The number of processes to use for the preprocessing."},
156
+ )
157
+ mlm_probability: float = field(
158
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
159
+ )
160
+ pad_to_max_length: bool = field(
161
+ default=False,
162
+ metadata={
163
+ "help": "Whether to pad all samples to `max_seq_length`. "
164
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
165
+ },
166
+ )
167
+ line_by_line: bool = field(
168
+ default=False,
169
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
170
+ )
171
+
172
+ def __post_init__(self):
173
+ if self.dataset_name is None and self.train_file is None and self.dataset_filepath is None and self.validation_file is None:
174
+ raise ValueError("Need either a dataset name or a training/validation file.")
175
+ else:
176
+ if self.train_file is not None:
177
+ extension = self.train_file.split(".")[-1]
178
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
179
+ if self.validation_file is not None:
180
+ extension = self.validation_file.split(".")[-1]
181
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
182
+
183
+
184
+ @flax.struct.dataclass
185
+ class FlaxDataCollatorForLanguageModeling:
186
+ """
187
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
188
+ are not all of the same length.
189
+
190
+ Args:
191
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
192
+ The tokenizer used for encoding the data.
193
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
194
+ The probability with which to (randomly) mask tokens in the input.
195
+
196
+ .. note::
197
+
198
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
199
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
200
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
201
+ argument :obj:`return_special_tokens_mask=True`.
202
+ """
203
+
204
+ tokenizer: PreTrainedTokenizerBase
205
+ mlm_probability: float = 0.15
206
+
207
+ def __post_init__(self):
208
+ if self.tokenizer.mask_token is None:
209
+ raise ValueError(
210
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
211
+ "You should pass `mlm=False` to train on causal language modeling instead."
212
+ )
213
+
214
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
215
+ # Handle dict or lists with proper padding and conversion to tensor.
216
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
217
+
218
+ # If special token mask has been preprocessed, pop it from the dict.
219
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
220
+
221
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
222
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
223
+ )
224
+ return batch
225
+
226
+ def mask_tokens(
227
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
228
+ ) -> Tuple[np.ndarray, np.ndarray]:
229
+ """
230
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
231
+ """
232
+ labels = inputs.copy()
233
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
234
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
235
+ special_tokens_mask = special_tokens_mask.astype("bool")
236
+
237
+ probability_matrix[special_tokens_mask] = 0.0
238
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
239
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
240
+
241
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
242
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
243
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
244
+
245
+ # 10% of the time, we replace masked input tokens with random word
246
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
247
+ indices_random &= masked_indices & ~indices_replaced
248
+
249
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
250
+ inputs[indices_random] = random_words[indices_random]
251
+
252
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
253
+ return inputs, labels
254
+
255
+
256
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
257
+ num_samples = len(samples_idx)
258
+ samples_to_remove = num_samples % batch_size
259
+
260
+ if samples_to_remove != 0:
261
+ samples_idx = samples_idx[:-samples_to_remove]
262
+ sections_split = num_samples // batch_size
263
+ batch_idx = np.split(samples_idx, sections_split)
264
+ return batch_idx
265
+
266
+
267
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
268
+ summary_writer.scalar("train_time", train_time, step)
269
+
270
+ train_metrics = get_metrics(train_metrics)
271
+ for key, vals in train_metrics.items():
272
+ tag = f"train_{key}"
273
+ for i, val in enumerate(vals):
274
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
275
+
276
+
277
+ def write_eval_metric(summary_writer, eval_metrics, step):
278
+ for metric_name, value in eval_metrics.items():
279
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
280
+
281
+
282
+ if __name__ == "__main__":
283
+ # See all possible arguments in src/transformers/training_args.py
284
+ # or by passing the --help flag to this script.
285
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
286
+
287
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
288
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
289
+ # If we pass only one argument to the script and it's the path to a json file,
290
+ # let's parse it to get our arguments.
291
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
292
+ else:
293
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
294
+
295
+ if (
296
+ os.path.exists(training_args.output_dir)
297
+ and os.listdir(training_args.output_dir)
298
+ and training_args.do_train
299
+ and not training_args.overwrite_output_dir
300
+ ):
301
+ raise ValueError(
302
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
303
+ "Use --overwrite_output_dir to overcome."
304
+ )
305
+
306
+ # Setup logging
307
+ logging.basicConfig(
308
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
309
+ level="NOTSET",
310
+ datefmt="[%X]",
311
+ )
312
+
313
+ # Log on each process the small summary:
314
+ logger = logging.getLogger(__name__)
315
+
316
+ # Set the verbosity to info of the Transformers logger (on main process only):
317
+ logger.info(f"Training/evaluation parameters {training_args}")
318
+
319
+ # Set seed before initializing model.
320
+ set_seed(training_args.seed)
321
+
322
+ # Handle the repository creation
323
+ if training_args.push_to_hub:
324
+ if training_args.hub_model_id is None:
325
+ repo_name = get_full_repo_name(
326
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
327
+ )
328
+ else:
329
+ repo_name = training_args.hub_model_id
330
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
331
+
332
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
333
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
334
+ # (the dataset will be downloaded automatically from the datasets Hub).
335
+ #
336
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
337
+ # 'text' is found. You can easily tweak this behavior (see below).
338
+ #
339
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
340
+ # download the dataset.
341
+ if data_args.dataset_name is not None:
342
+ # Downloading and loading a dataset from the hub.
343
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
344
+
345
+ if "validation" not in datasets.keys():
346
+ datasets["validation"] = load_dataset(
347
+ data_args.dataset_name,
348
+ data_args.dataset_config_name,
349
+ split=f"train[:{data_args.validation_split_percentage}%]",
350
+ cache_dir=model_args.cache_dir,
351
+ )
352
+ datasets["train"] = load_dataset(
353
+ data_args.dataset_name,
354
+ data_args.dataset_config_name,
355
+ split=f"train[{data_args.validation_split_percentage}%:]",
356
+ cache_dir=model_args.cache_dir,
357
+ )
358
+
359
+ elif data_args.dataset_filepath is not None:
360
+ # Loading a dataset from local file.
361
+ datasets = load_from_disk(data_args.dataset_filepath)
362
+ if "validation" not in datasets.keys():
363
+ datasets = datasets.train_test_split(test_size=data_args.validation_split_percentage/100)
364
+ datasets["validation"] = datasets["test"]
365
+ del datasets["test"]
366
+
367
+ else:
368
+ data_files = {}
369
+ if data_args.train_file is not None:
370
+ data_files["train"] = data_args.train_file
371
+ if data_args.validation_file is not None:
372
+ data_files["validation"] = data_args.validation_file
373
+ extension = data_args.train_file.split(".")[-1]
374
+ if extension == "txt":
375
+ extension = "text"
376
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
377
+
378
+ if "validation" not in datasets.keys():
379
+ datasets["validation"] = load_dataset(
380
+ extension,
381
+ data_files=data_files,
382
+ split=f"train[:{data_args.validation_split_percentage}%]",
383
+ cache_dir=model_args.cache_dir,
384
+ )
385
+ datasets["train"] = load_dataset(
386
+ extension,
387
+ data_files=data_files,
388
+ split=f"train[{data_args.validation_split_percentage}%:]",
389
+ cache_dir=model_args.cache_dir,
390
+ )
391
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
392
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
393
+
394
+ # Load pretrained model and tokenizer
395
+
396
+ # Distributed training:
397
+ # The .from_pretrained methods guarantee that only one local process can concurrently
398
+ # download model & vocab.
399
+ if model_args.config_name:
400
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
401
+ elif model_args.model_name_or_path:
402
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
403
+ else:
404
+ config = CONFIG_MAPPING[model_args.model_type]()
405
+ logger.warning("You are instantiating a new config instance from scratch.")
406
+
407
+ if model_args.tokenizer_name:
408
+ tokenizer = AutoTokenizer.from_pretrained(
409
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
410
+ )
411
+ elif model_args.model_name_or_path:
412
+ tokenizer = AutoTokenizer.from_pretrained(
413
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
414
+ )
415
+ else:
416
+ raise ValueError(
417
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
418
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
419
+ )
420
+
421
+ # Preprocessing the datasets.
422
+ # First we tokenize all the texts.
423
+ if training_args.do_train:
424
+ column_names = datasets["train"].column_names
425
+ else:
426
+ column_names = datasets["validation"].column_names
427
+ text_column_name = "text" if "text" in column_names else column_names[0]
428
+
429
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
430
+
431
+ if data_args.line_by_line:
432
+ # When using line_by_line, we just tokenize each nonempty line.
433
+ padding = "max_length" if data_args.pad_to_max_length else False
434
+
435
+ def tokenize_function(examples):
436
+ # Remove empty lines
437
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
438
+ return tokenizer(
439
+ examples,
440
+ return_special_tokens_mask=True,
441
+ padding=padding,
442
+ truncation=True,
443
+ max_length=max_seq_length,
444
+ )
445
+
446
+ if data_args.tokenized_dataset_filepath is not None:
447
+ # Loading a tokenized dataset from local file.
448
+ tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
449
+ else:
450
+ tokenized_datasets = datasets.map(
451
+ tokenize_function,
452
+ input_columns=[text_column_name],
453
+ batched=True,
454
+ num_proc=data_args.preprocessing_num_workers,
455
+ remove_columns=column_names,
456
+ load_from_cache_file=not data_args.overwrite_cache,
457
+ )
458
+
459
+ else:
460
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
461
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
462
+ # efficient when it receives the `special_tokens_mask`.
463
+ def tokenize_function(examples):
464
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
465
+
466
+ if data_args.tokenized_dataset_filepath is not None:
467
+ # Loading a tokenized dataset from local file.
468
+ tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
469
+ else:
470
+
471
+ tokenized_datasets = datasets.map(
472
+ tokenize_function,
473
+ batched=True,
474
+ num_proc=data_args.preprocessing_num_workers,
475
+ remove_columns=column_names,
476
+ load_from_cache_file=not data_args.overwrite_cache,
477
+ )
478
+
479
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
480
+ # max_seq_length.
481
+ def group_texts(examples):
482
+ # Concatenate all texts.
483
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
484
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
485
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
486
+ # customize this part to your needs.
487
+ if total_length >= max_seq_length:
488
+ total_length = (total_length // max_seq_length) * max_seq_length
489
+ # Split by chunks of max_len.
490
+ result = {
491
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
492
+ for k, t in concatenated_examples.items()
493
+ }
494
+ return result
495
+
496
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
497
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
498
+ # might be slower to preprocess.
499
+ #
500
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
501
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
502
+ tokenized_datasets = tokenized_datasets.map(
503
+ group_texts,
504
+ batched=True,
505
+ num_proc=data_args.preprocessing_num_workers,
506
+ load_from_cache_file=not data_args.overwrite_cache,
507
+ )
508
+
509
+ # save the tokenized dataset for future runs
510
+ if data_args.save_tokenized_dataset_filepath is not None:
511
+ if data_args.dataset_filepath is not None:
512
+ try:
513
+ os.system(f"sudo rm {data_args.dataset_filepath}/train/cache*")
514
+ os.system(f"sudo rm {data_args.dataset_filepath}/validation/cache*")
515
+ os.system(f"sudo rm {data_args.dataset_filepath}/train/tmp*")
516
+ os.system(f"sudo rm {data_args.dataset_filepath}/validation/tmp*")
517
+ except:
518
+ pass
519
+ tokenized_datasets.save_to_disk(data_args.save_tokenized_dataset_filepath)
520
+
521
+
522
+ # Enable tensorboard only on the master node
523
+ has_tensorboard = is_tensorboard_available()
524
+ if has_tensorboard and jax.process_index() == 0:
525
+ try:
526
+ from flax.metrics.tensorboard import SummaryWriter
527
+
528
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
529
+ except ImportError as ie:
530
+ has_tensorboard = False
531
+ logger.warning(
532
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
533
+ )
534
+ else:
535
+ logger.warning(
536
+ "Unable to display metrics through TensorBoard because the package is not installed: "
537
+ "Please run pip install tensorboard to enable."
538
+ )
539
+
540
+ # Data collator
541
+ # This one will take care of randomly masking the tokens.
542
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
543
+
544
+ # Initialize our training
545
+ rng = jax.random.PRNGKey(training_args.seed)
546
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
547
+
548
+ if model_args.model_name_or_path:
549
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
550
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
551
+ )
552
+ else:
553
+ model = FlaxAutoModelForMaskedLM.from_config(
554
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
555
+ )
556
+
557
+ # Store some constant
558
+ num_epochs = int(training_args.num_train_epochs)
559
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
560
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
561
+
562
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
563
+
564
+ # Create learning rate schedule
565
+ warmup_fn = optax.linear_schedule(
566
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
567
+ )
568
+ decay_fn = optax.linear_schedule(
569
+ init_value=training_args.learning_rate,
570
+ end_value=0,
571
+ transition_steps=num_train_steps - training_args.warmup_steps,
572
+ )
573
+ linear_decay_lr_schedule_fn = optax.join_schedules(
574
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
575
+ )
576
+
577
+ # We use Optax's "masking" functionality to not apply weight decay
578
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
579
+ # mask boolean with the same structure as the parameters.
580
+ # The mask is True for parameters that should be decayed.
581
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
582
+ # For other models, one should correct the layer norm parameter naming
583
+ # accordingly.
584
+ def decay_mask_fn(params):
585
+ flat_params = traverse_util.flatten_dict(params)
586
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
587
+ return traverse_util.unflatten_dict(flat_mask)
588
+
589
+ # create adam optimizer
590
+ if training_args.adafactor:
591
+ # We use the default parameters here to initialize adafactor,
592
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
593
+ optimizer = optax.adafactor(
594
+ learning_rate=linear_decay_lr_schedule_fn,
595
+ )
596
+ else:
597
+ optimizer = optax.adamw(
598
+ learning_rate=linear_decay_lr_schedule_fn,
599
+ b1=training_args.adam_beta1,
600
+ b2=training_args.adam_beta2,
601
+ eps=training_args.adam_epsilon,
602
+ weight_decay=training_args.weight_decay,
603
+ mask=decay_mask_fn,
604
+ )
605
+
606
+ # Setup train state
607
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
608
+
609
+ # Define gradient update step fn
610
+ def train_step(state, batch, dropout_rng):
611
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
612
+
613
+ def loss_fn(params):
614
+ labels = batch.pop("labels")
615
+
616
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
617
+
618
+ # compute loss, ignore padded input tokens
619
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
620
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
621
+
622
+ # take average
623
+ loss = loss.sum() / label_mask.sum()
624
+
625
+ return loss
626
+
627
+ grad_fn = jax.value_and_grad(loss_fn)
628
+ loss, grad = grad_fn(state.params)
629
+ grad = jax.lax.pmean(grad, "batch")
630
+ new_state = state.apply_gradients(grads=grad)
631
+
632
+ metrics = jax.lax.pmean(
633
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
634
+ )
635
+
636
+ return new_state, metrics, new_dropout_rng
637
+
638
+ # Create parallel version of the train step
639
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
640
+
641
+ # Define eval fn
642
+ def eval_step(params, batch):
643
+ labels = batch.pop("labels")
644
+
645
+ logits = model(**batch, params=params, train=False)[0]
646
+
647
+ # compute loss, ignore padded input tokens
648
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
649
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
650
+
651
+ # compute accuracy
652
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
653
+
654
+ # summarize metrics
655
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
656
+ metrics = jax.lax.psum(metrics, axis_name="batch")
657
+
658
+ return metrics
659
+
660
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
661
+
662
+ # Replicate the train state on each device
663
+ state = jax_utils.replicate(state)
664
+
665
+ train_time = 0
666
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
667
+ for epoch in epochs:
668
+ # ======================== Training ================================
669
+ train_start = time.time()
670
+ train_metrics = []
671
+
672
+ # Create sampling rng
673
+ rng, input_rng = jax.random.split(rng)
674
+
675
+ # Generate an epoch by shuffling sampling indices from the train dataset
676
+ num_train_samples = len(tokenized_datasets["train"])
677
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
678
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
679
+
680
+ # Gather the indexes for creating the batch and do a training step
681
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
682
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
683
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
684
+
685
+ # Model forward
686
+ model_inputs = shard(model_inputs.data)
687
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
688
+ train_metrics.append(train_metric)
689
+
690
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
691
+
692
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
693
+ # Save metrics
694
+ train_metric = jax_utils.unreplicate(train_metric)
695
+ train_time += time.time() - train_start
696
+ if has_tensorboard and jax.process_index() == 0:
697
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
698
+
699
+ epochs.write(
700
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
701
+ )
702
+
703
+ train_metrics = []
704
+
705
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
706
+ # ======================== Evaluating ==============================
707
+ num_eval_samples = len(tokenized_datasets["validation"])
708
+ eval_samples_idx = jnp.arange(num_eval_samples)
709
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
710
+
711
+ eval_metrics = []
712
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
713
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
714
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
715
+
716
+ # Model forward
717
+ model_inputs = shard(model_inputs.data)
718
+ metrics = p_eval_step(state.params, model_inputs)
719
+ eval_metrics.append(metrics)
720
+
721
+ # normalize eval metrics
722
+ eval_metrics = get_metrics(eval_metrics)
723
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
724
+ eval_normalizer = eval_metrics.pop("normalizer")
725
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
726
+
727
+ # Update progress bar
728
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
729
+
730
+ # Save metrics
731
+ if has_tensorboard and jax.process_index() == 0:
732
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
733
+
734
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
735
+ # save checkpoint after each epoch and push checkpoint to the hub
736
+ if jax.process_index() == 0:
737
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
738
+ model.save_pretrained(training_args.output_dir, params=params)
739
+ tokenizer.save_pretrained(training_args.output_dir)
740
+ if training_args.push_to_hub:
741
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
742
+
743
+ # save also at the end of epoch
744
+ try:
745
+ if jax.process_index() == 0:
746
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
747
+ model.save_pretrained(training_args.output_dir, params=params)
748
+ tokenizer.save_pretrained(training_args.output_dir)
749
+ if training_args.push_to_hub:
750
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
751
+ except:
752
+ # push to hub fails the whole script if nothing new to commit
753
+ pass
754
+
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
start_train.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+ python3 run_mlm_flax.py \
4
+ --output_dir="./" \
5
+ --model_type="roberta" \
6
+ --config_name="./" \
7
+ --tokenizer_name="./" \
8
+ --dataset_filepath="/researchdisk/training_dataset_full" \
9
+ --save_tokenized_dataset_filepath="/researchdisk/training_dataset_full_tokenized_128" \
10
+ --max_seq_length="128" \
11
+ --pad_to_max_length \
12
+ --preprocessing_num_workers="96" \
13
+ --per_device_train_batch_size="64" \
14
+ --per_device_eval_batch_size="64" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --adam_epsilon="1e-6" \
18
+ --learning_rate="2e-4" \
19
+ --warmup_steps="1500" \
20
+ --overwrite_output_dir \
21
+ --num_train_epochs="2" \
22
+ --save_strategy="steps" \
23
+ --save_steps="10000" \
24
+ --save_total_limit="5" \
25
+ --eval_steps="10000" \
26
+ --logging_steps="1000" \
27
+ --dtype="bfloat16" \
28
+ --push_to_hub \
29
+ --hub_model_id="Finnish-NLP/roberta-large-finnish-v2" \
30
+ --adafactor
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "RobertaTokenizer"}
train_tokenizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+ from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
3
+ from transformers import AutoConfig, AutoTokenizer
4
+
5
+
6
+ model_dir = "./"
7
+
8
+ # load roberta-large config
9
+ config = AutoConfig.from_pretrained("roberta-large")
10
+ config.save_pretrained(model_dir)
11
+
12
+ # load dataset
13
+ dataset = load_from_disk("/researchdisk/training_dataset_full")
14
+ dataset = dataset["train"]
15
+
16
+ # Instantiate tokenizer
17
+ tokenizer = ByteLevelBPETokenizer()
18
+ def batch_iterator(batch_size=1000):
19
+ for i in range(0, len(dataset), batch_size):
20
+ yield dataset[i: i + batch_size]["text"]
21
+
22
+ # Customized training
23
+ tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[
24
+ "<s>",
25
+ "<pad>",
26
+ "</s>",
27
+ "<unk>",
28
+ "<mask>",
29
+ ])
30
+
31
+ # Save files to disk
32
+ tokenizer.save(f"{model_dir}/tokenizer.json")
33
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
34
+ tokenizer.save_pretrained(model_dir)
vocab.json ADDED
The diff for this file is too large to render. See raw diff