File size: 4,308 Bytes
68d5b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
from https://github.com/openai/gpt-2/, changed for chinese
"""
import json
import os
import sentencepiece as spm
"""
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 
systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 
subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 
extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 
system that does not depend on language-specific pre/postprocessing.
https://github.com/google/sentencepiece

pip install sentencepiece

or  git clone https://github.com/google/sentencepiece.git
python setup.py install

"""

def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


class Encoder:
    def __init__(self, encoder, bpe_merges):
        self.encoder = encoder
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}
        self.max_len = 0

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)
        if not pairs:
            return token

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        return [self.encoder.get(token, 1) for token in self.tokenize(text)]

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        return text

    def tokenize(self, text):
        bpe_tokens = []
        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
        return bpe_tokens

    def convert_tokens_to_ids(self, tokens):
        return [self.encoder.get(token, 1) for token in tokens]
    
class Encoder_SP:
    def __init__(self, model_path):
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(model_path)
        

    def encode(self, text):
        """
        text="...."
        """
        return self.sp.EncodeAsIds(text)


    def decode(self, tokens):
        """
        tokens=[x1,x2,...]
        """
        text = [int(token) for token in tokens]
        #print(text)
        return self.sp.DecodeIds(text)

    def tokenize(self, text):
        return self.sp.EncodeAsPieces(text)

    def convert_tokens_to_ids(self, tokens):
        return [self.sp.PieceToId(token) for token in tokens]
    
def get_encoder(encoder_file, bpe_file):
    
    #以下是为了同一个函数入兼容sentencepiece 
    filepath, filename = os.path.split(encoder_file)
    shotname, extension = os.path.splitext(filename)
    
    if(".model" == extension) and (bpe_file == ""):
        return Encoder_SP(encoder_file)
    else:
        with open(encoder_file, 'r', encoding="utf-8") as f:
            encoder = json.load(f)
        with open(bpe_file, 'r', encoding="utf-8") as f:
            bpe_data = f.read()
        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
        return Encoder(
            encoder=encoder,
            bpe_merges=bpe_merges,
        )