boris commited on
Commit
46cb01f
1 Parent(s): 1055c3d

feat: add run_seq2seq_flax

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +813 -0
seq2seq/run_seq2seq_flax.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for seq2seq, text to image.
18
+ Script adapted from run_summarization_flax.py
19
+ """
20
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
+
22
+ import logging
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Optional
30
+
31
+ import datasets
32
+ import nltk # Here to have a nice missing dependency error message early on
33
+ import numpy as np
34
+ from datasets import Dataset, load_dataset, load_metric
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ from filelock import FileLock
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
46
+ from transformers import (
47
+ CONFIG_MAPPING,
48
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
49
+ AutoConfig,
50
+ AutoTokenizer,
51
+ FlaxAutoModelForSeq2SeqLM,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ )
56
+ from transformers.file_utils import is_offline_mode
57
+
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ try:
62
+ nltk.data.find("tokenizers/punkt")
63
+ except (LookupError, OSError):
64
+ if is_offline_mode():
65
+ raise LookupError(
66
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
67
+ )
68
+ with FileLock(".lock") as lock:
69
+ nltk.download("punkt", quiet=True)
70
+
71
+
72
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
73
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
74
+
75
+
76
+ @dataclass
77
+ class ModelArguments:
78
+ """
79
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
80
+ """
81
+
82
+ model_name_or_path: Optional[str] = field(
83
+ default=None,
84
+ metadata={
85
+ "help": "The model checkpoint for weights initialization."
86
+ "Don't set if you want to train a model from scratch."
87
+ },
88
+ )
89
+ model_type: Optional[str] = field(
90
+ default=None,
91
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
92
+ )
93
+ config_name: Optional[str] = field(
94
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
95
+ )
96
+ tokenizer_name: Optional[str] = field(
97
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
98
+ )
99
+ cache_dir: Optional[str] = field(
100
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
101
+ )
102
+ use_fast_tokenizer: bool = field(
103
+ default=True,
104
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
105
+ )
106
+ dtype: Optional[str] = field(
107
+ default="float32",
108
+ metadata={
109
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
110
+ },
111
+ )
112
+
113
+
114
+ @dataclass
115
+ class DataTrainingArguments:
116
+ """
117
+ Arguments pertaining to what data we are going to input our model for training and eval.
118
+ """
119
+
120
+ dataset_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
122
+ )
123
+ dataset_config_name: Optional[str] = field(
124
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
+ )
126
+ text_column: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
129
+ )
130
+ summary_column: Optional[str] = field(
131
+ default=None,
132
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
133
+ )
134
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
135
+ validation_file: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
138
+ )
139
+ test_file: Optional[str] = field(
140
+ default=None,
141
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
142
+ )
143
+ max_source_length: Optional[int] = field(
144
+ default=1024,
145
+ metadata={
146
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
147
+ "than this will be truncated, sequences shorter will be padded."
148
+ },
149
+ )
150
+ max_target_length: Optional[int] = field(
151
+ default=128,
152
+ metadata={
153
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
154
+ "than this will be truncated, sequences shorter will be padded."
155
+ },
156
+ )
157
+ val_max_target_length: Optional[int] = field(
158
+ default=None,
159
+ metadata={
160
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
161
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
162
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
163
+ "during evaluation."
164
+ },
165
+ )
166
+ max_train_samples: Optional[int] = field(
167
+ default=None,
168
+ metadata={
169
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
170
+ "value if set."
171
+ },
172
+ )
173
+ max_eval_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
177
+ "value if set."
178
+ },
179
+ )
180
+ max_predict_samples: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
184
+ "value if set."
185
+ },
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ source_prefix: Optional[str] = field(
192
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
193
+ )
194
+ predict_with_generate: bool = field(
195
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
196
+ )
197
+ num_beams: Optional[int] = field(
198
+ default=None,
199
+ metadata={
200
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
201
+ "which is used during evaluation."
202
+ },
203
+ )
204
+ overwrite_cache: bool = field(
205
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
206
+ )
207
+
208
+ def __post_init__(self):
209
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
210
+ raise ValueError("Need either a dataset name or a training/validation file.")
211
+ else:
212
+ if self.train_file is not None:
213
+ extension = self.train_file.split(".")[-1]
214
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
215
+ if self.validation_file is not None:
216
+ extension = self.validation_file.split(".")[-1]
217
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
218
+ if self.val_max_target_length is None:
219
+ self.val_max_target_length = self.max_target_length
220
+
221
+
222
+ summarization_name_mapping = {
223
+ "amazon_reviews_multi": ("review_body", "review_title"),
224
+ "big_patent": ("description", "abstract"),
225
+ "cnn_dailymail": ("article", "highlights"),
226
+ "orange_sum": ("text", "summary"),
227
+ "pn_summary": ("article", "summary"),
228
+ "psc": ("extract_text", "summary_text"),
229
+ "samsum": ("dialogue", "summary"),
230
+ "thaisum": ("body", "summary"),
231
+ "xglue": ("news_body", "news_title"),
232
+ "xsum": ("document", "summary"),
233
+ "wiki_summary": ("article", "highlights"),
234
+ }
235
+
236
+
237
+ class TrainState(train_state.TrainState):
238
+ dropout_rng: jnp.ndarray
239
+
240
+ def replicate(self):
241
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
242
+
243
+
244
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
245
+ """
246
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
247
+ Shuffle batches if `shuffle` is `True`.
248
+ """
249
+ steps_per_epoch = len(dataset) // batch_size
250
+
251
+ if shuffle:
252
+ batch_idx = jax.random.permutation(rng, len(dataset))
253
+ else:
254
+ batch_idx = jnp.arange(len(dataset))
255
+
256
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
257
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
258
+
259
+ for idx in batch_idx:
260
+ batch = dataset[idx]
261
+ batch = {k: jnp.array(v) for k, v in batch.items()}
262
+
263
+ batch = shard(batch)
264
+
265
+ yield batch
266
+
267
+
268
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
269
+ summary_writer.scalar("train_time", train_time, step)
270
+
271
+ train_metrics = get_metrics(train_metrics)
272
+ for key, vals in train_metrics.items():
273
+ tag = f"train_{key}"
274
+ for i, val in enumerate(vals):
275
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
276
+
277
+ for metric_name, value in eval_metrics.items():
278
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
279
+
280
+
281
+ def create_learning_rate_fn(
282
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
283
+ ) -> Callable[[int], jnp.array]:
284
+ """Returns a linear warmup, linear_decay learning rate function."""
285
+ steps_per_epoch = train_ds_size // train_batch_size
286
+ num_train_steps = steps_per_epoch * num_train_epochs
287
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
288
+ decay_fn = optax.linear_schedule(
289
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
290
+ )
291
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
292
+ return schedule_fn
293
+
294
+
295
+ def main():
296
+ # See all possible arguments in src/transformers/training_args.py
297
+ # or by passing the --help flag to this script.
298
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
299
+
300
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
301
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
302
+ # If we pass only one argument to the script and it's the path to a json file,
303
+ # let's parse it to get our arguments.
304
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
305
+ else:
306
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
307
+
308
+ if (
309
+ os.path.exists(training_args.output_dir)
310
+ and os.listdir(training_args.output_dir)
311
+ and training_args.do_train
312
+ and not training_args.overwrite_output_dir
313
+ ):
314
+ raise ValueError(
315
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
316
+ "Use --overwrite_output_dir to overcome."
317
+ )
318
+
319
+ # Make one log on every process with the configuration for debugging.
320
+ logging.basicConfig(
321
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
322
+ datefmt="%m/%d/%Y %H:%M:%S",
323
+ level=logging.INFO,
324
+ )
325
+ # Setup logging, we only want one process per machine to log things on the screen.
326
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
327
+ if jax.process_index() == 0:
328
+ datasets.utils.logging.set_verbosity_warning()
329
+ transformers.utils.logging.set_verbosity_info()
330
+ else:
331
+ datasets.utils.logging.set_verbosity_error()
332
+ transformers.utils.logging.set_verbosity_error()
333
+
334
+ # Set the verbosity to info of the Transformers logger (on main process only):
335
+ logger.info(f"Training/evaluation parameters {training_args}")
336
+
337
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
338
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
339
+ # (the dataset will be downloaded automatically from the datasets Hub).
340
+ #
341
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
342
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
343
+ #
344
+ if data_args.dataset_name is not None:
345
+ # Downloading and loading a dataset from the hub.
346
+ dataset = load_dataset(
347
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
348
+ )
349
+ else:
350
+ data_files = {}
351
+ if data_args.train_file is not None:
352
+ data_files["train"] = data_args.train_file
353
+ extension = data_args.train_file.split(".")[-1]
354
+ if data_args.validation_file is not None:
355
+ data_files["validation"] = data_args.validation_file
356
+ extension = data_args.validation_file.split(".")[-1]
357
+ if data_args.test_file is not None:
358
+ data_files["test"] = data_args.test_file
359
+ extension = data_args.test_file.split(".")[-1]
360
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
361
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
362
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
363
+
364
+ # Load pretrained model and tokenizer
365
+
366
+ if model_args.config_name:
367
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
368
+ elif model_args.model_name_or_path:
369
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
370
+ else:
371
+ config = CONFIG_MAPPING[model_args.model_type]()
372
+ logger.warning("You are instantiating a new config instance from scratch.")
373
+
374
+ if model_args.tokenizer_name:
375
+ tokenizer = AutoTokenizer.from_pretrained(
376
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
377
+ )
378
+ elif model_args.model_name_or_path:
379
+ tokenizer = AutoTokenizer.from_pretrained(
380
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
381
+ )
382
+ else:
383
+ raise ValueError(
384
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
385
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
386
+ )
387
+
388
+ if model_args.model_name_or_path:
389
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
390
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
391
+ )
392
+ else:
393
+ model = FlaxAutoModelForSeq2SeqLM.from_config(
394
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
395
+ )
396
+
397
+ if model.config.decoder_start_token_id is None:
398
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
399
+
400
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
401
+
402
+ # Preprocessing the datasets.
403
+ # We need to tokenize inputs and targets.
404
+ if training_args.do_train:
405
+ column_names = dataset["train"].column_names
406
+ elif training_args.do_eval:
407
+ column_names = dataset["validation"].column_names
408
+ elif training_args.do_predict:
409
+ column_names = dataset["test"].column_names
410
+ else:
411
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
412
+ return
413
+
414
+ # Get the column names for input/target.
415
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
416
+ if data_args.text_column is None:
417
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
418
+ else:
419
+ text_column = data_args.text_column
420
+ if text_column not in column_names:
421
+ raise ValueError(
422
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
423
+ )
424
+ if data_args.summary_column is None:
425
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
426
+ else:
427
+ summary_column = data_args.summary_column
428
+ if summary_column not in column_names:
429
+ raise ValueError(
430
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
431
+ )
432
+
433
+ # Temporarily set max_target_length for training.
434
+ max_target_length = data_args.max_target_length
435
+
436
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
437
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
438
+ # for that dynamically import the `shift_tokens_right` function from the model file
439
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
440
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
441
+
442
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
443
+ def preprocess_function(examples):
444
+ inputs = examples[text_column]
445
+ targets = examples[summary_column]
446
+ inputs = [prefix + inp for inp in inputs]
447
+ model_inputs = tokenizer(
448
+ inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
449
+ )
450
+
451
+ # Setup the tokenizer for targets
452
+ with tokenizer.as_target_tokenizer():
453
+ labels = tokenizer(
454
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
455
+ )
456
+
457
+ model_inputs["labels"] = labels["input_ids"]
458
+ decoder_input_ids = shift_tokens_right_fn(
459
+ jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
460
+ )
461
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
462
+
463
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
464
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
465
+
466
+ return model_inputs
467
+
468
+ if training_args.do_train:
469
+ if "train" not in dataset:
470
+ raise ValueError("--do_train requires a train dataset")
471
+ train_dataset = dataset["train"]
472
+ if data_args.max_train_samples is not None:
473
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
474
+ train_dataset = train_dataset.map(
475
+ preprocess_function,
476
+ batched=True,
477
+ num_proc=data_args.preprocessing_num_workers,
478
+ remove_columns=column_names,
479
+ load_from_cache_file=not data_args.overwrite_cache,
480
+ desc="Running tokenizer on train dataset",
481
+ )
482
+
483
+ if training_args.do_eval:
484
+ max_target_length = data_args.val_max_target_length
485
+ if "validation" not in dataset:
486
+ raise ValueError("--do_eval requires a validation dataset")
487
+ eval_dataset = dataset["validation"]
488
+ if data_args.max_eval_samples is not None:
489
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
490
+ eval_dataset = eval_dataset.map(
491
+ preprocess_function,
492
+ batched=True,
493
+ num_proc=data_args.preprocessing_num_workers,
494
+ remove_columns=column_names,
495
+ load_from_cache_file=not data_args.overwrite_cache,
496
+ desc="Running tokenizer on validation dataset",
497
+ )
498
+
499
+ if training_args.do_predict:
500
+ max_target_length = data_args.val_max_target_length
501
+ if "test" not in dataset:
502
+ raise ValueError("--do_predict requires a test dataset")
503
+ predict_dataset = dataset["test"]
504
+ if data_args.max_predict_samples is not None:
505
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
506
+ predict_dataset = predict_dataset.map(
507
+ preprocess_function,
508
+ batched=True,
509
+ num_proc=data_args.preprocessing_num_workers,
510
+ remove_columns=column_names,
511
+ load_from_cache_file=not data_args.overwrite_cache,
512
+ desc="Running tokenizer on prediction dataset",
513
+ )
514
+
515
+ # Metric
516
+ metric = load_metric("rouge")
517
+
518
+ def postprocess_text(preds, labels):
519
+ preds = [pred.strip() for pred in preds]
520
+ labels = [label.strip() for label in labels]
521
+
522
+ # rougeLSum expects newline after each sentence
523
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
524
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
525
+
526
+ return preds, labels
527
+
528
+ def compute_metrics(preds, labels):
529
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
530
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
531
+
532
+ # Some simple post-processing
533
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
534
+
535
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
536
+ # Extract a few results from ROUGE
537
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
538
+
539
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
540
+ result["gen_len"] = np.mean(prediction_lens)
541
+ result = {k: round(v, 4) for k, v in result.items()}
542
+ return result
543
+
544
+ # Enable tensorboard only on the master node
545
+ has_tensorboard = is_tensorboard_available()
546
+ if has_tensorboard and jax.process_index() == 0:
547
+ try:
548
+ from flax.metrics.tensorboard import SummaryWriter
549
+
550
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
551
+ except ImportError as ie:
552
+ has_tensorboard = False
553
+ logger.warning(
554
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
555
+ )
556
+ else:
557
+ logger.warning(
558
+ "Unable to display metrics through TensorBoard because the package is not installed: "
559
+ "Please run pip install tensorboard to enable."
560
+ )
561
+
562
+ # Initialize our training
563
+ rng = jax.random.PRNGKey(training_args.seed)
564
+ rng, dropout_rng = jax.random.split(rng)
565
+
566
+ # Store some constant
567
+ num_epochs = int(training_args.num_train_epochs)
568
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
569
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
570
+ steps_per_epoch = len(train_dataset) // train_batch_size
571
+ total_train_steps = steps_per_epoch * num_epochs
572
+
573
+ # Create learning rate schedule
574
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
575
+ len(train_dataset),
576
+ train_batch_size,
577
+ training_args.num_train_epochs,
578
+ training_args.warmup_steps,
579
+ training_args.learning_rate,
580
+ )
581
+
582
+ # We use Optax's "masking" functionality to not apply weight decay
583
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
584
+ # mask boolean with the same structure as the parameters.
585
+ # The mask is True for parameters that should be decayed.
586
+ # Note that this mask is specifically adapted for FlaxBart.
587
+ # For FlaxT5, one should correct the layer norm parameter naming
588
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
589
+ def decay_mask_fn(params):
590
+ flat_params = traverse_util.flatten_dict(params)
591
+ layer_norm_params = [
592
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
593
+ ]
594
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
595
+ return traverse_util.unflatten_dict(flat_mask)
596
+
597
+ # create adam optimizer
598
+ adamw = optax.adamw(
599
+ learning_rate=linear_decay_lr_schedule_fn,
600
+ b1=training_args.adam_beta1,
601
+ b2=training_args.adam_beta2,
602
+ eps=training_args.adam_epsilon,
603
+ weight_decay=training_args.weight_decay,
604
+ mask=decay_mask_fn,
605
+ )
606
+
607
+ # Setup train state
608
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
609
+
610
+ # label smoothed cross entropy
611
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
612
+ """
613
+ The label smoothing implementation is adapted from Flax's official example:
614
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
615
+ """
616
+ vocab_size = logits.shape[-1]
617
+ confidence = 1.0 - label_smoothing_factor
618
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
619
+ normalizing_constant = -(
620
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
621
+ )
622
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
623
+
624
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
625
+ loss = loss - normalizing_constant
626
+
627
+ # ignore padded tokens from loss
628
+ loss = loss * padding_mask
629
+ loss = loss.sum() / padding_mask.sum()
630
+ return loss
631
+
632
+ # Define gradient update step fn
633
+ def train_step(state, batch, label_smoothing_factor=0.0):
634
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
635
+
636
+ def compute_loss(params):
637
+ labels = batch.pop("labels")
638
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
639
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
640
+ return loss
641
+
642
+ grad_fn = jax.value_and_grad(compute_loss)
643
+ loss, grad = grad_fn(state.params)
644
+ grad = jax.lax.pmean(grad, "batch")
645
+
646
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
647
+
648
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
649
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
650
+
651
+ return new_state, metrics
652
+
653
+ # Define eval fn
654
+ def eval_step(params, batch, label_smoothing_factor=0.0):
655
+ labels = batch.pop("labels")
656
+ logits = model(**batch, params=params, train=False)[0]
657
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
658
+
659
+ # summarize metrics
660
+ metrics = {"loss": loss}
661
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
662
+ return metrics
663
+
664
+ # Define generation function
665
+ max_length = (
666
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
667
+ )
668
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
669
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
670
+
671
+ def generate_step(params, batch):
672
+ model.params = params
673
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
674
+ return output_ids.sequences
675
+
676
+ # Create parallel version of the train and eval step
677
+ p_train_step = jax.pmap(
678
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
679
+ )
680
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
681
+ p_generate_step = jax.pmap(generate_step, "batch")
682
+
683
+ # Replicate the train state on each device
684
+ state = state.replicate()
685
+
686
+ logger.info("***** Running training *****")
687
+ logger.info(f" Num examples = {len(train_dataset)}")
688
+ logger.info(f" Num Epochs = {num_epochs}")
689
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
690
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
691
+ logger.info(f" Total optimization steps = {total_train_steps}")
692
+
693
+ train_time = 0
694
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
695
+ for epoch in epochs:
696
+ # ======================== Training ================================
697
+ train_start = time.time()
698
+
699
+ # Create sampling rng
700
+ rng, input_rng = jax.random.split(rng)
701
+ train_metrics = []
702
+
703
+ # Generate an epoch by shuffling sampling indices from the train dataset
704
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
705
+ steps_per_epoch = len(train_dataset) // train_batch_size
706
+ # train
707
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
708
+ batch = next(train_loader)
709
+ state, train_metric = p_train_step(state, batch)
710
+ train_metrics.append(train_metric)
711
+
712
+ train_time += time.time() - train_start
713
+
714
+ train_metric = unreplicate(train_metric)
715
+
716
+ epochs.write(
717
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
718
+ )
719
+
720
+ # ======================== Evaluating ==============================
721
+ eval_metrics = []
722
+ eval_preds = []
723
+ eval_labels = []
724
+
725
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
726
+ eval_steps = len(eval_dataset) // eval_batch_size
727
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
728
+ # Model forward
729
+ batch = next(eval_loader)
730
+ labels = batch["labels"]
731
+
732
+ metrics = p_eval_step(state.params, batch)
733
+ eval_metrics.append(metrics)
734
+
735
+ # generation
736
+ if data_args.predict_with_generate:
737
+ generated_ids = p_generate_step(state.params, batch)
738
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
739
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
740
+
741
+ # normalize eval metrics
742
+ eval_metrics = get_metrics(eval_metrics)
743
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
744
+
745
+ # compute ROUGE metrics
746
+ rouge_desc = ""
747
+ if data_args.predict_with_generate:
748
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
749
+ eval_metrics.update(rouge_metrics)
750
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
751
+
752
+ # Print metrics and update progress bar
753
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
754
+ epochs.write(desc)
755
+ epochs.desc = desc
756
+
757
+ # Save metrics
758
+ if has_tensorboard and jax.process_index() == 0:
759
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
760
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
761
+
762
+ # ======================== Prediction loop ==============================
763
+ if training_args.do_predict:
764
+ logger.info("*** Predict ***")
765
+
766
+ pred_metrics = []
767
+ pred_generations = []
768
+ pred_labels = []
769
+
770
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
771
+ pred_steps = len(predict_dataset) // eval_batch_size
772
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
773
+ # Model forward
774
+ batch = next(pred_loader)
775
+ labels = batch["labels"]
776
+
777
+ metrics = p_eval_step(state.params, batch)
778
+ pred_metrics.append(metrics)
779
+
780
+ # generation
781
+ if data_args.predict_with_generate:
782
+ generated_ids = p_generate_step(state.params, batch)
783
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
784
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
785
+
786
+ # normalize prediction metrics
787
+ pred_metrics = get_metrics(pred_metrics)
788
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
789
+
790
+ # compute ROUGE metrics
791
+ rouge_desc = ""
792
+ if data_args.predict_with_generate:
793
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
794
+ pred_metrics.update(rouge_metrics)
795
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
796
+
797
+ # Print metrics
798
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
799
+ logger.info(desc)
800
+
801
+ # save checkpoint after each epoch and push checkpoint to the hub
802
+ if jax.process_index() == 0:
803
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
804
+ model.save_pretrained(
805
+ training_args.output_dir,
806
+ params=params,
807
+ push_to_hub=training_args.push_to_hub,
808
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
809
+ )
810
+
811
+
812
+ if __name__ == "__main__":
813
+ main()