|
import json |
|
from tokenizers import Tokenizer |
|
from data_sample.oov_base import jd_vocab_tokens |
|
from zhon.hanzi import punctuation as zh_punc |
|
|
|
def load_base_tokenizer(tokenizer_path): |
|
print("loading", tokenizer_path) |
|
data = json.load(open(tokenizer_path, "r", encoding="utf-8")) |
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True)) |
|
return data, tokenizer |
|
|
|
|
|
def append_token(word_list, base_tokenizer, unused_ids=None): |
|
""" |
|
append token to the end of vocab |
|
""" |
|
new_vocab = set() |
|
new_merges = set() |
|
|
|
data, base_tokenizer = base_tokenizer |
|
vocab = data["model"]["vocab"] |
|
merges = data["model"]["merges"] |
|
vocab_size = base_tokenizer.get_vocab_size(with_added_tokens=True) |
|
|
|
for word in word_list: |
|
encoding = base_tokenizer.encode(word) |
|
if len(encoding.ids) == 1: |
|
continue |
|
|
|
if len(encoding.ids) >= 4: |
|
print("[ERROR]: encoding不能超过4", word, encoding) |
|
|
|
tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] |
|
if "\u00e6\u00a5\u0143" in tokens: |
|
print(word) |
|
|
|
add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")] |
|
add_words = [token for token in add_tokens if len(token) > 1] |
|
new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json") |
|
|
|
append_token(add_words, new_tokenizer) |
|
|