Spaces:
Running
Running
| """ | |
| T5 Tokenizer | |
| --------------------------------------------------------------------- | |
| """ | |
| import transformers | |
| class T5Tokenizer: | |
| """Uses the T5 tokenizer to convert an input for processing. | |
| For more information, please see the T5 paper, "Exploring the Limits of | |
| Transfer Learning with a Unified Text-to-Text Transformer". | |
| Appendix D contains information about the various tasks supported | |
| by T5. | |
| Supports the following modes: | |
| * summarization: summarize English text | |
| * english_to_german: translate English to German | |
| * english_to_french: translate English to French | |
| * english_to_romanian: translate English to Romanian | |
| """ | |
| def __init__(self, mode="english_to_german", max_length=64): | |
| if mode == "english_to_german": | |
| self.tokenization_prefix = "translate English to German: " | |
| elif mode == "english_to_french": | |
| self.tokenization_prefix = "translate English to French: " | |
| elif mode == "english_to_romanian": | |
| self.tokenization_prefix = "translate English to Romanian: " | |
| elif mode == "summarization": | |
| self.tokenization_prefix = "summarize: " | |
| else: | |
| raise ValueError(f"Invalid t5 tokenizer mode {mode}.") | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| "t5-base", use_fast=True | |
| ) | |
| self.max_length = max_length | |
| def __call__(self, text, *args, **kwargs): | |
| """ | |
| Args: | |
| text (:obj:`str`, :obj:`List[str]`): | |
| The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings. | |
| """ | |
| assert isinstance(text, str) or ( | |
| isinstance(text, (list, tuple)) | |
| and (len(text) == 0 or isinstance(text[0], str)) | |
| ), "`text` must be a string or a list of strings." | |
| if isinstance(text, str): | |
| text = self.tokenization_prefix + text | |
| else: | |
| for i in range(len(text)): | |
| text[i] = self.tokenization_prefix + text[i] | |
| return self.tokenizer(text, *args, max_length=self.max_length, **kwargs) | |
| def decode(self, ids): | |
| """Converts IDs (typically generated by the model) back to a string.""" | |
| return self.tokenizer.decode(ids) | |