ydshieh commited on
Commit
16517d8
1 Parent(s): 91d8939

add copy to be reduced

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax_reduced.py +1421 -0
run_image_captioning_flax_reduced.py ADDED
@@ -0,0 +1,1421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library vision-encoder-decoder models for image captioning.
18
+ """
19
+
20
+ import json
21
+ import logging
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import asdict, dataclass, field
26
+ from enum import Enum
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Optional
30
+
31
+ import datasets
32
+ import nltk # Here to have a nice missing dependency error message early on
33
+ import numpy as np
34
+ from datasets import Dataset, load_dataset, load_metric
35
+ from tqdm import tqdm
36
+ from PIL import Image
37
+
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from filelock import FileLock
43
+ from flax import jax_utils, traverse_util
44
+ from flax.jax_utils import unreplicate
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from huggingface_hub import Repository
48
+ from transformers import (
49
+ CONFIG_MAPPING,
50
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
51
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoFeatureExtractor,
54
+ AutoTokenizer,
55
+ HfArgumentParser,
56
+ is_tensorboard_available,
57
+ FlaxAutoModelForVision2Seq,
58
+ )
59
+ from transformers.file_utils import get_full_repo_name, is_offline_mode
60
+
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+ try:
65
+ nltk.data.find("tokenizers/punkt")
66
+ except (LookupError, OSError):
67
+ if is_offline_mode():
68
+ raise LookupError(
69
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
70
+ )
71
+ with FileLock(".lock") as lock:
72
+ nltk.download("punkt", quiet=True)
73
+
74
+
75
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
76
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
+ DECODER_MODEL_TYPES = tuple(conf.model_type for conf in list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()))
78
+
79
+
80
+ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
81
+ def shift_tokens_right(input_ids: np.ndarray, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
82
+ """
83
+ Shift input ids one token to the right.
84
+ """
85
+ shifted_input_ids = np.zeros_like(input_ids)
86
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
87
+ shifted_input_ids[:, 0] = decoder_start_token_id
88
+
89
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
90
+ return shifted_input_ids
91
+
92
+
93
+ @dataclass
94
+ class TrainingArguments:
95
+ output_dir: str = field(
96
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
97
+ )
98
+ overwrite_output_dir: bool = field(
99
+ default=False,
100
+ metadata={
101
+ "help": (
102
+ "Overwrite the content of the output directory. "
103
+ "Use this to continue training if output_dir points to a checkpoint directory."
104
+ )
105
+ },
106
+ )
107
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
108
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
109
+ do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
110
+ per_device_train_batch_size: int = field(
111
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
112
+ )
113
+ per_device_eval_batch_size: int = field(
114
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
115
+ )
116
+ _block_size_doc = \
117
+ """
118
+ The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
119
+ cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
120
+ good option if your disk space is large enough to store the whole processed dataset.
121
+ If a positive value is given, the captions in the dataset will be tokenized before training and the results are
122
+ cached. During training, it iterates the dataset in chunks of size `block_size`. On each block, images are
123
+ transformed by the feature extractor with the results being kept in memory (no cache), and batches of size
124
+ `batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
125
+ dataset is large.
126
+ """
127
+ block_size: int = field(
128
+ default=0,
129
+ metadata={"help": _block_size_doc}
130
+ )
131
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
132
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
133
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
134
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
135
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
136
+ label_smoothing_factor: float = field(
137
+ default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
138
+ )
139
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
140
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
141
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
142
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
143
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
144
+ push_to_hub: bool = field(
145
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
146
+ )
147
+ hub_model_id: str = field(
148
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
149
+ )
150
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
151
+
152
+ def __post_init__(self):
153
+ if self.output_dir is not None:
154
+ self.output_dir = os.path.expanduser(self.output_dir)
155
+
156
+ def to_dict(self):
157
+ """
158
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
159
+ the token values by removing their value.
160
+ """
161
+ d = asdict(self)
162
+ for k, v in d.items():
163
+ if isinstance(v, Enum):
164
+ d[k] = v.value
165
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
166
+ d[k] = [x.value for x in v]
167
+ if k.endswith("_token"):
168
+ d[k] = f"<{k.upper()}>"
169
+ return d
170
+
171
+
172
+ @dataclass
173
+ class ModelArguments:
174
+ """
175
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
176
+ """
177
+
178
+ model_name_or_path: Optional[str] = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "The model checkpoint for weights initialization."
182
+ "Don't set if you want to train a model from scratch."
183
+ },
184
+ )
185
+ encoder_model_name_or_path: Optional[str] = field(
186
+ default=None,
187
+ metadata={
188
+ "help": "The encoder model checkpoint for weights initialization."
189
+ "Don't set if you want to train an encoder model from scratch."
190
+ },
191
+ )
192
+ decoder_model_name_or_path: Optional[str] = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "The decoder model checkpoint for weights initialization."
196
+ "Don't set if you want to train a decoder model from scratch."
197
+ },
198
+ )
199
+ model_type: Optional[str] = field(
200
+ default='vision-encoder-decoder',
201
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}
202
+ )
203
+ encoder_model_type: Optional[str] = field(
204
+ default=None,
205
+ metadata={"help": "If training from scratch, pass a vision encoder model type from the library. For example, 'vit'"}
206
+ )
207
+ decoder_model_type: Optional[str] = field(
208
+ default=None,
209
+ metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(DECODER_MODEL_TYPES)}
210
+ )
211
+ config_name: Optional[str] = field(
212
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
213
+ )
214
+ encoder_config_name: Optional[str] = field(
215
+ default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
216
+ )
217
+ decoder_config_name: Optional[str] = field(
218
+ default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
219
+ )
220
+ feature_extractor_name: Optional[str] = field(
221
+ default=None, metadata={"help": "Pretrained encoder feature extractor_name or path if not the same as encoder_model_name"}
222
+ )
223
+ tokenizer_name: Optional[str] = field(
224
+ default=None, metadata={"help": "Pretrained decoder tokenizer name or path if not the same as decoder_model_name"}
225
+ )
226
+ cache_dir: Optional[str] = field(
227
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
228
+ )
229
+ use_fast_tokenizer: bool = field(
230
+ default=True,
231
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
232
+ )
233
+ dtype: Optional[str] = field(
234
+ default="float32",
235
+ metadata={
236
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
237
+ },
238
+ )
239
+
240
+
241
+ @dataclass
242
+ class DataTrainingArguments:
243
+ """
244
+ Arguments pertaining to what data we are going to input our model for training and eval.
245
+ """
246
+
247
+ dataset_name: Optional[str] = field(
248
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
249
+ )
250
+ dataset_config_name: Optional[str] = field(
251
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
252
+ )
253
+ data_dir: Optional[str] = field(
254
+ default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
255
+ )
256
+ image_column: Optional[str] = field(
257
+ default=None,
258
+ metadata={"help": "The name of the column in the datasets containing the full image file paths."},
259
+ )
260
+ caption_column: Optional[str] = field(
261
+ default=None,
262
+ metadata={"help": "The name of the column in the datasets containing the image captions."},
263
+ )
264
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
265
+ validation_file: Optional[str] = field(
266
+ default=None,
267
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
268
+ )
269
+ test_file: Optional[str] = field(
270
+ default=None,
271
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
272
+ )
273
+ max_target_length: Optional[int] = field(
274
+ default=128,
275
+ metadata={
276
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
277
+ "than this will be truncated, sequences shorter will be padded."
278
+ },
279
+ )
280
+ val_max_target_length: Optional[int] = field(
281
+ default=None,
282
+ metadata={
283
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
284
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
285
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
286
+ "during evaluation."
287
+ },
288
+ )
289
+ max_train_samples: Optional[int] = field(
290
+ default=None,
291
+ metadata={
292
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
293
+ "value if set."
294
+ },
295
+ )
296
+ max_eval_samples: Optional[int] = field(
297
+ default=None,
298
+ metadata={
299
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
300
+ "value if set."
301
+ },
302
+ )
303
+ max_predict_samples: Optional[int] = field(
304
+ default=None,
305
+ metadata={
306
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
307
+ "value if set."
308
+ },
309
+ )
310
+ preprocessing_num_workers: Optional[int] = field(
311
+ default=None,
312
+ metadata={"help": "The number of processes to use for the preprocessing."},
313
+ )
314
+ predict_with_generate: bool = field(
315
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
316
+ )
317
+ num_beams: Optional[int] = field(
318
+ default=None,
319
+ metadata={
320
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
321
+ "which is used during evaluation."
322
+ },
323
+ )
324
+ overwrite_cache: bool = field(
325
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
326
+ )
327
+
328
+ def __post_init__(self):
329
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
330
+ raise ValueError("Need either a dataset name or a training/validation file.")
331
+ else:
332
+ if self.train_file is not None:
333
+ extension = self.train_file.split(".")[-1]
334
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
335
+ if self.validation_file is not None:
336
+ extension = self.validation_file.split(".")[-1]
337
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
338
+ if self.val_max_target_length is None:
339
+ self.val_max_target_length = self.max_target_length
340
+
341
+
342
+ image_captioning_name_mapping = {
343
+ "image_caption_dataset.py": ("image_path", "caption"),
344
+ }
345
+
346
+
347
+ class TrainState(train_state.TrainState):
348
+ dropout_rng: jnp.ndarray
349
+
350
+ def replicate(self):
351
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
352
+
353
+
354
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
355
+ """
356
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
357
+ Shuffle batches if `shuffle` is `True`.
358
+ """
359
+ steps = len(dataset) // batch_size # Skip incomplete batch.
360
+
361
+ if shuffle:
362
+ batch_idx = jax.random.permutation(rng, len(dataset))
363
+ batch_idx = np.asarray(batch_idx)
364
+ else:
365
+ batch_idx = np.arange(len(dataset))
366
+
367
+ for idx in range(steps):
368
+
369
+ start_idx = batch_size * idx
370
+ end_idx = batch_size * (idx + 1)
371
+
372
+ selected_indices = batch_idx[start_idx:end_idx]
373
+ batch = dataset[selected_indices]
374
+ batch = shard(batch)
375
+
376
+ yield batch
377
+
378
+
379
+ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):
380
+
381
+ if train_time:
382
+ summary_writer.scalar("train_time", train_time, step)
383
+
384
+ metrics = get_metrics(metrics)
385
+ for key, vals in metrics.items():
386
+ tag = f"{metric_key_prefix}_{key}"
387
+ for i, val in enumerate(vals):
388
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
389
+
390
+ else:
391
+ for metric_name, value in metrics.items():
392
+ summary_writer.scalar(f"{metric_key_prefix}_{metric_name}", value, step)
393
+
394
+
395
+ def create_learning_rate_fn(
396
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
397
+ ) -> Callable[[int], jnp.array]:
398
+ """Returns a linear warmup, linear_decay learning rate function."""
399
+ steps_per_epoch = train_ds_size // train_batch_size
400
+ num_train_steps = steps_per_epoch * num_train_epochs
401
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
402
+ decay_fn = optax.linear_schedule(
403
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
404
+ )
405
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
406
+ return schedule_fn
407
+
408
+
409
+ def main():
410
+ # See all possible arguments in src/transformers/training_args.py
411
+ # or by passing the --help flag to this script.
412
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
413
+
414
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
415
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
416
+ # If we pass only one argument to the script and it's the path to a json file,
417
+ # let's parse it to get our arguments.
418
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
419
+ else:
420
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
421
+
422
+ if (
423
+ os.path.exists(training_args.output_dir)
424
+ and os.listdir(training_args.output_dir)
425
+ and training_args.do_train
426
+ and not training_args.overwrite_output_dir
427
+ ):
428
+ raise ValueError(
429
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
430
+ "Use --overwrite_output_dir to overcome."
431
+ )
432
+
433
+ # Make one log on every process with the configuration for debugging.
434
+ logging.basicConfig(
435
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
436
+ datefmt="%m/%d/%Y %H:%M:%S",
437
+ level=logging.INFO,
438
+ )
439
+ # Setup logging, we only want one process per machine to log things on the screen.
440
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
441
+ if jax.process_index() == 0:
442
+ datasets.utils.logging.set_verbosity_warning()
443
+ transformers.utils.logging.set_verbosity_info()
444
+ else:
445
+ datasets.utils.logging.set_verbosity_error()
446
+ transformers.utils.logging.set_verbosity_error()
447
+
448
+ # Set the verbosity to info of the Transformers logger (on main process only):
449
+ logger.info(f"Training/evaluation parameters {training_args}")
450
+
451
+ # Handle the repository creation
452
+ if training_args.push_to_hub:
453
+ if training_args.hub_model_id is None:
454
+ repo_name = get_full_repo_name(
455
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
456
+ )
457
+ else:
458
+ repo_name = training_args.hub_model_id
459
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
460
+
461
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
462
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
463
+ # (the dataset will be downloaded automatically from the datasets Hub).
464
+ #
465
+ # For CSV/JSON files this script will use the first column for the full image path and the second column for the
466
+ # captions (unless you specify column names for this with the `image_column` and `caption_column` arguments).
467
+ #
468
+ if data_args.dataset_name is not None:
469
+ # Downloading and loading a dataset from the hub.
470
+ dataset = load_dataset(
471
+ data_args.dataset_name,
472
+ data_args.dataset_config_name,
473
+ cache_dir=model_args.cache_dir,
474
+ keep_in_memory=False,
475
+ data_dir=data_args.data_dir,
476
+ )
477
+ else:
478
+ data_files = {}
479
+ if data_args.train_file is not None:
480
+ data_files["train"] = data_args.train_file
481
+ extension = data_args.train_file.split(".")[-1]
482
+ if data_args.validation_file is not None:
483
+ data_files["validation"] = data_args.validation_file
484
+ extension = data_args.validation_file.split(".")[-1]
485
+ if data_args.test_file is not None:
486
+ data_files["test"] = data_args.test_file
487
+ extension = data_args.test_file.split(".")[-1]
488
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
489
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
490
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
491
+
492
+ # Load pretrained model and tokenizer
493
+
494
+ encoder_cache_dir, decoder_cache_dir = None, None
495
+ if model_args.cache_dir:
496
+ encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
497
+ decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
498
+
499
+ # Use explicit specified config
500
+ if model_args.config_name:
501
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
502
+ # Use pretrained model's config
503
+ elif model_args.model_name_or_path:
504
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
505
+ # Use specified `model_type` (default to `vision-encoder-decoder`)
506
+ else:
507
+
508
+ if not model_args.model_type in MODEL_TYPES:
509
+ raise ValueError(
510
+ f"Unrecognized model identifier: {model_args.model_type}. Should contain one of {', '.join(MODEL_TYPES)}."
511
+ )
512
+ config_class = CONFIG_MAPPING[model_args.model_type]
513
+
514
+ # Deal with encoder-decoder models that require specifying encoder/decoder
515
+ if hasattr(config_class, "from_encoder_decoder_configs"):
516
+
517
+ # Use explicit specified encoder config
518
+ if model_args.encoder_config_name:
519
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
520
+ # Use pretrained encoder model's config
521
+ elif model_args.encoder_model_name_or_path:
522
+ encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
523
+ # Use specified encoder model type
524
+ elif model_args.encoder_model_type:
525
+ encoder_config = AutoConfig.for_model(model_args.encoder_model_type)
526
+ logger.warning("You are instantiating a new config instance from scratch for the encoder.")
527
+ else:
528
+ raise ValueError("Encoder Config: if pretrained config or model location is not provided, `encoder_model_type` is required.")
529
+
530
+ # Use explicit specified decoder config
531
+ if model_args.decoder_config_name:
532
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
533
+ # Use pretrained decoder model's config
534
+ elif model_args.decoder_model_name_or_path:
535
+ decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
536
+ # Use specified decoder model type
537
+ elif model_args.decoder_model_type:
538
+ decoder_config = AutoConfig.for_model(model_args.decoder_model_type)
539
+ logger.warning("You are instantiating a new config instance from scratch for the decoder.")
540
+ else:
541
+ raise ValueError("Decoder Config: if pretrained config or model location is not provided, `decoder_model_type` is required.")
542
+
543
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
544
+ decoder_config.is_decoder = True
545
+ decoder_config.add_cross_attention = True
546
+
547
+ config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
548
+ # For self-contained model
549
+ else:
550
+ config = config_class()
551
+ logger.warning("You are instantiating a new config instance from scratch.")
552
+
553
+ decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
554
+ if not decoder_start_token_id and getattr(config, "decoder", None):
555
+ decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
556
+ bos_token_id = getattr(config, "bos_token_id", None)
557
+ if not bos_token_id and getattr(config, "decoder", None):
558
+ bos_token_id = getattr(config.decoder, "bos_token_id", None)
559
+ eos_token_id = getattr(config, "eos_token_id", None)
560
+ if not eos_token_id and getattr(config, "decoder", None):
561
+ eos_token_id = getattr(config.decoder, "eos_token_id", None)
562
+ pad_token_id = getattr(config, "pad_token_id", None)
563
+ if not pad_token_id and getattr(config, "decoder", None):
564
+ pad_token_id = getattr(config.decoder, "pad_token_id", None)
565
+
566
+ if decoder_start_token_id is None:
567
+ decoder_start_token_id = bos_token_id
568
+ if pad_token_id is None:
569
+ pad_token_id = eos_token_id
570
+
571
+ if getattr(config, "decoder", None):
572
+ config.decoder.decoder_start_token_id = decoder_start_token_id
573
+ config.decoder.bos_token_id = bos_token_id
574
+ config.decoder.eos_token_id = eos_token_id
575
+ config.decoder.pad_token_id = pad_token_id
576
+
577
+ # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
578
+ config.decoder_start_token_id = decoder_start_token_id
579
+ config.bos_token_id = bos_token_id
580
+ config.eos_token_id = eos_token_id
581
+ config.pad_token_id = pad_token_id
582
+
583
+ if model_args.model_name_or_path:
584
+ model = FlaxAutoModelForVision2Seq.from_pretrained(
585
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
586
+ )
587
+ else:
588
+ # model_class = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING[config.__class__]
589
+ model = FlaxAutoModelForVision2Seq.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
590
+ model_class = model.__class__
591
+
592
+ # encoder_class = FlaxAutoModel
593
+ # decoder_class = FlaxAutoModelForCausalLM
594
+ module = model.module.bind(model.params)
595
+ encoder_class_name = type(module.encoder).__name__.replace("Module", "Model")
596
+ decoder_class_name = type(module.decoder).__name__.replace("Module", "Model")
597
+ encoder_class = getattr(transformers, encoder_class_name, None)
598
+ decoder_class = getattr(transformers, decoder_class_name, None)
599
+
600
+ if hasattr(model_class, "from_encoder_decoder_pretrained"):
601
+
602
+ if model_args.encoder_model_name_or_path:
603
+ encoder = encoder_class.from_pretrained(
604
+ model_args.encoder_model_name_or_path,
605
+ config=config.encoder,
606
+ seed=training_args.seed,
607
+ dtype=getattr(jnp, model_args.dtype)
608
+ )
609
+ else:
610
+ encoder = encoder_class(config=config.encoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
611
+ logger.warning("You are instantiating a new model instance from scratch for the encoder.")
612
+
613
+ if model_args.decoder_model_name_or_path:
614
+ decoder = decoder_class.from_pretrained(
615
+ model_args.decoder_model_name_or_path,
616
+ config=config.decoder,
617
+ seed=training_args.seed,
618
+ dtype=getattr(jnp, model_args.dtype)
619
+ )
620
+ else:
621
+ decoder = decoder_class(config=config.decoder, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
622
+ logger.warning("You are instantiating a new model instance from scratch for the decoder.")
623
+
624
+ model = model_class.from_encoder_decoder_pretrained(
625
+ model_args.encoder_model_name_or_path,
626
+ model_args.decoder_model_name_or_path,
627
+ encoder_model=encoder,
628
+ decoder_model=decoder,
629
+ encoder_config=config.encoder,
630
+ decoder_config=config.decoder,
631
+ encoder_seed=training_args.seed,
632
+ decoder_seed=training_args.seed,
633
+ encoder_dtype=getattr(jnp, model_args.dtype),
634
+ decoder_dtype=getattr(jnp, model_args.dtype),
635
+ )
636
+
637
+ # Set `encoder-decoder` (top-level) specific config (not always necessary, but can avoid generate() error sometimes)
638
+ model.config.decoder_start_token_id = decoder_start_token_id
639
+ model.config.bos_token_id = bos_token_id
640
+ model.config.eos_token_id = eos_token_id
641
+ model.config.pad_token_id = pad_token_id
642
+
643
+ else:
644
+ logger.warning("You are instantiating a new model instance from scratch.")
645
+
646
+ feature_extractor = None
647
+ if model_args.feature_extractor_name:
648
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
649
+ model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
650
+ )
651
+ elif model_args.model_name_or_path:
652
+ try:
653
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
654
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
655
+ )
656
+ except ValueError as e:
657
+ logger.warning(e)
658
+ # Check encoder
659
+ if not feature_extractor:
660
+ if model_args.encoder_model_name_or_path:
661
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
662
+ model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
663
+ )
664
+ else:
665
+ raise ValueError(
666
+ "You are instantiating a new feature extractor from scratch. This is not supported by this script."
667
+ "You can do it from another script, save it, and load it from here, using --feature_extractor_name."
668
+ )
669
+
670
+ tokenizer = None
671
+ if model_args.tokenizer_name:
672
+ tokenizer = AutoTokenizer.from_pretrained(
673
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
674
+ )
675
+ elif model_args.model_name_or_path:
676
+ try:
677
+ tokenizer = AutoTokenizer.from_pretrained(
678
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
679
+ )
680
+ except ValueError as e:
681
+ logger.warning(e)
682
+
683
+ # Check decoder
684
+ if not tokenizer:
685
+ if model_args.decoder_model_name_or_path:
686
+ tokenizer = AutoTokenizer.from_pretrained(
687
+ model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
688
+ )
689
+ else:
690
+ raise ValueError(
691
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
692
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
693
+ )
694
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
695
+
696
+ # Preprocessing the datasets.
697
+ # We need to tokenize inputs and targets.
698
+ if training_args.do_train:
699
+ column_names = dataset["train"].column_names
700
+ elif training_args.do_eval:
701
+ column_names = dataset["validation"].column_names
702
+ elif training_args.do_predict:
703
+ column_names = dataset["test"].column_names
704
+ else:
705
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
706
+ return
707
+
708
+ # Get the column names for input/target.
709
+ dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
710
+ if data_args.image_column is None:
711
+ assert dataset_columns is not None
712
+ image_column = dataset_columns[0]
713
+ else:
714
+ image_column = data_args.image_column
715
+ if image_column not in column_names:
716
+ raise ValueError(
717
+ f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
718
+ )
719
+ if data_args.caption_column is None:
720
+ assert dataset_columns is not None
721
+ caption_column = dataset_columns[1]
722
+ else:
723
+ caption_column = data_args.caption_column
724
+ if caption_column not in column_names:
725
+ raise ValueError(
726
+ f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
727
+ )
728
+
729
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
730
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
731
+ # for that dynamically import the `shift_tokens_right` function from the model file
732
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
733
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
734
+
735
+ def filter_fn(examples):
736
+ """remove problematic images"""
737
+
738
+ bools = []
739
+ for image_file in examples[image_column]:
740
+ try:
741
+ image = Image.open(image_file)
742
+ encoder_inputs = feature_extractor(images=image, return_tensors="np")
743
+ bools.append(True)
744
+ except:
745
+ bools.append(False)
746
+
747
+ return bools
748
+
749
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
750
+ def tokenization_fn(examples, max_target_length):
751
+ """Run tokenization on captions."""
752
+
753
+ captions = []
754
+ for caption in examples[caption_column]:
755
+ captions.append(caption.lower() + ' ' + tokenizer.eos_token)
756
+
757
+ targets = captions
758
+
759
+ model_inputs = {}
760
+
761
+ # Setup the tokenizer for targets
762
+ with tokenizer.as_target_tokenizer():
763
+ labels = tokenizer(
764
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
765
+ )
766
+
767
+ model_inputs["labels"] = labels["input_ids"]
768
+ decoder_input_ids = shift_tokens_right_fn(
769
+ labels["input_ids"], config.pad_token_id, config.decoder_start_token_id
770
+ )
771
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
772
+
773
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
774
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
775
+
776
+ model_inputs[image_column] = examples[image_column]
777
+
778
+ return model_inputs
779
+
780
+ def feature_extraction_fn(examples, check_image=True):
781
+ """
782
+ Run feature extraction on images
783
+
784
+ If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
785
+ Otherwise, an exception will be thrown.
786
+ """
787
+
788
+ model_inputs = {}
789
+
790
+ if check_image:
791
+ images = []
792
+ to_keep = []
793
+ for image_file in examples[image_column]:
794
+ try:
795
+ img = Image.open(image_file)
796
+ images.append(img)
797
+ to_keep.append(True)
798
+ except:
799
+ to_keep.append(False)
800
+
801
+ for k, v in examples.items():
802
+ if k != image_column:
803
+ model_inputs[k] = v[to_keep]
804
+ else:
805
+ images = [Image.open(image_file) for image_file in examples[image_column]]
806
+
807
+ encoder_inputs = feature_extractor(images=images, return_tensors="np")
808
+ model_inputs["pixel_values"] = encoder_inputs.pixel_values
809
+
810
+ return model_inputs
811
+
812
+ def preprocess_fn(examples, max_target_length, check_image=True):
813
+ """Run tokenization + image feature extraction"""
814
+
815
+ model_inputs = {}
816
+ # This contains image path column
817
+ model_inputs.update(tokenization_fn(examples, max_target_length))
818
+ model_inputs.update(feature_extraction_fn(model_inputs, check_image=check_image))
819
+ # Remove image path column
820
+ model_inputs.pop(image_column)
821
+
822
+ return model_inputs
823
+
824
+ features = datasets.Features(
825
+ {
826
+ "pixel_values": datasets.Array3D(
827
+ shape=(
828
+ getattr(config.encoder, "num_channels", 3),
829
+ config.encoder.image_size,
830
+ config.encoder.image_size,
831
+ ),
832
+ dtype="float32",
833
+ ),
834
+ "labels": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
835
+ "decoder_input_ids": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
836
+ "decoder_attention_mask": datasets.Sequence(feature=datasets.Value(dtype='int32', id=None), length=-1, id=None),
837
+ }
838
+ )
839
+
840
+ # If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
841
+ run_feat_ext_at_beginning = training_args.block_size == 0
842
+ # Used in .map() below
843
+ function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
844
+ # `features` is used only for the final preprocessed dataset (for the performance purpose).
845
+ features_kwarg = features if run_feat_ext_at_beginning else None
846
+ # Keep `image_column` if the feature extraction is done during training
847
+ remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
848
+ processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
849
+
850
+ if training_args.do_train:
851
+ if "train" not in dataset:
852
+ raise ValueError("--do_train requires a train dataset")
853
+ train_dataset = dataset["train"]
854
+ if data_args.max_train_samples is not None:
855
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
856
+ # remove problematic examples
857
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
858
+ # instead here.)
859
+ if not run_feat_ext_at_beginning:
860
+ train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
861
+ train_dataset = train_dataset.map(
862
+ function=function_kwarg,
863
+ batched=True,
864
+ num_proc=data_args.preprocessing_num_workers,
865
+ # kept image paths
866
+ remove_columns=remove_columns_kwarg,
867
+ load_from_cache_file=not data_args.overwrite_cache,
868
+ desc=f"Running {processor_names} on train dataset",
869
+ fn_kwargs={"max_target_length": data_args.max_target_length},
870
+ features=features_kwarg,
871
+ )
872
+ if run_feat_ext_at_beginning:
873
+ # set format (for performance) since the dataset is ready to be used
874
+ train_dataset = train_dataset.with_format("numpy")
875
+
876
+ if training_args.do_eval:
877
+ if "validation" not in dataset:
878
+ raise ValueError("--do_eval requires a validation dataset")
879
+ eval_dataset = dataset["validation"]
880
+ if data_args.max_eval_samples is not None:
881
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
882
+ # remove problematic examples
883
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
884
+ # instead here.)
885
+ if not run_feat_ext_at_beginning:
886
+ eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
887
+ eval_dataset = eval_dataset.map(
888
+ function=function_kwarg,
889
+ batched=True,
890
+ num_proc=data_args.preprocessing_num_workers,
891
+ # kept image paths
892
+ remove_columns=remove_columns_kwarg,
893
+ load_from_cache_file=not data_args.overwrite_cache,
894
+ desc=f"Running {processor_names} on validation dataset",
895
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
896
+ features=features_kwarg,
897
+ )
898
+ if run_feat_ext_at_beginning:
899
+ # set format (for performance) since the dataset is ready to be used
900
+ eval_dataset = eval_dataset.with_format("numpy")
901
+
902
+ if training_args.do_predict:
903
+ if "test" not in dataset:
904
+ raise ValueError("--do_predict requires a test dataset")
905
+ predict_dataset = dataset["test"]
906
+ if data_args.max_predict_samples is not None:
907
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
908
+ # remove problematic examples
909
+ # (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
910
+ # instead here.)
911
+ if not run_feat_ext_at_beginning:
912
+ predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
913
+ predict_dataset = predict_dataset.map(
914
+ function=function_kwarg,
915
+ batched=True,
916
+ num_proc=data_args.preprocessing_num_workers,
917
+ # kept image paths
918
+ remove_columns=remove_columns_kwarg,
919
+ load_from_cache_file=not data_args.overwrite_cache,
920
+ desc=f"Running {processor_names} on prediction dataset",
921
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
922
+ features=features_kwarg,
923
+ )
924
+ if run_feat_ext_at_beginning:
925
+ # set format (for performance) since the dataset is ready to be used
926
+ predict_dataset = predict_dataset.with_format("numpy")
927
+
928
+ # Store some constant
929
+
930
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
931
+
932
+ if training_args.block_size % train_batch_size > 0:
933
+ raise ValueError(f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead.")
934
+
935
+ if training_args.do_train:
936
+ steps_per_epoch = len(train_dataset) // train_batch_size
937
+ num_train_examples_per_epoch = steps_per_epoch * train_batch_size
938
+ num_epochs = int(training_args.num_train_epochs)
939
+ total_train_steps = steps_per_epoch * num_epochs
940
+ else:
941
+ num_train_examples_per_epoch = 0
942
+
943
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
944
+
945
+ if training_args.do_eval:
946
+ num_eval_examples = len(eval_dataset)
947
+ eval_steps = num_eval_examples // eval_batch_size
948
+
949
+ if training_args.do_predict:
950
+ num_test_examples = len(predict_dataset)
951
+ test_steps = num_test_examples // eval_batch_size
952
+
953
+ def blockwise_data_loader(
954
+ rng: jax.random.PRNGKey,
955
+ ds: Dataset,
956
+ block_size: int,
957
+ batch_size: int,
958
+ shuffle: bool = False,
959
+ keep_in_memory: bool = False,
960
+ split: str = ""
961
+ ):
962
+ """
963
+ Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
964
+
965
+ If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image feature
966
+ extraction (with the column name being specified by `image_column`). The tokenization should be done before
967
+ training in this case.
968
+ """
969
+
970
+ if shuffle:
971
+ indices = jax.random.permutation(rng, len(ds))
972
+ indices = np.asarray(indices)
973
+ else:
974
+ indices = np.arange(len(ds))
975
+
976
+ _block_size = len(ds) if not block_size else block_size
977
+
978
+ steps_per_block = _block_size // batch_size
979
+ num_examples = len(ds)
980
+ steps = num_examples // batch_size
981
+ num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)
982
+
983
+ for idx in range(num_splits):
984
+
985
+ if not block_size:
986
+ _ds = ds
987
+ else:
988
+
989
+ start_idx = block_size * idx
990
+ end_idx = block_size * (idx + 1)
991
+
992
+ selected_indices = indices[start_idx:end_idx]
993
+
994
+ _ds = ds.select(selected_indices)
995
+
996
+ _ds = _ds.map(
997
+ feature_extraction_fn,
998
+ batched=True,
999
+ num_proc=data_args.preprocessing_num_workers,
1000
+ remove_columns=[image_column],
1001
+ load_from_cache_file=not data_args.overwrite_cache,
1002
+ features=features,
1003
+ keep_in_memory=keep_in_memory,
1004
+ # The images are already checked either in `.filter()` or in `preprocess_fn()`
1005
+ fn_kwargs={"check_image": False},
1006
+ desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
1007
+ )
1008
+ _ds = _ds.with_format("numpy")
1009
+
1010
+ # No need to shuffle here
1011
+ loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
1012
+
1013
+ for batch in loader:
1014
+ yield batch
1015
+
1016
+ # Metric
1017
+ metric = load_metric("rouge")
1018
+
1019
+ def postprocess_text(preds, labels):
1020
+ preds = [pred.strip() for pred in preds]
1021
+ labels = [label.strip() for label in labels]
1022
+
1023
+ # rougeLSum expects newline after each sentence
1024
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
1025
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
1026
+
1027
+ return preds, labels
1028
+
1029
+ def compute_metrics(preds, labels):
1030
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
1031
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
1032
+
1033
+ # Some simple post-processing
1034
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
1035
+
1036
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
1037
+ # Extract a few results from ROUGE
1038
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
1039
+
1040
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
1041
+ result["gen_len"] = np.mean(prediction_lens)
1042
+ result = {k: round(v, 6) for k, v in result.items()}
1043
+
1044
+ return result, decoded_preds, decoded_labels
1045
+
1046
+ # Enable tensorboard only on the master node
1047
+ has_tensorboard = is_tensorboard_available()
1048
+ if has_tensorboard and jax.process_index() == 0:
1049
+ try:
1050
+ from flax.metrics.tensorboard import SummaryWriter
1051
+
1052
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1053
+ except ImportError as ie:
1054
+ has_tensorboard = False
1055
+ logger.warning(
1056
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1057
+ )
1058
+ else:
1059
+ logger.warning(
1060
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1061
+ "Please run pip install tensorboard to enable."
1062
+ )
1063
+
1064
+ # Initialize our training
1065
+ rng = jax.random.PRNGKey(training_args.seed)
1066
+ rng, dropout_rng = jax.random.split(rng)
1067
+
1068
+ # Create learning rate schedule
1069
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1070
+ num_train_examples_per_epoch,
1071
+ train_batch_size,
1072
+ training_args.num_train_epochs,
1073
+ training_args.warmup_steps,
1074
+ training_args.learning_rate,
1075
+ )
1076
+
1077
+ # We use Optax's "masking" functionality to not apply weight decay
1078
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1079
+ # mask boolean with the same structure as the parameters.
1080
+ # The mask is True for parameters that should be decayed.
1081
+ # Note that this mask is specifically adapted for FlaxBart.
1082
+ # For FlaxT5, one should correct the layer norm parameter naming
1083
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1084
+ def decay_mask_fn(params):
1085
+ flat_params = traverse_util.flatten_dict(params)
1086
+ layer_norm_params = [
1087
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1088
+ ]
1089
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1090
+ return traverse_util.unflatten_dict(flat_mask)
1091
+
1092
+ # create adam optimizer
1093
+ adamw = optax.adamw(
1094
+ learning_rate=linear_decay_lr_schedule_fn,
1095
+ b1=training_args.adam_beta1,
1096
+ b2=training_args.adam_beta2,
1097
+ eps=training_args.adam_epsilon,
1098
+ weight_decay=training_args.weight_decay,
1099
+ mask=decay_mask_fn,
1100
+ )
1101
+
1102
+ # Setup train state
1103
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
1104
+
1105
+ # label smoothed cross entropy
1106
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
1107
+ """
1108
+ The label smoothing implementation is adapted from Flax's official example:
1109
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
1110
+ """
1111
+ vocab_size = logits.shape[-1]
1112
+ confidence = 1.0 - label_smoothing_factor
1113
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
1114
+ normalizing_constant = -(
1115
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
1116
+ )
1117
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
1118
+
1119
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
1120
+ loss = loss - normalizing_constant
1121
+
1122
+ # ignore padded tokens from loss
1123
+ loss = loss * padding_mask
1124
+ loss = loss.sum() / padding_mask.sum()
1125
+ return loss
1126
+
1127
+ # Define gradient update step fn
1128
+ def train_step(state, batch, label_smoothing_factor=0.0):
1129
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1130
+
1131
+ def compute_loss(params):
1132
+ labels = batch.pop("labels")
1133
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
1134
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
1135
+ return loss
1136
+
1137
+ grad_fn = jax.value_and_grad(compute_loss)
1138
+ loss, grad = grad_fn(state.params)
1139
+ grad = jax.lax.pmean(grad, "batch")
1140
+
1141
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
1142
+
1143
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1144
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1145
+
1146
+ return new_state, metrics
1147
+
1148
+ # Define eval fn
1149
+ def eval_step(params, batch, label_smoothing_factor=0.0):
1150
+ labels = batch.pop("labels")
1151
+ logits = model(**batch, params=params, train=False)[0]
1152
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
1153
+
1154
+ # summarize metrics
1155
+ metrics = {"loss": loss}
1156
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1157
+ return metrics
1158
+
1159
+ # Define generation function
1160
+ max_length = (
1161
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
1162
+ )
1163
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
1164
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
1165
+
1166
+ def generate_step(params, batch):
1167
+ model.params = params
1168
+ output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
1169
+ return output_ids.sequences
1170
+
1171
+ # Create parallel version of the train and eval step
1172
+ p_train_step = jax.pmap(
1173
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
1174
+ )
1175
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
1176
+ p_generate_step = jax.pmap(generate_step, "batch")
1177
+
1178
+ # Replicate the train state on each device
1179
+ state = state.replicate()
1180
+
1181
+ if training_args.do_train:
1182
+ logger.info("***** Running training *****")
1183
+ logger.info(f" Num train examples = {num_train_examples_per_epoch}")
1184
+ logger.info(f" Num Epochs = {num_epochs}")
1185
+ logger.info(f" Instantaneous train batch size per device = {training_args.per_device_train_batch_size}")
1186
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
1187
+ logger.info(f" Optimization steps per epoch = {steps_per_epoch}")
1188
+ logger.info(f" Total optimization steps = {total_train_steps}")
1189
+ if training_args.do_eval:
1190
+ logger.info(f" Num evaluation examples = {num_eval_examples}")
1191
+ logger.info(f" Instantaneous evaluation batch size per device = {training_args.per_device_eval_batch_size}")
1192
+ logger.info(f" Total evaluation batch size (w. parallel & distributed) = {eval_batch_size}")
1193
+ logger.info(f" Evaluation steps = {eval_steps}")
1194
+ if training_args.do_predict:
1195
+ logger.info(f" Num test examples = {num_test_examples}")
1196
+ logger.info(f" Instantaneous test batch size per device = {training_args.per_device_eval_batch_size}")
1197
+ logger.info(f" Total test batch size (w. parallel & distributed) = {eval_batch_size}")
1198
+ logger.info(f" Test steps = {test_steps}")
1199
+
1200
+ # create output directory
1201
+ if not os.path.isdir(os.path.join(training_args.output_dir)):
1202
+ os.makedirs(os.path.join(training_args.output_dir), exist_ok=True)
1203
+
1204
+ def save_ckpt(ckpt_dir: str, commit_msg: str = ""):
1205
+ """save checkpoints and push to Hugging Face Hub if specified"""
1206
+
1207
+ # save checkpoint after each epoch and push checkpoint to the hub
1208
+ if jax.process_index() == 0:
1209
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1210
+ model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
1211
+ tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
1212
+ if training_args.push_to_hub:
1213
+ repo.push_to_hub(commit_message=commit_msg, blocking=False)
1214
+
1215
+ def evaluation_loop(rng: jax.random.PRNGKey, dataset: Dataset, metric_key_prefix: str = "eval", ckpt_dir: str = "", is_prediction=False):
1216
+
1217
+ logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")
1218
+
1219
+ metrics = []
1220
+ preds = []
1221
+ labels = []
1222
+
1223
+ batches = blockwise_data_loader(
1224
+ rng,
1225
+ dataset,
1226
+ block_size=training_args.block_size,
1227
+ batch_size=eval_batch_size,
1228
+ keep_in_memory=False,
1229
+ shuffle=False,
1230
+ split="prediction" if is_prediction else "validation",
1231
+ )
1232
+ steps = len(dataset) // eval_batch_size
1233
+ for _ in tqdm(range(steps), desc=f"{'Predicting' if is_prediction else 'Evaluating'}...", position=2, leave=False):
1234
+ # Model forward
1235
+ batch = next(batches)
1236
+ _labels = batch.get("labels", None)
1237
+ if not is_prediction and _labels is None:
1238
+ raise ValueError("Evaluation requires the validation dataset to have `labels`")
1239
+
1240
+ if _labels is not None:
1241
+ _metrics = p_eval_step(state.params, batch)
1242
+ metrics.append(_metrics)
1243
+
1244
+ # generation
1245
+ if data_args.predict_with_generate:
1246
+ generated_ids = p_generate_step(state.params, batch)
1247
+ preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
1248
+ if _labels is not None:
1249
+ labels.extend(jax.device_get(_labels.reshape(-1, _labels.shape[-1])))
1250
+
1251
+ if metrics:
1252
+ # normalize metrics
1253
+ metrics = get_metrics(metrics)
1254
+ metrics = jax.tree_map(jnp.mean, metrics)
1255
+
1256
+ # compute ROUGE metrics
1257
+ generations = []
1258
+ rouge_desc = ""
1259
+ if data_args.predict_with_generate:
1260
+ if labels:
1261
+ rouge_metrics, decoded_preds, decoded_labels = compute_metrics(preds, labels)
1262
+ metrics.update(rouge_metrics)
1263
+ rouge_desc = " ".join([f"{'Predict' if is_prediction else 'Eval'} {key}: {value} |" for key, value in rouge_metrics.items()])
1264
+ for pred, label in zip(decoded_preds, decoded_labels):
1265
+ pred = pred.replace("\n", " ")
1266
+ label = label.replace("\n", " ")
1267
+ generations.append({"label": label, "pred": pred})
1268
+ else:
1269
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
1270
+ # Some simple post-processing
1271
+ decoded_preds = [pred.strip() for pred in decoded_preds]
1272
+ # rougeLSum expects newline after each sentence
1273
+ decoded_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds]
1274
+ for pred in decoded_preds:
1275
+ pred = pred.replace("\n", " ")
1276
+ generations.append({"pred": pred})
1277
+
1278
+ if metrics:
1279
+ # Print metrics and update progress bar
1280
+ desc = f"{'Predict' if is_prediction else 'Eval'} Loss: {metrics['loss']} | {rouge_desc})"
1281
+ if training_args.do_train and not is_prediction:
1282
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | " + desc
1283
+ epochs.write(desc)
1284
+ epochs.desc = desc
1285
+ logger.info(desc)
1286
+
1287
+ if jax.process_index() == 0:
1288
+
1289
+ if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
1290
+ os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)
1291
+
1292
+ if metrics:
1293
+
1294
+ # Save metrics (only for the evaluation/prediction being done along with training)
1295
+ if has_tensorboard and training_args.do_train:
1296
+ write_metric(summary_writer, metrics, train_time=None, step=cur_step, metric_key_prefix=metric_key_prefix)
1297
+
1298
+ # save final metrics in json
1299
+ metrics = {f"{metric_key_prefix}_{metric_name}": round(value.item(), 6) for metric_name, value in metrics.items()}
1300
+ _path = os.path.join(training_args.output_dir, ckpt_dir, f"{metric_key_prefix}_results.json")
1301
+ with open(_path, "w") as f:
1302
+ json.dump(metrics, f, indent=4, sort_keys=True)
1303
+
1304
+ # Update report
1305
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1306
+ fp.write(desc + '\n')
1307
+
1308
+ # Save generations
1309
+ if generations:
1310
+ with open(os.path.join(training_args.output_dir, ckpt_dir, f'{metric_key_prefix}_generation.json'), 'w', encoding='UTF-8') as fp:
1311
+ json.dump(generations, fp, ensure_ascii=False, indent=4)
1312
+
1313
+ def evaluate(rng: jax.random.PRNGKey, dataset: Dataset, ckpt_dir: str = ""):
1314
+ evaluation_loop(rng, dataset, metric_key_prefix='eval', ckpt_dir=ckpt_dir)
1315
+
1316
+ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
1317
+ evaluation_loop(rng, dataset, metric_key_prefix='test', is_prediction=True)
1318
+
1319
+ input_rng = None
1320
+
1321
+ if training_args.do_train:
1322
+
1323
+ cur_step = 0
1324
+ train_time = 0
1325
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1326
+
1327
+ for epoch in epochs:
1328
+
1329
+ # ======================== Training ================================
1330
+
1331
+ # Create sampling rng
1332
+ rng, input_rng = jax.random.split(rng)
1333
+
1334
+ train_metrics = []
1335
+
1336
+ train_batches = blockwise_data_loader(
1337
+ input_rng,
1338
+ train_dataset,
1339
+ block_size=training_args.block_size,
1340
+ batch_size=train_batch_size,
1341
+ keep_in_memory=True,
1342
+ shuffle=True,
1343
+ split="train"
1344
+ )
1345
+
1346
+ # train
1347
+ for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
1348
+
1349
+ cur_step += 1
1350
+ batch = next(train_batches)
1351
+ batch_start = time.time()
1352
+ state, train_metric = p_train_step(state, batch)
1353
+ train_metrics.append(train_metric)
1354
+ train_time += time.time() - batch_start
1355
+
1356
+ time_per_step = train_time / cur_step
1357
+ _train_metric = unreplicate(train_metric)
1358
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
1359
+ epochs.desc = desc
1360
+ epochs.write(desc)
1361
+
1362
+ # log and save info
1363
+ if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
1364
+
1365
+ logger.info(desc)
1366
+
1367
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1368
+ fp.write(desc + '\n')
1369
+
1370
+ # Save metrics
1371
+ if has_tensorboard and jax.process_index() == 0:
1372
+ write_metric(summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train")
1373
+
1374
+ # ======================== Evaluating (inside an epoch) ==============================
1375
+
1376
+ if training_args.do_eval and (training_args.eval_steps is not None and training_args.eval_steps > 0) and cur_step % training_args.eval_steps == 0:
1377
+ ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1378
+ commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1379
+ evaluate(input_rng, eval_dataset, ckpt_dir)
1380
+ save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
1381
+
1382
+ # ======================== Epoch End ==============================
1383
+
1384
+ # log and save info
1385
+ if training_args.logging_steps <= 0:
1386
+
1387
+ logger.info(desc)
1388
+
1389
+ with open(os.path.join(training_args.output_dir, 'log'), 'a', encoding='UTF-8') as fp:
1390
+ fp.write(desc + '\n')
1391
+
1392
+ # Save metrics
1393
+ if has_tensorboard and jax.process_index() == 0:
1394
+ write_metric(summary_writer, train_metrics, train_time=train_time, step=cur_step, metric_key_prefix="train")
1395
+
1396
+ # ======================== Evaluating (after each epoch) ==============================
1397
+
1398
+ if training_args.do_eval and (training_args.eval_steps is None or training_args.eval_steps <= 0):
1399
+ ckpt_dir = f"ckpt_epoch_{epoch + 1}_step_{cur_step}"
1400
+ commit_msg = f"Saving weights and logs of epoch {epoch + 1} - step {cur_step}"
1401
+ evaluate(input_rng, eval_dataset, ckpt_dir)
1402
+ save_ckpt(ckpt_dir=ckpt_dir, commit_msg=commit_msg)
1403
+
1404
+ # ======================== Evaluating | Predicting ==============================
1405
+
1406
+ # Create sampling rng
1407
+ if input_rng is None:
1408
+ rng, input_rng = jax.random.split(rng)
1409
+
1410
+ # run evaluation without training
1411
+ if training_args.do_eval and not training_args.do_train:
1412
+ evaluate(input_rng, eval_dataset)
1413
+
1414
+ # run prediction after (or without) training
1415
+ if training_args.do_predict:
1416
+ predict(input_rng, predict_dataset)
1417
+
1418
+
1419
+ if __name__ == "__main__":
1420
+
1421
+ main()