Optimum documentation

GaudiTrainer

You are viewing v1.7.1 version. A newer version v1.19.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

GaudiTrainer

The GaudiTrainer class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.

Before instantiating your GaudiTrainer, create a GaudiTrainingArguments object to access all the points of customization during training.

The GaudiTrainer class is optimized for 🤗 Transformers models running on Habana Gaudi.

Here is an example of how to customize GaudiTrainer to use a weighted loss (useful when you have an unbalanced training set):

from torch import nn
from optimum.habana import GaudiTrainer


class CustomGaudiTrainer(GaudiTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Another way to customize the training loop behavior for the PyTorch GaudiTrainer is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

GaudiTrainer

class optimum.habana.GaudiTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )

GaudiTrainer is built on top of the tranformers’ Trainer to enable deployment on Habana’s Gaudi.

create_optimizer

< >

( )

Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer’s init through optimizers, or subclass and override this method in a subclass.

evaluation_loop

< >

( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )

Prediction/evaluation loop, shared by Trainer.evaluate() and Trainer.predict(). Works both with or without labels.

log

< >

( logs: typing.Dict[str, float] )

Parameters

  • logs (Dict[str, float]) — The values to log.

Log logs on the various objects watching training. Subclass and override this method to inject custom behavior.

prediction_loop

< >

( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )

Prediction/evaluation loop, shared by Trainer.evaluate() and Trainer.predict(). Works both with or without labels.

prediction_step

< >

( model: Module inputs: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]] prediction_loss_only: bool ignore_keys: typing.Optional[typing.List[str]] = None ) Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]

Parameters

  • model (nn.Module) — The model to evaluate.
  • inputs (Dict[str, Union[torch.Tensor, Any]]) — The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument labels. Check your model’s documentation for all accepted arguments.
  • prediction_loss_only (bool) — Whether or not to return the loss only.
  • ignore_keys (Lst[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.

Returns

Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]

A tuple with the loss, logits and labels (each being optional).

Perform an evaluation step on model using inputs. Subclass and override to inject custom behavior.

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

Will save the model, so you can reload it using from_pretrained(). Will only save from the main process.

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs — Additional keyword arguments used to hide deprecated arguments

Main training entry point.

GaudiSeq2SeqTrainer

class optimum.habana.GaudiSeq2SeqTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )

evaluate

< >

( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' **gen_kwargs )

Parameters

  • eval_dataset (Dataset, optional) — Pass a dataset if you wish to override self.eval_dataset. If it is an Dataset, columns not accepted by the model.forward() method are automatically removed. It must implement the __len__ method.
  • ignore_keys (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
  • metric_key_prefix (str, optional, defaults to "eval") — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is "eval" (default)
  • max_length (int, optional) — The maximum target length to use when predicting with the generate method.
  • num_beams (int, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additional generate specific kwargs.

Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init compute_metrics argument). You can also subclass and override this method to inject custom behavior.

predict

< >

( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' **gen_kwargs )

Parameters

  • test_dataset (Dataset) — Dataset to run the predictions on. If it is a Dataset, columns not accepted by the model.forward() method are automatically removed. Has to implement the method __len__
  • ignore_keys (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
  • metric_key_prefix (str, optional, defaults to "eval") — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is "eval" (default)
  • max_length (int, optional) — The maximum target length to use when predicting with the generate method.
  • num_beams (int, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additional generate specific kwargs.

Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in evaluate().

If your predictions or labels have different sequence lengths (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100.
Returns: *NamedTuple* A namedtuple with the following keys: - predictions (`np.ndarray`): The predictions on `test_dataset`. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels).