imthanhlv commited on
Commit
07119aa
1 Parent(s): 02b48e7

added tokenizer

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +2 -0
  3. run_t5_mlm_flax.py +799 -0
  4. t5_tokenizer_model.py +112 -0
  5. tokenizer.json +0 -0
  6. train_tokenizer.py +32 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ __pycache__/
README.md ADDED
@@ -0,0 +1,2 @@
 
 
1
+ # T5 Vietnamese pretrain on news corpus
2
+
run_t5_mlm_flax.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
+ https://huggingface.co/models?filter=t5
21
+ """
22
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Dict, List, Optional
30
+
31
+ import numpy as np
32
+ from datasets import load_dataset
33
+ from tqdm import tqdm
34
+
35
+ import flax
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import optax
39
+ from flax import jax_utils, traverse_util
40
+ from flax.training import train_state
41
+ from flax.training.common_utils import get_metrics, onehot, shard
42
+ from transformers import (
43
+ CONFIG_MAPPING,
44
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
45
+ AutoTokenizer,
46
+ BatchEncoding,
47
+ FlaxT5ForConditionalGeneration,
48
+ HfArgumentParser,
49
+ PreTrainedTokenizerBase,
50
+ T5Config,
51
+ TrainingArguments,
52
+ is_tensorboard_available,
53
+ set_seed,
54
+ )
55
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
56
+
57
+
58
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
59
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
60
+
61
+
62
+ @dataclass
63
+ class ModelArguments:
64
+ """
65
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
66
+ """
67
+
68
+ model_name_or_path: Optional[str] = field(
69
+ default=None,
70
+ metadata={
71
+ "help": "The model checkpoint for weights initialization."
72
+ "Don't set if you want to train a model from scratch."
73
+ },
74
+ )
75
+ model_type: Optional[str] = field(
76
+ default=None,
77
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ cache_dir: Optional[str] = field(
86
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
87
+ )
88
+ use_fast_tokenizer: bool = field(
89
+ default=True,
90
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
91
+ )
92
+ dtype: Optional[str] = field(
93
+ default="float32",
94
+ metadata={
95
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
96
+ },
97
+ )
98
+
99
+
100
+ @dataclass
101
+ class DataTrainingArguments:
102
+ """
103
+ Arguments pertaining to what data we are going to input our model for training and eval.
104
+ """
105
+
106
+ dataset_name: Optional[str] = field(
107
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
108
+ )
109
+ dataset_config_name: Optional[str] = field(
110
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
111
+ )
112
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
113
+ validation_file: Optional[str] = field(
114
+ default=None,
115
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
116
+ )
117
+ train_ref_file: Optional[str] = field(
118
+ default=None,
119
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
120
+ )
121
+ validation_ref_file: Optional[str] = field(
122
+ default=None,
123
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
124
+ )
125
+ overwrite_cache: bool = field(
126
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
127
+ )
128
+ validation_split_percentage: Optional[int] = field(
129
+ default=5,
130
+ metadata={
131
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
132
+ },
133
+ )
134
+ max_seq_length: Optional[int] = field(
135
+ default=None,
136
+ metadata={
137
+ "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
138
+ },
139
+ )
140
+ preprocessing_num_workers: Optional[int] = field(
141
+ default=None,
142
+ metadata={"help": "The number of processes to use for the preprocessing."},
143
+ )
144
+ mlm_probability: float = field(
145
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
146
+ )
147
+ mean_noise_span_length: float = field(
148
+ default=3.0,
149
+ metadata={"help": "Mean span length of masked tokens"},
150
+ )
151
+
152
+ def __post_init__(self):
153
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
154
+ raise ValueError("Need either a dataset name or a training/validation file.")
155
+ else:
156
+ if self.train_file is not None:
157
+ extension = self.train_file.split(".")[-1]
158
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
159
+ if self.validation_file is not None:
160
+ extension = self.validation_file.split(".")[-1]
161
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
162
+
163
+
164
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
165
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
166
+
167
+ Training parameters to avoid padding with random_spans_noise_mask.
168
+ When training a model with random_spans_noise_mask, we would like to set the other
169
+ training hyperparmeters in a way that avoids padding.
170
+ This function helps us compute these hyperparameters.
171
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
172
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
173
+ This function tells us the required number of tokens in the raw example (for split_tokens())
174
+ as well as the length of the encoded targets. Note that this function assumes
175
+ the inputs and targets will have EOS appended and includes that in the reported length.
176
+
177
+ Args:
178
+ inputs_length: an integer - desired length of the tokenized inputs sequence
179
+ noise_density: a float
180
+ mean_noise_span_length: a float
181
+ Returns:
182
+ tokens_length: length of original text in tokens
183
+ targets_length: an integer - length in tokens of encoded targets sequence
184
+ """
185
+
186
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
187
+ num_noise_tokens = int(round(tokens_length * noise_density))
188
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
189
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
190
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
191
+ # and one EOS token.
192
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
193
+ _output_length = num_noise_tokens + num_noise_spans + 1
194
+ return _input_length, _output_length
195
+
196
+ tokens_length = inputs_length
197
+
198
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
199
+ tokens_length += 1
200
+
201
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
202
+
203
+ # minor hack to get the targets length to be equal to inputs length
204
+ # which is more likely to have been set to a nice round number.
205
+ if noise_density == 0.5 and targets_length > inputs_length:
206
+ tokens_length -= 1
207
+ targets_length -= 1
208
+ return tokens_length, targets_length
209
+
210
+
211
+ @flax.struct.dataclass
212
+ class FlaxDataCollatorForT5MLM:
213
+ """
214
+ Data collator used for T5 span-masked language modeling.
215
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
216
+ For more information on how T5 span-masked language modeling works, one can take a look
217
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
218
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
219
+
220
+ Args:
221
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
222
+ The tokenizer used for encoding the data.
223
+ noise_density (:obj:`float`):
224
+ The probability with which to (randomly) mask tokens in the input.
225
+ mean_noise_span_length (:obj:`float`):
226
+ The average span length of the masked tokens.
227
+ input_length (:obj:`int`):
228
+ The expected input length after masking.
229
+ target_length (:obj:`int`):
230
+ The expected target length after masking.
231
+ pad_token_id: (:obj:`int`):
232
+ The pad token id of the model
233
+ decoder_start_token_id: (:obj:`int):
234
+ The decoder start token id of the model
235
+ """
236
+
237
+ tokenizer: PreTrainedTokenizerBase
238
+ noise_density: float
239
+ mean_noise_span_length: float
240
+ input_length: int
241
+ target_length: int
242
+ pad_token_id: int
243
+ decoder_start_token_id: int
244
+
245
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
246
+
247
+ # convert list to dict and tensorize input
248
+ batch = BatchEncoding(
249
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
250
+ )
251
+
252
+ input_ids = batch["input_ids"]
253
+ batch_size, expandend_input_length = input_ids.shape
254
+
255
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
256
+ labels_mask = ~mask_indices
257
+
258
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
259
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
260
+
261
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
262
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
263
+
264
+ if batch["input_ids"].shape[-1] != self.input_length:
265
+ raise ValueError(
266
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
267
+ )
268
+
269
+ if batch["labels"].shape[-1] != self.target_length:
270
+ raise ValueError(
271
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
272
+ )
273
+
274
+ # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
275
+ batch["decoder_input_ids"] = shift_tokens_right(
276
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
277
+ )
278
+
279
+ return batch
280
+
281
+ def create_sentinel_ids(self, mask_indices):
282
+ """
283
+ Sentinel ids creation given the indices that should be masked.
284
+ The start indices of each mask are replaced by the sentinel ids in increasing
285
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
286
+ """
287
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
288
+ start_indices[:, 0] = mask_indices[:, 0]
289
+
290
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
291
+ sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
292
+ sentinel_ids -= mask_indices - start_indices
293
+
294
+ return sentinel_ids
295
+
296
+ def filter_input_ids(self, input_ids, sentinel_ids):
297
+ """
298
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
299
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
300
+ """
301
+ batch_size = input_ids.shape[0]
302
+
303
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
304
+ input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
305
+ input_ids = np.concatenate(
306
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
307
+ )
308
+ return input_ids
309
+
310
+ def random_spans_noise_mask(self, length):
311
+
312
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
313
+
314
+ Noise mask consisting of random spans of noise tokens.
315
+ The number of noise tokens and the number of noise spans and non-noise spans
316
+ are determined deterministically as follows:
317
+ num_noise_tokens = round(length * noise_density)
318
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
319
+ Spans alternate between non-noise and noise, beginning with non-noise.
320
+ Subject to the above restrictions, all masks are equally likely.
321
+
322
+ Args:
323
+ length: an int32 scalar (length of the incoming token sequence)
324
+ noise_density: a float - approximate density of output mask
325
+ mean_noise_span_length: a number
326
+
327
+ Returns:
328
+ a boolean tensor with shape [length]
329
+ """
330
+
331
+ orig_length = length
332
+
333
+ num_noise_tokens = int(np.round(length * self.noise_density))
334
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
335
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
336
+ num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
337
+
338
+ # avoid degeneracy by ensuring positive number of noise spans
339
+ num_noise_spans = max(num_noise_spans, 1)
340
+ num_nonnoise_tokens = length - num_noise_tokens
341
+
342
+ # pick the lengths of the noise spans and the non-noise spans
343
+ def _random_segmentation(num_items, num_segments):
344
+ """Partition a sequence of items randomly into non-empty segments.
345
+ Args:
346
+ num_items: an integer scalar > 0
347
+ num_segments: an integer scalar in [1, num_items]
348
+ Returns:
349
+ a Tensor with shape [num_segments] containing positive integers that add
350
+ up to num_items
351
+ """
352
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
353
+ np.random.shuffle(mask_indices)
354
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
355
+ segment_id = np.cumsum(first_in_segment)
356
+ # count length of sub segments assuming that list is sorted
357
+ _, segment_length = np.unique(segment_id, return_counts=True)
358
+ return segment_length
359
+
360
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
361
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
362
+
363
+ interleaved_span_lengths = np.reshape(
364
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
365
+ )
366
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
367
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
368
+ span_start_indicator[span_starts] = True
369
+ span_num = np.cumsum(span_start_indicator)
370
+ is_noise = np.equal(span_num % 2, 1)
371
+
372
+ return is_noise[:orig_length]
373
+
374
+
375
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
376
+ num_samples = len(samples_idx)
377
+ samples_to_remove = num_samples % batch_size
378
+
379
+ if samples_to_remove != 0:
380
+ samples_idx = samples_idx[:-samples_to_remove]
381
+ sections_split = num_samples // batch_size
382
+ batch_idx = np.split(samples_idx, sections_split)
383
+ return batch_idx
384
+
385
+
386
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
387
+ summary_writer.scalar("train_time", train_time, step)
388
+
389
+ train_metrics = get_metrics(train_metrics)
390
+ for key, vals in train_metrics.items():
391
+ tag = f"train_{key}"
392
+ for i, val in enumerate(vals):
393
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
394
+
395
+
396
+ def write_eval_metric(summary_writer, eval_metrics, step):
397
+ for metric_name, value in eval_metrics.items():
398
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
399
+
400
+
401
+ if __name__ == "__main__":
402
+ # See all possible arguments in src/transformers/training_args.py
403
+ # or by passing the --help flag to this script.
404
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
405
+
406
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
407
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
408
+ # If we pass only one argument to the script and it's the path to a json file,
409
+ # let's parse it to get our arguments.
410
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
411
+ else:
412
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
413
+
414
+ if (
415
+ os.path.exists(training_args.output_dir)
416
+ and os.listdir(training_args.output_dir)
417
+ and training_args.do_train
418
+ and not training_args.overwrite_output_dir
419
+ ):
420
+ raise ValueError(
421
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
422
+ "Use --overwrite_output_dir to overcome."
423
+ )
424
+
425
+ # Setup logging
426
+ logging.basicConfig(
427
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
428
+ level="NOTSET",
429
+ datefmt="[%X]",
430
+ )
431
+
432
+ # Log on each process the small summary:
433
+ logger = logging.getLogger(__name__)
434
+
435
+ # Set the verbosity to info of the Transformers logger (on main process only):
436
+ logger.info(f"Training/evaluation parameters {training_args}")
437
+
438
+ # Set seed before initializing model.
439
+ set_seed(training_args.seed)
440
+
441
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
442
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
443
+ # (the dataset will be downloaded automatically from the datasets Hub).
444
+ #
445
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
446
+ # 'text' is found. You can easily tweak this behavior (see below).
447
+ if data_args.dataset_name is not None:
448
+ # Downloading and loading a dataset from the hub.
449
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
450
+
451
+ if "validation" not in datasets.keys():
452
+ datasets["validation"] = load_dataset(
453
+ data_args.dataset_name,
454
+ data_args.dataset_config_name,
455
+ split=f"train[:{data_args.validation_split_percentage}%]",
456
+ cache_dir=model_args.cache_dir,
457
+ )
458
+ datasets["train"] = load_dataset(
459
+ data_args.dataset_name,
460
+ data_args.dataset_config_name,
461
+ split=f"train[{data_args.validation_split_percentage}%:]",
462
+ cache_dir=model_args.cache_dir,
463
+ )
464
+ else:
465
+ data_files = {}
466
+ if data_args.train_file is not None:
467
+ data_files["train"] = data_args.train_file
468
+ if data_args.validation_file is not None:
469
+ data_files["validation"] = data_args.validation_file
470
+ extension = data_args.train_file.split(".")[-1]
471
+ if extension == "txt":
472
+ extension = "text"
473
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
474
+
475
+ if "validation" not in datasets.keys():
476
+ datasets["validation"] = load_dataset(
477
+ extension,
478
+ data_files=data_files,
479
+ split=f"train[:{data_args.validation_split_percentage}%]",
480
+ cache_dir=model_args.cache_dir,
481
+ )
482
+ datasets["train"] = load_dataset(
483
+ extension,
484
+ data_files=data_files,
485
+ split=f"train[{data_args.validation_split_percentage}%:]",
486
+ cache_dir=model_args.cache_dir,
487
+ )
488
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
489
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
490
+
491
+ # Load pretrained model and tokenizer
492
+
493
+ if model_args.tokenizer_name:
494
+ tokenizer = AutoTokenizer.from_pretrained(
495
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
496
+ )
497
+ elif model_args.model_name_or_path:
498
+ tokenizer = AutoTokenizer.from_pretrained(
499
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
500
+ )
501
+ else:
502
+ raise ValueError(
503
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
504
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
505
+ )
506
+
507
+ if model_args.config_name:
508
+ config = T5Config.from_pretrained(
509
+ model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
510
+ )
511
+ elif model_args.model_name_or_path:
512
+ config = T5Config.from_pretrained(
513
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
514
+ )
515
+ else:
516
+ config = CONFIG_MAPPING[model_args.model_type]()
517
+ logger.warning("You are instantiating a new config instance from scratch.")
518
+
519
+ # Preprocessing the datasets.
520
+ # First we tokenize all the texts.
521
+ if training_args.do_train:
522
+ column_names = datasets["train"].column_names
523
+ else:
524
+ column_names = datasets["validation"].column_names
525
+ text_column_name = "text" if "text" in column_names else column_names[0]
526
+
527
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
528
+
529
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
530
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
531
+ def tokenize_function(examples):
532
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
533
+
534
+ tokenized_datasets = datasets.map(
535
+ tokenize_function,
536
+ batched=True,
537
+ num_proc=data_args.preprocessing_num_workers,
538
+ remove_columns=column_names,
539
+ load_from_cache_file=not data_args.overwrite_cache,
540
+ )
541
+
542
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
543
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
544
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
545
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
546
+ inputs_length=max_seq_length,
547
+ noise_density=data_args.mlm_probability,
548
+ mean_noise_span_length=data_args.mean_noise_span_length,
549
+ )
550
+
551
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
552
+ def group_texts(examples):
553
+ # Concatenate all texts.
554
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
555
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
556
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
557
+ # customize this part to your needs.
558
+ if total_length >= expanded_inputs_length:
559
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
560
+ # Split by chunks of max_len.
561
+ result = {
562
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
563
+ for k, t in concatenated_examples.items()
564
+ }
565
+ return result
566
+
567
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
568
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
569
+ # might be slower to preprocess.
570
+ #
571
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
572
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
573
+ tokenized_datasets = tokenized_datasets.map(
574
+ group_texts,
575
+ batched=True,
576
+ num_proc=data_args.preprocessing_num_workers,
577
+ load_from_cache_file=not data_args.overwrite_cache,
578
+ )
579
+
580
+ # Enable tensorboard only on the master node
581
+ has_tensorboard = is_tensorboard_available()
582
+ if has_tensorboard and jax.process_index() == 0:
583
+ try:
584
+ from flax.metrics.tensorboard import SummaryWriter
585
+
586
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
587
+ except ImportError as ie:
588
+ has_tensorboard = False
589
+ logger.warning(
590
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
591
+ )
592
+ else:
593
+ logger.warning(
594
+ "Unable to display metrics through TensorBoard because the package is not installed: "
595
+ "Please run pip install tensorboard to enable."
596
+ )
597
+
598
+ # Initialize our training
599
+ rng = jax.random.PRNGKey(training_args.seed)
600
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
601
+
602
+ if model_args.model_name_or_path:
603
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
604
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
605
+ )
606
+ else:
607
+ model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
608
+
609
+ # Data collator
610
+ # This one will take care of randomly masking the tokens.
611
+ data_collator = FlaxDataCollatorForT5MLM(
612
+ tokenizer=tokenizer,
613
+ noise_density=data_args.mlm_probability,
614
+ mean_noise_span_length=data_args.mean_noise_span_length,
615
+ input_length=max_seq_length,
616
+ target_length=targets_length,
617
+ pad_token_id=model.config.pad_token_id,
618
+ decoder_start_token_id=model.config.decoder_start_token_id,
619
+ )
620
+
621
+ # Store some constant
622
+ num_epochs = int(training_args.num_train_epochs)
623
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
624
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
625
+
626
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
627
+
628
+ # Create learning rate schedule
629
+ warmup_fn = optax.linear_schedule(
630
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
631
+ )
632
+ decay_fn = optax.linear_schedule(
633
+ init_value=training_args.learning_rate,
634
+ end_value=0,
635
+ transition_steps=num_train_steps - training_args.warmup_steps,
636
+ )
637
+ linear_decay_lr_schedule_fn = optax.join_schedules(
638
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
639
+ )
640
+
641
+ # We use Optax's "masking" functionality to not apply weight decay
642
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
643
+ # mask boolean with the same structure as the parameters.
644
+ # The mask is True for parameters that should be decayed.
645
+ def decay_mask_fn(params):
646
+ flat_params = traverse_util.flatten_dict(params)
647
+ flat_mask = {
648
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
649
+ for path in flat_params
650
+ }
651
+ return traverse_util.unflatten_dict(flat_mask)
652
+
653
+ # create adam optimizer
654
+ if training_args.adafactor:
655
+ # We use the default parameters here to initialize adafactor,
656
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
657
+ optimizer = optax.adafactor(
658
+ learning_rate=linear_decay_lr_schedule_fn,
659
+ )
660
+ else:
661
+ optimizer = optax.adamw(
662
+ learning_rate=linear_decay_lr_schedule_fn,
663
+ b1=training_args.adam_beta1,
664
+ b2=training_args.adam_beta2,
665
+ weight_decay=training_args.weight_decay,
666
+ mask=decay_mask_fn,
667
+ )
668
+
669
+ # Setup train state
670
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
671
+
672
+ # Define gradient update step fn
673
+ def train_step(state, batch, dropout_rng):
674
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
675
+
676
+ def loss_fn(params):
677
+ labels = batch.pop("labels")
678
+
679
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
680
+
681
+ # compute loss
682
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
683
+
684
+ return loss
685
+
686
+ grad_fn = jax.value_and_grad(loss_fn)
687
+ loss, grad = grad_fn(state.params)
688
+ grad = jax.lax.pmean(grad, "batch")
689
+ new_state = state.apply_gradients(grads=grad)
690
+
691
+ metrics = jax.lax.pmean(
692
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
693
+ )
694
+
695
+ return new_state, metrics, new_dropout_rng
696
+
697
+ # Create parallel version of the train step
698
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
699
+
700
+ # Define eval fn
701
+ def eval_step(params, batch):
702
+ labels = batch.pop("labels")
703
+
704
+ logits = model(**batch, params=params, train=False)[0]
705
+
706
+ # compute loss
707
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
708
+
709
+ # compute accuracy
710
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
711
+
712
+ # summarize metrics
713
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
714
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
715
+
716
+ return metrics
717
+
718
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
719
+
720
+ # Replicate the train state on each device
721
+ state = jax_utils.replicate(state)
722
+
723
+ train_time = 0
724
+ epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
725
+ for epoch in epochs:
726
+ # ======================== Training ================================
727
+ train_start = time.time()
728
+ train_metrics = []
729
+
730
+ # Create sampling rng
731
+ rng, input_rng = jax.random.split(rng)
732
+
733
+ # Generate an epoch by shuffling sampling indices from the train dataset
734
+ num_train_samples = len(tokenized_datasets["train"])
735
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
736
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
737
+
738
+ # Gather the indexes for creating the batch and do a training step
739
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
740
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
741
+ model_inputs = data_collator(samples)
742
+
743
+ # Model forward
744
+ model_inputs = shard(model_inputs.data)
745
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
746
+ train_metrics.append(train_metric)
747
+
748
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
749
+
750
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
751
+ # Save metrics
752
+ train_metric = jax_utils.unreplicate(train_metric)
753
+ train_time += time.time() - train_start
754
+ if has_tensorboard and jax.process_index() == 0:
755
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
756
+
757
+ epochs.write(
758
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
759
+ )
760
+
761
+ train_metrics = []
762
+
763
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
764
+ # ======================== Evaluating ==============================
765
+ num_eval_samples = len(tokenized_datasets["validation"])
766
+ eval_samples_idx = jnp.arange(num_eval_samples)
767
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
768
+
769
+ eval_metrics = []
770
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
771
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
772
+ model_inputs = data_collator(samples)
773
+
774
+ # Model forward
775
+ model_inputs = shard(model_inputs.data)
776
+ metrics = p_eval_step(state.params, model_inputs)
777
+ eval_metrics.append(metrics)
778
+
779
+ # get eval metrics
780
+ eval_metrics = get_metrics(eval_metrics)
781
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
782
+
783
+ # Update progress bar
784
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
785
+
786
+ # Save metrics
787
+ if has_tensorboard and jax.process_index() == 0:
788
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
789
+
790
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
791
+ # save checkpoint after each epoch and push checkpoint to the hub
792
+ if jax.process_index() == 0:
793
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
794
+ model.save_pretrained(
795
+ training_args.output_dir,
796
+ params=params,
797
+ push_to_hub=training_args.push_to_hub,
798
+ commit_message=f"Saving weights and logs of step {cur_step}",
799
+ )
t5_tokenizer_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ from typing import Iterator, List, Union
4
+
5
+ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
6
+ from tokenizers.implementations.base_tokenizer import BaseTokenizer
7
+ from tokenizers.models import Unigram
8
+ from tokenizers.processors import TemplateProcessing
9
+
10
+
11
+ class SentencePieceUnigramTokenizer(BaseTokenizer):
12
+ """
13
+ This class is a copy of `DeDLOC's tokenizer implementation <https://github.com/yandex-research/DeDLOC/blob/main/sahajbert/tokenizer/tokenizer_model.py>`__ .
14
+
15
+ Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization
16
+ Represents the Unigram algorithm, with the pretokenization used by SentencePiece
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ replacement: str = "▁",
22
+ add_prefix_space: bool = True,
23
+ unk_token: Union[str, AddedToken] = "<unk>",
24
+ eos_token: Union[str, AddedToken] = "</s>",
25
+ pad_token: Union[str, AddedToken] = "<pad>",
26
+ ):
27
+ self.special_tokens = {
28
+ "pad": {"id": 0, "token": pad_token},
29
+ "eos": {"id": 1, "token": eos_token},
30
+ "unk": {"id": 2, "token": unk_token},
31
+ }
32
+
33
+ self.special_tokens_list = [None] * len(self.special_tokens)
34
+ for token_dict in self.special_tokens.values():
35
+ self.special_tokens_list[token_dict["id"]] = token_dict["token"]
36
+
37
+ tokenizer = Tokenizer(Unigram())
38
+
39
+ tokenizer.normalizer = normalizers.Sequence(
40
+ [
41
+ normalizers.Nmt(),
42
+ normalizers.NFKC(),
43
+ normalizers.Replace(Regex(" {2,}"), " "),
44
+ normalizers.Lowercase(),
45
+ ]
46
+ )
47
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
48
+ [
49
+ pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
50
+ pre_tokenizers.Digits(individual_digits=True),
51
+ pre_tokenizers.Punctuation(),
52
+ ]
53
+ )
54
+ tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
55
+
56
+ tokenizer.post_processor = TemplateProcessing(
57
+ single=f"$A {self.special_tokens['eos']['token']}",
58
+ special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])],
59
+ )
60
+
61
+ parameters = {
62
+ "model": "SentencePieceUnigram",
63
+ "replacement": replacement,
64
+ "add_prefix_space": add_prefix_space,
65
+ }
66
+
67
+ super().__init__(tokenizer, parameters)
68
+
69
+ def train(
70
+ self,
71
+ files: Union[str, List[str]],
72
+ vocab_size: int = 8000,
73
+ show_progress: bool = True,
74
+ ):
75
+ """Train the model using the given files"""
76
+
77
+ trainer = trainers.UnigramTrainer(
78
+ vocab_size=vocab_size,
79
+ special_tokens=self.special_tokens_list,
80
+ show_progress=show_progress,
81
+ )
82
+
83
+ if isinstance(files, str):
84
+ files = [files]
85
+ self._tokenizer.train(files, trainer=trainer)
86
+
87
+ self.add_unk_id()
88
+
89
+ def train_from_iterator(
90
+ self,
91
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
92
+ vocab_size: int = 8000,
93
+ show_progress: bool = True,
94
+ ):
95
+ """Train the model using the given iterator"""
96
+
97
+ trainer = trainers.UnigramTrainer(
98
+ vocab_size=vocab_size,
99
+ special_tokens=self.special_tokens_list,
100
+ show_progress=show_progress,
101
+ )
102
+
103
+ self._tokenizer.train_from_iterator(iterator, trainer=trainer)
104
+
105
+ self.add_unk_id()
106
+
107
+ def add_unk_id(self):
108
+ tokenizer_json = json.loads(self._tokenizer.to_str())
109
+
110
+ tokenizer_json["model"]["unk_id"] = self.special_tokens["unk"]["id"]
111
+
112
+ self._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
train_tokenizer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+
3
+ from t5_tokenizer_model import SentencePieceUnigramTokenizer
4
+
5
+
6
+ vocab_size = 32_000
7
+ input_sentence_size = None
8
+
9
+ # Initialize a dataset
10
+ dataset = datasets.load_dataset("imthanhlv/binhvq_dedup", split="train")
11
+
12
+ tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
13
+
14
+
15
+ # Build an iterator over this dataset
16
+ def batch_iterator(input_sentence_size=None):
17
+ if input_sentence_size is None:
18
+ input_sentence_size = len(dataset)
19
+ batch_length = 100
20
+ for i in range(0, input_sentence_size, batch_length):
21
+ yield dataset[i: i + batch_length]["text"]
22
+
23
+
24
+ # Train tokenizer
25
+ tokenizer.train_from_iterator(
26
+ iterator=batch_iterator(input_sentence_size=input_sentence_size),
27
+ vocab_size=vocab_size,
28
+ show_progress=True,
29
+ )
30
+
31
+ # Save files to disk
32
+ tokenizer.save("./tokenizer.json")