eson's picture
update
d10ecd7
raw
history blame
No virus
6.23 kB
import shutil
import json
from queue import Queue
from tokenizers import Tokenizer
from data_sample.oov_base import jd_vocab_tokens
from zhon.hanzi import punctuation as zh_punc
def load_base_tokenizer(tokenizer_path):
print("loading", tokenizer_path)
data = json.load(open(tokenizer_path, "r", encoding="utf-8"))
tokenizer = Tokenizer.from_file(tokenizer_path)
print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True))
return data, tokenizer
def insert_token(word, index):
pass
# 不能删除的token。比如初始统计是低频的,可以删除,但是新增词典里包含的。
def load_reserve_tokens(word_list, base_tokenizer):
data, base_tokenizer = base_tokenizer
reserved_token = set()
for word in word_list:
encoding = base_tokenizer.encode(word)
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
for i in range(0, len(encoding.ids)):
reserved_token.add("".join(tokens[:i+1]))
return reserved_token
reserved_token = set()
def append_token(word_list, base_tokenizer, output_tokenizer_path, unused_ids=None):
"""
append token to the end of vocab
"""
new_vocab = set()
new_merges = set()
data, base_tokenizer = base_tokenizer
vocab = data["model"]["vocab"]
merges = data["model"]["merges"]
vocab_size = base_tokenizer.basic_count(with_added_tokens=True)
for word in word_list:
encoding = base_tokenizer.encode(word)
if len(encoding.ids) == 1:
continue
if len(encoding.ids) >= 4:
print("[ERROR]: encoding不能超过4", word, encoding)
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids]
# print("merging", word, json.dumps(tokens))
for i in range(1, len(encoding.ids)):
new_vocab.add("".join(tokens[:i+1]))
new_merges.add("".join(tokens[:i]) + " " + tokens[i])
# append to the end of vocab
# print("new_vocab size", len(new_vocab))
# print("new_merges size", len(new_merges))
if unused_ids == None:
for token in new_vocab:
vocab[token] = vocab_size
vocab_size += 1
merges += new_merges
else:
for iddx, token in enumerate(new_vocab):
# print(unused_ids.qsize())
unused_token_id, unused_token_str, unused_merges = unused_ids.get()
if unused_token_id == 39468:
print("catch")
if unused_token_str in reserved_token:
print("skip unused token", unused_token_id, unused_token_str)
unused_token_id, unused_token_str, unused_merges = unused_ids.get()
print("[%d]merging %s to unused %s %s" % (unused_ids.qsize(), json.dumps(token), unused_token_id, json.dumps(unused_token_str)) )
vocab[token] = unused_token_id
if unused_token_id != vocab.pop(unused_token_str):
print("ERROR")
# assert unused_token_id == vocab.pop(unused_token_str)
merges.remove(unused_merges)
# print(new_merges)
merges += new_merges
# print("共merge %d 个 token" % (len(new_vocab)))
# print(json.dumps(list(new_vocab)))
with open(output_tokenizer_path, "w", encoding="utf-8") as f_out:
json.dump(data, f_out, indent=2)
return data, base_tokenizer
# data, base_tokenizer = load_base_tokenizer(output_tokenizer_path)
# encoding = base_tokenizer.encode(word)
# print(encoding.ids)
def load_unused_id():
unused_ids = Queue(maxsize=0)
for line in open("word_count.corpus.remove.jsonl", "r", encoding="utf-8"):
line_data = json.loads(line)
token_id = line_data["id"]
token_str = line_data["token"]
merges = line_data["merges"]
unused_ids.put((token_id, token_str, merges))
# for i in range(2000):
# unused_ids.get()
return unused_ids
def check_tokenize(base_tokenizer, word):
data, base_tokenizer = base_tokenizer
encodings = base_tokenizer.encode(word)
assert len(encodings.ids) == 1
assert base_tokenizer.decode(encodings.ids) == word
def add_tokens():
unused_ids = load_unused_id()
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
add_chars = [char for token in add_tokens for char in token]
add_chars = list(set(add_chars))
add_words = [token for token in add_tokens if len(token) > 1]
tokenizer_path = "../20B_tokenizer_chinese.json"
# tokenizer_path = "../../gpt_nexo_20b/20B_tokenizer.json"
base_tokenizer = load_base_tokenizer(tokenizer_path)
reserved_token.update(load_reserve_tokens(add_chars, base_tokenizer))
## add chars
append_token(add_chars, base_tokenizer, "20B_tokenizer.1.json", unused_ids=unused_ids)
print(unused_ids.qsize()) # 22320
new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json")
append_token(add_words,
new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
#
# ## add words
# while unused_ids.qsize() != 22320:
# unused_ids.get()
# assert unused_ids.qsize() == 22320
#
# shutil.copyfile("20B_tokenizer.1.json", "20B_tokenizer.2.json")
# while len(add_words) > 0:
# new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
# append_token([add_words.pop()],
# new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids)
# # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
def check_all_tokens():
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")]
add_chars = [char for token in add_tokens for char in token]
add_chars = list(set(add_chars))
add_words = [token for token in add_tokens if len(token) > 1]
# add_chars = ['吳']
base_tokenizer = load_base_tokenizer("20B_tokenizer.2.json")
for k in add_chars:
check_tokenize(base_tokenizer, k)
for word in add_words:
# print(word)
check_tokenize(base_tokenizer, word)
add_tokens()
check_all_tokens()