File size: 2,298 Bytes
4a1df2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
T5 Tokenizer
---------------------------------------------------------------------

"""

import transformers


class T5Tokenizer:
    """Uses the T5 tokenizer to convert an input for processing.

    For more information, please see the T5 paper, "Exploring the Limits of
    Transfer Learning with a Unified Text-to-Text Transformer".
    Appendix D contains information about the various tasks supported
    by T5.

    Supports the following modes:

    * summarization: summarize English text
    * english_to_german: translate English to German
    * english_to_french: translate English to French
    * english_to_romanian: translate English to Romanian
    """

    def __init__(self, mode="english_to_german", max_length=64):
        if mode == "english_to_german":
            self.tokenization_prefix = "translate English to German: "
        elif mode == "english_to_french":
            self.tokenization_prefix = "translate English to French: "
        elif mode == "english_to_romanian":
            self.tokenization_prefix = "translate English to Romanian: "
        elif mode == "summarization":
            self.tokenization_prefix = "summarize: "
        else:
            raise ValueError(f"Invalid t5 tokenizer mode {mode}.")

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            "t5-base", use_fast=True
        )
        self.max_length = max_length

    def __call__(self, text, *args, **kwargs):
        """
        Args:
            text (:obj:`str`, :obj:`List[str]`):
                    The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings.
        """
        assert isinstance(text, str) or (
            isinstance(text, (list, tuple))
            and (len(text) == 0 or isinstance(text[0], str))
        ), "`text` must be a string or a list of strings."
        if isinstance(text, str):
            text = self.tokenization_prefix + text
        else:
            for i in range(len(text)):
                text[i] = self.tokenization_prefix + text[i]
        return self.tokenizer(text, *args, max_length=self.max_length, **kwargs)

    def decode(self, ids):
        """Converts IDs (typically generated by the model) back to a string."""
        return self.tokenizer.decode(ids)