AWS Trainium & Inferentia documentation

Adding support for new architectures

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Adding support for new architectures

NOTE: ❗This section does not apply to the decoder model’s inference with autoregressive sampling integrated through transformers-neuronx. If you want to add support for these models, please open an issue on the Optimum Neuron GitHub repo, and ping maintainers for help.

You want to export and run a new model on AWS Inferentia or Trainium? Check the guideline, and submit a pull request to 🤗 Optimum Neuron’s GitHub repo!

To support a new model architecture in the Optimum Neuron library here are some steps to follow:

  1. Implement a custom Neuron configuration.
  2. Export and validate the model.
  3. Contribute to the GitHub repo.

Implement a custom Neuron configuration

To support the export of a new model to a Neuron compatible format, the first thing to do is to define a Neuron configuration, describing how to export the PyTorch model by specifying:

  1. The input names.
  2. The output names.
  3. The dummy inputs used to trace the model: the Neuron Compiler records the computational graph via tracing and works on the resulting TorchScript module.
  4. The compilation arguments used to control the trade-off between hardware efficiency (latency, throughput) and accuracy.

Depending on the choice of model and task, we represent the data above with configuration classes. Each configuration class is associated with a specific model architecture, and follows the naming convention ArchitectureNameNeuronConfig. For instance, the configuration that specifies the Neuron export of BERT models is BertNeuronConfig.

Since many architectures share similar properties for their Neuron configuration, 🤗 Optimum adopts a 3-level class hierarchy:

  1. Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc).
  2. Middle-end classes. These are aware of the modality. Multiple config classes could exist for the same modality, depending on the inputs they support. They specify which input generators should be used for generating the dummy inputs, but remain model-agnostic.
  3. Model-specific classes like the BertNeuronConfig mentioned above. These are the ones actually used to export models.

Example: Adding support for ESM models

Here we take the support of ESM models as an example. Let’s create an EsmNeuronConfig class in the optimum/exporters/neuron/model_configs.py.

When an Esm model interprets as a text encoder, we are able to inherit from the middle-end class TextEncoderNeuronConfig. Since the modeling and configuration of Esm is almost the same as BERT when it is interpreted as an encoder, we can use the NormalizedConfigManager with model_type=bert to normalize the configuration to generate dummy inputs for tracing the model.

And one last step, since optimum-neuron is an extension of optimum, we need to register the Neuron config that we create to the TasksManager with the register_in_tasks_manager decorator by specifying the model type and supported tasks.


@register_in_tasks_manager("esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"])
class EsmNeuronConfig(TextEncoderNeuronConfig):
    NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert")
    ATOL_FOR_VALIDATION = 1e-3  # absolute tolerance to compare for comparing model on CPUs

    @property
    def inputs(self) -> List[str]:
        return ["input_ids", "attention_mask"]

Export and validate the model

With the Neuron configuration class that you implemented, now do a quick test if it works as expected:

  • Export
optimum-cli export neuron --model facebook/esm2_t33_650M_UR50D --task text-classification --batch_size 1 --sequence_length 16 esm_neuron/

During the export validate_model_outputs will be called to validate the outputs of your exported Neuron model by comparing them to the results of PyTorch on the CPU. You could also validate the model manually with:

from optimum.exporters.neuron import validate_model_outputs

validate_model_outputs(
    neuron_config, base_model, neuron_model_path, neuron_named_outputs, neuron_config.ATOL_FOR_VALIDATION
)
  • Inference (optional)
from transformers import AutoTokenizer
from optimum.neuron import NeuronModelForSequenceClassification

model = NeuronModelForSequenceClassification.from_pretrained("esm_neuron/")
tokenizer = AutoTokenizer.from_pretrained("esm_neuron/")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
logits = model(**inputs).logits

Contribute to the GitHub repo

We are almost all set. Now submit a pull request to make your work accessible to all community members!

We usually test smaller checkpoints to accelerate the CIs, you could find tiny models for testing under the Hugging Face Internal Testing Organization.

You have made a new model accessible on Neuron for the community! Thanks for joining us in the endeavor of democratizing good machine learning 🤗.