yhavinga commited on
Commit
fdc4101
1 Parent(s): b27c9c1

Add config, tokenizer and training script

Browse files
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5WithLMHeadModel"
4
+ ],
5
+ "d_ff": 3072,
6
+ "d_kv": 64,
7
+ "d_model": 768,
8
+ "decoder_start_token_id": 0,
9
+ "dropout_rate": 0.1,
10
+ "eos_token_id": 1,
11
+ "feed_forward_proj": "relu",
12
+ "gradient_checkpointing": false,
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "n_positions": 512,
18
+ "num_decoder_layers": 12,
19
+ "num_heads": 12,
20
+ "num_layers": 12,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_num_buckets": 32,
24
+ "task_specific_params": {
25
+ "summarization": {
26
+ "early_stopping": true,
27
+ "length_penalty": 2.0,
28
+ "max_length": 200,
29
+ "min_length": 30,
30
+ "no_repeat_ngram_size": 3,
31
+ "num_beams": 4,
32
+ "prefix": "summarize: "
33
+ },
34
+ "translation_en_to_de": {
35
+ "early_stopping": true,
36
+ "max_length": 300,
37
+ "num_beams": 4,
38
+ "prefix": "translate English to German: "
39
+ },
40
+ "translation_en_to_fr": {
41
+ "early_stopping": true,
42
+ "max_length": 300,
43
+ "num_beams": 4,
44
+ "prefix": "translate English to French: "
45
+ },
46
+ "translation_en_to_ro": {
47
+ "early_stopping": true,
48
+ "max_length": 300,
49
+ "num_beams": 4,
50
+ "prefix": "translate English to Romanian: "
51
+ }
52
+ },
53
+ "transformers_version": "4.9.0.dev0",
54
+ "use_cache": true,
55
+ "vocab_size": 32128
56
+ }
create_config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ model_dir = "./" # ${MODEL_DIR}
4
+
5
+ config = T5Config.from_pretrained("t5-base")
6
+ config.save_pretrained(model_dir)
run_t5.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL="t5-base-dutch"
2
+
3
+ MODEL_DIR="${HOME}/${MODEL}"
4
+
5
+ mkdir -p "${MODEL_DIR}/runs"
6
+
7
+ # T5 paper lr 0.01 with batch size 128
8
+ # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
9
+ # Warmup steps is set to 6% of the training steps
10
+
11
+ ./run_t5_mlm_flax_custom_dataset.py \
12
+ --output_dir="${MODEL_DIR}" \
13
+ --model_type="t5" \
14
+ --config_name="flax-community/${MODEL}" \
15
+ --tokenizer_name="${MODEL_DIR}" \
16
+ --preprocessing_num_workers="96" \
17
+ --do_train --do_eval \
18
+ --adafactor \
19
+ --dtype="bfloat16" \
20
+ --max_seq_length="512" \
21
+ --gradient_accumulation_steps="4" \
22
+ --per_device_train_batch_size="32" \
23
+ --per_device_eval_batch_size="32" \
24
+ --learning_rate="5e-3" \
25
+ --overwrite_output_dir \
26
+ --num_train_epochs="1" \
27
+ --logging_steps="15" \
28
+ --save_steps="300" \
29
+ --eval_steps="1000000" \
30
+ --push_to_hub
31
+
32
+ #git add pytorch_model.bin
33
+ #git commit -m "Update pytorch model after training"
34
+ #git push origin main
run_t5_mlm_flax.py ADDED
@@ -0,0 +1 @@
 
