PFEemp2024's picture
solving GPU error for previous version
4a1df2e
"""
T5 Tokenizer
---------------------------------------------------------------------
"""
import transformers
class T5Tokenizer:
"""Uses the T5 tokenizer to convert an input for processing.
For more information, please see the T5 paper, "Exploring the Limits of
Transfer Learning with a Unified Text-to-Text Transformer".
Appendix D contains information about the various tasks supported
by T5.
Supports the following modes:
* summarization: summarize English text
* english_to_german: translate English to German
* english_to_french: translate English to French
* english_to_romanian: translate English to Romanian
"""
def __init__(self, mode="english_to_german", max_length=64):
if mode == "english_to_german":
self.tokenization_prefix = "translate English to German: "
elif mode == "english_to_french":
self.tokenization_prefix = "translate English to French: "
elif mode == "english_to_romanian":
self.tokenization_prefix = "translate English to Romanian: "
elif mode == "summarization":
self.tokenization_prefix = "summarize: "
else:
raise ValueError(f"Invalid t5 tokenizer mode {mode}.")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"t5-base", use_fast=True
)
self.max_length = max_length
def __call__(self, text, *args, **kwargs):
"""
Args:
text (:obj:`str`, :obj:`List[str]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings.
"""
assert isinstance(text, str) or (
isinstance(text, (list, tuple))
and (len(text) == 0 or isinstance(text[0], str))
), "`text` must be a string or a list of strings."
if isinstance(text, str):
text = self.tokenization_prefix + text
else:
for i in range(len(text)):
text[i] = self.tokenization_prefix + text[i]
return self.tokenizer(text, *args, max_length=self.max_length, **kwargs)
def decode(self, ids):
"""Converts IDs (typically generated by the model) back to a string."""
return self.tokenizer.decode(ids)