Transformers documentation
XLA Integration for TensorFlow Models
XLA Integration for TensorFlow Models
å éç·åœ¢ä»£æ°ïŒAccelerated Linear AlgebraïŒãé称XLAã¯ãTensorFlowã¢ãã«ã®ã©ã³ã¿ã€ã ãé«éåããããã®ã³ã³ãã€ã©ã§ããå ¬åŒããã¥ã¡ã³ãã«ããã°ãXLAïŒAccelerated Linear AlgebraïŒã¯ç·åœ¢ä»£æ°ã®ããã®ãã¡ã€ã³åºæã®ã³ã³ãã€ã©ã§ãTensorFlowã¢ãã«ãæœåšçã«ãœãŒã¹ã³ãŒãã®å€æŽãªãã§é«éåã§ããŸãã
TensorFlowã§XLAã䜿çšããã®ã¯ç°¡åã§ããXLAã¯tensorflow
ã©ã€ãã©ãªå
ã«ããã±ãŒãžåãããŠãããtf.function
ãªã©ã®ã°ã©ããäœæããé¢æ°å
ã§jit_compile
åŒæ°ã䜿çšããŠããªã¬ãŒã§ããŸããfit()
ãpredict()
ãªã©ã®Kerasã¡ãœããã䜿çšããå Žåãmodel.compile()
ã«jit_compile
åŒæ°ãæž¡ãã ãã§XLAãæå¹ã«ã§ããŸãããã ããXLAã¯ãããã®ã¡ãœããã«éå®ãããŠããããã§ã¯ãããŸãããä»»æã®tf.function
ãé«éåããããã«ã䜿çšã§ããŸãã
ð€ Transformerså ã®ããã€ãã®TensorFlowã¡ãœããã¯ãXLAãšäºææ§ãããããã«æžãçŽãããŠããŸããããã«ã¯ãGPT2ãT5ãOPTãªã©ã®ããã¹ãçæã¢ãã«ããWhisperãªã©ã®é³å£°åŠçã¢ãã«ãå«ãŸããŸãã
é床åäžã®å ·äœçãªéã¯ã¢ãã«ã«éåžžã«äŸåããŸãããð€ Transformerså ã®TensorFlowããã¹ãçæã¢ãã«ã§ã¯ãçŽ100åã®é床åäžã確èªããŠããŸãããã®ããã¥ã¡ã³ãã§ã¯ããããã®ã¢ãã«ã«XLAã䜿çšããŠæ倧ã®ããã©ãŒãã³ã¹ãåŸãæ¹æ³ã説æããŸãããŸãããã³ãããŒã¯ãšXLAçµ±åã®ãã¶ã€ã³å²åŠã«ã€ããŠè©³ããåŠã³ããå Žåã®è¿œå ãªãœãŒã¹ãžã®ãªã³ã¯ãæäŸããŸãã
Running TF functions with XLA
以äžã®TensorFlowã¢ãã«ãèããŠã¿ãŸãããïŒ
import tensorflow as tf
model = tf.keras.Sequential(
[tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
)
äžèšã®ã¢ãã«ã¯ã次å
ã(10, )
ã®å
¥åãåãå
¥ããŸãããã®ã¢ãã«ããã©ã¯ãŒããã¹ã§å®è¡ããã«ã¯ã次ã®ããã«ããŸãïŒ
# Generate random inputs for the model.
batch_size = 16
input_vector_dim = 10
random_inputs = tf.random.normal((batch_size, input_vector_dim))
# Run a forward pass.
_ = model(random_inputs)
XLAã§ã³ã³ãã€ã«ãããé¢æ°ã䜿çšããŠãã©ã¯ãŒããã¹ãå®è¡ããã«ã¯ã以äžã®ããã«ããŸãïŒ
xla_fn = tf.function(model, jit_compile=True)
_ = xla_fn(random_inputs)
model
ã®ããã©ã«ãã® call()
é¢æ°ã¯XLAã°ã©ããã³ã³ãã€ã«ããããã«äœ¿çšãããŸãããã ããXLAã«ã³ã³ãã€ã«ãããä»ã®ã¢ãã«é¢æ°ãããå Žåããããå¯èœã§ãã以äžã¯ãã®æ¹æ³ã§ãïŒ
my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)
Running a TF text generation model with XLA from ð€ Transformers
ð€ Transformerså
ã§XLAã§ã®é«éåãããçæãæå¹ã«ããã«ã¯ãææ°ããŒãžã§ã³ã®transformers
ãã€ã³ã¹ããŒã«ãããŠããå¿
èŠããããŸãã次ã®ã³ãã³ããå®è¡ããŠã€ã³ã¹ããŒã«ã§ããŸãïŒ
pip install transformers --upgrade
次ã«ã次ã®ã³ãŒããå®è¡ã§ããŸãïŒ
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
# Will error if the minimal version of Transformers is not installed.
from transformers.utils import check_min_version
check_min_version("4.21.0")
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]
# One line to create an XLA generation function
xla_generate = tf.function(model.generate, jit_compile=True)
tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the
generate()
ã§XLAãæå¹ã«ããã®ã¯ããã£ãäžè¡ã®ã³ãŒãã§ããã³ãŒãã®æ®ãéšåã¯å€æŽãããŠããŸããããã ããXLAåºæã®ããã€ãã®æ³šæç¹ãäžèšã®ã³ãŒãã¹ããããã«ãããŸãããããã«æ³šæããå¿
èŠããããXLAãããããé床åäžãå®çŸããããã«ããããææ¡ããããšãéèŠã§ãã次ã®ã»ã¯ã·ã§ã³ã§ãããã«ã€ããŠè©³ãã説æããŸãã
Gotchas to be aware of
XLAãæå¹ã«ããé¢æ°ïŒäžèšã®xla_generate()
ãªã©ïŒãåããŠå®è¡ãããšãå
éšã§èšç®ã°ã©ããæšè«ããããšããŸãããããã¯æéãããããŸãããã®ããã»ã¹ã¯âãã¬ãŒã·ã³ã°âïŒtracingïŒãšããŠç¥ãããŠããŸãã
çææéãé«éã§ã¯ãªãããšã«æ°ä»ããããããŸãããxla_generate()
ïŒãŸãã¯ä»ã®XLA察å¿é¢æ°ïŒã®é£ç¶åŒã³åºãã§ã¯ãé¢æ°ãžã®å
¥åãæåã«èšç®ã°ã©ããæ§ç¯ããããšããšåã圢ç¶ã«åŸã£ãŠããå Žåãèšç®ã°ã©ããæšè«ããå¿
èŠã¯ãããŸãããããã¯ãå
¥å圢ç¶ãåºå®ãããŠããã¢ããªãã£ïŒäŸïŒç»åïŒã«ã¯åé¡ãããŸããããå€æ°ã®å
¥å圢ç¶ã¢ããªãã£ïŒäŸïŒããã¹ãïŒãæ±ãå Žåã«ã¯æ³šæãå¿
èŠã§ãã
xla_generate()
ãåžžã«åãå
¥å圢ç¶ã§åäœããããã«ããã«ã¯ãããŒã¯ãã€ã¶ãåŒã³åºãéã«padding
åŒæ°ãæå®ã§ããŸãã
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True)
# Here, we call the tokenizer with padding options.
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
print(f"Generated -- {decoded_text}")
ããã«ãããxla_generate()
ãžã®å
¥åãåžžã«ãã¬ãŒã¹ããã圢ç¶ã®å
¥åãåãåãããšã確èªããçææéã®é«éåãå®çŸã§ããŸãã以äžã®ã³ãŒãã§ããã確èªã§ããŸãïŒ
import time
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2")
xla_generate = tf.function(model.generate, jit_compile=True)
for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
start = time.time_ns()
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
end = time.time_ns()
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")
Tesla T4 GPUã䜿çšãããšã次ã®ãããªåºåãæåŸ ãããŸãïŒ
Execution time -- 30819.6 ms Execution time -- 79.0 ms Execution time -- 78.9 ms
æåã®xla_generate()
åŒã³åºãã¯ãã¬ãŒã·ã³ã°ã®ããã«æéãããããŸãããé£ç¶ããåŒã³åºãã¯æ¡éãã«é«éã§ããçæãªãã·ã§ã³ã®ãããªãå€æŽããåãã¬ãŒã·ã³ã°ãåŒãèµ·ãããçææéã®é
延ãåŒãèµ·ããããšã«æ³šæããŠãã ããã
ãã®ããã¥ã¡ã³ãã§ã¯ãð€ TransformersãæäŸããããã¹ãçæãªãã·ã§ã³ããã¹ãŠç¶²çŸ ããŠããŸãããé«åºŠãªãŠãŒã¹ã±ãŒã¹ã«ã€ããŠã¯ããã¥ã¡ã³ããŒã·ã§ã³ãåç §ããããšããå§ãããŸãã
Additional Resources
ããã§ã¯ãð€ Transformersãšäžè¬çãªXLAã«ã€ããŠããã«è©³ããåŠã³ããå Žåã®ããã€ãã®è¿œå ãªãœãŒã¹ãæäŸããŸãã
- ãã®Colab Notebookã§ã¯ãXLA察å¿ã®ãšã³ã³ãŒããŒãã³ãŒããŒïŒT5ãªã©ïŒããã³ãã³ãŒããŒå°çšïŒGPT2ãªã©ïŒããã¹ãçæã¢ãã«ãè©Šãããã®å¯Ÿè©±åãã¢ãæäŸãããŠããŸãã
- ãã®ããã°èšäºã§ã¯ãXLA察å¿ã¢ãã«ã®æ¯èŒãã³ãããŒã¯ã®æŠèŠãšãTensorFlowã§ã®XLAã«ã€ããŠã®å奜çãªçŽ¹ä»ãæäŸãããŠããŸãã
- ãã®ããã°èšäºã§ã¯ãð€ Transformersã®TensorFlowã¢ãã«ã«XLAãµããŒããè¿œå ããéã®èšèšå²åŠã«ã€ããŠèª¬æããŠããŸãã
- äžè¬çãªXLAãšTensorFlowã°ã©ãã«ã€ããŠè©³ããåŠã¶ããã®ããããã®æçš¿ïŒ