1
+ /home/yeb/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py
run_t5_mlm_flax_custom_dataset.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
+ https://huggingface.co/models?filter=t5
21
+ """
22
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ import json
28
+ import shutil
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from flax.serialization import to_bytes, from_bytes
45
+ from transformers import (
46
+ CONFIG_MAPPING,
47
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
48
+ BatchEncoding,
49
+ FlaxT5ForConditionalGeneration,
50
+ T5ForConditionalGeneration,
51
+ HfArgumentParser,
52
+ PreTrainedTokenizerBase,
53
+ T5Config,
54
+ T5TokenizerFast,
55
+ TrainingArguments,
56
+ is_tensorboard_available,
57
+ set_seed,
58
+ )
59
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+ data_files = []
68
+
69
+
70
+ @dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
74
+ """
75
+
76
+ model_name_or_path: Optional[str] = field(
77
+ default=None,
78
+ metadata={
79
+ "help": "The model checkpoint for weights initialization."
80
+ "Don't set if you want to train a model from scratch."
81
+ },
82
+ )
83
+ model_type: Optional[str] = field(
84
+ default=None,
85
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
86
+ )
87
+ config_name: Optional[str] = field(
88
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
89
+ )
90
+ tokenizer_name: Optional[str] = field(
91
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
92
+ )
93
+ cache_dir: Optional[str] = field(
94
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
95
+ )
96
+ use_fast_tokenizer: bool = field(
97
+ default=True,
98
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
99
+ )
100
+ dtype: Optional[str] = field(
101
+ default="float32",
102
+ metadata={
103
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
104
+ },
105
+ )
106
+
107
+
108
+ @dataclass
109
+ class DataTrainingArguments:
110
+ """
111
+ Arguments pertaining to what data we are going to input our model for training and eval.
112
+ """
113
+
114
+ dataset_name: Optional[str] = field(
115
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
116
+ )
117
+ dataset_config_name: Optional[str] = field(
118
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
119
+ )
120
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
121
+ validation_file: Optional[str] = field(
122
+ default=None,
123
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
124
+ )
125
+ train_ref_file: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
128
+ )
129
+ validation_ref_file: Optional[str] = field(
130
+ default=None,
131
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
132
+ )
133
+ overwrite_cache: bool = field(
134
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
135
+ )
136
+ validation_split_percentage: Optional[int] = field(
137
+ default=5,
138
+ metadata={
139
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
140
+ },
141
+ )
142
+ max_seq_length: Optional[int] = field(
143
+ default=None,
144
+ metadata={
145
+ "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
146
+ },
147
+ )
148
+ preprocessing_num_workers: Optional[int] = field(
149
+ default=None,
150
+ metadata={"help": "The number of processes to use for the preprocessing."},
151
+ )
152
+ mlm_probability: float = field(
153
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
154
+ )
155
+ mean_noise_span_length: float = field(
156
+ default=3.0,
157
+ metadata={"help": "Mean span length of masked tokens"},
158
+ )
159
+
160
+ def __post_init__(self):
161
+ return
162
+ # if self.dataset_name is None and self.train_file is None and self.validation_file is None:
163
+ # raise ValueError("Need either a dataset name or a training/validation file.")
164
+ # else:
165
+ # if self.train_file is not None:
166
+ # extension = self.train_file.split(".")[-1]
167
+ # assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
168
+ # if self.validation_file is not None:
169
+ # extension = self.validation_file.split(".")[-1]
170
+ # assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
171
+
172
+
173
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
174
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
175
+
176
+ Training parameters to avoid padding with random_spans_noise_mask.
177
+ When training a model with random_spans_noise_mask, we would like to set the other
178
+ training hyperparmeters in a way that avoids padding.
179
+ This function helps us compute these hyperparameters.
180
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
181
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
182
+ This function tells us the required number of tokens in the raw example (for split_tokens())
183
+ as well as the length of the encoded targets. Note that this function assumes
184
+ the inputs and targets will have EOS appended and includes that in the reported length.
185
+
186
+ Args:
187
+ inputs_length: an integer - desired length of the tokenized inputs sequence
188
+ noise_density: a float
189
+ mean_noise_span_length: a float
190
+ Returns:
191
+ tokens_length: length of original text in tokens
192
+ targets_length: an integer - length in tokens of encoded targets sequence
193
+ """
194
+
195
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
196
+ num_noise_tokens = int(round(tokens_length * noise_density))
197
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
198
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
199
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
200
+ # and one EOS token.
201
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
202
+ _output_length = num_noise_tokens + num_noise_spans + 1
203
+ return _input_length, _output_length
204
+
205
+ tokens_length = inputs_length
206
+
207
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
208
+ tokens_length += 1
209
+
210
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
211
+
212
+ # minor hack to get the targets length to be equal to inputs length
213
+ # which is more likely to have been set to a nice round number.
214
+ if noise_density == 0.5 and targets_length > inputs_length:
215
+ tokens_length -= 1
216
+ targets_length -= 1
217
+ return tokens_length, targets_length
218
+
219
+
220
+ @flax.struct.dataclass
221
+ class FlaxDataCollatorForT5MLM:
222
+ """
223
+ Data collator used for T5 span-masked language modeling.
224
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
225
+ For more information on how T5 span-masked language modeling works, one can take a look
226
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
227
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
228
+
229
+ Args:
230
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
231
+ The tokenizer used for encoding the data.
232
+ noise_density (:obj:`float`):
233
+ The probability with which to (randomly) mask tokens in the input.
234
+ mean_noise_span_length (:obj:`float`):
235
+ The average span length of the masked tokens.
236
+ input_length (:obj:`int`):
237
+ The expected input length after masking.
238
+ target_length (:obj:`int`):
239
+ The expected target length after masking.
240
+ pad_token_id: (:obj:`int`):
241
+ The pad token id of the model
242
+ decoder_start_token_id: (:obj:`int):
243
+ The decoder start token id of the model
244
+ """
245
+
246
+ tokenizer: PreTrainedTokenizerBase
247
+ noise_density: float
248
+ mean_noise_span_length: float
249
+ input_length: int
250
+ target_length: int
251
+ pad_token_id: int
252
+ decoder_start_token_id: int
253
+
254
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
255
+
256
+ # convert list to dict and tensorize input
257
+ batch = BatchEncoding(
258
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
259
+ )
260
+
261
+ input_ids = batch["input_ids"]
262
+ batch_size, expandend_input_length = input_ids.shape
263
+
264
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
265
+ labels_mask = ~mask_indices
266
+
267
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
268
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
269
+
270
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
271
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
272
+
273
+ if batch["input_ids"].shape[-1] != self.input_length:
274
+ raise ValueError(
275
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
276
+ )
277
+
278
+ if batch["labels"].shape[-1] != self.target_length:
279
+ raise ValueError(
280
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
281
+ )
282
+
283
+ # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
284
+ batch["decoder_input_ids"] = shift_tokens_right(
285
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
286
+ )
287
+
288
+ return batch
289
+
290
+ def create_sentinel_ids(self, mask_indices):
291
+ """
292
+ Sentinel ids creation given the indices that should be masked.
293
+ The start indices of each mask are replaced by the sentinel ids in increasing
294
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
295
+ """
296
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
297
+ start_indices[:, 0] = mask_indices[:, 0]
298
+
299
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
300
+ sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
301
+ sentinel_ids -= mask_indices - start_indices
302
+
303
+ return sentinel_ids
304
+
305
+ def filter_input_ids(self, input_ids, sentinel_ids):
306
+ """
307
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
308
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
309
+ """
310
+ batch_size = input_ids.shape[0]
311
+
312
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
313
+ input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
314
+ input_ids = np.concatenate(
315
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
316
+ )
317
+ return input_ids
318
+
319
+ def random_spans_noise_mask(self, length):
320
+
321
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
322
+
323
+ Noise mask consisting of random spans of noise tokens.
324
+ The number of noise tokens and the number of noise spans and non-noise spans
325
+ are determined deterministically as follows:
326
+ num_noise_tokens = round(length * noise_density)
327
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
328
+ Spans alternate between non-noise and noise, beginning with non-noise.
329
+ Subject to the above restrictions, all masks are equally likely.
330
+
331
+ Args:
332
+ length: an int32 scalar (length of the incoming token sequence)
333
+ noise_density: a float - approximate density of output mask
334
+ mean_noise_span_length: a number
335
+
336
+ Returns:
337
+ a boolean tensor with shape [length]
338
+ """
339
+
340
+ orig_length = length
341
+
342
+ num_noise_tokens = int(np.round(length * self.noise_density))
343
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
344
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
345
+ num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
346
+
347
+ # avoid degeneracy by ensuring positive number of noise spans
348
+ num_noise_spans = max(num_noise_spans, 1)
349
+ num_nonnoise_tokens = length - num_noise_tokens
350
+
351
+ # pick the lengths of the noise spans and the non-noise spans
352
+ def _random_segmentation(num_items, num_segments):
353
+ """Partition a sequence of items randomly into non-empty segments.
354
+ Args:
355
+ num_items: an integer scalar > 0
356
+ num_segments: an integer scalar in [1, num_items]
357
+ Returns:
358
+ a Tensor with shape [num_segments] containing positive integers that add
359
+ up to num_items
360
+ """
361
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
362
+ np.random.shuffle(mask_indices)
363
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
364
+ segment_id = np.cumsum(first_in_segment)
365
+ segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
366
+ return segment_length
367
+
368
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
369
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
370
+
371
+ interleaved_span_lengths = np.reshape(
372
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
373
+ )
374
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
375
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
376
+ span_start_indicator[span_starts] = True
377
+ span_num = np.cumsum(span_start_indicator)
378
+ is_noise = np.equal(span_num % 2, 1)
379
+
380
+ return is_noise[:orig_length]
381
+
382
+
383
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
384
+ num_samples = len(samples_idx)
385
+ samples_to_remove = num_samples % batch_size
386
+
387
+ if samples_to_remove != 0:
388
+ samples_idx = samples_idx[:-samples_to_remove]
389
+ sections_split = num_samples // batch_size
390
+ batch_idx = np.split(samples_idx, sections_split)
391
+ return batch_idx
392
+
393
+
394
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
395
+ summary_writer.scalar("train_time", train_time, step)
396
+
397
+ train_metrics = get_metrics(train_metrics)
398
+ for key, vals in train_metrics.items():
399
+ tag = f"train_{key}"
400
+ for i, val in enumerate(vals):
401
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
402
+
403
+
404
+ def write_eval_metric(summary_writer, eval_metrics, step):
405
+ for metric_name, value in eval_metrics.items():
406
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
407
+
408
+ # utils
409
+ def mb_item(x):
410
+ return x.item() if hasattr(x, "item") else x
411
+
412
+
413
+ # checkpoint functions
414
+ def save_checkpoint(model, save_dir, state, with_opt: bool = True):
415
+ state = jax_utils.unreplicate(state)
416
+ logger.info(f"SAVING CHECKPOINT IN {save_dir}")
417
+ save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
418
+ model.save_pretrained(
419
+ save_dir,
420
+ params=state.params,
421
+ push_to_hub=False
422
+ )
423
+ if with_opt:
424
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
425
+ f.write(to_bytes(state.opt_state))
426
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
427
+ json.dump({"step": state.step.item()}, f)
428
+ logger.info(f"Updating model on the hub")
429
+ model.save_pretrained(
430
+ training_args.output_dir,
431
+ params=state.params,
432
+ push_to_hub=training_args.push_to_hub,
433
+ commit_message=f"Saving weights and logs of step {cur_step}",
434
+ )
435
+ logger.info("checkpoint saved")
436
+
437
+
438
+ def restore_checkpoint(save_dir, state):
439
+ logger.info(f"RESTORING CHECKPOINT FROM {save_dir}")
440
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
441
+ params = from_bytes(state.params, f.read())
442
+
443
+ with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
444
+ opt_state = from_bytes(state.opt_state, f.read())
445
+
446
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
447
+ training_state = json.load(f)
448
+ step = training_state["step"]
449
+
450
+ logger.info("checkpoint restored")
451
+ return state.replace(step=step, params=params, opt_state=opt_state), step
452
+
453
+
454
+ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
455
+ "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
456
+ # TODO: what to remove is decided using step number only, we might want to improve that
457
+ ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
458
+ # sort checkpoints by step
459
+ ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
460
+ ckpts_to_delete = ckpts_sorted[:-save_total_limit]
461
+ for ckpt in ckpts_to_delete:
462
+ logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
463
+ shutil.rmtree(ckpt)
464
+
465
+
466
+
467
+ if __name__ == "__main__":
468
+ # See all possible arguments in src/transformers/training_args.py
469
+ # or by passing the --help flag to this script.
470
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
471
+
472
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
473
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
474
+ # If we pass only one argument to the script and it's the path to a json file,
475
+ # let's parse it to get our arguments.
476
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
477
+ else:
478
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
479
+
480
+ if (
481
+ os.path.exists(training_args.output_dir)
482
+ and os.listdir(training_args.output_dir)
483
+ and training_args.do_train
484
+ and not training_args.overwrite_output_dir
485
+ ):
486
+ raise ValueError(
487
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
488
+ "Use --overwrite_output_dir to overcome."
489
+ )
490
+
491
+ # Setup logging
492
+ logging.basicConfig(
493
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
494
+ level="NOTSET",
495
+ datefmt="[%X]",
496
+ )
497
+
498
+ # Log on each process the small summary:
499
+ logger = logging.getLogger(__name__)
500
+
501
+ # Set the verbosity to info of the Transformers logger (on main process only):
502
+ logger.info(f"Training/evaluation parameters {training_args}")
503
+
504
+ # Set seed before initializing model.
505
+ set_seed(training_args.seed)
506
+
507
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
508
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
509
+ # (the dataset will be downloaded automatically from the datasets Hub).
510
+ #
511
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
512
+ # 'text' is found. You can easily tweak this behavior (see below).
513
+ if data_args.dataset_name is not None:
514
+ # Downloading and loading a dataset from the hub.
515
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
516
+
517
+ if "validation" not in datasets.keys():
518
+ datasets["validation"] = load_dataset(
519
+ data_args.dataset_name,
520
+ data_args.dataset_config_name,
521
+ split=f"train[:{data_args.validation_split_percentage}%]",
522
+ cache_dir=model_args.cache_dir,
523
+ )
524
+ datasets["train"] = load_dataset(
525
+ data_args.dataset_name,
526
+ data_args.dataset_config_name,
527
+ split=f"train[{data_args.validation_split_percentage}%:]",
528
+ cache_dir=model_args.cache_dir,
529
+ )
530
+ else:
531
+ data_dir = "/home/yeb"
532
+ # data_dir = "/home/yeb/Developer/data"
533
+
534
+ def train_val_files():
535
+ import glob
536
+ import random
537
+ SEED = 12345
538
+
539
+ def add_jsonlines_dir(path, filespec):
540
+ global data_files
541
+ data_files += glob.glob(f"{path}/{filespec}")
542
+ print(f"Number of files {len(data_files)} after adding {path}")
543
+
544
+ # add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
545
+ add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
546
+ add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
547
+ add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
548
+ random.Random(SEED).shuffle(data_files)
549
+
550
+ print(data_files)
551
+ total = len(data_files)
552
+ print(total)
553
+ perc = 0.05
554
+ val_size = int(perc * total)
555
+ train_size = total - val_size
556
+ train = data_files[:train_size]
557
+ val = data_files[train_size:]
558
+ print(f"Got {len(train)} training files and {perc*100} % {len(val)} validation files")
559
+
560
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
561
+
562
+ return train, val
563
+
564
+ train, val = train_val_files()
565
+
566
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
567
+
568
+ # data_files = {}
569
+ # if data_args.train_file is not None:
570
+ # data_files["train"] = data_args.train_file
571
+ # if data_args.validation_file is not None:
572
+ # data_files["validation"] = data_args.validation_file
573
+ # extension = data_args.train_file.split(".")[-1]
574
+ # if extension == "txt":
575
+ # extension = "text"
576
+ # datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
577
+
578
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
579
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
580
+
581
+ # Load pretrained model and tokenizer
582
+
583
+ if model_args.tokenizer_name:
584
+ tokenizer = T5TokenizerFast.from_pretrained(
585
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
586
+ )
587
+ elif model_args.model_name_or_path:
588
+ tokenizer = T5TokenizerFast.from_pretrained(
589
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
590
+ )
591
+ else:
592
+ raise ValueError(
593
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
594
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
595
+ )
596
+
597
+ if model_args.config_name:
598
+ config = T5Config.from_pretrained(
599
+ model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
600
+ )
601
+ elif model_args.model_name_or_path:
602
+ config = T5Config.from_pretrained(
603
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
604
+ )
605
+ else:
606
+ config = CONFIG_MAPPING[model_args.model_type]()
607
+ logger.warning("You are instantiating a new config instance from scratch.")
608
+
609
+ # Preprocessing the datasets.
610
+ # First we tokenize all the texts.
611
+ if training_args.do_train:
612
+ column_names = datasets["train"].column_names
613
+ else:
614
+ column_names = datasets["validation"].column_names
615
+ text_column_name = "text" if "text" in column_names else column_names[0]
616
+
617
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
618
+
619
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
620
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
621
+ def tokenize_function(examples):
622
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
623
+
624
+ logger.info(f"Start tokenization, remove_column_names = {column_names}")
625
+
626
+ tokenized_datasets = datasets.map(
627
+ tokenize_function,
628
+ batched=True,
629
+ num_proc=data_args.preprocessing_num_workers,
630
+ remove_columns=column_names,
631
+ load_from_cache_file=not data_args.overwrite_cache,
632
+ )
633
+
634
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
635
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
636
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
637
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
638
+ inputs_length=max_seq_length,
639
+ noise_density=data_args.mlm_probability,
640
+ mean_noise_span_length=data_args.mean_noise_span_length,
641
+ )
642
+
643
+ logger.info(f"Expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
644
+
645
+ logger.info(f"Start group_texts")
646
+
647
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
648
+ def group_texts(examples):
649
+ # Concatenate all texts.
650
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
651
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
652
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
653
+ # customize this part to your needs.
654
+ if total_length >= expanded_inputs_length:
655
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
656
+ # Split by chunks of max_len.
657
+ result = {
658
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
659
+ for k, t in concatenated_examples.items()
660
+ }
661
+ return result
662
+
663
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
664
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
665
+ # might be slower to preprocess.
666
+ #
667
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
668
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
669
+ tokenized_datasets = tokenized_datasets.map(
670
+ group_texts,
671
+ batched=True,
672
+ batch_size=200,
673
+ num_proc=data_args.preprocessing_num_workers,
674
+ load_from_cache_file=not data_args.overwrite_cache,
675
+ )
676
+
677
+ # Enable tensorboard only on the master node
678
+ has_tensorboard = is_tensorboard_available()
679
+ if has_tensorboard and jax.process_index() == 0:
680
+ try:
681
+ from flax.metrics.tensorboard import SummaryWriter
682
+
683
+ summary_writer = SummaryWriter(log_dir=Path(training_args.logging_dir))
684
+ except ImportError as ie:
685
+ has_tensorboard = False
686
+ logger.warning(
687
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
688
+ )
689
+ else:
690
+ logger.warning(
691
+ "Unable to display metrics through TensorBoard because the package is not installed: "
692
+ "Please run pip install tensorboard to enable."
693
+ )
694
+
695
+ # Initialize our training
696
+ rng = jax.random.PRNGKey(training_args.seed)
697
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
698
+
699
+ if model_args.model_name_or_path:
700
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
701
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
702
+ )
703
+ else:
704
+ model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
705
+
706
+ # Data collator
707
+ # This one will take care of randomly masking the tokens.
708
+ data_collator = FlaxDataCollatorForT5MLM(
709
+ tokenizer=tokenizer,
710
+ noise_density=data_args.mlm_probability,
711
+ mean_noise_span_length=data_args.mean_noise_span_length,
712
+ input_length=max_seq_length,
713
+ target_length=targets_length,
714
+ pad_token_id=model.config.pad_token_id,
715
+ decoder_start_token_id=model.config.decoder_start_token_id,
716
+ )
717
+
718
+ # Store some constant
719
+ num_epochs = int(training_args.num_train_epochs)
720
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
721
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
722
+
723
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
724
+
725
+ # Create learning rate schedule
726
+
727
+ # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
728
+ warmup_steps = int(0.06 * num_train_steps)
729
+ logging.info(f"Warmup steps set to 6% = {warmup_steps} of total train steps {num_train_steps}")
730
+
731
+ warmup_fn = optax.linear_schedule(
732
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
733
+ )
734
+ decay_fn = optax.linear_schedule(
735
+ init_value=training_args.learning_rate,
736
+ end_value=0,
737
+ transition_steps=num_train_steps - warmup_steps,
738
+ )
739
+ linear_decay_lr_schedule_fn = optax.join_schedules(
740
+ schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]
741
+ )
742
+
743
+ # We use Optax's "masking" functionality to not apply weight decay
744
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
745
+ # mask boolean with the same structure as the parameters.
746
+ # The mask is True for parameters that should be decayed.
747
+ def decay_mask_fn(params):
748
+ flat_params = traverse_util.flatten_dict(params)
749
+ flat_mask = {
750
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
751
+ for path in flat_params
752
+ }
753
+ return traverse_util.unflatten_dict(flat_mask)
754
+
755
+ # create adam optimizer
756
+ if training_args.adafactor:
757
+ # We use the default parameters here to initialize adafactor,
758
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
759
+ optimizer = optax.adafactor(
760
+ learning_rate=linear_decay_lr_schedule_fn,
761
+ )
762
+ else:
763
+ optimizer = optax.adamw(
764
+ learning_rate=linear_decay_lr_schedule_fn,
765
+ b1=training_args.adam_beta1,
766
+ b2=training_args.adam_beta2,
767
+ weight_decay=training_args.weight_decay,
768
+ mask=decay_mask_fn,
769
+ )
770
+
771
+ if training_args.gradient_accumulation_steps > 1:
772
+ optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
773
+ grad_accum_steps = training_args.gradient_accumulation_steps
774
+
775
+ # Setup train state
776
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
777
+
778
+ # Define gradient update step fn
779
+ def train_step(state, batch, dropout_rng):
780
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
781
+
782
+ def loss_fn(params):
783
+ labels = batch.pop("labels")
784
+
785
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
786
+
787
+ # compute loss
788
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
789
+
790
+ return loss
791
+
792
+ grad_fn = jax.value_and_grad(loss_fn)
793
+ loss, grad = grad_fn(state.params)
794
+ grad = jax.lax.pmean(grad, "batch")
795
+ new_state = state.apply_gradients(grads=grad)
796
+
797
+ metrics = jax.lax.pmean(
798
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
799
+ )
800
+
801
+ return new_state, metrics, new_dropout_rng
802
+
803
+ # Create parallel version of the train step
804
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
805
+
806
+ # Define eval fn
807
+ def eval_step(params, batch):
808
+ labels = batch.pop("labels")
809
+
810
+ logits = model(**batch, params=params, train=False)[0]
811
+
812
+ # compute loss
813
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
814
+
815
+ # compute accuracy
816
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
817
+
818
+ # summarize metrics
819
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
820
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
821
+
822
+ return metrics
823
+
824
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
825
+
826
+ logger.info("Replicate the train state on each device")
827
+
828
+ # Replicate the train state on each device
829
+ state = jax_utils.replicate(state)
830
+
831
+ steps_per_epoch = len(datasets['train']) // train_batch_size
832
+ total_train_steps = steps_per_epoch * num_epochs
833
+
834
+ logger.info("***** Running training *****")
835
+ logger.info(f" Num examples = {len(datasets['train'])}")
836
+ logger.info(f" Num Epochs = {num_epochs}")
837
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
838
+ logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
839
+ logger.info(f" Total optimization steps = {total_train_steps}")
840
+
841
+ train_time = 0
842
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
843
+ for epoch in epochs:
844
+ # ======================== Training ================================
845
+ train_start = time.time()
846
+ train_metrics = []
847
+
848
+ # Create sampling rng
849
+ rng, input_rng = jax.random.split(rng)
850
+
851
+ # Generate an epoch by shuffling sampling indices from the train dataset
852
+ num_train_samples = len(tokenized_datasets["train"])
853
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
854
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
855
+
856
+ # Gather the indexes for creating the batch and do a training step
857
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
858
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
859
+ model_inputs = data_collator(samples)
860
+
861
+ # Model forward
862
+ model_inputs = shard(model_inputs.data)
863
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
864
+ train_metrics.append(train_metric)
865
+
866
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
867
+
868
+ if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
869
+ # Save metrics
870
+ train_metric = jax_utils.unreplicate(train_metric)
871
+ train_time += time.time() - train_start
872
+ if has_tensorboard and jax.process_index() == 0:
873
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
874
+
875
+ epochs.write(
876
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
877
+ )
878
+
879
+ train_metrics = []
880
+
881
+ if cur_step % training_args.eval_steps * grad_accum_steps == 0 and cur_step > 0:
882
+ # ======================== Evaluating ==============================
883
+ num_eval_samples = len(tokenized_datasets["validation"])
884
+ eval_samples_idx = jnp.arange(num_eval_samples)
885
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
886
+
887
+ eval_metrics = []
888
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
889
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
890
+ model_inputs = data_collator(samples)
891
+
892
+ # Model forward
893
+ model_inputs = shard(model_inputs.data)
894
+ metrics = p_eval_step(state.params, model_inputs)
895
+ eval_metrics.append(metrics)
896
+
897
+ # get eval metrics
898
+ eval_metrics = get_metrics(eval_metrics)
899
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
900
+
901
+ # Update progress bar
902
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
903
+
904
+ # Save metrics
905
+ if has_tensorboard and jax.process_index() == 0:
906
+ cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
907
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
908
+
909
+ if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
910
+ logger.info(f"We should save the model here after {cur_step} steps")
911
+ # save checkpoint after each epoch and push checkpoint to the hub
912
+ if jax.process_index() == 0:
913
+ save_checkpoint(model, training_args.output_dir, state)
914
+ if training_args.save_total_limit is not None:
915
+ rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
916
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
917
+ #
918
+ # logger.info(f"Saving model after {cur_step} steps")
919
+ # model.save_pretrained(
920
+ # training_args.output_dir,
921
+ # params=params,
922
+ # push_to_hub=training_args.push_to_hub,
923
+ # commit_message=f"Saving weights and logs of step {cur_step}",
924
+ # )
925
+
926
+
927
+ # Save model at end
928
+ if jax.process_index() == 0:
929
+ save_checkpoint(model, training_args.output_dir, state, with_opt=False)
930
+
931
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
932
+ # logger.info(f"Saving model at end")
933
+ # model.save_pretrained(
934
+ # training_args.output_dir,
935
+ # params=params,
936
+ # push_to_hub=training_args.push_to_hub,
937
+ # commit_message=f"Saving weights and logs at end of run (step {cur_step})",
938
+ # )
939
+ # pt_model = T5ForConditionalGeneration.from_pretrained(training_args.output_dir, from_flax=True)
940
+ # pt_model.save_pretrained(training_args.output_dir,
941
+ # params=params)
t5_tokenizer_model.py ADDED
@@ -0,0 +1 @@
 
1
+ /home/yeb/transformers/examples/flax/language-modeling/t5_tokenizer_model.py
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
train_tokenizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from t5_tokenizer_model import SentencePieceUnigramTokenizer
3
+
4
+ # from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
5
+
6
+ data_dir = "/home/yeb"
7
+ data_files = []
8
+
9
+
10
+ def train_val_files():
11
+ import glob
12
+ import random
13
+ SEED = 12345
14
+
15
+ def add_jsonlines_dir(path, filespec):
16
+ global data_files
17
+ data_files += glob.glob(f"{path}/{filespec}")
18
+ print(f"Number of files {len(data_files)} after adding {path}")
19
+
20
+ # add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
21
+ add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
22
+ add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
23
+ add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
24
+ random.Random(SEED).shuffle(data_files)
25
+
26
+ print(data_files)
27
+ total = len(data_files)
28
+ print(total)
29
+ perc = 0.01
30
+ val_size = int(perc * total)
31
+ train_size = total - val_size
32
+ train = data_files[:train_size]
33
+ val = data_files[train_size:]
34
+ print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
35
+
36
+ assert list(set(train) & set(val)) == [], "Train overlaps with test"
37
+
38
+ return train, val
39
+
40
+
41
+ train, val = train_val_files()
42
+
43
+ dataset = load_dataset('json', data_files={'train': train, 'validation': val}, split='train')
44
+
45
+ model_dir = "/t5-small-dutch" # ${MODEL_DIR}
46
+
47
+ vocab_size = 32000
48
+ input_sentence_size = None
49
+ tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
50
+
51
+
52
+ # Build an iterator over this dataset
53
+ def batch_iterator(input_sentence_size=None):
54
+ if input_sentence_size is None:
55
+ input_sentence_size = len(dataset)
56
+ batch_length = 100
57
+ for i in range(0, input_sentence_size, batch_length):
58
+ yield dataset[i: i + batch_length]["text"]
59
+
60
+ # Train tokenizer
61
+ tokenizer.train_from_iterator(
62
+ iterator=batch_iterator(input_sentence_size=input_sentence_size),
63
+ vocab_size=vocab_size,
64
+ show_progress=True,
65
+ )
66
+
67
+ # Save files to disk
68
+ tokenizer.save("./tokenizer.json")