Spaces:
Running
Running
import shutil | |
import json | |
from queue import Queue | |
from tokenizers import Tokenizer | |
from data_sample.oov_base import jd_vocab_tokens | |
from zhon.hanzi import punctuation as zh_punc | |
def load_base_tokenizer(tokenizer_path): | |
print("loading", tokenizer_path) | |
data = json.load(open(tokenizer_path, "r", encoding="utf-8")) | |
tokenizer = Tokenizer.from_file(tokenizer_path) | |
print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True)) | |
return data, tokenizer | |
def insert_token(word, index): | |
pass | |
# 不能删除的token。比如初始统计是低频的,可以删除,但是新增词典里包含的。 | |
def load_reserve_tokens(word_list, base_tokenizer): | |
data, base_tokenizer = base_tokenizer | |
reserved_token = set() | |
for word in word_list: | |
encoding = base_tokenizer.encode(word) | |
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] | |
for i in range(0, len(encoding.ids)): | |
reserved_token.add("".join(tokens[:i+1])) | |
return reserved_token | |
reserved_token = set() | |
def append_token(word_list, base_tokenizer, output_tokenizer_path, unused_ids=None): | |
""" | |
append token to the end of vocab | |
""" | |
new_vocab = set() | |
new_merges = set() | |
data, base_tokenizer = base_tokenizer | |
vocab = data["model"]["vocab"] | |
merges = data["model"]["merges"] | |
vocab_size = base_tokenizer.basic_count(with_added_tokens=True) | |
for word in word_list: | |
encoding = base_tokenizer.encode(word) | |
if len(encoding.ids) == 1: | |
continue | |
if len(encoding.ids) >= 4: | |
print("[ERROR]: encoding不能超过4", word, encoding) | |
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] | |
# print("merging", word, json.dumps(tokens)) | |
for i in range(1, len(encoding.ids)): | |
new_vocab.add("".join(tokens[:i+1])) | |
new_merges.add("".join(tokens[:i]) + " " + tokens[i]) | |
# append to the end of vocab | |
# print("new_vocab size", len(new_vocab)) | |
# print("new_merges size", len(new_merges)) | |
if unused_ids == None: | |
for token in new_vocab: | |
vocab[token] = vocab_size | |
vocab_size += 1 | |
merges += new_merges | |
else: | |
for iddx, token in enumerate(new_vocab): | |
# print(unused_ids.qsize()) | |
unused_token_id, unused_token_str, unused_merges = unused_ids.get() | |
if unused_token_id == 39468: | |
print("catch") | |
if unused_token_str in reserved_token: | |
print("skip unused token", unused_token_id, unused_token_str) | |
unused_token_id, unused_token_str, unused_merges = unused_ids.get() | |
print("[%d]merging %s to unused %s %s" % (unused_ids.qsize(), json.dumps(token), unused_token_id, json.dumps(unused_token_str)) ) | |
vocab[token] = unused_token_id | |
if unused_token_id != vocab.pop(unused_token_str): | |
print("ERROR") | |
# assert unused_token_id == vocab.pop(unused_token_str) | |
merges.remove(unused_merges) | |
# print(new_merges) | |
merges += new_merges | |
# print("共merge %d 个 token" % (len(new_vocab))) | |
# print(json.dumps(list(new_vocab))) | |
with open(output_tokenizer_path, "w", encoding="utf-8") as f_out: | |
json.dump(data, f_out, indent=2) | |
return data, base_tokenizer | |
# data, base_tokenizer = load_base_tokenizer(output_tokenizer_path) | |
# encoding = base_tokenizer.encode(word) | |
# print(encoding.ids) | |
def load_unused_id(): | |
unused_ids = Queue(maxsize=0) | |
for line in open("word_count.corpus.remove.jsonl", "r", encoding="utf-8"): | |
line_data = json.loads(line) | |
token_id = line_data["id"] | |
token_str = line_data["token"] | |
merges = line_data["merges"] | |
unused_ids.put((token_id, token_str, merges)) | |
# for i in range(2000): | |
# unused_ids.get() | |
return unused_ids | |
def check_tokenize(base_tokenizer, word): | |
data, base_tokenizer = base_tokenizer | |
encodings = base_tokenizer.encode(word) | |
assert len(encodings.ids) == 1 | |
assert base_tokenizer.decode(encodings.ids) == word | |
def add_tokens(): | |
unused_ids = load_unused_id() | |
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")] | |
add_chars = [char for token in add_tokens for char in token] | |
add_chars = list(set(add_chars)) | |
add_words = [token for token in add_tokens if len(token) > 1] | |
tokenizer_path = "../20B_tokenizer_chinese.json" | |
# tokenizer_path = "../../gpt_nexo_20b/20B_tokenizer.json" | |
base_tokenizer = load_base_tokenizer(tokenizer_path) | |
reserved_token.update(load_reserve_tokens(add_chars, base_tokenizer)) | |
## add chars | |
append_token(add_chars, base_tokenizer, "20B_tokenizer.1.json", unused_ids=unused_ids) | |
print(unused_ids.qsize()) # 22320 | |
new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json") | |
append_token(add_words, | |
new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids) | |
new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") | |
# | |
# ## add words | |
# while unused_ids.qsize() != 22320: | |
# unused_ids.get() | |
# assert unused_ids.qsize() == 22320 | |
# | |
# shutil.copyfile("20B_tokenizer.1.json", "20B_tokenizer.2.json") | |
# while len(add_words) > 0: | |
# new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") | |
# append_token([add_words.pop()], | |
# new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids) | |
# # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") | |
def check_all_tokens(): | |
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")] | |
add_chars = [char for token in add_tokens for char in token] | |
add_chars = list(set(add_chars)) | |
add_words = [token for token in add_tokens if len(token) > 1] | |
# add_chars = ['吳'] | |
base_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") | |
for k in add_chars: | |
check_tokenize(base_tokenizer, k) | |
for word in add_words: | |
# print(word) | |
check_tokenize(base_tokenizer, word) | |
add_tokens() | |
check_all_tokens() | |