Cletrason commited on
Commit
6395c7e
1 Parent(s): 3bc6009

Create trainer_tf.py

Browse files
Files changed (1) hide show
  1. trainer_tf.py +801 -0
trainer_tf.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tensorflow trainer class."""
15
+
16
+ import datetime
17
+ import math
18
+ import os
19
+ import warnings
20
+ from typing import Callable, Dict, Optional, Tuple
21
+
22
+ from .utils import ENV_VARS_TRUE_VALUES
23
+
24
+
25
+ # Integrations must be imported before ML frameworks:
26
+ # isort: off
27
+ from .integrations import (
28
+ is_comet_available,
29
+ is_wandb_available,
30
+ )
31
+
32
+ # isort: on
33
+
34
+ import numpy as np
35
+ import tensorflow as tf
36
+ from tensorflow.python.distribute.values import PerReplica
37
+
38
+ from .modeling_tf_utils import TFPreTrainedModel
39
+ from .optimization_tf import GradientAccumulator, create_optimizer
40
+ from .trainer_utils import (
41
+ PREFIX_CHECKPOINT_DIR,
42
+ EvalPrediction,
43
+ IntervalStrategy,
44
+ PredictionOutput,
45
+ enable_full_determinism,
46
+ set_seed,
47
+ )
48
+ from .training_args_tf import TFTrainingArguments
49
+ from .utils import logging
50
+
51
+
52
+ if is_wandb_available():
53
+ import wandb
54
+
55
+ if is_comet_available():
56
+ import comet_ml
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+
61
+ class TFTrainer:
62
+ """
63
+ TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, optimized for 🤗 Transformers.
64
+
65
+ Args:
66
+ model ([`TFPreTrainedModel`]):
67
+ The model to train, evaluate or use for predictions.
68
+ args ([`TFTrainingArguments`]):
69
+ The arguments to tweak training.
70
+ train_dataset ([`~tf.data.Dataset`], *optional*):
71
+ The dataset to use for training. The dataset should yield tuples of `(features, labels)` where `features`
72
+ is a dict of input features and `labels` is the labels. If `labels` is a tensor, the loss is calculated by
73
+ the model by calling `model(features, labels=labels)`. If `labels` is a dict, such as when using a
74
+ QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
75
+ `model(features, **labels)`.
76
+ eval_dataset ([`~tf.data.Dataset`], *optional*):
77
+ The dataset to use for evaluation. The dataset should yield tuples of `(features, labels)` where `features`
78
+ is a dict of input features and `labels` is the labels. If `labels` is a tensor, the loss is calculated by
79
+ the model by calling `model(features, labels=labels)`. If `labels` is a dict, such as when using a
80
+ QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
81
+ `model(features, **labels)`.
82
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
83
+ The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
84
+ a dictionary string to metric values.
85
+ tb_writer (`tf.summary.SummaryWriter`, *optional*):
86
+ Object to write to TensorBoard.
87
+ optimizers (`Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]`, *optional*):
88
+ A tuple containing the optimizer and the scheduler to use. The optimizer default to an instance of
89
+ [`tf.keras.optimizers.Adam`] if `args.weight_decay_rate` is 0 else an instance of [`AdamWeightDecay`]. The
90
+ scheduler will default to an instance of [`tf.keras.optimizers.schedules.PolynomialDecay`] if
91
+ `args.num_warmup_steps` is 0 else an instance of [`WarmUp`].
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ model: TFPreTrainedModel,
97
+ args: TFTrainingArguments,
98
+ train_dataset: Optional[tf.data.Dataset] = None,
99
+ eval_dataset: Optional[tf.data.Dataset] = None,
100
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
101
+ tb_writer: Optional[tf.summary.SummaryWriter] = None,
102
+ optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = (
103
+ None,
104
+ None,
105
+ ),
106
+ ):
107
+ self.model = model
108
+ self.args = args
109
+ self.train_dataset = train_dataset
110
+ self.eval_dataset = eval_dataset
111
+ self.compute_metrics = compute_metrics
112
+ self.optimizer, self.lr_scheduler = optimizers
113
+ self.gradient_accumulator = GradientAccumulator()
114
+ self.global_step = 0
115
+ self.epoch_logging = 0
116
+ self.eval_loss = tf.keras.metrics.Sum()
117
+
118
+ warnings.warn(
119
+ "The class `TFTrainer` is deprecated and will be removed in version 5 of Transformers. "
120
+ "We recommend using native Keras instead, by calling methods like `fit()` and `predict()` "
121
+ "directly on the model object. Detailed examples of the Keras style can be found in our "
122
+ "examples at https://github.com/huggingface/transformers/tree/main/examples/tensorflow",
123
+ FutureWarning,
124
+ )
125
+
126
+ if tb_writer is not None:
127
+ self.tb_writer = tb_writer
128
+ else:
129
+ self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
130
+
131
+ if is_wandb_available():
132
+ self.setup_wandb()
133
+ elif os.getenv("WANDB_DISABLED", "").upper() not in ENV_VARS_TRUE_VALUES:
134
+ logger.info(
135
+ "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
136
+ "run `pip install wandb && wandb login` see https://docs.wandb.com/huggingface."
137
+ )
138
+
139
+ if is_comet_available():
140
+ self.setup_comet()
141
+ elif os.environ.get("COMET_MODE") != "DISABLED":
142
+ logger.info(
143
+ "To use comet_ml logging, run `pip/conda install comet_ml` "
144
+ "see https://www.comet.ml/docs/python-sdk/huggingface/"
145
+ )
146
+
147
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
148
+
149
+ def get_train_tfdataset(self) -> tf.data.Dataset:
150
+ """
151
+ Returns the training [`~tf.data.Dataset`].
152
+
153
+ Subclass and override this method if you want to inject some custom behavior.
154
+ """
155
+ if self.train_dataset is None:
156
+ raise ValueError("Trainer: training requires a train_dataset.")
157
+
158
+ self.total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps
159
+ self.num_train_examples = self.train_dataset.cardinality().numpy()
160
+
161
+ if self.num_train_examples < 0:
162
+ raise ValueError("The training dataset must have an asserted cardinality")
163
+
164
+ ds = (
165
+ self.train_dataset.repeat()
166
+ .shuffle(self.num_train_examples, seed=self.args.seed)
167
+ .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
168
+ .prefetch(tf.data.experimental.AUTOTUNE)
169
+ )
170
+
171
+ return self.args.strategy.experimental_distribute_dataset(ds)
172
+
173
+ def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
174
+ """
175
+ Returns the evaluation [`~tf.data.Dataset`].
176
+
177
+ Args:
178
+ eval_dataset ([`~tf.data.Dataset`], *optional*):
179
+ If provided, will override *self.eval_dataset*. The dataset should yield tuples of `(features, labels)`
180
+ where `features` is a dict of input features and `labels` is the labels. If `labels` is a tensor, the
181
+ loss is calculated by the model by calling `model(features, labels=labels)`. If `labels` is a dict,
182
+ such as when using a QuestionAnswering head model with multiple targets, the loss is instead calculated
183
+ by calling `model(features, **labels)`.
184
+
185
+ Subclass and override this method if you want to inject some custom behavior.
186
+ """
187
+ if eval_dataset is None and self.eval_dataset is None:
188
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
189
+
190
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
191
+ num_examples = eval_dataset.cardinality().numpy()
192
+
193
+ if num_examples < 0:
194
+ raise ValueError("The training dataset must have an asserted cardinality")
195
+
196
+ approx = math.floor if self.args.dataloader_drop_last else math.ceil
197
+ steps = approx(num_examples / self.args.eval_batch_size)
198
+ ds = (
199
+ eval_dataset.repeat()
200
+ .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last)
201
+ .prefetch(tf.data.experimental.AUTOTUNE)
202
+ )
203
+
204
+ return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
205
+
206
+ def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
207
+ """
208
+ Returns a test [`~tf.data.Dataset`].
209
+
210
+ Args:
211
+ test_dataset ([`~tf.data.Dataset`]):
212
+ The dataset to use. The dataset should yield tuples of `(features, labels)` where `features` is a dict
213
+ of input features and `labels` is the labels. If `labels` is a tensor, the loss is calculated by the
214
+ model by calling `model(features, labels=labels)`. If `labels` is a dict, such as when using a
215
+ QuestionAnswering head model with multiple targets, the loss is instead calculated by calling
216
+ `model(features, **labels)`.
217
+
218
+ Subclass and override this method if you want to inject some custom behavior.
219
+ """
220
+
221
+ num_examples = test_dataset.cardinality().numpy()
222
+
223
+ if num_examples < 0:
224
+ raise ValueError("The training dataset must have an asserted cardinality")
225
+
226
+ steps = math.ceil(num_examples / self.args.eval_batch_size)
227
+ ds = test_dataset.batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
228
+
229
+ return self.args.strategy.experimental_distribute_dataset(ds), steps, num_examples
230
+
231
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
232
+ """
233
+ Setup the optimizer and the learning rate scheduler.
234
+
235
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
236
+ TFTrainer's init through `optimizers`, or subclass and override this method.
237
+ """
238
+ if not self.optimizer and not self.lr_scheduler:
239
+ warmup_steps = (
240
+ self.args.warmup_steps
241
+ if self.args.warmup_steps > 0
242
+ else math.ceil(num_training_steps * self.args.warmup_ratio)
243
+ )
244
+
245
+ self.optimizer, self.lr_scheduler = create_optimizer(
246
+ self.args.learning_rate,
247
+ num_training_steps,
248
+ warmup_steps,
249
+ adam_beta1=self.args.adam_beta1,
250
+ adam_beta2=self.args.adam_beta2,
251
+ adam_epsilon=self.args.adam_epsilon,
252
+ weight_decay_rate=self.args.weight_decay,
253
+ power=self.args.poly_power,
254
+ )
255
+
256
+ def setup_wandb(self):
257
+ """
258
+ Setup the optional Weights & Biases (`wandb`) integration.
259
+
260
+ One can subclass and override this method to customize the setup if needed. Find more information `here
261
+ <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
262
+
263
+ Environment:
264
+ WANDB_PROJECT:
265
+ (Optional): str - "huggingface" by default, set this to a custom string to store results in a different
266
+ project.
267
+ WANDB_DISABLED:
268
+ (Optional): boolean - defaults to false, set to "true" to disable wandb entirely.
269
+ """
270
+
271
+ logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
272
+ combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
273
+ wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name)
274
+
275
+ def setup_comet(self):
276
+ """
277
+ Setup the optional Comet.ml integration.
278
+
279
+ Environment:
280
+ COMET_MODE:
281
+ (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
282
+ COMET_PROJECT_NAME:
283
+ (Optional): str - Comet.ml project name for experiments
284
+ COMET_OFFLINE_DIRECTORY:
285
+ (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
286
+
287
+ For a number of configurable items in the environment, see `here
288
+ <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
289
+ """
290
+ comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
291
+ args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
292
+ experiment = None
293
+ if comet_mode == "ONLINE":
294
+ experiment = comet_ml.Experiment(**args)
295
+ logger.info("Automatic Comet.ml online logging enabled")
296
+ elif comet_mode == "OFFLINE":
297
+ args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
298
+ experiment = comet_ml.OfflineExperiment(**args)
299
+ logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
300
+ if experiment is not None:
301
+ experiment._set_model_graph(self.model, framework="transformers")
302
+ experiment._log_parameters(self.args, prefix="args/", framework="transformers")
303
+ experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
304
+
305
+ def prediction_loop(
306
+ self,
307
+ dataset: tf.data.Dataset,
308
+ steps: int,
309
+ num_examples: int,
310
+ description: str,
311
+ prediction_loss_only: Optional[bool] = None,
312
+ ) -> PredictionOutput:
313
+ """
314
+ Prediction/evaluation loop, shared by [`~TFTrainer.evaluate`] and [`~TFTrainer.predict`].
315
+
316
+ Works both with or without labels.
317
+ """
318
+
319
+ prediction_loss_only = (
320
+ prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
321
+ )
322
+
323
+ logger.info(f"***** Running {description} *****")
324
+ logger.info(f" Num examples in dataset = {num_examples}")
325
+ if description == "Evaluation":
326
+ logger.info(f" Num examples in used in evaluation = {self.args.eval_batch_size * steps}")
327
+ logger.info(f" Batch size = {self.args.eval_batch_size}")
328
+
329
+ label_ids: np.ndarray = None
330
+ preds: np.ndarray = None
331
+ self.eval_loss.reset_states()
332
+
333
+ # Reset the past mems state at the beginning of the evaluation if necessary.
334
+ if self.args.past_index >= 0:
335
+ self._past = None
336
+
337
+ for step, batch in enumerate(dataset):
338
+ logits = self.distributed_prediction_steps(batch)
339
+ _, labels = batch
340
+
341
+ if not prediction_loss_only:
342
+ if isinstance(logits, tuple):
343
+ logits = logits[0]
344
+
345
+ if isinstance(labels, tuple):
346
+ labels = labels[0]
347
+
348
+ if self.args.n_replicas > 1:
349
+ for val in logits.values:
350
+ if preds is None:
351
+ preds = val.numpy()
352
+ else:
353
+ preds = np.append(preds, val.numpy(), axis=0)
354
+
355
+ for val in labels.values:
356
+ if label_ids is None:
357
+ label_ids = val.numpy()
358
+ else:
359
+ label_ids = np.append(label_ids, val.numpy(), axis=0)
360
+ else:
361
+ if preds is None:
362
+ preds = logits.numpy()
363
+ else:
364
+ preds = np.append(preds, logits.numpy(), axis=0)
365
+
366
+ if label_ids is None:
367
+ label_ids = labels.numpy()
368
+ else:
369
+ label_ids = np.append(label_ids, labels.numpy(), axis=0)
370
+
371
+ if step == steps - 1:
372
+ break
373
+
374
+ if self.compute_metrics is not None and preds is not None and label_ids is not None:
375
+ metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
376
+ else:
377
+ metrics = {}
378
+
379
+ metrics["eval_loss"] = self.eval_loss.result().numpy() / steps
380
+
381
+ for key in list(metrics.keys()):
382
+ if not key.startswith("eval_"):
383
+ metrics[f"eval_{key}"] = metrics.pop(key)
384
+
385
+ if self.args.past_index and hasattr(self, "_past"):
386
+ # Clean the state at the end of training
387
+ delattr(self, "_past")
388
+
389
+ return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
390
+
391
+ def log(self, logs: Dict[str, float]) -> None:
392
+ """
393
+ Log `logs` on the various objects watching training.
394
+
395
+ Subclass and override this method to inject custom behavior.
396
+
397
+ Args:
398
+ logs (`Dict[str, float]`):
399
+ The values to log.
400
+ """
401
+ logs["epoch"] = self.epoch_logging
402
+
403
+ if self.tb_writer:
404
+ with self.tb_writer.as_default():
405
+ for k, v in logs.items():
406
+ tf.summary.scalar(k, v, step=self.global_step)
407
+ self.tb_writer.flush()
408
+
409
+ if is_wandb_available():
410
+ wandb.log(logs, step=self.global_step)
411
+
412
+ if is_comet_available():
413
+ experiment = comet_ml.config.get_global_experiment()
414
+ if experiment is not None:
415
+ experiment._log_metrics(
416
+ logs, step=self.global_step, epoch=self.epoch_logging, framework="transformers"
417
+ )
418
+
419
+ output = {**logs, **{"step": self.global_step}}
420
+
421
+ logger.info(output)
422
+
423
+ def evaluate(self, eval_dataset: Optional[tf.data.Dataset] = None) -> Dict[str, float]:
424
+ """
425
+ Run evaluation and returns metrics.
426
+
427
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
428
+ (pass it to the init `compute_metrics` argument).
429
+
430
+ Args:
431
+ eval_dataset ([`~tf.data.Dataset`], *optional*):
432
+ Pass a dataset if you wish to override `self.eval_dataset`. The dataset should yield tuples of
433
+ `(features, labels)` where `features` is a dict of input features and `labels` is the labels. If
434
+ `labels` is a tensor, the loss is calculated by the model by calling `model(features, labels=labels)`.
435
+ If `labels` is a dict, such as when using a QuestionAnswering head model with multiple targets, the
436
+ loss is instead calculated by calling `model(features, **labels)`.
437
+
438
+ Returns:
439
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
440
+ """
441
+ eval_ds, steps, num_examples = self.get_eval_tfdataset(eval_dataset)
442
+
443
+ output = self.prediction_loop(eval_ds, steps, num_examples, description="Evaluation")
444
+ logs = {**output.metrics}
445
+ logs["epoch"] = self.epoch_logging
446
+
447
+ self.log(logs)
448
+
449
+ return output.metrics
450
+
451
+ def prediction_step(
452
+ self, features: tf.Tensor, labels: tf.Tensor, nb_instances_in_global_batch: tf.Tensor
453
+ ) -> tf.Tensor:
454
+ """
455
+ Compute the prediction on features and update the loss with labels.
456
+
457
+ Subclass and override to inject some custom behavior.
458
+ """
459
+ per_example_loss, logits = self.run_model(features, labels, False)
460
+ scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
461
+
462
+ self.eval_loss.update_state(scaled_loss)
463
+
464
+ return logits
465
+
466
+ @tf.function
467
+ def distributed_prediction_steps(self, batch):
468
+ nb_instances_in_batch = self._compute_nb_instances(batch)
469
+ inputs = self._get_step_inputs(batch, nb_instances_in_batch)
470
+
471
+ logits = self.args.strategy.run(self.prediction_step, inputs)
472
+
473
+ return logits
474
+
475
+ def train(self) -> None:
476
+ """
477
+ Train method to train the model.
478
+ """
479
+ train_ds = self.get_train_tfdataset()
480
+
481
+ if self.args.debug:
482
+ tf.summary.trace_on(graph=True, profiler=True)
483
+
484
+ self.gradient_accumulator.reset()
485
+
486
+ num_update_steps_per_epoch = self.num_train_examples / self.total_train_batch_size
487
+
488
+ # In fact, ``self.args.dataloader_drop_last`` has no effect in `trainer_tf.py`, because
489
+ # the dataset is repeated before being batched.
490
+ # It has the effect only when TPU is used which requires explicit tensor shape in order to make
491
+ # the gradient accumulation implementation work.
492
+ approx = math.floor if self.args.dataloader_drop_last else math.ceil
493
+ num_update_steps_per_epoch = approx(num_update_steps_per_epoch)
494
+
495
+ # At least one update for each epoch.
496
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
497
+ self.steps_per_epoch = num_update_steps_per_epoch
498
+
499
+ if self.args.max_steps > 0:
500
+ t_total = self.args.max_steps
501
+ epochs = (self.args.max_steps // self.steps_per_epoch) + int(
502
+ self.args.max_steps % self.steps_per_epoch > 0
503
+ )
504
+ else:
505
+ t_total = self.steps_per_epoch * self.args.num_train_epochs
506
+ epochs = self.args.num_train_epochs
507
+
508
+ # Since ``self.args.num_train_epochs`` can be `float`, we make ``epochs`` be a `float` always.
509
+ epochs = float(epochs)
510
+
511
+ with self.args.strategy.scope():
512
+ self.create_optimizer_and_scheduler(num_training_steps=t_total)
513
+ folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR)
514
+ ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
515
+ self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit)
516
+
517
+ iterations = self.optimizer.iterations
518
+ epochs_trained = 0
519
+ steps_trained_in_current_epoch = 0
520
+ if self.model.ckpt_manager.latest_checkpoint:
521
+ logger.info(
522
+ f"Checkpoint file {self.model.ckpt_manager.latest_checkpoint} found and restoring from checkpoint"
523
+ )
524
+ ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
525
+
526
+ self.global_step = iterations.numpy()
527
+
528
+ epochs_trained = self.global_step // self.steps_per_epoch
529
+ steps_trained_in_current_epoch = self.global_step % self.steps_per_epoch
530
+
531
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
532
+ logger.info(f" Continuing training from epoch {epochs_trained}")
533
+ logger.info(f" Continuing training from global step {self.global_step}")
534
+ logger.info(f" Will skip the first {steps_trained_in_current_epoch} steps in the first epoch")
535
+
536
+ tf.summary.experimental.set_step(self.global_step)
537
+
538
+ with self.tb_writer.as_default():
539
+ tf.summary.text("args", self.args.to_json_string())
540
+
541
+ self.tb_writer.flush()
542
+
543
+ logger.info("***** Running training *****")
544
+ logger.info(f" Num examples = {self.num_train_examples}")
545
+ # TODO: We might want to print a more precise ``epochs`` if self.args.max_steps > 0 ?
546
+ logger.info(f" Num Epochs = {epochs}")
547
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
548
+ logger.info(
549
+ f" Total train batch size (w. parallel, distributed & accumulation) = {self.total_train_batch_size}"
550
+ )
551
+ logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
552
+ logger.info(f" Steps per epoch = {self.steps_per_epoch}")
553
+ logger.info(f" Total optimization steps = {t_total}")
554
+
555
+ self.train_loss = tf.keras.metrics.Sum()
556
+ start_time = datetime.datetime.now()
557
+
558
+ for epoch_iter in range(epochs_trained, int(epochs)):
559
+ # Reset the past mems state at the beginning of each epoch if necessary.
560
+ if self.args.past_index >= 0:
561
+ self._past = None
562
+
563
+ for step, batch in enumerate(train_ds):
564
+ # Skip past any already trained steps if resuming training
565
+ if steps_trained_in_current_epoch > 0:
566
+ steps_trained_in_current_epoch -= 1
567
+ continue
568
+
569
+ self.distributed_training_steps(batch)
570
+
571
+ self.global_step = iterations.numpy()
572
+ self.epoch_logging = epoch_iter + (step + 1) / self.steps_per_epoch
573
+
574
+ training_loss = self.train_loss.result() / (step + 1)
575
+
576
+ if self.args.debug:
577
+ logs = {}
578
+ logs["loss"] = training_loss.numpy()
579
+ logs["epoch"] = self.epoch_logging
580
+
581
+ self.log(logs)
582
+
583
+ if self.global_step == 1 and self.args.debug:
584
+ with self.tb_writer.as_default():
585
+ tf.summary.trace_export(
586
+ name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
587
+ )
588
+
589
+ if (
590
+ self.args.eval_steps > 0
591
+ and self.args.evaluation_strategy == IntervalStrategy.STEPS
592
+ and self.global_step % self.args.eval_steps == 0
593
+ ):
594
+ self.evaluate()
595
+
596
+ if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
597
+ self.global_step == 1 and self.args.logging_first_step
598
+ ):
599
+ logs = {}
600
+ logs["loss"] = training_loss.numpy()
601
+ logs["learning_rate"] = self.lr_scheduler(self.global_step).numpy()
602
+ logs["epoch"] = self.epoch_logging
603
+
604
+ self.log(logs)
605
+
606
+ if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
607
+ ckpt_save_path = self.model.ckpt_manager.save()
608
+
609
+ logger.info(f"Saving checkpoint for step {self.global_step} at {ckpt_save_path}")
610
+
611
+ if self.args.max_steps > 0 and self.global_step >= t_total:
612
+ break
613
+
614
+ if self.global_step % self.steps_per_epoch == 0:
615
+ break
616
+
617
+ self.train_loss.reset_states()
618
+
619
+ if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
620
+ break
621
+
622
+ end_time = datetime.datetime.now()
623
+
624
+ logger.info(f"Training took: {str(end_time - start_time)}")
625
+
626
+ if self.args.past_index and hasattr(self, "_past"):
627
+ # Clean the state at the end of training
628
+ delattr(self, "_past")
629
+
630
+ def training_step(self, features, labels, nb_instances_in_global_batch):
631
+ """
632
+ Perform a training step on features and labels.
633
+
634
+ Subclass and override to inject some custom behavior.
635
+ """
636
+ per_example_loss, _ = self.run_model(features, labels, True)
637
+ scaled_loss = per_example_loss / tf.cast(nb_instances_in_global_batch, dtype=per_example_loss.dtype)
638
+ gradients = tf.gradients(scaled_loss, self.model.trainable_variables)
639
+ gradients = [
640
+ g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables)
641
+ ]
642
+
643
+ if self.args.gradient_accumulation_steps > 1:
644
+ self.gradient_accumulator(gradients)
645
+
646
+ self.train_loss.update_state(scaled_loss)
647
+
648
+ if self.args.gradient_accumulation_steps == 1:
649
+ return gradients
650
+
651
+ def apply_gradients(self, features, labels, nb_instances_in_global_batch):
652
+ if self.args.gradient_accumulation_steps == 1:
653
+ gradients = self.training_step(features, labels, nb_instances_in_global_batch)
654
+
655
+ self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
656
+ else:
657
+ for _ in tf.range(self.args.gradient_accumulation_steps):
658
+ reduced_features = {
659
+ k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
660
+ }
661
+
662
+ if tf.is_tensor(labels):
663
+ reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
664
+ elif isinstance(labels, dict):
665
+ reduced_labels = {
666
+ k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
667
+ }
668
+ else:
669
+ raise ValueError("The labels must be either a tf.Tensor or a dict.")
670
+
671
+ self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
672
+
673
+ features = {
674
+ k: tf.concat(
675
+ [ft[self.args.train_batch_size // self.args.n_replicas :], reduced_features[k]],
676
+ axis=0,
677
+ )
678
+ for k, ft in features.items()
679
+ }
680
+
681
+ if tf.is_tensor(labels):
682
+ labels = tf.concat(
683
+ [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
684
+ )
685
+ elif isinstance(labels, dict):
686
+ labels = {
687
+ k: tf.concat(
688
+ [lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
689
+ axis=0,
690
+ )
691
+ for k, lbl in labels.items()
692
+ }
693
+ else:
694
+ raise ValueError("The labels must be either a tf.Tensor or a dict.")
695
+
696
+ gradients = self.gradient_accumulator.gradients
697
+ gradients = [
698
+ (tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients
699
+ ]
700
+
701
+ self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
702
+ self.gradient_accumulator.reset()
703
+
704
+ @tf.function
705
+ def distributed_training_steps(self, batch):
706
+ with self.args.strategy.scope():
707
+ nb_instances_in_batch = self._compute_nb_instances(batch)
708
+ inputs = self._get_step_inputs(batch, nb_instances_in_batch)
709
+
710
+ self.args.strategy.run(self.apply_gradients, inputs)
711
+
712
+ @staticmethod
713
+ def _compute_nb_instances(batch):
714
+ labels = batch[-1]
715
+ if isinstance(labels, PerReplica):
716
+ labels = tf.concat(labels.values, axis=0)
717
+
718
+ nb_instances = tf.reduce_sum(tf.cast(labels != -100, dtype=tf.int32))
719
+
720
+ return nb_instances
721
+
722
+ @staticmethod
723
+ def _get_step_inputs(batch, nb_instances):
724
+ features, labels = batch
725
+
726
+ if isinstance(labels, PerReplica):
727
+ # need to make a `PerReplica` objects for ``nb_instances``
728
+ nb_instances = PerReplica([nb_instances] * len(labels.values))
729
+
730
+ step_inputs = (features, labels, nb_instances)
731
+
732
+ return step_inputs
733
+
734
+ def run_model(self, features, labels, training):
735
+ """
736
+ Computes the loss of the given features and labels pair.
737
+
738
+ Subclass and override this method if you want to inject some custom behavior.
739
+
740
+ Args:
741
+ features (`tf.Tensor`): A batch of input features.
742
+ labels (`tf.Tensor`): A batch of labels.
743
+ training (`bool`): Whether or not to run the model in training mode.
744
+
745
+ Returns:
746
+ A tuple of two `tf.Tensor`: The loss and logits.
747
+ """
748
+
749
+ if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
750
+ features["mems"] = self._past
751
+
752
+ if isinstance(labels, (dict)):
753
+ outputs = self.model(features, training=training, **labels)[:2]
754
+ else:
755
+ outputs = self.model(features, labels=labels, training=training)[:2]
756
+
757
+ loss, logits = outputs[:2]
758
+
759
+ if self.args.past_index >= 0:
760
+ self._past = outputs[self.args.past_index]
761
+
762
+ return loss, logits
763
+
764
+ def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput:
765
+ """
766
+ Run prediction and returns predictions and potential metrics.
767
+
768
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
769
+ will also return metrics, like in `evaluate()`.
770
+
771
+ Args:
772
+ test_dataset ([`~tf.data.Dataset`]):
773
+ Dataset to run the predictions on. The dataset should yield tuples of `(features, labels)` where
774
+ `features` is a dict of input features and `labels` is the labels. If `labels` is a tensor, the loss is
775
+ calculated by the model by calling `model(features, labels=labels)`. If `labels` is a dict, such as
776
+ when using a QuestionAnswering head model with multiple targets, the loss is instead calculated by
777
+ calling `model(features, **labels)`
778
+
779
+ Returns: *NamedTuple* A namedtuple with the following keys:
780
+
781
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
782
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
783
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
784
+ labels).
785
+ """
786
+ test_ds, steps, num_examples = self.get_test_tfdataset(test_dataset)
787
+
788
+ return self.prediction_loop(test_ds, steps, num_examples, description="Prediction")
789
+
790
+ def save_model(self, output_dir: Optional[str] = None):
791
+ """
792
+ Will save the model, so you can reload it using `from_pretrained()`.
793
+ """
794
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
795
+
796
+ logger.info(f"Saving model in {output_dir}")
797
+
798
+ if not isinstance(self.model, TFPreTrainedModel):
799
+ raise ValueError("Trainer.model appears to not be a PreTrainedModel")
800
+
801
+ self.model.save_pretrained(output_dir)