Transformers documentation

XLA Integration for TensorFlow Models

Join the Hugging Face community

to get started

XLA Integration for TensorFlow Models

Accelerated Linear Algebra, dubbed XLA, is a compiler for accelerating the runtime of TensorFlow Models. From the official documentation:

XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.

Using XLA in TensorFlow is simple – it comes packaged inside the tensorflow library, and it can be triggered with the jit_compile argument in any graph-creating function such as tf.function. When using Keras methods like fit() and predict(), you can enable XLA simply by passing the jit_compile argument to model.compile(). However, XLA is not limited to these methods - it can also be used to accelerate any arbitrary tf.function.

Several TensorFlow methods in 🤗 Transformers have been rewritten to be XLA-compatible, including text generation for models such as GPT2, T5 and OPT, as well as speech processing for models such as Whisper.

While the exact amount of speed-up is very much model-dependent, for TensorFlow text generation models inside 🤗 Transformers, we noticed a speed-up of ~100x. This document will explain how you can use XLA for these models to get the maximum amount of performance. We’ll also provide links to additional resources if you’re interested to learn more about the benchmarks and our design philosophy behind the XLA integration.

Running TF functions with XLA

Let us consider the following model in TensorFlow:

import tensorflow as tf

model = tf.keras.Sequential(
[tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
)

The above model accepts inputs having a dimension of (10, ). We can use the model for running a forward pass like so:

# Generate random inputs for the model.
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))

# Run a forward pass.
_ = model(random_inputs)

In order to run the forward pass with an XLA-compiled function, we’d need to do:

xla_fn = tf.function(model, jit_compile=True)
_ = xla_fn(random_inputs)

The default call() function of the model is used for compiling the XLA graph. But if there’s any other model function you want to compile into XLA that’s also possible with:

my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)

Running a TF text generation model with XLA from 🤗 Transformers

To enable XLA-accelerated generation within 🤗 Transformers, you need to have a recent version of transformers installed. You can install it by running:

pip install transformers --upgrade

And then you can run the following code:

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM

# Will error if the minimal version of Transformers is not installed.
from transformers.utils import check_min_version

check_min_version("4.21.0")

model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]

# One line to create an XLA generation function
xla_generate = tf.function(model.generate, jit_compile=True)

tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)

decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the

As you can notice, enabling XLA on generate() is just a single line of code. The rest of the code remains unchanged. However, there are a couple of gotchas in the above code snippet that are specific to XLA. You need to be aware of those to realize the speed-ups that XLA can bring in. We discuss these in the following section.

Gotchas to be aware of

When you are executing an XLA-enabled function (like xla_generate() above) for the first time, it will internally try to infer the computation graph, which is time-consuming. This process is known as “tracing”.

You might notice that the generation time is not fast. Successive calls of xla_generate() (or any other XLA-enabled function) won’t have to infer the computation graph, given the inputs to the function follow the same shape with which the computation graph was initially built. While this is not a problem for modalities with fixed input shapes (e.g., images), you must pay attention if you are working with variable input shape modalities (e.g., text).

To ensure xla_generate() always operates with the same input shapes, you can specify the padding arguments when calling the tokenizer.

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM

model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]

xla_generate = tf.function(model.generate, jit_compile=True)

# Here, we call the tokenizer with padding options.

generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")

This way, you can ensure that the inputs to xla_generate() will always receive inputs with the shape it was traced with and thus leading to speed-ups in the generation time. You can verify this with the code below:

import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM

model = TFAutoModelForCausalLM.from_pretrained("gpt2")

xla_generate = tf.function(model.generate, jit_compile=True)

for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
start = time.time_ns()
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
end = time.time_ns()
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

On a Tesla T4 GPU, you can expect the outputs like so:

Execution time -- 30819.6 ms

Execution time -- 79.0 ms

Execution time -- 78.9 ms

The first call to xla_generate() is time-consuming because of tracing, but the successive calls are orders of magnitude faster. Keep in mind that any change in the generation options at any point with trigger re-tracing and thus leading to slow-downs in the generation time.

We didn’t cover all the text generation options 🤗 Transformers provides in this document. We encourage you to read the documentation for advanced use cases.