TFBartForConditionalGeneration does not work with XLA compiler

#6
by shreyansj - opened

Hi,
I followed this Huggingface blogpost to accelerate the performance of text generation using TF with XLA. On wrappingmodel.generate with xla_generate = tf.function(model.generate, jit_compile=True), we get the following error:
```

NotImplementedError Traceback (most recent call last)
in ()
26 tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
27 start = time.time_ns()
---> 28 generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
29 end = time.time_ns()
30 decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1145 except Exception as e: # pylint:disable=broad-except
1146 if hasattr(e, "ag_error_metadata"):
-> 1147 raise e.ag_error_metadata.to_exception(e)
1148 else:
1149 raise

NotImplementedError: in user code:

File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 605, in generate  *
    seed=model_kwargs.pop("seed", None),
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 1687, in _generate  *
    input_ids,
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_utils.py", line 2854, in beam_search_body_fn  *
    log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 94, in __call__  *
    scores = processor(input_ids, scores, cur_len)
File "/usr/local/lib/python3.7/dist-packages/transformers/generation_tf_logits_process.py", line 427, in __call__  *
    raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")

NotImplementedError: TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.

Example code for reproducing this error:

Stand-alone TF XLA generate example for Decoder-Only Models.

Note: execution times are deeply dependent on hardware.

If you have a machine with a powerful GPU, I highly recommend you to try this example there!

import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

1. Load model and tokenizer

model_name = "facebook/bart-large-cnn"

remember: decoder-only models need left-padding

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", pad_token="")
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

2. Prepare tokenization and generation arguments -- don't forget padding to avoid retracing!

tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}
generation_kwargs = {"num_beams": 4, "max_new_tokens": 64}

3. Create your XLA generate function a̶n̶d̶ ̶m̶a̶k̶e̶ ̶P̶y̶T̶o̶r̶c̶h̶ ̶e̶a̶t̶ ̶d̶u̶s̶t̶

This is the only change with respect to original generate workflow!

xla_generate = tf.function(model.generate, jit_compile=True)

4. Generate! Remember -- the first call will be slow, but all subsequent calls will be fast if you've done things right.

input_prompts = [f"The best thing about {country} is" for country in ["Spain", "Japan", "Angola"]]
for input_prompt in input_prompts:
tokenized_inputs = tokenizer([input_prompt], **tokenization_kwargs)
start = time.time_ns()
generated_text = xla_generate(**tokenized_inputs, **generation_kwargs)
end = time.time_ns()
decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
print(f"Original prompt -- {input_prompt}")
print(f"Generated -- {decoded_text}")
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")


Other models work fine and the issue seems to be with BART. Is BART not supported to work with XLA? 

tagging @patrickvonplaten to get your opinion on this according to the docs.

Sign up or log in to comment