Spaces:
Running
Running
File size: 6,229 Bytes
751936e d10ecd7 751936e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import shutil
import json
from queue import Queue
from tokenizers import Tokenizer
from data_sample.oov_base import jd_vocab_tokens
from zhon.hanzi import punctuation as zh_punc
def load_base_tokenizer(tokenizer_path):
print("loading", tokenizer_path)
data = json.load(open(tokenizer_path, "r", encoding="utf-8"))
tokenizer = Tokenizer.from_file(tokenizer_path)
print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True))
return data, tokenizer
def insert_token(word, index):
pass
# 不能删除的token。比如初始统计是低频的,可以删除,但是新增词典里包含的。
def load_reserve_tokens(word_list, base_tokenizer):
data, base_tokenizer = base_tokenizer
reserved_token = set()
for word in word_list:
encoding = base_tokenizer.encode(word)
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
for i in range(0, len(encoding.ids)):
reserved_token.add("".join(tokens[:i+1]))
return reserved_token
reserved_token = set()
def append_token(word_list, base_tokenizer, output_tokenizer_path, unused_ids=None):
"""
append token to the end of vocab
"""
new_vocab = set()
new_merges = set()
data, base_tokenizer = base_tokenizer
vocab = data["model"]["vocab"]
merges = data["model"]["merges"]
vocab_size = base_tokenizer.basic_count(with_added_tokens=True)
for word in word_list:
encoding = base_tokenizer.encode(word)
if len(encoding.ids) == 1:
continue
if len(encoding.ids) >= 4:
print("[ERROR]: encoding不能超过4", word, encoding)
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
# print("merging", word, json.dumps(tokens))
for i in range(1, len(encoding.ids)):
new_vocab.add("".join(tokens[:i+1]))
new_merges.add("".join(tokens[:i]) + " " + tokens[i])
# append to the end of vocab
# print("new_vocab size", len(new_vocab))
# print("new_merges size", len(new_merges))
if unused_ids == None:
for token in new_vocab:
vocab[token] = vocab_size
vocab_size += 1
merges += new_merges
else:
for iddx, token in enumerate(new_vocab):
# print(unused_ids.qsize())
unused_token_id, unused_token_str, unused_merges = unused_ids.get()
if unused_token_id == 39468:
print("catch")
if unused_token_str in reserved_token:
print("skip unused token", unused_token_id, unused_token_str)
unused_token_id, unused_token_str, unused_merges = unused_ids.get()
print("[%d]merging %s to unused %s %s" % (unused_ids.qsize(), json.dumps(token), unused_token_id, json.dumps(unused_token_str)) )
vocab[token] = unused_token_id
if unused_token_id != vocab.pop(unused_token_str):
print("ERROR")
# assert unused_token_id == vocab.pop(unused_token_str)
merges.remove(unused_merges)
# print(new_merges)
merges += new_merges
# print("共merge %d 个 token" % (len(new_vocab)))
# print(json.dumps(list(new_vocab)))
with open(output_tokenizer_path, "w", encoding="utf-8") as f_out:
json.dump(data, f_out, indent=2)
return data, base_tokenizer
# data, base_tokenizer = load_base_tokenizer(output_tokenizer_path)
# encoding = base_tokenizer.encode(word)
# print(encoding.ids)
def load_unused_id():
unused_ids = Queue(maxsize=0)
for line in open("word_count.corpus.remove.jsonl", "r", encoding="utf-8"):
line_data = json.loads(line)
token_id = line_data["id"]
token_str = line_data["token"]
merges = line_data["merges"]
unused_ids.put((token_id, token_str, merges))
# for i in range(2000):
# unused_ids.get()
return unused_ids
def check_tokenize(base_tokenizer, word):
data, base_tokenizer = base_tokenizer
encodings = base_tokenizer.encode(word)
assert len(encodings.ids) == 1
assert base_tokenizer.decode(encodings.ids) == word
def add_tokens():
unused_ids = load_unused_id()
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
add_chars = [char for token in add_tokens for char in token]
add_chars = list(set(add_chars))
add_words = [token for token in add_tokens if len(token) > 1]
tokenizer_path = "../20B_tokenizer_chinese.json"
# tokenizer_path = "../../gpt_nexo_20b/20B_tokenizer.json"
base_tokenizer = load_base_tokenizer(tokenizer_path)
reserved_token.update(load_reserve_tokens(add_chars, base_tokenizer))
## add chars
append_token(add_chars, base_tokenizer, "20B_tokenizer.1.json", unused_ids=unused_ids)
print(unused_ids.qsize()) # 22320
new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json")
append_token(add_words,
new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
#
# ## add words
# while unused_ids.qsize() != 22320:
# unused_ids.get()
# assert unused_ids.qsize() == 22320
#
# shutil.copyfile("20B_tokenizer.1.json", "20B_tokenizer.2.json")
# while len(add_words) > 0:
# new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
# append_token([add_words.pop()],
# new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
# # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
def check_all_tokens():
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
add_chars = [char for token in add_tokens for char in token]
add_chars = list(set(add_chars))
add_words = [token for token in add_tokens if len(token) > 1]
# add_chars = ['吳']
base_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
for k in add_chars:
check_tokenize(base_tokenizer, k)
for word in add_words:
# print(word)
check_tokenize(base_tokenizer, word)
add_tokens()
check_all_tokens()
|