AWS Trainium & Inferentia documentation

Configuration classes for Neuron exports

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Configuration classes for Neuron exports

Exporting a PyTorch model to neuron compiled model involves specifying:

  1. The input names.
  2. The output names.
  3. Dummy inputs used to trace the model. This is needed for Neuron-Compiler to record the computational graph and convert it to TorchScript module.
  4. Compilation arguments used to control the trade-off between hardware efficiency(latency, throughput) and accuracy.

Depending on the choice of model and task, we represent the data above with configuration classes. Each configuration class is associated with a specific model architecture, and follows the naming convention ArchitectureNameNeuronConfig. For instance, the configuration which specifies the Neuron export of BERT models is BertNeuronConfig.

Since many architectures share similar properties for their Neuron configuration, ๐Ÿค— Optimum adopts a 3-level class hierarchy:

  1. Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc).
  2. Middle-end classes. These are aware of the modality, but multiple can exist for the same modality depending on the inputs they support. They specify which input generators should be used for the dummy inputs, but remain model-agnostic.
  3. Model-specific classes like the BertNeuronConfig mentioned above. These are the ones actually used to export models.

Supported architectures

Architecture Task
ALBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
BERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
CamemBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
ConvBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
DeBERTa (INF2 only) feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
DeBERTa-v2 (INF2 only) feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
DistilBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
ELECTRA feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
FlauBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
GPT2 text-generation
MobileBERT feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
MPNet feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
RoBERTa feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
RoFormer feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
XLM feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification
XLM-RoBERTa feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification

More details for checking supported tasks here.

More architectures coming soon, stay tuned! ๐Ÿš€