File size: 4,109 Bytes
6b99d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

import torch
import sentencepiece
import jieba
import numpy as np

from transformers.tokenization_utils import PreTrainedTokenizer

jieba.add_word('<s>')
jieba.add_word('</s>')
jieba.add_word('<eot>')
jieba.add_word('<unk>')
jieba.add_word('<sep>')
jieba.add_word('<pad>')


class GPTPanguTokenizer(PreTrainedTokenizer):
    # Ref: https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha/src/branch/master/tokenization_jieba.py
    vocab_files_names = {
        "model_file": "vocab.model"
    }

    def __init__(
            self,
            model_file,
            **kwargs
    ):
        super().__init__(**kwargs)

        self.sp = sentencepiece.SentencePieceProcessor()
        self.sp.Load(model_file=model_file)
        self.translator = str.maketrans(" \n", "\u2582\u2583")

        # special token ids
        # self.eos_token_id = self.sp.piece_to_id("<eot>")

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        if self.bos_token_id is not None:
            if token_ids_1 is None:
                return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
            bos = [self.bos_token_id]
            sep = [self.sep_token_id]
            eos = [self.eos_token_id]
            return bos + token_ids_0 + sep + token_ids_1 + eos
        else:
            if token_ids_1 is None:
                return token_ids_0 + [self.eos_token_id]
            sep = [self.sep_token_id]
            eos = [self.eos_token_id]
            return token_ids_0 + sep + token_ids_1 + eos

    def tokenize(self, text, **kwargs):
        """ Tokenize a string. """
        seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
        return seg_list

    def convert_tokens_to_ids(self, tokens):
        if tokens is None:
            return None

        if isinstance(tokens, str):
            return self._convert_token_to_id_with_added_voc(tokens)


        special_tokens_index = [i for i, token in enumerate(tokens) if token in self.all_special_tokens]

        ids = []
        i = 0
        for j in special_tokens_index:
            new_seg = " ".join(tokens[i:j])
            ids.extend(self.sp.encode(new_seg))
            ids.append(self._convert_token_to_id(tokens[j]))
            i = j + 1

        new_seg = " ".join(tokens[i:])
        ids.extend(self.sp.encode(new_seg))

        return ids

        # new_seg = " ".join(tokens)
        # return self.sp.encode(new_seg)
        # # return tokens

    def _convert_token_to_id(self, token):
        return self.sp.piece_to_id(token)

    def _convert_id_to_token(self, index):
        return self.sp.id_to_piece(index)

    def convert_ids_to_tokens(self, ids):
        return self.decode(ids)

    def decode(self, ids, **kwargs):
        if isinstance(ids, torch.Tensor) or isinstance(ids, np.ndarray):
            ids = ids.tolist()

        if kwargs.get('skip_special_tokens', None) is True:
            ids = [token_id for token_id in ids if token_id not in self.all_special_ids]
        text = self.sp.decode(ids)
        if isinstance(text, list):
            text = text[0]
        text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')#.replace('⁇', self.unk_token)
        return text

    @property
    def vocab_size(self) -> int:
        """
        `int`: Size of the base vocabulary (without the added tokens).
        """
        return len(self.sp)