ydshieh commited on
Commit
5306066
1 Parent(s): 283180e

upload debug.py

Browse files
Files changed (1) hide show
  1. debug.py +1343 -0
debug.py ADDED
@@ -0,0 +1,1343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ import datetime
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Optional
30
+
31
+ import datasets
32
+ import nltk # Here to have a nice missing dependency error message early on
33
+ import numpy as np
34
+ from datasets import Dataset, load_dataset, load_metric
35
+ from tqdm import tqdm
36
+ from PIL import Image
37
+
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from filelock import FileLock
43
+ from flax import jax_utils, traverse_util
44
+ from flax.jax_utils import unreplicate
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from huggingface_hub import Repository
48
+ from transformers import (
49
+ CONFIG_MAPPING,
50
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
51
+ AutoConfig,
52
+ AutoFeatureExtractor,
53
+ AutoTokenizer,
54
+ HfArgumentParser,
55
+ TrainingArguments,
56
+ is_tensorboard_available,
57
+ FlaxAutoModelForVision2Seq,
58
+ )
59
+ from transformers.file_utils import get_full_repo_name, is_offline_mode
60
+
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
+
66
+
67
+ try:
68
+ nltk.data.find("tokenizers/punkt")
69
+ except (LookupError, OSError):
70
+ if is_offline_mode():
71
+ raise LookupError(
72
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
73
+ )
74
+ with FileLock(".lock") as lock:
75
+ nltk.download("punkt", quiet=True)
76
+
77
+
78
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
79
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
80
+
81
+
82
+ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
83
+ def shift_tokens_right(input_ids: np.ndarray, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
84
+ """
85
+ Shift input ids one token to the right.
86
+ """
87
+ shifted_input_ids = np.zeros_like(input_ids)
88
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
89
+ shifted_input_ids[:, 0] = decoder_start_token_id
90
+
91
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
92
+ return shifted_input_ids
93
+
94
+
95
+ @dataclass
96
+ class CustomTrainingArguments(TrainingArguments):
97
+
98
+ do_predict_during_training: bool = field(default=None, metadata={"help": "???"})
99
+ do_predict_after_evaluation: bool = field(default=None, metadata={"help": "???"})
100
+ block_size: int = field(default=None, metadata={"help": "???"})
101
+
102
+
103
+ @dataclass
104
+ class ModelArguments:
105
+ """
106
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
107
+ """
108
+
109
+ model_name_or_path: Optional[str] = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "The model checkpoint for weights initialization."
113
+ "Don't set if you want to train a model from scratch."
114
+ },
115
+ )
116
+ encoder_model_name_or_path: Optional[str] = field(
117
+ default=None,
118
+ metadata={
119
+ "help": "The encoder model checkpoint for weights initialization."
120
+ "Don't set if you want to train a model from scratch."
121
+ },
122
+ )
123
+ decoder_model_name_or_path: Optional[str] = field(
124
+ default=None,
125
+ metadata={
126
+ "help": "The decoder model checkpoint for weights initialization."
127
+ "Don't set if you want to train a model from scratch."
128
+ },
129
+ )
130
+ model_type: Optional[str] = field(
131
+ default='vision-encoder-decoder',
132
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
133
+ )
134
+ encoder_model_type: Optional[str] = field(
135
+ default=None,
136
+ metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
137
+ )
138
+ decoder_model_type: Optional[str] = field(
139
+ default=None,
140
+ metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
141
+ )
142
+ config_name: Optional[str] = field(
143
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
144
+ )
145
+ encoder_config_name: Optional[str] = field(
146
+ default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"}
147
+ )
148
+ decoder_config_name: Optional[str] = field(
149
+ default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"}
150
+ )
151
+ feature_extractor_name: Optional[str] = field(
152
+ default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"}
153
+ )
154
+ tokenizer_name: Optional[str] = field(
155
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"}
156
+ )
157
+ cache_dir: Optional[str] = field(
158
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
159
+ )
160
+ use_fast_tokenizer: bool = field(
161
+ default=True,
162
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
163
+ )
164
+ dtype: Optional[str] = field(
165
+ default="float32",
166
+ metadata={
167
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
168
+ },
169
+ )
170
+
171
+
172
+ @dataclass
173
+ class DataTrainingArguments:
174
+ """
175
+ Arguments pertaining to what data we are going to input our model for training and eval.
176
+ """
177
+
178
+ dataset_name: Optional[str] = field(
179
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
180
+ )
181
+ dataset_config_name: Optional[str] = field(
182
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
183
+ )
184
+ data_dir: Optional[str] = field(
185
+ default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
186
+ )
187
+ image_column: Optional[str] = field(
188
+ default=None,
189
+ metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."},
190
+ )
191
+ caption_column: Optional[str] = field(
192
+ default=None,
193
+ metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."},
194
+ )
195
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
196
+ validation_file: Optional[str] = field(
197
+ default=None,
198
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
199
+ )
200
+ test_file: Optional[str] = field(
201
+ default=None,
202
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
203
+ )
204
+ max_source_length: Optional[int] = field(
205
+ default=1024,
206
+ metadata={
207
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
208
+ "than this will be truncated, sequences shorter will be padded."
209
+ },
210
+ )
211
+ max_target_length: Optional[int] = field(
212
+ default=128,
213
+ metadata={
214
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
215
+ "than this will be truncated, sequences shorter will be padded."
216
+ },
217
+ )
218
+ val_max_target_length: Optional[int] = field(
219
+ default=None,
220
+ metadata={
221
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
222
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
223
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
224
+ "during evaluation."
225
+ },
226
+ )
227
+ max_train_samples: Optional[int] = field(
228
+ default=None,
229
+ metadata={
230
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
231
+ "value if set."
232
+ },
233
+ )
234
+ max_eval_samples: Optional[int] = field(
235
+ default=None,
236
+ metadata={
237
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
238
+ "value if set."
239
+ },
240
+ )
241
+ max_predict_samples: Optional[int] = field(
242
+ default=None,
243
+ metadata={
244
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
245
+ "value if set."
246
+ },
247
+ )
248
+ preprocessing_num_workers: Optional[int] = field(
249
+ default=None,
250
+ metadata={"help": "The number of processes to use for the preprocessing."},
251
+ )
252
+ predict_with_generate: bool = field(
253
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
254
+ )
255
+ num_beams: Optional[int] = field(
256
+ default=None,
257
+ metadata={
258
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
259
+ "which is used during evaluation."
260
+ },
261
+ )
262
+ overwrite_cache: bool = field(
263
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
264
+ )
265
+
266
+ def __post_init__(self):
267
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
268
+ raise ValueError("Need either a dataset name or a training/validation file.")
269
+ else:
270
+ if self.train_file is not None:
271
+ extension = self.train_file.split(".")[-1]
272
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
273
+ if self.validation_file is not None:
274
+ extension = self.validation_file.split(".")[-1]
275
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
276
+ if self.val_max_target_length is None:
277
+ self.val_max_target_length = self.max_target_length
278
+
279
+
280
+ image_captioning_name_mapping = {
281
+ "image_caption_dataset.py": ("image_file", "caption"),
282
+ }
283
+
284
+
285
+ class TrainState(train_state.TrainState):
286
+ dropout_rng: jnp.ndarray
287
+
288
+ def replicate(self):
289
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
290
+
291
+
292
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
293
+ """
294
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
295
+ Shuffle batches if `shuffle` is `True`.
296
+ """
297
+ steps_per_epoch = len(dataset) // batch_size
298
+
299
+ if shuffle:
300
+ batch_idx = jax.random.permutation(rng, len(dataset))
301
+ else:
302
+ batch_idx = jnp.arange(len(dataset))
303
+
304
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
305
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
306
+
307
+ for idx in batch_idx:
308
+ batch = dataset[idx]
309
+ batch = {k: jnp.array(v) for k, v in batch.items()}
310
+
311
+ batch = shard(batch)
312
+
313
+ yield batch
314
+
315
+
316
+ def write_metric(summary_writer, mode, metrics, step, train_time=None):
317
+
318
+ if train_time:
319
+ summary_writer.scalar("train_time", train_time, step)
320
+
321
+ if mode == "train":
322
+ metrics = get_metrics(metrics)
323
+ for key, vals in metrics.items():
324
+ tag = f"{mode}_{key}"
325
+ for i, val in enumerate(vals):
326
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
327
+
328
+ elif mode in ["valid", "pred"]:
329
+ for metric_name, value in metrics.items():
330
+ summary_writer.scalar(f"{mode}_{metric_name}", value, step)
331
+
332
+
333
+ def create_learning_rate_fn(
334
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
335
+ ) -> Callable[[int], jnp.array]:
336
+ """Returns a linear warmup, linear_decay learning rate function."""
337
+ steps_per_epoch = train_ds_size // train_batch_size
338
+ num_train_steps = steps_per_epoch * num_train_epochs
339
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
340
+ decay_fn = optax.linear_schedule(
341
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
342
+ )
343
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
344
+ return schedule_fn
345
+
346
+
347
+ def main():
348
+ # See all possible arguments in src/transformers/training_args.py
349
+ # or by passing the --help flag to this script.
350
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
351
+
352
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
353
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
354
+ # If we pass only one argument to the script and it's the path to a json file,
355
+ # let's parse it to get our arguments.
356
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
357
+ else:
358
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
359
+
360
+ if (
361
+ os.path.exists(training_args.output_dir)
362
+ and os.listdir(training_args.output_dir)
363
+ and training_args.do_train
364
+ and not training_args.overwrite_output_dir
365
+ ):
366
+ raise ValueError(
367
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
368
+ "Use --overwrite_output_dir to overcome."
369
+ )
370
+
371
+ # Make one log on every process with the configuration for debugging.
372
+ logging.basicConfig(
373
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
374
+ datefmt="%m/%d/%Y %H:%M:%S",
375
+ level=logging.INFO,
376
+ )
377
+ # Setup logging, we only want one process per machine to log things on the screen.
378
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
379
+ if jax.process_index() == 0:
380
+ datasets.utils.logging.set_verbosity_warning()
381
+ transformers.utils.logging.set_verbosity_info()
382
+ else:
383
+ datasets.utils.logging.set_verbosity_error()
384
+ transformers.utils.logging.set_verbosity_error()
385
+
386
+ # Set the verbosity to info of the Transformers logger (on main process only):
387
+ logger.info(f"Training/evaluation parameters {training_args}")
388
+
389
+ # Handle the repository creation
390
+ if training_args.push_to_hub:
391
+ if training_args.hub_model_id is None:
392
+ repo_name = get_full_repo_name(
393
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
394
+ )
395
+ else:
396
+ repo_name = training_args.hub_model_id
397
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
398
+
399
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
400
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
401
+ # (the dataset will be downloaded automatically from the datasets Hub).
402
+ #
403
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
404
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
405
+ #
406
+ if data_args.dataset_name is not None:
407
+ # Downloading and loading a dataset from the hub.
408
+ dataset = load_dataset(
409
+ data_args.dataset_name, data_args.dataset_config_name, keep_in_memory=False, data_dir=data_args.data_dir,
410
+ cache_dir="./dataset_cache/"
411
+ )
412
+ else:
413
+ data_files = {}
414
+ if data_args.train_file is not None:
415
+ data_files["train"] = data_args.train_file
416
+ extension = data_args.train_file.split(".")[-1]
417
+ if data_args.validation_file is not None:
418
+ data_files["validation"] = data_args.validation_file
419
+ extension = data_args.validation_file.split(".")[-1]
420
+ if data_args.test_file is not None:
421
+ data_files["test"] = data_args.test_file
422
+ extension = data_args.test_file.split(".")[-1]
423
+ # TODO: Check
424
+ dataset = load_dataset(extension, data_files=data_files, cache_dir="./dataset_cache/", data_dir=data_args.data_dir, )
425
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
426
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
427
+
428
+ # Load pretrained model and tokenizer
429
+
430
+ encoder_cache_dir, decoder_cache_dir = None, None
431
+ if model_args.cache_dir:
432
+ encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
433
+ decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
434
+
435
+ # Use explicit specified config
436
+ if model_args.config_name:
437
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
438
+ # Use pretrained model's config
439
+ elif model_args.model_name_or_path:
440
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
441
+ # Use specified `model_type` (default to `vision-encoder-decoder`)
442
+ else:
443
+
444
+ if not model_args.model_type in MODEL_TYPES:
445
+ raise ValueError(
446
+ f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
447
+ )
448
+ config_class = CONFIG_MAPPING[model_args.model_type]
449
+
450
+ # Deal with encoder-decoder models that require specifying encoder/decoder
451
+ if hasattr(config_class, "from_encoder_decoder_configs"):
452
+
453
+ # Use explicit specified encoder config
454
+ if model_args.encoder_config_name:
455
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
456
+ # Use pretrained encoder model's config
457
+ elif model_args.encoder_model_name_or_path:
458
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
459
+ # Use specified encoder model type
460
+ elif model_args.encoder_model_type:
461
+ encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
462
+ logger.warning("You are instantiating a new config instance from scratch for the encoder.")
463
+ else:
464
+ raise ValueError("Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required.")
465
+
466
+ # Use explicit specified decoder config
467
+ if model_args.decoder_config_name:
468
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
469
+ # Use pretrained decoder model's config
470
+ elif model_args.decoder_model_name_or_path:
471
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
472
+ # Use specified decoder model type
473
+ elif model_args.decoder_model_type:
474
+ decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
475
+ logger.warning("You are instantiating a new config instance from scratch for the decoder.")
476
+ else:
477
+ raise ValueError("Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required.")
478
+
479
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
480
+ decoder_config.is_decoder = True
481
+ decoder_config.add_cross_attention = True
482
+
483
+ config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
484
+ # For self-contained model
485
+ else:
486
+ config = config_class()
487
+ logger.warning("You are instantiating a new config instance from scratch.")
488
+
489
+ decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
490
+ if not decoder_start_token_id and getattr(config, "decoder", None):
491
+ decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
492
+ bos_token_id = getattr(config, "bos_token_id", None)
493
+ if not bos_token_id and getattr(config, "decoder", None):
494
+ bos_token_id = getattr(config.decoder, "bos_token_id", None)
495
+ eos_token_id = getattr(config, "eos_token_id", None)
496
+ if not eos_token_id and getattr(config, "decoder", None):
497
+ eos_token_id = getattr(config.decoder, "eos_token_id", None)
498
+ pad_token_id = getattr(config, "pad_token_id", None)
499
+ if not pad_token_id and getattr(config, "decoder", None):
500
+ pad_token_id = getattr(config.decoder, "pad_token_id", None)
501
+
502
+ if decoder_start_token_id is None:
503
+ decoder_start_token_id = bos_token_id
504
+ if pad_token_id is None:
505
+ pad_token_id = eos_token_id
506
+
507
+ if getattr(config, "decoder", None):
508
+ config.decoder.decoder_start_token_id = decoder_start_token_id
509
+ config.decoder.bos_token_id = bos_token_id
510
+ config.decoder.eos_token_id = eos_token_id
511
+ config.decoder.pad_token_id = pad_token_id
512
+
513
+ # Set `encoder-decoder` (top-level) specific config
514
+ config.decoder_start_token_id = decoder_start_token_id
515
+ config.bos_token_id = bos_token_id
516
+ config.eos_token_id = eos_token_id
517
+ config.pad_token_id = pad_token_id
518
+
519
+ if model_args.model_name_or_path:
520
+ model = FlaxAutoModelForVision2Seq.from_pretrained(
521
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
522
+ )
523
+ else:
524
+ # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
525
+ model = FlaxAutoModelForVision2Seq.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
526
+ model_class = model.__class__
527
+
528
+ # encoder_class = FlaxAutoModel
529
+ # decoder_class = FlaxAutoModelForCausalLM
530
+ module = model.module.bind(model.params)
531
+ encoder_class_name = type(module.encoder).__name__.replace("Module", "Model")
532
+ decoder_class_name = type(module.decoder).__name__.replace("Module", "Model")
533
+ encoder_class = getattr(transformers, encoder_class_name, None)
534
+ decoder_class = getattr(transformers, decoder_class_name, None)
535
+
536
+ if hasattr(model_class, "from_encoder_decoder_pretrained"):
537
+
538
+ if model_args.encoder_model_name_or_path:
539
+ encoder = encoder_class.from_pretrained(
540
+ model_args.encoder_model_name_or_path,
541
+ config=config.encoder,
542
+ seed=training_args.seed,
543
+ dtype=getattr(jnp, model_args.dtype)
544
+ )
545
+ else:
546
+ encoder = encoder_class(config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
547
+ logger.warning("You are instantiating a new model instance from scratch for the encoder.")
548
+
549
+ if model_args.decoder_model_name_or_path:
550
+ decoder = decoder_class.from_pretrained(
551
+ model_args.decoder_model_name_or_path,
552
+ config=config.decoder,
553
+ seed=training_args.seed,
554
+ dtype=getattr(jnp, model_args.dtype)
555
+ )
556
+ else:
557
+ decoder = decoder_class(config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
558
+ logger.warning("You are instantiating a new model instance from scratch for the decoder.")
559
+
560
+ model = model_class.from_encoder_decoder_pretrained(
561
+ model_args.encoder_model_name_or_path,
562
+ model_args.decoder_model_name_or_path,
563
+ encoder_model=encoder,
564
+ decoder_model=decoder,
565
+ encoder_config=config.encoder,
566
+ decoder_config=config.decoder,
567
+ encoder_seed=training_args.seed,
568
+ decoder_seed=training_args.seed,
569
+ encoder_dtype=getattr(jnp, model_args.dtype),
570
+ decoder_dtype=getattr(jnp, model_args.dtype),
571
+ )
572
+
573
+ # Set `encoder-decoder` (top-level) specific config
574
+ model.config.decoder_start_token_id = decoder_start_token_id
575
+ model.config.bos_token_id = bos_token_id
576
+ model.config.eos_token_id = eos_token_id
577
+ model.config.pad_token_id = pad_token_id
578
+
579
+ else:
580
+ logger.warning("You are instantiating a new model instance from scratch.")
581
+
582
+ feature_extractor = None
583
+ if model_args.feature_extractor_name:
584
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
585
+ model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
586
+ )
587
+ elif model_args.model_name_or_path:
588
+ try:
589
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
590
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
591
+ )
592
+ except ValueError as e:
593
+ logger.warning(e)
594
+ # Check encoder
595
+ if not feature_extractor:
596
+ if model_args.encoder_model_name_or_path:
597
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
598
+ model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
599
+ )
600
+ else:
601
+ raise ValueError(
602
+ "You are instantiating a new feature extractor from scratch. This is not supported by this script."
603
+ "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
604
+ )
605
+
606
+ def get_tokenizer():
607
+
608
+ tokenizer = None
609
+ if model_args.tokenizer_name:
610
+ tokenizer = AutoTokenizer.from_pretrained(
611
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
612
+ )
613
+ elif model_args.model_name_or_path:
614
+ try:
615
+ tokenizer = AutoTokenizer.from_pretrained(
616
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
617
+ )
618
+ except ValueError as e:
619
+ logger.warning(e)
620
+
621
+ # Check decoder
622
+ if not tokenizer:
623
+ if model_args.decoder_model_name_or_path:
624
+ tokenizer = AutoTokenizer.from_pretrained(
625
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
626
+ )
627
+ else:
628
+ raise ValueError(
629
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
630
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
631
+ )
632
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
633
+
634
+ return tokenizer
635
+
636
+ tokenizer = get_tokenizer()
637
+
638
+ # Preprocessing the datasets.
639
+ # We need to tokenize inputs and targets.
640
+ if training_args.do_train:
641
+ column_names = dataset["train"].column_names
642
+ elif training_args.do_eval:
643
+ column_names = dataset["validation"].column_names
644
+ elif training_args.do_predict:
645
+ column_names = dataset["test"].column_names
646
+ else:
647
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
648
+ return
649
+
650
+ # Get the column names for input/target.
651
+ dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
652
+ if data_args.image_column is None:
653
+ assert dataset_columns is not None
654
+ image_column = dataset_columns[0]
655
+ else:
656
+ image_column = data_args.image_column
657
+ if image_column not in column_names:
658
+ raise ValueError(
659
+ f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
660
+ )
661
+ if data_args.caption_column is None:
662
+ assert dataset_columns is not None
663
+ caption_column = dataset_columns[1]
664
+ else:
665
+ caption_column = data_args.caption_column
666
+ if caption_column not in column_names:
667
+ raise ValueError(
668
+ f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
669
+ )
670
+
671
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
672
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
673
+ # for that dynamically import the `shift_tokens_right` function from the model file
674
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
675
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
676
+
677
+ def filter_fn(examples):
678
+
679
+ bools = []
680
+ for image_file in examples[image_column]:
681
+ with Image.open(image_file) as image:
682
+ try:
683
+ feature_extractor(images=image, return_tensors="np")
684
+ bools.append(True)
685
+ except:
686
+ bools.append(False)
687
+
688
+ return bools
689
+
690
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
691
+ def tokenization_fn(examples, max_target_length):
692
+
693
+ captions = []
694
+ for caption in examples[caption_column]:
695
+ captions.append(caption.lower() + ' ' + tokenizer.eos_token)
696
+
697
+ targets = captions
698
+
699
+ model_inputs = {}
700
+
701
+ # Setup the tokenizer for targets
702
+ with tokenizer.as_target_tokenizer():
703
+ labels = tokenizer(
704
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
705
+ )
706
+
707
+ model_inputs["labels"] = labels["input_ids"]
708
+ decoder_input_ids = shift_tokens_right_fn(
709
+ labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
710
+ )
711
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
712
+
713
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
714
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
715
+
716
+ model_inputs[image_column] = examples[image_column]
717
+
718
+ return model_inputs
719
+
720
+ def feature_extraction_fn(examples):
721
+
722
+ pixel_values = []
723
+
724
+ for image_file in examples[image_column]:
725
+ with Image.open(image_file) as image:
726
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
727
+ pixel_values.append(encoder_inputs.pixel_values)
728
+
729
+ pixel_values = np.concatenate(pixel_values)
730
+
731
+ model_inputs = examples
732
+ model_inputs['pixel_values'] = pixel_values
733
+
734
+ return model_inputs
735
+
736
+ features = datasets.Features(
737
+ {
738
+ "pixel_values": datasets.Array3D(
739
+ shape=(
740
+ getattr(config.encoder, "num_channels", 3),
741
+ config.encoder.image_size,
742
+ config.encoder.image_size,
743
+ ),
744
+ dtype="float32",
745
+ ),
746
+ "labels": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
747
+ "decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
748
+ "decoder_attention_mask": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
749
+ }
750
+ )
751
+
752
+ if training_args.do_train:
753
+
754
+ if "train" not in dataset:
755
+ raise ValueError("--do_train requires a train dataset")
756
+ train_dataset = dataset["train"]
757
+ train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
758
+
759
+ # remove problematic examples
760
+ s = time.time()
761
+ train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
762
+ e = time.time()
763
+ print(f'filter time: {e-s}')
764
+ print(len(train_dataset))
765
+
766
+ rng = jax.random.PRNGKey(training_args.seed)
767
+ rng, input_rng = jax.random.split(rng)
768
+
769
+ s = time.time()
770
+ indices_jax = jax.random.permutation(input_rng, len(train_dataset))
771
+ e = time.time()
772
+ print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
773
+
774
+ s = time.time()
775
+ indices_np = np.random.permutation(len(train_dataset))
776
+ e = time.time()
777
+ print(f'get permutation indices for the whole dataset with np - time: {e-s}')
778
+
779
+ # indices = jnp.arange(len(ds))
780
+
781
+ block_size = 4096
782
+ for idx in range(4):
783
+
784
+ start_idx = block_size * idx
785
+ end_idx = block_size * (idx + 1)
786
+
787
+ s = time.time()
788
+ selected_indices_jax = indices_jax[start_idx:end_idx]
789
+ e = time.time()
790
+ print(f'get block indices with jax - time: {e-s}')
791
+ print(type(selected_indices_jax))
792
+
793
+ s = time.time()
794
+ selected_indices_np = indices_np[start_idx:end_idx]
795
+ e = time.time()
796
+ print(f'get block indices with np - time: {e-s}')
797
+ print(type(selected_indices_np))
798
+
799
+
800
+ s = time.time()
801
+ _ds = train_dataset.select(selected_indices_jax)
802
+ e = time.time()
803
+ print(f'select block with jax - time: {e-s}')
804
+
805
+ s = time.time()
806
+ _ds = train_dataset.select(selected_indices_np)
807
+ e = time.time()
808
+ print(f'select block with np - time: {e-s}')
809
+
810
+ s = time.time()
811
+ _selected_indices_np = np.array(selected_indices_jax)
812
+ e = time.time()
813
+ print(f'convert jax to np - time: {e-s}')
814
+
815
+
816
+ batch_size = 256
817
+
818
+ steps_per_epoch = len(_ds) // batch_size
819
+
820
+ s = time.time()
821
+ batch_idx_jax = jax.random.permutation(rng, len(_ds))
822
+ e = time.time()
823
+ print(f'get permutation indices for the block with jax - time: {e-s}')
824
+ # batch_idx = jnp.arange(len(dataset))
825
+
826
+ s = time.time()
827
+ batch_idx_np = np.random.permutation(len(_ds))
828
+ e = time.time()
829
+ print(f'get permutation indices for the block with np - time: {e-s}')
830
+
831
+ s = time.time()
832
+ batch_idx_jax = batch_idx_jax[: steps_per_epoch * batch_size] # Skip incomplete batch.
833
+ e = time.time()
834
+ print(f'skip incomplete batch with jax - time: {e-s}')
835
+
836
+ s = time.time()
837
+ batch_idx_np = batch_idx_np[: steps_per_epoch * batch_size] # Skip incomplete batch.
838
+ e = time.time()
839
+ print(f'skip incomplete batch with np - time: {e-s}')
840
+
841
+ s = time.time()
842
+ batch_idx_jax = batch_idx_jax.reshape((steps_per_epoch, batch_size))
843
+ e = time.time()
844
+ print(f'reshape block indices with jax - time: {e-s}')
845
+
846
+ s = time.time()
847
+ batch_idx_np = batch_idx_np.reshape((steps_per_epoch, batch_size))
848
+ e = time.time()
849
+ print(f'reshape block indices with np - time: {e-s}')
850
+
851
+ for idx in batch_idx_jax:
852
+
853
+ s = time.time()
854
+ batch = _ds[idx]
855
+ e = time.time()
856
+ print(f'get one batch with jax - time: {e-s}')
857
+
858
+ #s = time.time()
859
+ #batch = {k: jnp.array(v) for k, v in batch.items()}
860
+ #e = time.time()
861
+ #print(f'convert one batch to jnp time: {e-s}')
862
+
863
+ for idx in batch_idx_np:
864
+
865
+ s = time.time()
866
+ batch = _ds[idx]
867
+ e = time.time()
868
+ print(f'get one batch with np - time: {e-s}')
869
+
870
+
871
+ exit(0)
872
+
873
+
874
+ if training_args.do_predict:
875
+ if "test" not in dataset:
876
+ raise ValueError("--do_predict requires a test dataset")
877
+ predict_dataset = dataset["test"]
878
+ # remove problematic examples
879
+ predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
880
+ if data_args.max_predict_samples is not None:
881
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
882
+ predict_dataset = predict_dataset.map(
883
+ tokenization_fn,
884
+ batched=True,
885
+ num_proc=data_args.preprocessing_num_workers,
886
+ # kept image paths
887
+ remove_columns=[x for x in column_names if x != image_column],
888
+ load_from_cache_file=not data_args.overwrite_cache,
889
+ desc=f"Running tokenizer on prediction dataset",
890
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
891
+ )
892
+
893
+ tokenizer = get_tokenizer()
894
+
895
+ # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
896
+ # data loader separately (in a sequential order).
897
+ block_size = training_args.block_size
898
+
899
+ # Store some constant
900
+
901
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
902
+
903
+ if training_args.do_train:
904
+ steps_per_epoch = len(train_dataset) // train_batch_size
905
+ num_train_examples_per_epoch = steps_per_epoch * train_batch_size
906
+ num_epochs = int(training_args.num_train_epochs)
907
+ total_train_steps = steps_per_epoch * num_epochs
908
+ else:
909
+ num_train_examples_per_epoch = 0
910
+
911
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
912
+
913
+ if training_args.do_eval:
914
+ num_eval_examples = len(eval_dataset)
915
+ eval_steps = num_eval_examples // eval_batch_size + int(num_eval_examples % eval_batch_size > 0)
916
+
917
+ if training_args.do_predict:
918
+ num_test_examples = len(predict_dataset)
919
+ test_steps = num_test_examples // eval_batch_size + int(num_test_examples % eval_batch_size > 0)
920
+
921
+ def get_batch_iter(rng, ds, block_size, batch_size, shuffle=False, drop_last_batch=False, keep_in_memory=False, split=""):
922
+
923
+ if not block_size:
924
+ block_size = len(ds)
925
+
926
+ steps_per_split = block_size // batch_size
927
+ num_examples = len(ds)
928
+ steps = num_examples // batch_size + int(num_examples % batch_size > 0 and not drop_last_batch)
929
+ num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
930
+
931
+ if shuffle:
932
+ indices = jax.random.permutation(input_rng, len(ds))
933
+ else:
934
+ indices = jnp.arange(len(ds))
935
+
936
+ for idx in range(num_splits):
937
+
938
+ start_idx = block_size * idx
939
+ end_idx = block_size * (idx + 1)
940
+
941
+ selected_indices = indices[start_idx:end_idx]
942
+
943
+ _ds = ds.select(selected_indices)
944
+
945
+ names = {
946
+ "train": "train",
947
+ "valid": "validation",
948
+ "test": "prediction",
949
+ }
950
+
951
+ _ds =_ds.map(
952
+ feature_extraction_fn,
953
+ batched=True,
954
+ num_proc=data_args.preprocessing_num_workers,
955
+ remove_columns=[image_column],
956
+ load_from_cache_file=not data_args.overwrite_cache,
957
+ features=features,
958
+ keep_in_memory=keep_in_memory,
959
+ desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
960
+ )
961
+ _ds = _ds.with_format("numpy")
962
+
963
+ # No need to shuffle here
964
+ loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
965
+
966
+ for batch in loader:
967
+ yield batch
968
+
969
+ # Metric
970
+ metric = load_metric("rouge")
971
+
972
+ def postprocess_text(preds, labels):
973
+ preds = [pred.strip() for pred in preds]
974
+ labels = [label.strip() for label in labels]
975
+
976
+ # rougeLSum expects newline after each sentence
977
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
978
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
979
+
980
+ return preds, labels
981
+
982
+ def compute_metrics(preds, labels):
983
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
984
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
985
+
986
+ # Some simple post-processing
987
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
988
+
989
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
990
+ # Extract a few results from ROUGE
991
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
992
+
993
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
994
+ result["gen_len"] = np.mean(prediction_lens)
995
+ result = {k: round(v, 6) for k, v in result.items()}
996
+
997
+ return result, decoded_preds, decoded_labels
998
+
999
+ # Enable tensorboard only on the master node
1000
+ has_tensorboard = is_tensorboard_available()
1001
+ if has_tensorboard and jax.process_index() == 0:
1002
+ try:
1003
+ from flax.metrics.tensorboard import SummaryWriter
1004
+
1005
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1006
+ except ImportError as ie:
1007
+ has_tensorboard = False
1008
+ logger.warning(
1009
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1010
+ )
1011
+ else:
1012
+ logger.warning(
1013
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1014
+ "Please run pip install tensorboard to enable."
1015
+ )
1016
+
1017
+ # Initialize our training
1018
+ rng = jax.random.PRNGKey(training_args.seed)
1019
+ rng, dropout_rng = jax.random.split(rng)
1020
+
1021
+ # Create learning rate schedule
1022
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1023
+ num_train_examples_per_epoch,
1024
+ train_batch_size,
1025
+ training_args.num_train_epochs,
1026
+ training_args.warmup_steps,
1027
+ training_args.learning_rate,
1028
+ )
1029
+
1030
+ # We use Optax's "masking" functionality to not apply weight decay
1031
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1032
+ # mask boolean with the same structure as the parameters.
1033
+ # The mask is True for parameters that should be decayed.
1034
+ # Note that this mask is specifically adapted for FlaxBart.
1035
+ # For FlaxT5, one should correct the layer norm parameter naming
1036
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1037
+ def decay_mask_fn(params):
1038
+ flat_params = traverse_util.flatten_dict(params)
1039
+ layer_norm_params = [
1040
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1041
+ ]
1042
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1043
+ return traverse_util.unflatten_dict(flat_mask)
1044
+
1045
+ # create adam optimizer
1046
+ adamw = optax.adamw(
1047
+ learning_rate=linear_decay_lr_schedule_fn,
1048
+ b1=training_args.adam_beta1,
1049
+ b2=training_args.adam_beta2,
1050
+ eps=training_args.adam_epsilon,
1051
+ weight_decay=training_args.weight_decay,
1052
+ mask=decay_mask_fn,
1053
+ )
1054
+
1055
+ # Setup train state
1056
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
1057
+
1058
+ # label smoothed cross entropy
1059
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
1060
+ """
1061
+ The label smoothing implementation is adapted from Flax's official example:
1062
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
1063
+ """
1064
+ vocab_size = logits.shape[-1]
1065
+ confidence = 1.0 - label_smoothing_factor
1066
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
1067
+ normalizing_constant = -(
1068
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
1069
+ )
1070
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
1071
+
1072
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
1073
+ loss = loss - normalizing_constant
1074
+
1075
+ # ignore padded tokens from loss
1076
+ loss = loss * padding_mask
1077
+ loss = loss.sum() / padding_mask.sum()
1078
+ return loss
1079
+
1080
+ # Define gradient update step fn
1081
+ def train_step(state, batch, label_smoothing_factor=0.0):
1082
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1083
+
1084
+ def compute_loss(params):
1085
+ labels = batch.pop("labels")
1086
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
1087
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
1088
+ return loss
1089
+
1090
+ grad_fn = jax.value_and_grad(compute_loss)
1091
+ loss, grad = grad_fn(state.params)
1092
+ grad = jax.lax.pmean(grad, "batch")
1093
+
1094
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
1095
+
1096
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1097
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1098
+
1099
+ return new_state, metrics
1100
+
1101
+ # Define eval fn
1102
+ def eval_step(params, batch, label_smoothing_factor=0.0):
1103
+ labels = batch.pop("labels")
1104
+ logits = model(**batch, params=params, train=False)[0]
1105
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
1106
+
1107
+ # summarize metrics
1108
+ metrics = {"loss": loss}
1109
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1110
+ return metrics
1111
+
1112
+ # Define generation function
1113
+ max_length = (
1114
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
1115
+ )
1116
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
1117
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
1118
+
1119
+ def generate_step(params, batch):
1120
+ model.params = params
1121
+ output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
1122
+ return output_ids.sequences
1123
+
1124
+ # Create parallel version of the train and eval step
1125
+ p_train_step = jax.pmap(
1126
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
1127
+ )
1128
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
1129
+ p_generate_step = jax.pmap(generate_step, "batch")
1130
+
1131
+ # Replicate the train state on each device
1132
+ state = state.replicate()
1133
+
1134
+ if training_args.do_train:
1135
+ logger.info("***** Running training *****")
1136
+ logger.info(f" Num train examples = {len(train_dataset)}")
1137
+ logger.info(f" Num train examples per epoch = {num_train_examples_per_epoch}")
1138
+ logger.info(f" Num Epochs = {num_epochs}")
1139
+ logger.info(f" Instantaneous train batch size per device = {training_args.per_device_train_batch_size}")
1140
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
1141
+ logger.info(f" Optimization steps per epoch = {steps_per_epoch}")
1142
+ logger.info(f" Total optimization steps = {total_train_steps}")
1143
+ if training_args.do_eval:
1144
+ logger.info(f" Num evaluation examples = {num_eval_examples}")
1145
+ logger.info(f" Instantaneous evaluation batch size per device = {training_args.per_device_eval_batch_size}")
1146
+ logger.info(f" Total evaluation batch size (w. parallel & distributed) = {eval_batch_size}")
1147
+ logger.info(f" Evaluation steps = {eval_steps}")
1148
+ if training_args.do_predict:
1149
+ logger.info(f" Num test examples = {num_test_examples}")
1150
+ logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
1151
+ logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
1152
+ logger.info(f" Total train batch size (w. parallel & distributed) = {eval_batch_size}")
1153
+ logger.info(f" Test steps = {test_steps}")
1154
+
1155
+ # create output directory
1156
+ if not os.path.isdir(os.path.join(training_args.output_dir)):
1157
+ os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1158
+
1159
+ def save_results(epoch, step):
1160
+
1161
+ # save checkpoint after each epoch and push checkpoint to the hub
1162
+ if jax.process_index() == 0:
1163
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1164
+ dir_name = f'ckpt_epoch_{epoch + 1}_step_{step}'
1165
+ model.save_pretrained(os.path.join(training_args.output_dir, dir_name), params=params)
1166
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, dir_name))
1167
+ if training_args.push_to_hub:
1168
+ commit_msg = f"Saving weights and logs of epoch {epoch + 1}- step {step}"
1169
+ repo.push_to_hub(commit_message=commit_msg, blocking=False)
1170
+
1171
+ def run_eval_or_test(rng, dataset, name, is_inside_training=True):
1172
+
1173
+ if name not in ["valid", "test"]:
1174
+ raise ValueError(f"`name` must be either \"valid\" or \"test\". Got {name} instead.")
1175
+
1176
+ logger.info(f"*** {'Predict' if name == 'test' else 'Evaluate'} ***")
1177
+
1178
+ metrics = []
1179
+ preds = []
1180
+ labels = []
1181
+
1182
+ batches = get_batch_iter(rng, dataset, block_size=block_size, batch_size=eval_batch_size, keep_in_memory=False, shuffle=False, split=name)
1183
+ steps = len(dataset) // eval_batch_size + int(len(dataset) % eval_batch_size > 0)
1184
+ for _ in tqdm(range(steps), desc=f"{'Predicting' if name == 'test' else 'Evaluating'}...", position=2, leave=False):
1185
+ # Model forward
1186
+ batch = next(batches)
1187
+ _labels = batch.get("labels", None)
1188
+ if name == "valid" and _labels is None:
1189
+ raise ValueError("Validation dataset requires `labels`")
1190
+
1191
+ if _labels is not None:
1192
+ _metrics = p_eval_step(state.params, batch)
1193
+ metrics.append(_metrics)
1194
+
1195
+ # generation
1196
+ if data_args.predict_with_generate:
1197
+ generated_ids = p_generate_step(state.params, batch)
1198
+ preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
1199
+ if _labels is not None:
1200
+ labels.extend(jax.device_get(_labels.reshape(-1, _labels.shape[-1])))
1201
+
1202
+ if metrics:
1203
+ # normalize metrics
1204
+ metrics = get_metrics(metrics)
1205
+ metrics = jax.tree_map(jnp.mean, metrics)
1206
+
1207
+ # compute ROUGE metrics
1208
+ generations = []
1209
+ rouge_desc = ""
1210
+ if data_args.predict_with_generate:
1211
+ if labels:
1212
+ rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1213
+ metrics.update(rouge_metrics)
1214
+ rouge_desc = " ".join([f"{'Predict' if name == 'test' else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1215
+ for pred, label in zip(decoded_preds, decoded_labels):
1216
+ pred = pred.replace("\n", " ")
1217
+ label = label.replace("\n", " ")
1218
+ generations.append({"label": label, "pred": pred})
1219
+ else:
1220
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
1221
+ # Some simple post-processing
1222
+ decoded_preds = [pred.strip() for pred in decoded_preds]
1223
+ # rougeLSum expects newline after each sentence
1224
+ decoded_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds]
1225
+ for pred in decoded_preds:
1226
+ pred = pred.replace("\n", " ")
1227
+ generations.append({"pred": pred})
1228
+
1229
+ if metrics:
1230
+ # Print metrics and update progress bar
1231
+ desc = f"{'Predict' if name == 'test' else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1232
+ if is_inside_training:
1233
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1234
+ epochs.write(desc)
1235
+ epochs.desc = desc
1236
+ logger.info(desc)
1237
+
1238
+ if jax.process_index() == 0:
1239
+
1240
+ ckpt_dir = ""
1241
+ if is_inside_training:
1242
+ ckpt_dir = f'ckpt_epoch_{epoch + 1}_step_{cur_step}'
1243
+ if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1244
+ os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
1245
+
1246
+ if metrics:
1247
+
1248
+ # save final metrics in json
1249
+ metrics = {f"{name}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
1250
+ path = os.path.join(training_args.output_dir, ckpt_dir, f"{name}_results.json")
1251
+ with open(path, "w") as f:
1252
+ json.dump(metrics, f, indent=4, sort_keys=True)
1253
+
1254
+ # Update report
1255
+ with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1256
+ fp.write(desc + '\n')
1257
+
1258
+ # Save metrics
1259
+ if has_tensorboard and is_inside_training:
1260
+ write_metric(summary_writer, name, metrics, cur_step)
1261
+
1262
+ # Save generations
1263
+ if generations:
1264
+ with open(os.path.join(training_args.output_dir, ckpt_dir, f'generation_{name}.json'), 'w', encoding='UTF-8') as fp:
1265
+ json.dump(generations, fp, ensure_ascii=False, indent=4)
1266
+
1267
+ input_rng = None
1268
+
1269
+ if training_args.do_train:
1270
+
1271
+ cur_step = 0
1272
+ train_time = 0
1273
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1274
+
1275
+ for epoch in epochs:
1276
+
1277
+ # ======================== Training ================================
1278
+
1279
+ # Create sampling rng
1280
+ rng, input_rng = jax.random.split(rng)
1281
+
1282
+ train_metrics = []
1283
+
1284
+ train_batches = get_batch_iter(input_rng, train_dataset, block_size=block_size, batch_size=train_batch_size, keep_in_memory=True, shuffle=True, drop_last_batch=training_args.dataloader_drop_last, split="train")
1285
+
1286
+ # train
1287
+ for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
1288
+
1289
+ cur_step += 1
1290
+ batch = next(train_batches)
1291
+ batch_start = time.time()
1292
+ state, train_metric = p_train_step(state, batch)
1293
+ train_metrics.append(train_metric)
1294
+ train_time += time.time() - batch_start
1295
+
1296
+ if cur_step % training_args.logging_steps == 0 or (training_args.eval_steps is not None and cur_step % training_args.eval_steps == 0) or cur_step % steps_per_epoch == 0:
1297
+
1298
+ time_per_step = train_time / cur_step
1299
+
1300
+ _train_metric = unreplicate(train_metric)
1301
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
1302
+ epochs.desc = desc
1303
+ epochs.write(desc)
1304
+ logger.info(desc)
1305
+ with open(os.path.join(training_args.output_dir, 'report.txt'), 'a', encoding='UTF-8') as fp:
1306
+ fp.write(desc + '\n')
1307
+
1308
+ # Save metrics
1309
+ if has_tensorboard and jax.process_index() == 0:
1310
+ write_metric(summary_writer, "train", train_metrics, cur_step, train_time=train_time)
1311
+
1312
+ # ======================== Evaluating ==============================
1313
+
1314
+ if training_args.do_eval and ((training_args.eval_steps is not None and cur_step % training_args.eval_steps) or cur_step % steps_per_epoch == 0):
1315
+ run_eval_or_test(input_rng, eval_dataset, name="valid", is_inside_training=True)
1316
+
1317
+ # ======================== Prediction loop ==============================
1318
+
1319
+ # run prediction after evaluation if specified, otherwise only after each epoch
1320
+ if training_args.do_predict and training_args.do_predict_during_training and training_args.do_predict_after_evaluation:
1321
+ run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1322
+
1323
+ # ======================== Save ==============================
1324
+
1325
+ save_results(epoch + 1, cur_step)
1326
+
1327
+ # run prediction after each epoch (if not done during training)
1328
+ if training_args.do_predict and training_args.do_predict_during_training and not training_args.do_predict_after_evaluation:
1329
+ run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=True)
1330
+ save_results(epoch + 1, cur_step)
1331
+
1332
+ # Create sampling rng
1333
+ if input_rng is None:
1334
+ rng, input_rng = jax.random.split(rng)
1335
+
1336
+ # run prediction after each epoch (if not done during training)
1337
+ if training_args.do_predict and not (training_args.do_train and training_args.do_predict_during_training):
1338
+ run_eval_or_test(input_rng, predict_dataset, name='test', is_inside_training=False)
1339
+
1340
+
1341
+ if __name__ == "__main__":
1342
+
1343
+ main()