pere commited on
Commit
da9a4c8
1 Parent(s): e986b9d

first commmit - translation

Browse files
Files changed (1) hide show
  1. run_translation_t5_flax.py +808 -0
run_translation_t5_flax.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from functools import partial
27
+ from pathlib import Path
28
+ from typing import Callable, Optional
29
+
30
+ import datasets
31
+ import nltk # Here to have a nice missing dependency error message early on
32
+ import numpy as np
33
+ from datasets import Dataset, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import jax
37
+ import jax.numpy as jnp
38
+ import optax
39
+ import transformers
40
+ from filelock import FileLock
41
+ from flax import jax_utils, traverse_util
42
+ from flax.jax_utils import unreplicate
43
+ from flax.training import train_state
44
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
+ from transformers import (
46
+ CONFIG_MAPPING,
47
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
48
+ AutoConfig,
49
+ AutoTokenizer,
50
+ FlaxAutoModelForSeq2SeqLM,
51
+ HfArgumentParser,
52
+ TrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.file_utils import is_offline_mode
56
+
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ try:
61
+ nltk.data.find("tokenizers/punkt")
62
+ except (LookupError, OSError):
63
+ if is_offline_mode():
64
+ raise LookupError(
65
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
66
+ )
67
+ with FileLock(".lock") as lock:
68
+ nltk.download("punkt", quiet=True)
69
+
70
+
71
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
72
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
73
+
74
+
75
+ @dataclass
76
+ class ModelArguments:
77
+ """
78
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
79
+ """
80
+
81
+ model_name_or_path: Optional[str] = field(
82
+ default=None,
83
+ metadata={
84
+ "help": "The model checkpoint for weights initialization."
85
+ "Don't set if you want to train a model from scratch."
86
+ },
87
+ )
88
+ model_type: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
91
+ )
92
+ config_name: Optional[str] = field(
93
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
94
+ )
95
+ tokenizer_name: Optional[str] = field(
96
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
97
+ )
98
+ cache_dir: Optional[str] = field(
99
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
104
+ )
105
+ dtype: Optional[str] = field(
106
+ default="float32",
107
+ metadata={
108
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
109
+ },
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class DataTrainingArguments:
115
+ """
116
+ Arguments pertaining to what data we are going to input our model for training and eval.
117
+ """
118
+
119
+ dataset_name: Optional[str] = field(
120
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
121
+ )
122
+ dataset_config_name: Optional[str] = field(
123
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
124
+ )
125
+ text_column: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
128
+ )
129
+ summary_column: Optional[str] = field(
130
+ default=None,
131
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
132
+ )
133
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
134
+ validation_file: Optional[str] = field(
135
+ default=None,
136
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
137
+ )
138
+ max_source_length: Optional[int] = field(
139
+ default=1024,
140
+ metadata={
141
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
142
+ "than this will be truncated, sequences shorter will be padded."
143
+ },
144
+ )
145
+ max_target_length: Optional[int] = field(
146
+ default=128,
147
+ metadata={
148
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
149
+ "than this will be truncated, sequences shorter will be padded."
150
+ },
151
+ )
152
+ val_max_target_length: Optional[int] = field(
153
+ default=None,
154
+ metadata={
155
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
156
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
157
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
158
+ "during evaluation."
159
+ },
160
+ )
161
+ max_train_samples: Optional[int] = field(
162
+ default=None,
163
+ metadata={
164
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
165
+ "value if set."
166
+ },
167
+ )
168
+ max_eval_samples: Optional[int] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
172
+ "value if set."
173
+ },
174
+ )
175
+ max_predict_samples: Optional[int] = field(
176
+ default=None,
177
+ metadata={
178
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
179
+ "value if set."
180
+ },
181
+ )
182
+ preprocessing_num_workers: Optional[int] = field(
183
+ default=None,
184
+ metadata={"help": "The number of processes to use for the preprocessing."},
185
+ )
186
+ source_prefix: Optional[str] = field(
187
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
188
+ )
189
+ predict_with_generate: bool = field(
190
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
191
+ )
192
+ num_beams: Optional[int] = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
196
+ "which is used during evaluation."
197
+ },
198
+ )
199
+ overwrite_cache: bool = field(
200
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
201
+ )
202
+
203
+ def __post_init__(self):
204
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
205
+ raise ValueError("Need either a dataset name or a training/validation file.")
206
+ else:
207
+ if self.train_file is not None:
208
+ extension = self.train_file.split(".")[-1]
209
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
210
+ if self.validation_file is not None:
211
+ extension = self.validation_file.split(".")[-1]
212
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
213
+ if self.val_max_target_length is None:
214
+ self.val_max_target_length = self.max_target_length
215
+
216
+
217
+ summarization_name_mapping = {
218
+ "amazon_reviews_multi": ("review_body", "review_title"),
219
+ "big_patent": ("description", "abstract"),
220
+ "cnn_dailymail": ("article", "highlights"),
221
+ "orange_sum": ("text", "summary"),
222
+ "pn_summary": ("article", "summary"),
223
+ "psc": ("extract_text", "summary_text"),
224
+ "samsum": ("dialogue", "summary"),
225
+ "thaisum": ("body", "summary"),
226
+ "xglue": ("news_body", "news_title"),
227
+ "xsum": ("document", "summary"),
228
+ "wiki_summary": ("article", "highlights"),
229
+ }
230
+
231
+
232
+ class TrainState(train_state.TrainState):
233
+ dropout_rng: jnp.ndarray
234
+
235
+ def replicate(self):
236
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
237
+
238
+
239
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
240
+ """
241
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
242
+ Shuffle batches if `shuffle` is `True`.
243
+ """
244
+ steps_per_epoch = len(dataset) // batch_size
245
+
246
+ if shuffle:
247
+ batch_idx = jax.random.permutation(rng, len(dataset))
248
+ else:
249
+ batch_idx = jnp.arange(len(dataset))
250
+
251
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
252
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
253
+
254
+ for idx in batch_idx:
255
+ batch = dataset[idx]
256
+ batch = {k: jnp.array(v) for k, v in batch.items()}
257
+
258
+ batch = shard(batch)
259
+
260
+ yield batch
261
+
262
+
263
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
264
+ summary_writer.scalar("train_time", train_time, step)
265
+
266
+ train_metrics = get_metrics(train_metrics)
267
+ for key, vals in train_metrics.items():
268
+ tag = f"train_{key}"
269
+ for i, val in enumerate(vals):
270
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
271
+
272
+ for metric_name, value in eval_metrics.items():
273
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
274
+
275
+
276
+ def create_learning_rate_fn(
277
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
278
+ ) -> Callable[[int], jnp.array]:
279
+ """Returns a linear warmup, linear_decay learning rate function."""
280
+ steps_per_epoch = train_ds_size // train_batch_size
281
+ num_train_steps = steps_per_epoch * num_train_epochs
282
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
283
+ decay_fn = optax.linear_schedule(
284
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
285
+ )
286
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
287
+ return schedule_fn
288
+
289
+
290
+ def main():
291
+ # See all possible arguments in src/transformers/training_args.py
292
+ # or by passing the --help flag to this script.
293
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
294
+
295
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
296
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
297
+ # If we pass only one argument to the script and it's the path to a json file,
298
+ # let's parse it to get our arguments.
299
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
300
+ else:
301
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
302
+
303
+ if (
304
+ os.path.exists(training_args.output_dir)
305
+ and os.listdir(training_args.output_dir)
306
+ and training_args.do_train
307
+ and not training_args.overwrite_output_dir
308
+ ):
309
+ raise ValueError(
310
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
311
+ "Use --overwrite_output_dir to overcome."
312
+ )
313
+
314
+ # Make one log on every process with the configuration for debugging.
315
+ logging.basicConfig(
316
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
317
+ datefmt="%m/%d/%Y %H:%M:%S",
318
+ level=logging.INFO,
319
+ )
320
+ # Setup logging, we only want one process per machine to log things on the screen.
321
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
322
+ if jax.process_index() == 0:
323
+ datasets.utils.logging.set_verbosity_warning()
324
+ transformers.utils.logging.set_verbosity_info()
325
+ else:
326
+ datasets.utils.logging.set_verbosity_error()
327
+ transformers.utils.logging.set_verbosity_error()
328
+
329
+ # Set the verbosity to info of the Transformers logger (on main process only):
330
+ logger.info(f"Training/evaluation parameters {training_args}")
331
+
332
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
333
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
334
+ # (the dataset will be downloaded automatically from the datasets Hub).
335
+ #
336
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
337
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
338
+ #
339
+ if data_args.dataset_name is not None:
340
+ # Downloading and loading a dataset from the hub.
341
+ dataset = load_dataset(
342
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
343
+ )
344
+ else:
345
+ data_files = {}
346
+ if data_args.train_file is not None:
347
+ data_files["train"] = data_args.train_file
348
+ extension = data_args.train_file.split(".")[-1]
349
+ if data_args.validation_file is not None:
350
+ data_files["validation"] = data_args.validation_file
351
+ extension = data_args.validation_file.split(".")[-1]
352
+ if data_args.test_file is not None:
353
+ data_files["test"] = data_args.test_file
354
+ extension = data_args.test_file.split(".")[-1]
355
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
356
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
357
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
358
+
359
+ # Load pretrained model and tokenizer
360
+
361
+ if model_args.config_name:
362
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
363
+ elif model_args.model_name_or_path:
364
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
365
+ else:
366
+ config = CONFIG_MAPPING[model_args.model_type]()
367
+ logger.warning("You are instantiating a new config instance from scratch.")
368
+
369
+ if model_args.tokenizer_name:
370
+ tokenizer = AutoTokenizer.from_pretrained(
371
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
372
+ )
373
+ elif model_args.model_name_or_path:
374
+ tokenizer = AutoTokenizer.from_pretrained(
375
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
376
+ )
377
+ else:
378
+ raise ValueError(
379
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
380
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
381
+ )
382
+
383
+ if model_args.model_name_or_path:
384
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
385
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
386
+ )
387
+ else:
388
+ model = FlaxAutoModelForSeq2SeqLM.from_config(
389
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
390
+ )
391
+
392
+ if model.config.decoder_start_token_id is None:
393
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
394
+
395
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
396
+
397
+ # Preprocessing the datasets.
398
+ # We need to tokenize inputs and targets.
399
+ if training_args.do_train:
400
+ column_names = dataset["train"].column_names
401
+ elif training_args.do_eval:
402
+ column_names = dataset["validation"].column_names
403
+ elif training_args.do_predict:
404
+ column_names = dataset["test"].column_names
405
+ else:
406
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
407
+ return
408
+
409
+ # Get the column names for input/target.
410
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
411
+ if data_args.text_column is None:
412
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
413
+ else:
414
+ text_column = data_args.text_column
415
+ if text_column not in column_names:
416
+ raise ValueError(
417
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
418
+ )
419
+ if data_args.summary_column is None:
420
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
421
+ else:
422
+ summary_column = data_args.summary_column
423
+ if summary_column not in column_names:
424
+ raise ValueError(
425
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
426
+ )
427
+
428
+ # Temporarily set max_target_length for training.
429
+ max_target_length = data_args.max_target_length
430
+
431
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
432
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
433
+ # for that dynamically import the `shift_tokens_right` function from the model file
434
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
435
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
436
+
437
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
438
+ def preprocess_function(examples):
439
+ inputs = examples[text_column]
440
+ targets = examples[summary_column]
441
+ inputs = [prefix + inp for inp in inputs]
442
+ model_inputs = tokenizer(
443
+ inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
444
+ )
445
+
446
+ # Setup the tokenizer for targets
447
+ with tokenizer.as_target_tokenizer():
448
+ labels = tokenizer(
449
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
450
+ )
451
+
452
+ model_inputs["labels"] = labels["input_ids"]
453
+ decoder_input_ids = shift_tokens_right_fn(
454
+ jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
455
+ )
456
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
457
+
458
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
459
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
460
+
461
+ return model_inputs
462
+
463
+ if training_args.do_train:
464
+ if "train" not in dataset:
465
+ raise ValueError("--do_train requires a train dataset")
466
+ train_dataset = dataset["train"]
467
+ if data_args.max_train_samples is not None:
468
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
469
+ train_dataset = train_dataset.map(
470
+ preprocess_function,
471
+ batched=True,
472
+ num_proc=data_args.preprocessing_num_workers,
473
+ remove_columns=column_names,
474
+ load_from_cache_file=not data_args.overwrite_cache,
475
+ desc="Running tokenizer on train dataset",
476
+ )
477
+
478
+ if training_args.do_eval:
479
+ max_target_length = data_args.val_max_target_length
480
+ if "validation" not in dataset:
481
+ raise ValueError("--do_eval requires a validation dataset")
482
+ eval_dataset = dataset["validation"]
483
+ if data_args.max_eval_samples is not None:
484
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
485
+ eval_dataset = eval_dataset.map(
486
+ preprocess_function,
487
+ batched=True,
488
+ num_proc=data_args.preprocessing_num_workers,
489
+ remove_columns=column_names,
490
+ load_from_cache_file=not data_args.overwrite_cache,
491
+ desc="Running tokenizer on validation dataset",
492
+ )
493
+
494
+ if training_args.do_predict:
495
+ max_target_length = data_args.val_max_target_length
496
+ if "test" not in dataset:
497
+ raise ValueError("--do_predict requires a test dataset")
498
+ predict_dataset = dataset["test"]
499
+ if data_args.max_predict_samples is not None:
500
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
501
+ predict_dataset = predict_dataset.map(
502
+ preprocess_function,
503
+ batched=True,
504
+ num_proc=data_args.preprocessing_num_workers,
505
+ remove_columns=column_names,
506
+ load_from_cache_file=not data_args.overwrite_cache,
507
+ desc="Running tokenizer on prediction dataset",
508
+ )
509
+
510
+ # Metric
511
+ metric = load_metric("rouge")
512
+
513
+ def postprocess_text(preds, labels):
514
+ preds = [pred.strip() for pred in preds]
515
+ labels = [label.strip() for label in labels]
516
+
517
+ # rougeLSum expects newline after each sentence
518
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
519
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
520
+
521
+ return preds, labels
522
+
523
+ def compute_metrics(preds, labels):
524
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
525
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
526
+
527
+ # Some simple post-processing
528
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
529
+
530
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
531
+ # Extract a few results from ROUGE
532
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
533
+
534
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
535
+ result["gen_len"] = np.mean(prediction_lens)
536
+ result = {k: round(v, 4) for k, v in result.items()}
537
+ return result
538
+
539
+ # Enable tensorboard only on the master node
540
+ has_tensorboard = is_tensorboard_available()
541
+ if has_tensorboard and jax.process_index() == 0:
542
+ try:
543
+ from flax.metrics.tensorboard import SummaryWriter
544
+
545
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
546
+ except ImportError as ie:
547
+ has_tensorboard = False
548
+ logger.warning(
549
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
550
+ )
551
+ else:
552
+ logger.warning(
553
+ "Unable to display metrics through TensorBoard because the package is not installed: "
554
+ "Please run pip install tensorboard to enable."
555
+ )
556
+
557
+ # Initialize our training
558
+ rng = jax.random.PRNGKey(training_args.seed)
559
+ rng, dropout_rng = jax.random.split(rng)
560
+
561
+ # Store some constant
562
+ num_epochs = int(training_args.num_train_epochs)
563
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
564
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
565
+ steps_per_epoch = len(train_dataset) // train_batch_size
566
+ total_train_steps = steps_per_epoch * num_epochs
567
+
568
+ # Create learning rate schedule
569
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
570
+ len(train_dataset),
571
+ train_batch_size,
572
+ training_args.num_train_epochs,
573
+ training_args.warmup_steps,
574
+ training_args.learning_rate,
575
+ )
576
+
577
+ # We use Optax's "masking" functionality to not apply weight decay
578
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
579
+ # mask boolean with the same structure as the parameters.
580
+ # The mask is True for parameters that should be decayed.
581
+ # Note that this mask is specifically adapted for FlaxBart.
582
+ # For FlaxT5, one should correct the layer norm parameter naming
583
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
584
+ def decay_mask_fn(params):
585
+ flat_params = traverse_util.flatten_dict(params)
586
+ layer_norm_params = [
587
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
588
+ ]
589
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
590
+ return traverse_util.unflatten_dict(flat_mask)
591
+
592
+ # create adam optimizer
593
+ adamw = optax.adamw(
594
+ learning_rate=linear_decay_lr_schedule_fn,
595
+ b1=training_args.adam_beta1,
596
+ b2=training_args.adam_beta2,
597
+ eps=training_args.adam_epsilon,
598
+ weight_decay=training_args.weight_decay,
599
+ mask=decay_mask_fn,
600
+ )
601
+
602
+ # Setup train state
603
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
604
+
605
+ # label smoothed cross entropy
606
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
607
+ """
608
+ The label smoothing implementation is adapted from Flax's official example:
609
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
610
+ """
611
+ vocab_size = logits.shape[-1]
612
+ confidence = 1.0 - label_smoothing_factor
613
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
614
+ normalizing_constant = -(
615
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
616
+ )
617
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
618
+
619
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
620
+ loss = loss - normalizing_constant
621
+
622
+ # ignore padded tokens from loss
623
+ loss = loss * padding_mask
624
+ loss = loss.sum() / padding_mask.sum()
625
+ return loss
626
+
627
+ # Define gradient update step fn
628
+ def train_step(state, batch, label_smoothing_factor=0.0):
629
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
630
+
631
+ def compute_loss(params):
632
+ labels = batch.pop("labels")
633
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
634
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
635
+ return loss
636
+
637
+ grad_fn = jax.value_and_grad(compute_loss)
638
+ loss, grad = grad_fn(state.params)
639
+ grad = jax.lax.pmean(grad, "batch")
640
+
641
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
642
+
643
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
644
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
645
+
646
+ return new_state, metrics
647
+
648
+ # Define eval fn
649
+ def eval_step(params, batch, label_smoothing_factor=0.0):
650
+ labels = batch.pop("labels")
651
+ logits = model(**batch, params=params, train=False)[0]
652
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
653
+
654
+ # summarize metrics
655
+ metrics = {"loss": loss}
656
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
657
+ return metrics
658
+
659
+ # Define generation function
660
+ max_length = (
661
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
662
+ )
663
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
664
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
665
+
666
+ def generate_step(params, batch):
667
+ model.params = params
668
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
669
+ return output_ids.sequences
670
+
671
+ # Create parallel version of the train and eval step
672
+ p_train_step = jax.pmap(
673
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
674
+ )
675
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
676
+ p_generate_step = jax.pmap(generate_step, "batch")
677
+
678
+ # Replicate the train state on each device
679
+ state = state.replicate()
680
+
681
+ logger.info("***** Running training *****")
682
+ logger.info(f" Num examples = {len(train_dataset)}")
683
+ logger.info(f" Num Epochs = {num_epochs}")
684
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
685
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
686
+ logger.info(f" Total optimization steps = {total_train_steps}")
687
+
688
+ train_time = 0
689
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
690
+ for epoch in epochs:
691
+ # ======================== Training ================================
692
+ train_start = time.time()
693
+
694
+ # Create sampling rng
695
+ rng, input_rng = jax.random.split(rng)
696
+ train_metrics = []
697
+
698
+ # Generate an epoch by shuffling sampling indices from the train dataset
699
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
700
+ steps_per_epoch = len(train_dataset) // train_batch_size
701
+ # train
702
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
703
+ batch = next(train_loader)
704
+ state, train_metric = p_train_step(state, batch)
705
+ train_metrics.append(train_metric)
706
+
707
+ train_time += time.time() - train_start
708
+
709
+ train_metric = unreplicate(train_metric)
710
+
711
+ epochs.write(
712
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
713
+ )
714
+
715
+ # ======================== Evaluating ==============================
716
+ eval_metrics = []
717
+ eval_preds = []
718
+ eval_labels = []
719
+
720
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
721
+ eval_steps = len(eval_dataset) // eval_batch_size
722
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
723
+ # Model forward
724
+ batch = next(eval_loader)
725
+ labels = batch["labels"]
726
+
727
+ metrics = p_eval_step(state.params, batch)
728
+ eval_metrics.append(metrics)
729
+
730
+ # generation
731
+ if data_args.predict_with_generate:
732
+ generated_ids = p_generate_step(state.params, batch)
733
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
734
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
735
+
736
+ # normalize eval metrics
737
+ eval_metrics = get_metrics(eval_metrics)
738
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
739
+
740
+ # compute ROUGE metrics
741
+ rouge_desc = ""
742
+ if data_args.predict_with_generate:
743
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
744
+ eval_metrics.update(rouge_metrics)
745
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
746
+
747
+ # Print metrics and update progress bar
748
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
749
+ epochs.write(desc)
750
+ epochs.desc = desc
751
+
752
+ # Save metrics
753
+ if has_tensorboard and jax.process_index() == 0:
754
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
755
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
756
+
757
+ # ======================== Prediction loop ==============================
758
+ if training_args.do_predict:
759
+ logger.info("*** Predict ***")
760
+
761
+ pred_metrics = []
762
+ pred_generations = []
763
+ pred_labels = []
764
+
765
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
766
+ pred_steps = len(predict_dataset) // eval_batch_size
767
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
768
+ # Model forward
769
+ batch = next(pred_loader)
770
+ labels = batch["labels"]
771
+
772
+ metrics = p_eval_step(state.params, batch)
773
+ pred_metrics.append(metrics)
774
+
775
+ # generation
776
+ if data_args.predict_with_generate:
777
+ generated_ids = p_generate_step(state.params, batch)
778
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
779
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
780
+
781
+ # normalize prediction metrics
782
+ pred_metrics = get_metrics(pred_metrics)
783
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
784
+
785
+ # compute ROUGE metrics
786
+ rouge_desc = ""
787
+ if data_args.predict_with_generate:
788
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
789
+ pred_metrics.update(rouge_metrics)
790
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
791
+
792
+ # Print metrics
793
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
794
+ logger.info(desc)
795
+
796
+ # save checkpoint after each epoch and push checkpoint to the hub
797
+ if jax.process_index() == 0:
798
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
799
+ model.save_pretrained(
800
+ training_args.output_dir,
801
+ params=params,
802
+ push_to_hub=training_args.push_to_hub,
803
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
804
+ )
805
+
806
+
807
+ if __name__ == "__main__":
808
+ main()