File size: 3,218 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Union, List

from transformers import AutoTokenizer
import torch


class HFPTTokenizer(object):
    def __init__(self, pt_name=None):

        self.pt_name = pt_name
        self.added_sep_token = 0
        self.added_cls_token = 0
        self.enable_add_tokens = False
        self.gpt_special_case = (not self.enable_add_tokens) and ("gpt" in self.pt_name)

        if pt_name is None:
            self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(pt_name)

        # Adding tokens to GPT causing NaN training loss.
        # Disable for now until further investigation.
        if self.enable_add_tokens:
            if self.tokenizer.sep_token is None:
                self.tokenizer.add_special_tokens({"sep_token": "<SEP>"})
                self.added_sep_token = 1

            if self.tokenizer.cls_token is None:
                self.tokenizer.add_special_tokens({"cls_token": "<CLS>"})
                self.added_cls_token = 1

        if self.gpt_special_case:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.sep_token = self.tokenizer.eos_token

    def get_eot_token(self):
        return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)[0]

    def get_sot_token(self):
        return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)[0]

    def get_eot_token_list(self):
        return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)

    def get_sot_token_list(self):
        return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)

    def get_tokenizer_obj(self):
        return self.tokenizer

    # Language model needs to know if new tokens
    # were added to the dictionary.
    def check_added_tokens(self):
        return self.added_sep_token + self.added_cls_token

    def tokenize(self, texts: Union[str, List[str]], context_length: int = 77):
        if isinstance(texts, str):
            texts = [texts]

        padding = "max_length"

        seqstart = []
        seqtok = []
        seqend = []

        max_length = context_length

        if self.added_cls_token > 0:
            seqstart = self.get_sot_token_list()
            max_length = max_length - 1

        if self.added_sep_token > 0:
            seqend = self.get_eot_token_list()
            max_length = max_length - 1

        tokens = self.tokenizer(texts, padding=padding, truncation=True, max_length=max_length)["input_ids"]

        for i in range(len(tokens)):
            tokens[i] = seqstart + tokens[i] + seqend

        if self.gpt_special_case:
            for i in range(len(tokens)):
                tokens[i][-1] = self.get_eot_token()

        # print(str(tokens))

        result = torch.Tensor(tokens).type(torch.LongTensor)

        return result

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def __call__(self, texts: Union[str, List[str]], context_length: int = 77):
        return self.tokenize(texts, context_length)