Optimum documentation

Trainer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Trainer

ORTTrainer

class optimum.onnxruntime.ORTTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None feature: str = 'default' args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None model_init: typing.Callable[[], transformers.modeling_utils.PreTrainedModel] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) onnx_model_path: typing.Union[str, os.PathLike] = None )

compute_loss_ort

< >

( model inputs input_names output_names return_outputs = False )

How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior.

evaluate

< >

( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' inference_with_ort: bool = False )

Run evaluation within ONNX Runtime or PyTorch backend and returns metrics.(Overriden from Trainer.evaluate())

evaluation_loop_ort

< >

( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )

Prediction/evaluation loop, shared by ORTTrainer.evaluate() and ORTTrainer.predict(). Works both with or without labels.

predict

< >

( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' inference_with_ort: bool = False )

Run prediction within ONNX Runtime or PyTorch backend and returns predictions and potential metrics. (Overriden from Trainer.predict())

prediction_loop_ort

< >

( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )

Prediction/evaluation loop, shared by Trainer.evaluate() and Trainer.predict(). Works both with or without labels.

train

< >

( resume_from_checkpoint: typing.Union[bool, str, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )

Parameters

  • resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.
  • ignore_keys_for_eval (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs — Additional keyword arguments used to hide deprecated arguments

Main onnxruntime training entry point.

ORTSeq2SeqTrainer

class optimum.onnxruntime.ORTSeq2SeqTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None feature: str = 'default' args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None model_init: typing.Callable[[], transformers.modeling_utils.PreTrainedModel] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) onnx_model_path: typing.Union[str, os.PathLike] = None )

evaluate

< >

( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' max_length: typing.Optional[int] = None num_beams: typing.Optional[int] = None inference_with_ort: bool = False )

Parameters

  • eval_dataset (Dataset, optional) — Pass a dataset if you wish to override self.eval_dataset. If it is an datasets.Dataset, columns not accepted by the model.forward() method are automatically removed. It must implement the __len__ method.
  • ignore_keys (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
  • metric_key_prefix (str, optional, defaults to "eval") — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is "eval" (default)
  • max_length (int, optional) — The maximum target length to use when predicting with the generate method.
  • num_beams (int, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search.
  • inference_with_ort (bool, optional) — Whether enable inference within ONNX Runtime backend. The inference will be done within PyTorch by default.

Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init compute_metrics argument). You can also subclass and override this method to inject custom behavior.

predict

< >

( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' max_length: typing.Optional[int] = None num_beams: typing.Optional[int] = None inference_with_ort: bool = False )

Parameters

  • test_dataset (Dataset) — Dataset to run the predictions on. If it is an datasets.Dataset, columns not accepted by the model.forward() method are automatically removed. Has to implement the method __len__
  • ignore_keys (List[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
  • metric_key_prefix (str, optional, defaults to "eval") — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is "eval" (default)
  • max_length (int, optional) — The maximum target length to use when predicting with the generate method.
  • num_beams (int, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search.
  • inference_with_ort (bool, optional) — Whether enable inference within ONNX Runtime backend. The inference will be done within PyTorch by default.

Run prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like in evaluate().

If your predictions or labels have different sequence lengths (for instance because you’re doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100.

Returns: NamedTuple A namedtuple with the following keys:

  • predictions (np.ndarray): The predictions on test_dataset.
  • label_ids (np.ndarray, optional): The labels (if the dataset contained some).
  • metrics (Dict[str, float], optional): The potential dictionary of metrics (if the dataset contained labels).