Exporting 🤗 Transformers models to ONNX
🤗 Transformers provides a transformers.onnx
package that enables you to
convert model checkpoints to an ONNX graph by leveraging configuration objects.
See the guide on exporting 🤗 Transformers models for more details.
ONNX Configurations
We provide three abstract classes that you should inherit from, depending on the type of model architecture you wish to export:
- Encoder-based models inherit from OnnxConfig
- Decoder-based models inherit from OnnxConfigWithPast
- Encoder-decoder models inherit from OnnxSeq2SeqConfigWithPast
OnnxConfig
class transformers.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'default' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = None )
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
flatten_output_collection_property
< source >( name: str field: typing.Iterable[typing.Any] ) → (Dict[str, Any])
Returns
(Dict[str, Any])
Outputs with flattened structure and key mapping this new structure.
Flatten any potential nested structure expanding the name of the field with the index of the element within the structure.
Instantiate a OnnxConfig for a specific model
generate_dummy_inputs
< source >( preprocessor: typing.Union[ForwardRef('PreTrainedTokenizerBase'), ForwardRef('FeatureExtractionMixin')] batch_size: int = -1 seq_length: int = -1 num_choices: int = -1 is_pair: bool = False framework: typing.Optional[transformers.utils.generic.TensorType] = None num_channels: int = 3 image_width: int = 40 image_height: int = 40 tokenizer: PreTrainedTokenizerBase = None )
Parameters
-
batch_size (
int
, optional, defaults to -1) — The batch size to export the model for (-1 means dynamic axis). -
num_choices (
int
, optional, defaults to -1) — The number of candidate answers provided for multiple choice task (-1 means dynamic axis). -
seq_length (
int
, optional, defaults to -1) — The sequence length to export the model for (-1 means dynamic axis). -
is_pair (
bool
, optional, defaults toFalse
) — Indicate if the input is a pair (sentence 1, sentence 2) -
framework (
TensorType
, optional, defaults toNone
) — The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. -
num_channels (
int
, optional, defaults to 3) — The number of channels of the generated images. -
image_width (
int
, optional, defaults to 40) — The width of the generated images. -
image_height (
int
, optional, defaults to 40) — The height of the generated images.
Generate inputs to provide to the ONNX exporter for the specific framework
Flag indicating if the model requires using external data format
OnnxConfigWithPast
class transformers.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = None use_past: bool = False )
fill_with_past_key_values_
< source >( inputs_or_outputs: typing.Mapping[str, typing.Mapping[int, str]] direction: str )
Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
Instantiate a OnnxConfig with use_past
attribute set to True
OnnxSeq2SeqConfigWithPast
class transformers.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' patching_specs: typing.List[transformers.onnx.config.PatchingSpec] = None use_past: bool = False )
ONNX Features
Each ONNX configuration is associated with a set of features that enable you to export models for different types of topologies or tasks.
FeaturesManager
check_supported_model_or_raise
< source >( model: typing.Union[ForwardRef('PreTrainedModel'), ForwardRef('TFPreTrainedModel')] feature: str = 'default' )
Check whether or not the model has the requested features.
get_config
< source >(
model_type: str
feature: str
)
→
OnnxConfig
Gets the OnnxConfig for a model_type and feature combination.
get_model_class_for_feature
< source >( feature: str framework: str = 'pt' )
Attempts to retrieve an AutoModel class from a feature name.
get_model_from_feature
< source >( feature: str model: str framework: str = 'pt' cache_dir: str = None )
Attempts to retrieve a model from a model’s name and the feature to be enabled.
get_supported_features_for_model_type
< source >( model_type: str model_name: typing.Optional[str] = None )
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.