Configuration classes for ONNX exports
Exporting a model to ONNX involves specifying:
- The input names.
- The output names.
- The dynamic axes. These refer to the input dimensions can be changed dynamically at runtime (e.g. a batch size or sequence length). All other axes will be treated as static, and hence fixed at runtime.
- Dummy inputs to trace the model. This is needed in PyTorch to record the computational graph and convert it to ONNX.
Since this data depends on the choice of model and task, we represent it in terms of configuration classes. Each configuration class is associated with
a specific model architecture, and follows the naming convention ArchitectureNameOnnxConfig
. For instance, the configuration which specifies the ONNX
export of BERT models is BertOnnxConfig
.
Since many architectures share similar properties for their ONNX configuration, 🤗 Optimum adopts a 3-level class hierarchy:
- Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc).
- Middle-end classes. These are aware of the modality, but multiple can exist for the same modality depending on the inputs they support. They specify which input generators should be used for the dummy inputs, but remain model-agnostic.
- Model-specific classes like the
BertOnnxConfig
mentioned above. These are the ones actually used to export models.
Base classes
class optimum.exporters.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'default' )
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
Class attributes:
- NORMALIZED_CONFIG_CLASS (
Type
) — A class derived from NormalizedConfig specifying how to normalize the model config. - DUMMY_INPUT_GENERATOR_CLASSES (
Tuple[Type]
) — A tuple of classes derived from DummyInputGenerator specifying how to create dummy inputs. - ATOL_FOR_VALIDATION (
Union[float, Dict[str, float]]
) — A float or a dictionary mapping task names to float, where the float values represent the absolute tolerance value to use during model conversion validation. - DEFAULT_ONNX_OPSET (
int
, defaults to 11) — The default ONNX opset to use for the ONNX export. - MIN_TORCH_VERSION (
packaging.version.Version
, defaults to~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION
) — The minimum torch version supporting the export of the model to ONNX. - PATCHING_SPECS (
Optional[List[PatchingSpec]]
, defaults toNone
) — Specify which operators / modules should be patched before performing the export, and how. This is useful when some operator is not supported in ONNX for instance.
inputs
< source >(
)
→
Dict[str, Dict[int, str]]
Returns
Dict[str, Dict[int, str]]
A mapping of each input name to a mapping of axis position to the axes symbolic name.
Dict containing the axis definition of the input tensors to provide to the model.
outputs
< source >(
)
→
Dict[str, Dict[int, str]]
Returns
Dict[str, Dict[int, str]]
A mapping of each output name to a mapping of axis position to the axes symbolic name.
Dict containing the axis definition of the output tensors to provide to the model.
generate_dummy_inputs
< source >(
framework: str = 'pt'
**kwargs
)
→
Dict
Parameters
-
framework (
str
, defaults to"pt"
) — The framework for which to create the dummy inputs. -
batch_size (
int
, defaults to 2) — The batch size to use in the dummy inputs. -
sequence_length (
int
, defaults to 16) — The sequence length to use in the dummy inputs. -
num_choices (
int
, defaults to 4) — The number of candidate answers provided for multiple choice task. -
image_width (
int
, defaults to 64) — The width to use in the dummy inputs for vision tasks. -
image_height (
int
, defaults to 64) — The height to use in the dummy inputs for vision tasks. -
num_channels (
int
, defaults to 3) — The number of channels to use in the dummpy inputs for vision tasks. -
feature_size (
int
, defaults to 80) — The number of features to use in the dummpy inputs for audio tasks in case it is not raw audio. This is for example the number of STFT bins or MEL bins. -
nb_max_frames (
int
, defaults to 3000) — The number of frames to use in the dummpy inputs for audio tasks in case the input is not raw audio. -
audio_sequence_length (
int
, defaults to 16000) — The number of frames to use in the dummpy inputs for audio tasks in case the input is raw audio.
Returns
Dict
A dictionary mapping the input names to dummy tensors in the proper framework format.
Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.
class optimum.exporters.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' use_past: bool = False use_past_in_inputs: typing.Optional[bool] = None use_present_in_outputs: typing.Optional[bool] = None )
Inherits from OnnxConfig. A base class to handle the ONNX configuration of decoder-only models.
with_past
< source >( config: PretrainedConfig task: str = 'default' ) → OnnxConfig
Instantiates a OnnxConfig with use_past
attribute set to True
.
class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' use_past: bool = False use_past_in_inputs: typing.Optional[bool] = None use_present_in_outputs: typing.Optional[bool] = None behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> )
Inherits from OnnxConfigWithPast. A base class to handle the ONNX configuration of encoder-decoder models.
Override this to specify custom attribute change for a given behavior.
with_behavior
< source >(
behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior]
use_past: bool = False
)
→
OnnxSeq2SeqConfigWithPast
Creates a copy of the current OnnxConfig but with a different ConfigBehavior
and use_past
value.
Middle-end classes
Text
class optimum.exporters.onnx.TextEncoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'default' )
Handles encoder-based text architectures.
class optimum.exporters.onnx.TextDecoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'default' use_past: bool = False use_past_in_inputs: typing.Optional[bool] = None use_present_in_outputs: typing.Optional[bool] = None )
Handles decoder-based text architectures.
class optimum.exporters.onnx.TextSeq2SeqOnnxConfig
< source >( config: PretrainedConfig task: str = 'default' use_past: bool = False use_past_in_inputs: typing.Optional[bool] = None use_present_in_outputs: typing.Optional[bool] = None behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> )
Handles encoder-decoder-based text architectures.
Vision
class optimum.exporters.onnx.config.VisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'default' )
Handles vision architectures.
Multi-modal
class optimum.exporters.onnx.config.TextAndVisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'default' )
Handles multi-modal text and vision architectures.