File size: 6,229 Bytes
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10ecd7
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186




import shutil
import json
from queue import Queue
from tokenizers import Tokenizer
from data_sample.oov_base import jd_vocab_tokens
from zhon.hanzi import punctuation as zh_punc

def load_base_tokenizer(tokenizer_path):
    print("loading", tokenizer_path)
    data = json.load(open(tokenizer_path, "r", encoding="utf-8"))
    tokenizer = Tokenizer.from_file(tokenizer_path)
    print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True))
    return data, tokenizer


def insert_token(word, index):
    pass

# 不能删除的token。比如初始统计是低频的,可以删除,但是新增词典里包含的。


def load_reserve_tokens(word_list, base_tokenizer):
    data, base_tokenizer = base_tokenizer
    reserved_token = set()
    for word in word_list:
        encoding = base_tokenizer.encode(word)
        tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
        for i in range(0, len(encoding.ids)):
            reserved_token.add("".join(tokens[:i+1]))
    return reserved_token


reserved_token = set()


def append_token(word_list, base_tokenizer, output_tokenizer_path, unused_ids=None):
    """
    append token to the end of vocab
    """
    new_vocab = set()
    new_merges = set()

    data, base_tokenizer = base_tokenizer
    vocab = data["model"]["vocab"]
    merges = data["model"]["merges"]
    vocab_size = base_tokenizer.basic_count(with_added_tokens=True)

    for word in word_list:
        encoding = base_tokenizer.encode(word)
        if len(encoding.ids) == 1:
            continue

        if len(encoding.ids) >= 4:
            print("[ERROR]: encoding不能超过4", word, encoding)

        tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
        # print("merging", word, json.dumps(tokens))
        for i in range(1, len(encoding.ids)):
            new_vocab.add("".join(tokens[:i+1]))
            new_merges.add("".join(tokens[:i]) + " " + tokens[i])

    # append to the end of vocab
    # print("new_vocab size", len(new_vocab))
    # print("new_merges size", len(new_merges))
    if unused_ids == None:
        for token in new_vocab:
            vocab[token] = vocab_size
            vocab_size += 1
        merges += new_merges
    else:
        for iddx, token in enumerate(new_vocab):
            # print(unused_ids.qsize())
            unused_token_id, unused_token_str, unused_merges = unused_ids.get()
            if unused_token_id == 39468:
                print("catch")
            if unused_token_str in reserved_token:
                print("skip unused token", unused_token_id, unused_token_str)
                unused_token_id, unused_token_str, unused_merges = unused_ids.get()

            print("[%d]merging %s to unused %s %s" % (unused_ids.qsize(), json.dumps(token), unused_token_id, json.dumps(unused_token_str)) )
            vocab[token] = unused_token_id
            if unused_token_id != vocab.pop(unused_token_str):
                print("ERROR")
            # assert unused_token_id == vocab.pop(unused_token_str)
            merges.remove(unused_merges)
        # print(new_merges)
        merges += new_merges

    # print("共merge %d 个 token" % (len(new_vocab)))
    # print(json.dumps(list(new_vocab)))


    with open(output_tokenizer_path, "w", encoding="utf-8") as f_out:
        json.dump(data, f_out, indent=2)

    return data, base_tokenizer




    # data, base_tokenizer = load_base_tokenizer(output_tokenizer_path)
    # encoding = base_tokenizer.encode(word)
    # print(encoding.ids)


def load_unused_id():
    unused_ids = Queue(maxsize=0)
    for line in open("word_count.corpus.remove.jsonl", "r", encoding="utf-8"):
        line_data = json.loads(line)
        token_id = line_data["id"]
        token_str = line_data["token"]
        merges = line_data["merges"]
        unused_ids.put((token_id, token_str, merges))
    # for i in range(2000):
    #     unused_ids.get()
    return unused_ids


def check_tokenize(base_tokenizer, word):
    data, base_tokenizer = base_tokenizer
    encodings = base_tokenizer.encode(word)
    assert len(encodings.ids) == 1
    assert base_tokenizer.decode(encodings.ids) == word


def add_tokens():


    unused_ids = load_unused_id()
    add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
    add_chars = [char for token in add_tokens for char in token]
    add_chars = list(set(add_chars))
    add_words = [token for token in add_tokens if len(token) > 1]


    tokenizer_path = "../20B_tokenizer_chinese.json"
    # tokenizer_path = "../../gpt_nexo_20b/20B_tokenizer.json"
    base_tokenizer = load_base_tokenizer(tokenizer_path)
    reserved_token.update(load_reserve_tokens(add_chars, base_tokenizer))

    ## add chars
    append_token(add_chars, base_tokenizer, "20B_tokenizer.1.json", unused_ids=unused_ids)
    print(unused_ids.qsize())  # 22320
    new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json")

    append_token(add_words,
                 new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
    new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")

    #
    # ## add words
    # while unused_ids.qsize() != 22320:
    #     unused_ids.get()
    # assert unused_ids.qsize() == 22320
    #
    # shutil.copyfile("20B_tokenizer.1.json", "20B_tokenizer.2.json")
    # while len(add_words) > 0:
    #     new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
    #     append_token([add_words.pop()],
    #                  new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
    #     # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")


def check_all_tokens():
    add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
    add_chars = [char for token in add_tokens for char in token]
    add_chars = list(set(add_chars))
    add_words = [token for token in add_tokens if len(token) > 1]
    # add_chars = ['吳']
    base_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
    for k in add_chars:
        check_tokenize(base_tokenizer, k)
    for word in add_words:
        # print(word)
        check_tokenize(base_tokenizer, word)

add_tokens()
check_all_tokens()