Exporting 🤗 Transformers models to ONNX
🤗 Transformers は transformers.onnx
パッケージを提供します。
設定オブジェクトを利用することで、モデルのチェックポイントをONNXグラフに変換することができます。
詳細はガイド を参照してください。 を参照してください。
ONNX Configurations
以下の3つの抽象クラスを提供しています。 エクスポートしたいモデルアーキテクチャのタイプに応じて、継承すべき3つの抽象クラスを提供します:
- エンコーダーベースのモデルは OnnxConfig を継承します。
- デコーダーベースのモデルは OnnxConfigWithPast を継承します。
- エンコーダー・デコーダーモデルは OnnxSeq2SeqConfigWithPast を継承しています。
OnnxConfig
class transformers.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'default' patching_specs: List = None )
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
flatten_output_collection_property
< source >( name: str field: Iterable ) → (Dict[str, Any])
Returns
(Dict[str, Any])
Outputs with flattened structure and key mapping this new structure.
Flatten any potential nested structure expanding the name of the field with the index of the element within the structure.
Instantiate a OnnxConfig for a specific model
generate_dummy_inputs
< source >( preprocessor: Union batch_size: int = -1 seq_length: int = -1 num_choices: int = -1 is_pair: bool = False framework: Optional = None num_channels: int = 3 image_width: int = 40 image_height: int = 40 sampling_rate: int = 22050 time_duration: float = 5.0 frequency: int = 220 tokenizer: PreTrainedTokenizerBase = None )
Parameters
- batch_size (
int
, optional, defaults to -1) — The batch size to export the model for (-1 means dynamic axis). - num_choices (
int
, optional, defaults to -1) — The number of candidate answers provided for multiple choice task (-1 means dynamic axis). - seq_length (
int
, optional, defaults to -1) — The sequence length to export the model for (-1 means dynamic axis). - is_pair (
bool
, optional, defaults toFalse
) — Indicate if the input is a pair (sentence 1, sentence 2) - framework (
TensorType
, optional, defaults toNone
) — The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for. - num_channels (
int
, optional, defaults to 3) — The number of channels of the generated images. - image_width (
int
, optional, defaults to 40) — The width of the generated images. - image_height (
int
, optional, defaults to 40) — The height of the generated images. - sampling_rate (
int
, optional defaults to 22050) — The sampling rate for audio data generation. - time_duration (
float
, optional defaults to 5.0) — Total seconds of sampling for audio data generation. - frequency (
int
, optional defaults to 220) — The desired natural frequency of generated audio.
Generate inputs to provide to the ONNX exporter for the specific framework
generate_dummy_inputs_onnxruntime
< source >( reference_model_inputs: Mapping ) → Mapping[str, Tensor]
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq models which have the encoder and decoder exported as separate ONNX files.
Flag indicating if the model requires using external data format
OnnxConfigWithPast
class transformers.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )
fill_with_past_key_values_
< source >( inputs_or_outputs: Mapping direction: str inverted_values_shape: bool = False )
Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
Instantiate a OnnxConfig with use_past
attribute set to True
OnnxSeq2SeqConfigWithPast
class transformers.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )
ONNX Features
各 ONNX 構成は、次のことを可能にする一連の 機能 に関連付けられています。 さまざまなタイプのトポロジまたはタスクのモデルをエクスポートします。
FeaturesManager
Check whether or not the model has the requested features.
determine_framework
< source >( model: str framework: str = None )
Determines the framework to use for the export.
The priority is in the following order:
- User input via
framework
. - If local checkpoint is provided, use the same framework as the checkpoint.
- Available framework in environment, with priority given to PyTorch
get_config
< source >( model_type: str feature: str ) → OnnxConfig
Gets the OnnxConfig for a model_type and feature combination.
get_model_class_for_feature
< source >( feature: str framework: str = 'pt' )
Attempts to retrieve an AutoModel class from a feature name.
get_model_from_feature
< source >( feature: str model: str framework: str = None cache_dir: str = None )
Attempts to retrieve a model from a model’s name and the feature to be enabled.
get_supported_features_for_model_type
< source >( model_type: str model_name: Optional = None )
Tries to retrieve the feature -> OnnxConfig constructor map from the model type.