The GaudiTrainer
class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.
Before instantiating your GaudiTrainer
, create a GaudiTrainingArguments
object to access all the points of customization during training.
The GaudiTrainer
class is optimized for 🤗 Transformers models running on Habana Gaudi.
Here is an example of how to customize GaudiTrainer
to use a weighted loss (useful when you have an unbalanced training set):
from torch import nn
from optimum.habana import GaudiTrainer
class CustomGaudiTrainer(GaudiTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
Another way to customize the training loop behavior for the PyTorch GaudiTrainer
is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).
class optimum.habana.GaudiTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[] = None eval_dataset: typing.Union[, typing.Dict[str,], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )
GaudiTrainer is built on top of the tranformers’ Trainer to enable deployment on Habana’s Gaudi.
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer’s init through optimizers
, or subclass and override this method in a subclass.
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
Works both with or without labels.
< source >( logs: typing.Dict[str, float] )
Log logs
on the various objects watching training.
Subclass and override this method to inject custom behavior.
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
Works both with or without labels.
< source >( model: Module inputs: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]] prediction_loss_only: bool ignore_keys: typing.Optional[typing.List[str]] = None ) → Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
model (
) — The model to evaluate. -
inputs (
Dict[str, Union[torch.Tensor, Any]]
) — The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argumentlabels
. Check your model’s documentation for all accepted arguments. -
prediction_loss_only (
) — Whether or not to return the loss only. -
ignore_keys (
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
A tuple with the loss, logits and labels (each being optional).
Perform an evaluation step on model
using inputs
Subclass and override to inject custom behavior.
Will save the model, so you can reload it using from_pretrained()
Will only save from the main process.
< source >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )
resume_from_checkpoint (
, optional) — If astr
, local path to a saved checkpoint as saved by a previous instance ofTrainer
. If abool
and equalsTrue
, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer
. If present, training will resume from the model/optimizer/scheduler states loaded here. -
trial (
orDict[str, Any]
, optional) — The trial run or the hyperparameter dictionary for hyperparameter search. -
ignore_keys_for_eval (
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs — Additional keyword arguments used to hide deprecated arguments
Main training entry point.
class optimum.habana.GaudiSeq2SeqTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[] = None eval_dataset: typing.Union[, typing.Dict[str,], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )
< source >( eval_dataset: typing.Optional[] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' **gen_kwargs )
eval_dataset (
, optional) — Pass a dataset if you wish to overrideself.eval_dataset
. If it is an Dataset, columns not accepted by themodel.forward()
method are automatically removed. It must implement the__len__
method. -
ignore_keys (
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additionalgenerate
specific kwargs.
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init compute_metrics
You can also subclass and override this method to inject custom behavior.
< source >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' **gen_kwargs )
test_dataset (
) — Dataset to run the predictions on. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. Has to implement the method__len__
ignore_keys (
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additionalgenerate
specific kwargs.
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in evaluate()