TFBartForConditionalGeneration does not work with XLA compiler

#6
by shreyansj - opened

Hi,
I followed this Huggingface blogpost to accelerate the performance of text generation using TF with XLA. On wrappingmodel.generate with xla_generate = tf.function(model.generate, jit_compile=True), we get the following error:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-8-0858aef5d3e2> in <module>()
     26     tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
     27     start = time.time_ns()
---> 28     generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
     29     end = time.time_ns()
     30     decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

NotImplementedError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 605, in generate  *
        seed=model_kwargs.pop("seed", None),
    File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 1687, in _generate  *
        input_ids,
    File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 2854, in beam_search_body_fn  *
        log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
    File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 94, in __call__  *
        scores = processor(input_ids, scores, cur_len)
    File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 427, in __call__  *
        raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")

    NotImplementedError: TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.

Example code for reproducing this error:

# Stand-alone TF XLA generate example for Decoder-Only Models.

# Note: execution times are deeply dependent on hardware.
# If you have a machine with a powerful GPU, I highly recommend you to try this example there!
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# 1. Load model and tokenizer
model_name = "facebook/bart-large-cnn"
# remember: decoder-only models need left-padding
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", pad_token="</s>")
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

# 2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 64}

# 3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶
# This is the only change with respect to original generate workflow!
xla_generate = tf.function(model.generate, jit_compile=True)

# 4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.
input_prompts = [f"The best thing about {country} is" for country in ["Spain", "Japan", "Angola"]]
for input_prompt in input_prompts:
    tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
    start = time.time_ns()
    generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
    end = time.time_ns()
    decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
    print(f"Original prompt -- {input_prompt}")
    print(f"Generated -- {decoded_text}")
    print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")

Other models work fine and the issue seems to be with BART. Is BART not supported to work with XLA?

tagging @patrickvonplaten to get your opinion on this according to the docs.

Sign up or log in to comment