XLA Integration for TensorFlow Models
Accelerated Linear Algebra, dubbed XLA, is a compiler for accelerating the runtime of TensorFlow Models. From the official documentation:
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.
Using XLA in TensorFlow is simple – it comes packaged inside the tensorflow
library, and it can be triggered with the jit_compile
argument in any graph-creating function such as tf.function
. When using Keras methods like fit()
and predict()
, you can enable XLA simply by passing the jit_compile
argument to model.compile()
. However, XLA is not limited to these methods - it can also be used to accelerate any arbitrary tf.function
.
Several TensorFlow methods in 🤗 Transformers have been rewritten to be XLA-compatible, including text generation for models such as GPT2, T5 and OPT, as well as speech processing for models such as Whisper.
While the exact amount of speed-up is very much model-dependent, for TensorFlow text generation models inside 🤗 Transformers, we noticed a speed-up of ~100x. This document will explain how you can use XLA for these models to get the maximum amount of performance. We’ll also provide links to additional resources if you’re interested to learn more about the benchmarks and our design philosophy behind the XLA integration.
Running TF functions with XLA
Let us consider the following model in TensorFlow:
import tensorflow as tf
model = tf.keras.Sequential(
[tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
)
The above model accepts inputs having a dimension of (10, )
. We can use the model for running a forward pass like so:
# Generate random inputs for the model.
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))
# Run a forward pass.
_ = model(random_inputs)
In order to run the forward pass with an XLA-compiled function, we’d need to do:
xla_fn = tf.function(model, jit_compile=True)
_ = xla_fn(random_inputs)
The default call()
function of the model
is used for compiling the XLA graph. But if there’s any other model function you want to compile into XLA that’s also possible with:
my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)
Running a TF text generation model with XLA from 🤗 Transformers
To enable XLA-accelerated generation within 🤗 Transformers, you need to have a recent version of transformers
installed. You can install it by running:
pip install transformers --upgrade
And then you can run the following code:
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
# Will error if the minimal version of Transformers is not installed.
from transformers.utils import check_min_version
check_min_version("4.21.0")
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]
# One line to create an XLA generation function
xla_generate = tf.function(model.generate, jit_compile=True)
tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the
As you can notice, enabling XLA on generate()
is just a single line of code. The rest of the code remains unchanged. However, there are a couple of gotchas in the above code snippet that are specific to XLA. You need to be aware of those to realize the speed-ups that XLA can bring in. We discuss these in the following section.
Gotchas to be aware of
When you are executing an XLA-enabled function (like xla_generate()
above) for the first time, it will internally try to infer the computation graph, which is time-consuming. This process is known as “tracing”.
You might notice that the generation time is not fast. Successive calls of xla_generate()
(or any other XLA-enabled function) won’t have to infer the computation graph, given the inputs to the function follow the same shape with which the computation graph was initially built. While this is not a problem for modalities with fixed input shapes (e.g., images), you must pay attention if you are working with variable input shape modalities (e.g., text).
To ensure xla_generate()
always operates with the same input shapes, you can specify the padding
arguments when calling the tokenizer.
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True)
# Here, we call the tokenizer with padding options.
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
This way, you can ensure that the inputs to xla_generate()
will always receive inputs with the shape it was traced with and thus leading to speed-ups in the generation time. You can verify this with the code below:
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
xla_generate = tf.function(model.generate, jit_compile=True)
for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
start = time.time_ns()
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
end = time.time_ns()
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")
On a Tesla T4 GPU, you can expect the outputs like so:
Execution time -- 30819.6 ms Execution time -- 79.0 ms Execution time -- 78.9 ms
The first call to xla_generate()
is time-consuming because of tracing, but the successive calls are orders of magnitude faster. Keep in mind that any change in the generation options at any point with trigger re-tracing and thus leading to slow-downs in the generation time.
We didn’t cover all the text generation options 🤗 Transformers provides in this document. We encourage you to read the documentation for advanced use cases.
Additional Resources
Here, we leave you with some additional resources if you want to delve deeper into XLA in 🤗 Transformers and in general.
- This Colab Notebook provides an interactive demonstration if you want to fiddle with the XLA-compatible encoder-decoder (like T5) and decoder-only (like GPT2) text generation models.
- This blog post provides an overview of the comparison benchmarks for XLA-compatible models along with a friendly introduction to XLA in TensorFlow.
- This blog post discusses our design philosophy behind adding XLA support to the TensorFlow models in 🤗 Transformers.
- Recommended posts for learning more about XLA and TensorFlow graphs in general: