fgaim commited on
Commit
f3ab687
1 Parent(s): c4f527f

Add script

Browse files
Files changed (1) hide show
  1. run_flax_glue.py +526 -0
run_flax_glue.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
17
+ import argparse
18
+ import logging
19
+ import os
20
+ import random
21
+ import time
22
+ from itertools import chain
23
+ from typing import Any, Callable, Dict, Tuple
24
+
25
+ import datasets
26
+ from datasets import load_dataset, load_metric
27
+
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import optax
31
+ import transformers
32
+ from flax import struct, traverse_util
33
+ from flax.jax_utils import replicate, unreplicate
34
+ from flax.metrics import tensorboard
35
+ from flax.training import train_state
36
+ from flax.training.common_utils import get_metrics, onehot, shard
37
+ from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ Array = Any
43
+ Dataset = datasets.arrow_dataset.Dataset
44
+ PRNGKey = Any
45
+
46
+
47
+ task_to_keys = {
48
+ "cola": ("sentence", None),
49
+ "mnli": ("premise", "hypothesis"),
50
+ "mrpc": ("sentence1", "sentence2"),
51
+ "qnli": ("question", "sentence"),
52
+ "qqp": ("question1", "question2"),
53
+ "rte": ("sentence1", "sentence2"),
54
+ "sst2": ("sentence", None),
55
+ "swahili_news": ("text", None),
56
+ "stsb": ("sentence1", "sentence2"),
57
+ "wnli": ("sentence1", "sentence2"),
58
+ }
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
63
+ parser.add_argument(
64
+ "--task_name",
65
+ type=str,
66
+ default=None,
67
+ help="The name of the glue task to train on.",
68
+ choices=list(task_to_keys.keys()),
69
+ )
70
+ parser.add_argument(
71
+ "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
72
+ )
73
+ parser.add_argument(
74
+ "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
75
+ )
76
+ parser.add_argument(
77
+ "--max_length",
78
+ type=int,
79
+ default=128,
80
+ help=(
81
+ "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
82
+ " sequences shorter will be padded."
83
+ ),
84
+ )
85
+ parser.add_argument(
86
+ "--model_name_or_path",
87
+ type=str,
88
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
89
+ required=True,
90
+ )
91
+ parser.add_argument(
92
+ "--use_slow_tokenizer",
93
+ action="store_true",
94
+ help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
95
+ )
96
+ parser.add_argument(
97
+ "--per_device_train_batch_size",
98
+ type=int,
99
+ default=8,
100
+ help="Batch size (per device) for the training dataloader.",
101
+ )
102
+ parser.add_argument(
103
+ "--per_device_eval_batch_size",
104
+ type=int,
105
+ default=8,
106
+ help="Batch size (per device) for the evaluation dataloader.",
107
+ )
108
+ parser.add_argument(
109
+ "--learning_rate",
110
+ type=float,
111
+ default=5e-5,
112
+ help="Initial learning rate (after the potential warmup period) to use.",
113
+ )
114
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
115
+ parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
116
+ parser.add_argument(
117
+ "--max_train_steps",
118
+ type=int,
119
+ default=None,
120
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
121
+ )
122
+ parser.add_argument(
123
+ "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
124
+ )
125
+ parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
126
+ parser.add_argument("--seed", type=int, default=3, help="A seed for reproducible training.")
127
+ parser.add_argument(
128
+ "--push_to_hub",
129
+ action="store_true",
130
+ help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
131
+ )
132
+ args = parser.parse_args()
133
+
134
+ # Sanity checks
135
+ if args.task_name is None and args.train_file is None and args.validation_file is None:
136
+ raise ValueError("Need either a task name or a training/validation file.")
137
+ else:
138
+ if args.train_file is not None:
139
+ extension = args.train_file.split(".")[-1]
140
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
141
+ if args.validation_file is not None:
142
+ extension = args.validation_file.split(".")[-1]
143
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
144
+
145
+ if args.output_dir is not None:
146
+ os.makedirs(args.output_dir, exist_ok=True)
147
+
148
+ return args
149
+
150
+
151
+ def create_train_state(
152
+ model: FlaxAutoModelForSequenceClassification,
153
+ learning_rate_fn: Callable[[int], float],
154
+ is_regression: bool,
155
+ num_labels: int,
156
+ weight_decay: float,
157
+ ) -> train_state.TrainState:
158
+ """Create initial training state."""
159
+
160
+ class TrainState(train_state.TrainState):
161
+ """Train state with an Optax optimizer.
162
+
163
+ The two functions below differ depending on whether the task is classification
164
+ or regression.
165
+
166
+ Args:
167
+ logits_fn: Applied to last layer to obtain the logits.
168
+ loss_fn: Function to compute the loss.
169
+ """
170
+
171
+ logits_fn: Callable = struct.field(pytree_node=False)
172
+ loss_fn: Callable = struct.field(pytree_node=False)
173
+
174
+ # We use Optax's "masking" functionality to not apply weight decay
175
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
176
+ # mask boolean with the same structure as the parameters.
177
+ # The mask is True for parameters that should be decayed.
178
+ def decay_mask_fn(params):
179
+ flat_params = traverse_util.flatten_dict(params)
180
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
181
+ return traverse_util.unflatten_dict(flat_mask)
182
+
183
+ tx = optax.adamw(
184
+ learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
185
+ )
186
+
187
+ if is_regression:
188
+
189
+ def mse_loss(logits, labels):
190
+ return jnp.mean((logits[..., 0] - labels) ** 2)
191
+
192
+ return TrainState.create(
193
+ apply_fn=model.__call__,
194
+ params=model.params,
195
+ tx=tx,
196
+ logits_fn=lambda logits: logits[..., 0],
197
+ loss_fn=mse_loss,
198
+ )
199
+ else: # Classification.
200
+
201
+ def cross_entropy_loss(logits, labels):
202
+ xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
203
+ return jnp.mean(xentropy)
204
+
205
+ return TrainState.create(
206
+ apply_fn=model.__call__,
207
+ params=model.params,
208
+ tx=tx,
209
+ logits_fn=lambda logits: logits.argmax(-1),
210
+ loss_fn=cross_entropy_loss,
211
+ )
212
+
213
+
214
+ def create_learning_rate_fn(
215
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
216
+ ) -> Callable[[int], jnp.array]:
217
+ """Returns a linear warmup, linear_decay learning rate function."""
218
+ steps_per_epoch = train_ds_size // train_batch_size
219
+ num_train_steps = steps_per_epoch * num_train_epochs
220
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
221
+ decay_fn = optax.linear_schedule(
222
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
223
+ )
224
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
225
+ return schedule_fn
226
+
227
+
228
+ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
229
+ """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
230
+ steps_per_epoch = len(dataset) // batch_size
231
+ perms = jax.random.permutation(rng, len(dataset))
232
+ perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch.
233
+ perms = perms.reshape((steps_per_epoch, batch_size))
234
+
235
+ for perm in perms:
236
+ batch = dataset[perm]
237
+ batch = {k: jnp.array(v) for k, v in batch.items()}
238
+ batch = shard(batch)
239
+
240
+ yield batch
241
+
242
+
243
+ def glue_eval_data_collator(dataset: Dataset, batch_size: int):
244
+ """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
245
+ for i in range(len(dataset) // batch_size):
246
+ batch = dataset[i * batch_size : (i + 1) * batch_size]
247
+ batch = {k: jnp.array(v) for k, v in batch.items()}
248
+ batch = shard(batch)
249
+
250
+ yield batch
251
+
252
+
253
+ def main():
254
+ args = parse_args()
255
+
256
+ # Make one log on every process with the configuration for debugging.
257
+ logging.basicConfig(
258
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
259
+ datefmt="%m/%d/%Y %H:%M:%S",
260
+ level=logging.INFO,
261
+ )
262
+ # Setup logging, we only want one process per machine to log things on the screen.
263
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
264
+ if jax.process_index() == 0:
265
+ datasets.utils.logging.set_verbosity_warning()
266
+ transformers.utils.logging.set_verbosity_info()
267
+ else:
268
+ datasets.utils.logging.set_verbosity_error()
269
+ transformers.utils.logging.set_verbosity_error()
270
+
271
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
272
+ # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
273
+
274
+ # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
275
+ # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
276
+ # label if at least two columns are provided.
277
+
278
+ # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
279
+ # single column. You can easily tweak this behavior (see below)
280
+
281
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
282
+ # download the dataset.
283
+ if args.task_name == "swahili_news":
284
+ raw_datasets = load_dataset("swahili_news")
285
+ valid_test_split = 10
286
+ raw_datasets["validation"] = load_dataset(
287
+ "swahili_news",
288
+ split=f"train[:{valid_test_split}%]"
289
+ )
290
+ raw_datasets["train"] = load_dataset(
291
+ "swahili_news",
292
+ split=f"train[{valid_test_split}%:]"
293
+ )
294
+ print(f"train: {len(raw_datasets['train'])}, validation: {len(raw_datasets['validation'])},")
295
+ elif args.task_name is not None:
296
+ # Downloading and loading a dataset from the hub.
297
+ raw_datasets = load_dataset("glue", args.task_name)
298
+ else:
299
+ # Loading the dataset from local csv or json file.
300
+ data_files = {}
301
+ if args.train_file is not None:
302
+ data_files["train"] = args.train_file
303
+ if args.validation_file is not None:
304
+ data_files["validation"] = args.validation_file
305
+ extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1]
306
+ raw_datasets = load_dataset(extension, data_files=data_files)
307
+ # See more about loading any type of standard or custom dataset at
308
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
309
+
310
+ # Labels
311
+ if args.task_name is not None:
312
+ is_regression = args.task_name == "stsb"
313
+ if not is_regression:
314
+ label_list = raw_datasets["train"].features["label"].names
315
+ num_labels = len(label_list)
316
+ else:
317
+ num_labels = 1
318
+ else:
319
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
320
+ is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
321
+ if is_regression:
322
+ num_labels = 1
323
+ else:
324
+ # A useful fast method:
325
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
326
+ label_list = raw_datasets["train"].unique("label")
327
+ label_list.sort() # Let's sort it for determinism
328
+ num_labels = len(label_list)
329
+
330
+ # Load pretrained model and tokenizer
331
+ config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
332
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
333
+ model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
334
+
335
+ # Preprocessing the datasets
336
+ if args.task_name is not None:
337
+ sentence1_key, sentence2_key = task_to_keys[args.task_name]
338
+ else:
339
+ # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
340
+ non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
341
+ if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
342
+ sentence1_key, sentence2_key = "sentence1", "sentence2"
343
+ else:
344
+ if len(non_label_column_names) >= 2:
345
+ sentence1_key, sentence2_key = non_label_column_names[:2]
346
+ else:
347
+ sentence1_key, sentence2_key = non_label_column_names[0], None
348
+
349
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
350
+ label_to_id = None
351
+ if (
352
+ model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
353
+ and args.task_name is not None
354
+ and not is_regression
355
+ ):
356
+ # Some have all caps in their config, some don't.
357
+ label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
358
+ if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
359
+ logger.info(
360
+ f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
361
+ "Using it!"
362
+ )
363
+ label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
364
+ else:
365
+ logger.warning(
366
+ "Your model seems to have been trained with labels, but they don't match the dataset: ",
367
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
368
+ "\nIgnoring the model labels as a result.",
369
+ )
370
+ elif args.task_name is None:
371
+ label_to_id = {v: i for i, v in enumerate(label_list)}
372
+
373
+ def preprocess_function(examples):
374
+ # Tokenize the texts
375
+ texts = (
376
+ (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
377
+ )
378
+ result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True)
379
+
380
+ if "label" in examples:
381
+ if label_to_id is not None:
382
+ # Map labels to IDs (not necessary for GLUE tasks)
383
+ result["labels"] = [label_to_id[l] for l in examples["label"]]
384
+ else:
385
+ # In all cases, rename the column to labels because the model will expect that.
386
+ result["labels"] = examples["label"]
387
+ return result
388
+
389
+ processed_datasets = raw_datasets.map(
390
+ preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
391
+ )
392
+
393
+ train_dataset = processed_datasets["train"]
394
+ eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
395
+
396
+ # Log a few random samples from the training set:
397
+ for index in random.sample(range(len(train_dataset)), 3):
398
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
399
+
400
+ # Define a summary writer
401
+ summary_writer = tensorboard.SummaryWriter(args.output_dir)
402
+ summary_writer.hparams(vars(args))
403
+
404
+ def write_metric(train_metrics, eval_metrics, train_time, step):
405
+ summary_writer.scalar("train_time", train_time, step)
406
+
407
+ train_metrics = get_metrics(train_metrics)
408
+ for key, vals in train_metrics.items():
409
+ tag = f"train_{key}"
410
+ for i, val in enumerate(vals):
411
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
412
+
413
+ for metric_name, value in eval_metrics.items():
414
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
415
+
416
+ num_epochs = int(args.num_train_epochs)
417
+ rng = jax.random.PRNGKey(args.seed)
418
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
419
+
420
+ train_batch_size = args.per_device_train_batch_size * jax.local_device_count()
421
+ eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count()
422
+
423
+ learning_rate_fn = create_learning_rate_fn(
424
+ len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
425
+ )
426
+
427
+ state = create_train_state(
428
+ model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
429
+ )
430
+
431
+ # define step functions
432
+ def train_step(
433
+ state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
434
+ ) -> Tuple[train_state.TrainState, float]:
435
+ """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
436
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
437
+ targets = batch.pop("labels")
438
+
439
+ def loss_fn(params):
440
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
441
+ loss = state.loss_fn(logits, targets)
442
+ return loss
443
+
444
+ grad_fn = jax.value_and_grad(loss_fn)
445
+ loss, grad = grad_fn(state.params)
446
+ grad = jax.lax.pmean(grad, "batch")
447
+ new_state = state.apply_gradients(grads=grad)
448
+ metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
449
+ return new_state, metrics, new_dropout_rng
450
+
451
+ p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
452
+
453
+ def eval_step(state, batch):
454
+ logits = state.apply_fn(**batch, params=state.params, train=False)[0]
455
+ return state.logits_fn(logits)
456
+
457
+ p_eval_step = jax.pmap(eval_step, axis_name="batch")
458
+
459
+ if args.task_name == "swahili_news":
460
+ metric = load_metric("glue", "sst2")
461
+ elif args.task_name is not None:
462
+ metric = load_metric("glue", args.task_name)
463
+ else:
464
+ metric = load_metric("accuracy")
465
+
466
+ logger.info(f"===== Starting training ({num_epochs} epochs) =====")
467
+ train_time = 0
468
+
469
+ # make sure weights are replicated on each device
470
+ state = replicate(state)
471
+
472
+ for epoch in range(1, num_epochs + 1):
473
+ logger.info(f"Epoch {epoch}")
474
+ logger.info(" Training...")
475
+
476
+ train_start = time.time()
477
+ train_metrics = []
478
+ rng, input_rng = jax.random.split(rng)
479
+
480
+ # train
481
+ for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
482
+ state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs)
483
+ train_metrics.append(metrics)
484
+ train_time += time.time() - train_start
485
+ logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
486
+
487
+ logger.info(" Evaluating...")
488
+
489
+ # evaluate
490
+ for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
491
+ labels = batch.pop("labels")
492
+ predictions = p_eval_step(state, batch)
493
+ metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
494
+
495
+ # evaluate also on leftover examples (not divisible by batch_size)
496
+ num_leftover_samples = len(eval_dataset) % eval_batch_size
497
+
498
+ # make sure leftover batch is evaluated on one device
499
+ if num_leftover_samples > 0 and jax.process_index() == 0:
500
+ # take leftover samples
501
+ batch = eval_dataset[-num_leftover_samples:]
502
+ batch = {k: jnp.array(v) for k, v in batch.items()}
503
+
504
+ labels = batch.pop("labels")
505
+ predictions = eval_step(unreplicate(state), batch)
506
+ metric.add_batch(predictions=predictions, references=labels)
507
+
508
+ eval_metric = metric.compute()
509
+ logger.info(f" Done! Eval metrics: {eval_metric}")
510
+
511
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
512
+ write_metric(train_metrics, eval_metric, train_time, cur_step)
513
+
514
+ # save checkpoint after each epoch and push checkpoint to the hub
515
+ if jax.process_index() == 0:
516
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
517
+ model.save_pretrained(
518
+ args.output_dir,
519
+ params=params,
520
+ push_to_hub=args.push_to_hub,
521
+ commit_message=f"Saving weights and logs of epoch {epoch}",
522
+ )
523
+
524
+
525
+ if __name__ == "__main__":
526
+ main()