Transformers documentation

导出 🤗 Transformers 模型到 ONNX

You are viewing v4.44.2 version. A newer version v4.46.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

导出 🤗 Transformers 模型到 ONNX

🤗 Transformers提供了一个transformers.onnx包,通过利用配置对象,您可以将模型checkpoints转换为ONNX图。

有关更多详细信息,请参阅导出 🤗 Transformers 模型的指南

ONNX Configurations

我们提供了三个抽象类,取决于您希望导出的模型架构类型:

OnnxConfig

class transformers.onnx.OnnxConfig

< >

( config: PretrainedConfig task: str = 'default' patching_specs: List = None )

Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.

flatten_output_collection_property

< >

( name: str field: Iterable ) (Dict[str, Any])

Returns

(Dict[str, Any])

Outputs with flattened structure and key mapping this new structure.

Flatten any potential nested structure expanding the name of the field with the index of the element within the structure.

from_model_config

< >

( config: PretrainedConfig task: str = 'default' )

Instantiate a OnnxConfig for a specific model

generate_dummy_inputs

< >

( preprocessor: Union batch_size: int = -1 seq_length: int = -1 num_choices: int = -1 is_pair: bool = False framework: Optional = None num_channels: int = 3 image_width: int = 40 image_height: int = 40 sampling_rate: int = 22050 time_duration: float = 5.0 frequency: int = 220 tokenizer: PreTrainedTokenizerBase = None )

Parameters

  • batch_size (int, optional, defaults to -1) — The batch size to export the model for (-1 means dynamic axis).
  • num_choices (int, optional, defaults to -1) — The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
  • seq_length (int, optional, defaults to -1) — The sequence length to export the model for (-1 means dynamic axis).
  • is_pair (bool, optional, defaults to False) — Indicate if the input is a pair (sentence 1, sentence 2)
  • framework (TensorType, optional, defaults to None) — The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
  • num_channels (int, optional, defaults to 3) — The number of channels of the generated images.
  • image_width (int, optional, defaults to 40) — The width of the generated images.
  • image_height (int, optional, defaults to 40) — The height of the generated images.
  • sampling_rate (int, optional defaults to 22050) — The sampling rate for audio data generation.
  • time_duration (float, optional defaults to 5.0) — Total seconds of sampling for audio data generation.
  • frequency (int, optional defaults to 220) — The desired natural frequency of generated audio.

Generate inputs to provide to the ONNX exporter for the specific framework

generate_dummy_inputs_onnxruntime

< >

( reference_model_inputs: Mapping ) Mapping[str, Tensor]

Parameters

  • reference_model_inputs ([Mapping[str, Tensor]) — Reference inputs for the model.

Returns

Mapping[str, Tensor]

The mapping holding the kwargs to provide to the model’s forward function

Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq models which have the encoder and decoder exported as separate ONNX files.

use_external_data_format

< >

( num_parameters: int )

Flag indicating if the model requires using external data format

OnnxConfigWithPast

class transformers.onnx.OnnxConfigWithPast

< >

( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )

fill_with_past_key_values_

< >

( inputs_or_outputs: Mapping direction: str inverted_values_shape: bool = False )

Fill the input_or_outputs mapping with past_key_values dynamic axes considering.

with_past

< >

( config: PretrainedConfig task: str = 'default' )

Instantiate a OnnxConfig with use_past attribute set to True

OnnxSeq2SeqConfigWithPast

class transformers.onnx.OnnxSeq2SeqConfigWithPast

< >

( config: PretrainedConfig task: str = 'default' patching_specs: List = None use_past: bool = False )

ONNX Features

每个ONNX配置与一组 特性 相关联,使您能够为不同类型的拓扑结构或任务导出模型。

FeaturesManager

class transformers.onnx.FeaturesManager

< >

( )

check_supported_model_or_raise

< >

( model: Union feature: str = 'default' )

Check whether or not the model has the requested features.

determine_framework

< >

( model: str framework: str = None )

Parameters

  • model (str) — The name of the model to export.
  • framework (str, optional, defaults to None) — The framework to use for the export. See above for priority if none provided.

Determines the framework to use for the export.

The priority is in the following order:

  1. User input via framework.
  2. If local checkpoint is provided, use the same framework as the checkpoint.
  3. Available framework in environment, with priority given to PyTorch

get_config

< >

( model_type: str feature: str ) OnnxConfig

Parameters

  • model_type (str) — The model type to retrieve the config for.
  • feature (str) — The feature to retrieve the config for.

Returns

OnnxConfig

config for the combination

Gets the OnnxConfig for a model_type and feature combination.

get_model_class_for_feature

< >

( feature: str framework: str = 'pt' )

Parameters

  • feature (str) — The feature required.
  • framework (str, optional, defaults to "pt") — The framework to use for the export.

Attempts to retrieve an AutoModel class from a feature name.

get_model_from_feature

< >

( feature: str model: str framework: str = None cache_dir: str = None )

Parameters

  • feature (str) — The feature required.
  • model (str) — The name of the model to export.
  • framework (str, optional, defaults to None) — The framework to use for the export. See FeaturesManager.determine_framework for the priority should none be provided.

Attempts to retrieve a model from a model’s name and the feature to be enabled.

get_supported_features_for_model_type

< >

( model_type: str model_name: Optional = None )

Parameters

  • model_type (str) — The model type to retrieve the supported features for.
  • model_name (str, optional) — The name attribute of the model object, only used for the exception message.

Tries to retrieve the feature -> OnnxConfig constructor map from the model type.

< > Update on GitHub