Transformers documentation

Export to TorchScript

You are viewing v4.45.2 version. A newer version v4.46.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Export to TorchScript

This is the very beginning of our experiments with TorchScript and we are still exploring its capabilities with variable-input-size models. It is a focus of interest to us and we will deepen our analysis in upcoming releases, with more code examples, a more flexible implementation, and benchmarks comparing Python-based codes with compiled TorchScript.

According to the TorchScript documentation:

TorchScript is a way to create serializable and optimizable models from PyTorch code.

There are two PyTorch modules, JIT and TRACE, that allow developers to export their models to be reused in other programs like efficiency-oriented C++ programs.

We provide an interface that allows you to export 🤗 Transformers models to TorchScript so they can be reused in a different environment than PyTorch-based Python programs. Here, we explain how to export and use our models using TorchScript.

Exporting a model requires two things:

  • model instantiation with the torchscript flag
  • a forward pass with dummy inputs

These necessities imply several things developers should be careful about as detailed below.

TorchScript flag and tied weights

The torchscript flag is necessary because most of the 🤗 Transformers language models have tied weights between their Embedding layer and their Decoding layer. TorchScript does not allow you to export models that have tied weights, so it is necessary to untie and clone the weights beforehand.

Models instantiated with the torchscript flag have their Embedding layer and Decoding layer separated, which means that they should not be trained down the line. Training would desynchronize the two layers, leading to unexpected results.

This is not the case for models that do not have a language model head, as those do not have tied weights. These models can be safely exported without the torchscript flag.

Dummy inputs and standard lengths

The dummy inputs are used for a models forward pass. While the inputs’ values are propagated through the layers, PyTorch keeps track of the different operations executed on each tensor. These recorded operations are then used to create the trace of the model.

The trace is created relative to the inputs’ dimensions. It is therefore constrained by the dimensions of the dummy input, and will not work for any other sequence length or batch size. When trying with a different size, the following error is raised:

`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`

We recommended you trace the model with a dummy input size at least as large as the largest input that will be fed to the model during inference. Padding can help fill the missing values. However, since the model is traced with a larger input size, the dimensions of the matrix will also be large, resulting in more calculations.

Be careful of the total number of operations done on each input and follow the performance closely when exporting varying sequence-length models.

Using TorchScript in Python

This section demonstrates how to save and load models as well as how to use the trace for inference.

Saving a model

To export a BertModel with TorchScript, instantiate BertModel from the BertConfig class and then save it to disk under the filename traced_bert.pt:

from transformers import BertModel, BertTokenizer, BertConfig
import torch

enc = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(
    vocab_size_or_config_json_file=32000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    torchscript=True,
)

# Instantiating the model
model = BertModel(config)

# The model needs to be in evaluation mode
model.eval()

# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
model = BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)

# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

Loading a model

Now you can load the previously saved BertModel, traced_bert.pt, from disk and use it on the previously initialised dummy_input:

loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()

all_encoder_layers, pooled_output = loaded_model(*dummy_input)

Using a traced model for inference

Use the traced model for inference by using its __call__ dunder method:

traced_model(tokens_tensor, segments_tensors)

Deploy Hugging Face TorchScript models to AWS with the Neuron SDK

AWS introduced the Amazon EC2 Inf1 instance family for low cost, high performance machine learning inference in the cloud. The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator, specializing in deep learning inferencing workloads. AWS Neuron is the SDK for Inferentia that supports tracing and optimizing transformers models for deployment on Inf1. The Neuron SDK provides:

  1. Easy-to-use API with one line of code change to trace and optimize a TorchScript model for inference in the cloud.
  2. Out of the box performance optimizations for improved cost-performance.
  3. Support for Hugging Face transformers models built with either PyTorch or TensorFlow.

Implications

Transformers models based on the BERT (Bidirectional Encoder Representations from Transformers) architecture, or its variants such as distilBERT and roBERTa run best on Inf1 for non-generative tasks such as extractive question answering, sequence classification, and token classification. However, text generation tasks can still be adapted to run on Inf1 according to this AWS Neuron MarianMT tutorial. More information about models that can be converted out of the box on Inferentia can be found in the Model Architecture Fit section of the Neuron documentation.

Dependencies

Using AWS Neuron to convert models requires a Neuron SDK environment which comes preconfigured on AWS Deep Learning AMI.

Converting a model for AWS Neuron

Convert a model for AWS NEURON using the same code from Using TorchScript in Python to trace a BertModel. Import the torch.neuron framework extension to access the components of the Neuron SDK through a Python API:

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron

You only need to modify the following line:

- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [tokens_tensor, segments_tensors])

This enables the Neuron SDK to trace the model and optimize it for Inf1 instances.

To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates, please see the AWS NeuronSDK documentation.

< > Update on GitHub