GaudiTrainer
The GaudiTrainer
class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.
Before instantiating your GaudiTrainer
, create a GaudiTrainingArguments object to access all the points of customization during training.
The GaudiTrainer
class is optimized for 🤗 Transformers models running on Habana Gaudi.
Here is an example of how to customize GaudiTrainer
to use a weighted loss (useful when you have an unbalanced training set):
from torch import nn
from optimum.habana import GaudiTrainer
class CustomGaudiTrainer(GaudiTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
Another way to customize the training loop behavior for the PyTorch GaudiTrainer
is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).
GaudiTrainer
class optimum.habana.GaudiTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: TrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )
GaudiTrainer is built on top of the tranformers’ Trainer to enable deployment on Habana’s Gaudi.
A helper wrapper that creates an appropriate context manager for autocast
while feeding it the desired
arguments, depending on the situation. Modified by Habana to enable using autocast
on Gaudi devices.
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer’s init through optimizers
, or subclass and override this method in a subclass.
evaluation_loop
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
.
Works both with or without labels.
log
< source >( logs: typing.Dict[str, float] )
Log logs
on the various objects watching training.
Subclass and override this method to inject custom behavior.
prediction_loop
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' )
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
.
Works both with or without labels.
prediction_step
< source >( model: Module inputs: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]] prediction_loss_only: bool ignore_keys: typing.Optional[typing.List[str]] = None ) → Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
Parameters
-
model (
torch.nn.Module
) — The model to evaluate. -
inputs (
Dict[str, Union[torch.Tensor, Any]]
) — The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argumentlabels
. Check your model’s documentation for all accepted arguments. -
prediction_loss_only (
bool
) — Whether or not to return the loss only. -
ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.
Returns
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
A tuple with the loss, logits and labels (each being optional).
Perform an evaluation step on model
using inputs
.
Subclass and override to inject custom behavior.
Will save the model, so you can reload it using from_pretrained()
.
Will only save from the main process.
train
< source >( resume_from_checkpoint: typing.Union[bool, str, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )
Parameters
-
resume_from_checkpoint (
str
orbool
, optional) — If astr
, local path to a saved checkpoint as saved by a previous instance ofTrainer
. If abool
and equalsTrue
, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer
. If present, training will resume from the model/optimizer/scheduler states loaded here. -
trial (
optuna.Trial
orDict[str, Any]
, optional) — The trial run or the hyperparameter dictionary for hyperparameter search. -
ignore_keys_for_eval (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. -
kwargs (
Dict[str, Any]
, optional) — Additional keyword arguments used to hide deprecated arguments
Main training entry point.
training_step
< source >(
model: Module
inputs: typing.Dict[str, typing.Union[torch.Tensor, typing.Any]]
)
→
torch.Tensor
Parameters
-
model (
torch.nn.Module
) — The model to train. -
inputs (
Dict[str, Union[torch.Tensor, Any]]
) — The inputs and targets of the model.The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument
labels
. Check your model’s documentation for all accepted arguments.
Returns
torch.Tensor
The tensor with training loss on this batch.
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
GaudiSeq2SeqTrainer
class optimum.habana.GaudiSeq2SeqTrainer
< source >( model: typing.Union[ForwardRef('PreTrainedModel'), torch.nn.modules.module.Module] = None gaudi_config: GaudiConfig = None args: GaudiTrainingArguments = None data_collator: typing.Optional[ForwardRef('DataCollator')] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[ForwardRef('PreTrainedTokenizerBase')] = None model_init: typing.Union[typing.Callable[[], ForwardRef('PreTrainedModel')], NoneType] = None compute_metrics: typing.Union[typing.Callable[[ForwardRef('EvalPrediction')], typing.Dict], NoneType] = None callbacks: typing.Optional[typing.List[ForwardRef('TrainerCallback')]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )
evaluate
< source >( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' **gen_kwargs )
Parameters
-
eval_dataset (
Dataset
, optional) — Pass a dataset if you wish to overrideself.eval_dataset
. If it is an Dataset, columns not accepted by themodel.forward()
method are automatically removed. It must implement the__len__
method. -
ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additionalgenerate
specific kwargs.
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init compute_metrics
argument).
You can also subclass and override this method to inject custom behavior.
predict
< source >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' **gen_kwargs )
Parameters
-
test_dataset (
Dataset
) — Dataset to run the predictions on. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. Has to implement the method__len__
-
ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. -
metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) -
max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. -
num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs — Additionalgenerate
specific kwargs.
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in evaluate()
.
GaudiTrainingArguments
class optimum.habana.GaudiTrainingArguments
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: typing.Optional[float] = 1e-06 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: typing.Optional[bool] = False save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = False save_on_each_node: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'hpu_amp' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = 'hccl' tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False sharded_ddp: typing.Union[typing.List[transformers.trainer_utils.ShardedDDPOption], str, NoneType] = '' fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Optional[str] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None deepspeed: typing.Optional[str] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str, NoneType] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Optional[typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = False ddp_bucket_cap_mb: typing.Optional[int] = 230 ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False include_inputs_for_metrics: bool = False fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None use_habana: typing.Optional[bool] = False gaudi_config_name: typing.Optional[str] = None use_lazy_mode: typing.Optional[bool] = False use_hpu_graphs: typing.Optional[bool] = False use_hpu_graphs_for_inference: typing.Optional[bool] = False use_hpu_graphs_for_training: typing.Optional[bool] = False distribution_strategy: typing.Optional[str] = 'ddp' throughput_warmup_steps: typing.Optional[int] = 0 adjust_throughput: bool = False pipelining_fwd_bwd: typing.Optional[bool] = False non_blocking_data_copy: typing.Optional[bool] = False profiling_warmup_steps: typing.Optional[int] = 0 profiling_steps: typing.Optional[int] = 0 )
Parameters
-
use_habana (
bool
, optional, defaults toFalse
) — Whether to use Habana’s HPU for running the model. -
gaudi_config_name (
str
, optional) — Pretrained Gaudi config name or path. -
use_lazy_mode (
bool
, optional, defaults toFalse
) — Whether to use lazy mode for running the model. -
use_hpu_graphs (
bool
, optional, defaults toFalse
) — Deprecated, useuse_hpu_graphs_for_inference
instead. Whether to use HPU graphs for performing inference. -
use_hpu_graphs_for_inference (
bool
, optional, defaults toFalse
) — Whether to use HPU graphs for performing inference. It will speed up latency but may not be compatible with some operations. -
use_hpu_graphs_for_training (
bool
, optional, defaults toFalse
) — Whether to use HPU graphs for performing inference. It will speed up training but may not be compatible with some operations. -
distribution_strategy (
str
, optional, defaults toddp
) — Determines how data parallel distributed training is achieved. May be:ddp
orfast_ddp
. -
throughput_warmup_steps (
int
, optional, defaults to 0) — Number of steps to ignore for throughput calculation. For example, withthroughput_warmup_steps=N
, the first N steps will not be considered in the calculation of the throughput. This is especially useful in lazy mode where the first two or three iterations typically take longer. -
adjust_throughput (‘bool’, optional, defaults to
False
) — Whether to remove the time taken for logging, evaluating and saving from throughput calculation. -
pipelining_fwd_bwd (
bool
, optional, defaults toFalse
) — Whether to add an additionalmark_step
between forward and backward for pipelining host backward building and HPU forward computing. -
non_blocking_data_copy (
bool
, optional, defaults toFalse
) — Whether to enable async data copy when preparing inputs. -
profiling_warmup_steps (
int
, optional, defaults to 0) — Number of steps to ignore for profling. -
profiling_steps (
int
, optional, defaults to 0) — Number of steps to be captured when enabling profiling.
GaudiTrainingArguments is built on top of the Tranformers’ TrainingArguments to enable deployment on Habana’s Gaudi.
GaudiSeq2SeqTrainingArguments
class optimum.habana.GaudiSeq2SeqTrainingArguments
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: typing.Optional[float] = 1e-06 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: typing.Optional[bool] = False save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = False save_on_each_node: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'hpu_amp' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = 'hccl' tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False sharded_ddp: typing.Union[typing.List[transformers.trainer_utils.ShardedDDPOption], str, NoneType] = '' fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Optional[str] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None deepspeed: typing.Optional[str] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str, NoneType] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Optional[typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = False ddp_bucket_cap_mb: typing.Optional[int] = 230 ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False include_inputs_for_metrics: bool = False fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None use_habana: typing.Optional[bool] = False gaudi_config_name: typing.Optional[str] = None use_lazy_mode: typing.Optional[bool] = False use_hpu_graphs: typing.Optional[bool] = False use_hpu_graphs_for_inference: typing.Optional[bool] = False use_hpu_graphs_for_training: typing.Optional[bool] = False distribution_strategy: typing.Optional[str] = 'ddp' throughput_warmup_steps: typing.Optional[int] = 0 adjust_throughput: bool = False pipelining_fwd_bwd: typing.Optional[bool] = False non_blocking_data_copy: typing.Optional[bool] = False profiling_warmup_steps: typing.Optional[int] = 0 profiling_steps: typing.Optional[int] = 0 sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, optimum.habana.transformers.generation.configuration_utils.GaudiGenerationConfig, NoneType] = None )
Parameters
-
sortish_sampler (
bool
, optional, defaults toFalse
) — Whether to use a sortish sampler or not. Only possible if the underlying datasets are Seq2SeqDataset for now but will become generally available in the near future. It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for the training set. -
predict_with_generate (
bool
, optional, defaults toFalse
) — Whether to use generate to calculate generative metrics (ROUGE, BLEU). -
generation_max_length (
int
, optional) — Themax_length
to use on each evaluation loop whenpredict_with_generate=True
. Will default to themax_length
value of the model configuration. -
generation_num_beams (
int
, optional) — Thenum_beams
to use on each evaluation loop whenpredict_with_generate=True
. Will default to thenum_beams
value of the model configuration. -
generation_config (
str
orPath
ortransformers.generation.GenerationConfig
, optional) — Allows to load atransformers.generation.GenerationConfig
from thefrom_pretrained
method. This can be either:- a string, the model id of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - a path to a directory containing a configuration file saved using the
transformers.GenerationConfig.save_pretrained
method, e.g.,./my_model_directory/
. - a
transformers.generation.GenerationConfig
object.
- a string, the model id of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
GaudiSeq2SeqTrainingArguments is built on top of the Tranformers’ Seq2SeqTrainingArguments to enable deployment on Habana’s Gaudi.
Serializes this instance while replace Enum
by their values and GaudiGenerationConfig
by dictionaries (for JSON
serialization support). It obfuscates the token values by removing their value.