Optimum documentation

Adding support for an unsupported architecture

You are viewing v1.7.3 version. A newer version v1.17.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Adding support for an unsupported architecture

If you wish to export a model whose architecture is not already supported by the library, these are the main steps to follow:

  1. Implement a custom ONNX configuration.
  2. Register the ONNX configuration in the TasksManager.
  3. Export the model to ONNX.
  4. Validate the outputs of the original and exported models.

In this section, we’ll look at how BERT was implemented to show what’s involved with each step.

Implementing a custom ONNX configuration

Let’s start with the ONNX configuration object. We provide a 3-level class hierarchy, and to add support for a model, inheriting from the right middle-end class will be the way to go most of the time. You might have to implement a middle-end class yourself if you are adding an architecture handling a modality and/or case never seen before.

A good way to implement a custom ONNX configuration is to look at the existing configuration implementations in the optimum/exporters/onnx/model_configs.py file.

Also, if the architecture you are trying to add is (very) similar to an architecture that is already supported (for instance adding support for ALBERT when BERT is already supported), trying to simply inheriting from this class might work.

When inheriting from a middle-end class, look for the one handling the same modality / category of models as the one you are trying to support.

Example: Adding support for BERT

Since BERT is an encoder-based model for text, its configuration inherits from the middle-end class TextEncoderOnnxConfig. In optimum/exporters/onnx/model_configs.py:

# This class is actually in optimum/exporters/onnx/config.py
class TextEncoderOnnxConfig(OnnxConfig):
    # Describes how to generate the dummy inputs.
    DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)

class BertOnnxConfig(TextEncoderOnnxConfig):
    # Specifies how to normalize the BertConfig, this is needed to access common attributes
    # during dummy input generation.
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    # Sets the absolute tolerance to when validating the exported ONNX model against the
    # reference model.

    def inputs(self) -> Dict[str, Dict[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
            dynamic_axis = {0: "batch_size", 1: "sequence_length"}
        return {
            "input_ids": dynamic_axis,
            "attention_mask": dynamic_axis,
            "token_type_ids": dynamic_axis,

First let’s explain what TextEncoderOnnxConfig is all about. While most of the features are already implemented in OnnxConfig, this class is modality-agnostic, meaning that it does not know what kind of inputs it should handle. The way input generation is handled is via the DUMMY_INPUT_GENERATOR_CLASSES attribute, which is a tuple of DummyInputGenerators. Here we are making a modality-aware configuration inheriting from OnnxConfig by specifying DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,).

Then comes the model-specific class, BertOnnxConfig. Two class attributes are specified here:

  • NORMALIZED_CONFIG_CLASS: this must be a NormalizedConfig, it basically allows the input generator to access the model config attributes in a generic way.
  • ATOL_FOR_VALIDATION: it is used when validating the exported model against the original one, this is the absolute acceptable tolerance for the output values difference.

Every configuration object must implement the inputs property and return a mapping, where each key corresponds to an input name, and each value indicates the axes in that input that are dynamic. For BERT, we can see that three inputs are required: input_ids, attention_mask and token_type_ids. These inputs have the same shape of (batch_size, sequence_length) (except for the multiple-choice task) which is why we see the same axes used in the configuration.

Once you have implemented an ONNX configuration, you can instantiate it by providing the base model’s configuration as follows:

>>> from transformers import AutoConfig
>>> from optimum.exporters.onnx.model_configs import BertOnnxConfig
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config = BertOnnxConfig(config)

The resulting object has several useful properties. For example, you can view the ONNX operator set that will be used during the export:

>>> print(onnx_config.DEFAULT_ONNX_OPSET)

You can also view the outputs associated with the model as follows:

>>> print(onnx_config.outputs)
OrderedDict([('last_hidden_state', {0: 'batch_size', 1: 'sequence_length'})])

Notice that the outputs property follows the same structure as the inputs; it returns an OrderedDict of named outputs and their shapes. The output structure is linked to the choice of task that the configuration is initialised with. By default, the ONNX configuration is initialized with the default task that corresponds to exporting a model loaded with the AutoModel class. If you want to export a model for another task, just provide a different task to the task argument when you initialize the ONNX configuration. For example, if we wished to export BERT with a sequence classification head, we could use:

>>> from transformers import AutoConfig

>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config_for_seq_clf = BertOnnxConfig(config, task="sequence-classification")
>>> print(onnx_config_for_seq_clf.outputs)
OrderedDict([('logits', {0: 'batch_size'})])

Check out BartOnnxConfig for an advanced example.

Registering the ONNX configuration in the TasksManager

The TasksManager is the main entry-point to load a model given a name and a task, and to get the proper configuration for a given (architecture, backend) couple. When adding support for the export to ONNX, registering the configuration to the TasksManager will make the export available in the command line tool.

To do that, add an entry in the _SUPPORTED_MODEL_TYPE attribute:

  • If the model is already supported for other backends than ONNX, it will already have an entry, so you will only need to add an onnx key specifying the name of the configuration class.
  • Otherwise, you will have to add the whole entry.

For BERT, it looks as follows:

    "bert": supported_tasks_mapping(

Exporting the model

Once you have implemented the ONNX configuration, the next step is to export the model. Here we can use the export() function provided by the optimum.exporters.onnx package. This function expects the ONNX configuration, along with the base model, and the path to save the exported file:

>>> from pathlib import Path
>>> from optimum.exporters import TasksManager
>>> from optimum.exporters.onnx import export
>>> from transformers import AutoModel

>>> base_model = AutoModel.from_pretrained("bert-base-uncased")

>>> onnx_path = Path("model.onnx")
>>> onnx_config_constructor = TasksManager.get_exporter_config_constructor("onnx", base_model)
>>> onnx_config = onnx_config_constructor(base_model.config)

>>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_path, onnx_config.DEFAULT_ONNX_OPSET)

The onnx_inputs and onnx_outputs returned by the export() function are lists of the keys defined in the inputs and inputs properties of the configuration. Once the model is exported, you can test that the model is well formed as follows:

>>> import onnx

>>> onnx_model = onnx.load("model.onnx")
>>> onnx.checker.check_model(onnx_model)

If your model is larger than 2GB, you will see that many additional files are created during the export. This is expected because ONNX uses Protocol Buffers to store the model and these have a size limit of 2GB. See the ONNX documentation for instructions on how to load models with external data.

Validating the model outputs

The final step is to validate that the outputs from the base and exported model agree within some absolute tolerance. Here we can use the validate_model_outputs() function provided by the optimum.exporters.onnx package:

>>> from optimum.exporters.onnx import validate_model_outputs

>>> validate_model_outputs(
...     onnx_config, base_model, onnx_path, onnx_outputs, onnx_config.ATOL_FOR_VALIDATION
... )

Contributing the new configuration to 🤗 Optimum

Now that the support for the architectures has been implemented, and validated, there are two things left:

  1. Add your model architecture to the tests in tests/exporters/test_onnx_export.py
  2. Create a PR on the optimum repo

Thanks for you contribution!