Trainer ----------------------------------------------------------------------------------------------------------------------- The :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` classes provide an API for feature-complete training in most standard use cases. It's used in most of the :doc:`example scripts <../examples>`. Before instantiating your :class:`~transformers.Trainer`/:class:`~transformers.TFTrainer`, create a :class:`~transformers.TrainingArguments`/:class:`~transformers.TFTrainingArguments` to access all the points of customization during training. The API supports distributed training on multiple GPUs/TPUs, mixed precision through `NVIDIA Apex `__ for PyTorch and :obj:`tf.keras.mixed_precision` for TensorFlow. Both :class:`~transformers.Trainer` and :class:`~transformers.TFTrainer` contain the basic training loop supporting the previous features. To inject custom behavior you can subclass them and override the following methods: - **get_train_dataloader**/**get_train_tfdataset** -- Creates the training DataLoader (PyTorch) or TF Dataset. - **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset. - **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset. - **log** -- Logs information on the various objects watching training. - **setup_wandb** -- Setups wandb (see `here `__ for more information). - **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at init. - **compute_loss** - Computes the loss on a batch of training inputs. - **training_step** -- Performs a training step. - **prediction_step** -- Performs an evaluation/test step. - **run_model** (TensorFlow only) -- Basic pass through the model. - **evaluate** -- Runs an evaluation loop and returns metrics. - **predict** -- Returns predictions (with metrics if labels are available) on a test set. Here is an example of how to customize :class:`~transformers.Trainer` using a custom loss function: .. code-block:: python from transformers import Trainer class MyTrainer(Trainer): def compute_loss(self, model, inputs): labels = inputs.pop("labels") outputs = models(**inputs) logits = outputs[0] return my_custom_loss(logits, labels) Trainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.Trainer :members: TFTrainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainer :members: TrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TrainingArguments :members: TFTrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainingArguments :members: Utilities ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.EvalPrediction .. autofunction:: transformers.set_seed .. autofunction:: transformers.torch_distributed_zero_first