import shutil import json from queue import Queue from tokenizers import Tokenizer from data_sample.oov_base import jd_vocab_tokens from zhon.hanzi import punctuation as zh_punc def load_base_tokenizer(tokenizer_path): print("loading", tokenizer_path) data = json.load(open(tokenizer_path, "r", encoding="utf-8")) tokenizer = Tokenizer.from_file(tokenizer_path) print("vocab_size with added_tokens:", tokenizer.get_vocab_size(with_added_tokens=True)) return data, tokenizer def insert_token(word, index): pass # 不能删除的token。比如初始统计是低频的,可以删除,但是新增词典里包含的。 def load_reserve_tokens(word_list, base_tokenizer): data, base_tokenizer = base_tokenizer reserved_token = set() for word in word_list: encoding = base_tokenizer.encode(word) tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] for i in range(0, len(encoding.ids)): reserved_token.add("".join(tokens[:i+1])) return reserved_token reserved_token = set() def append_token(word_list, base_tokenizer, output_tokenizer_path, unused_ids=None): """ append token to the end of vocab """ new_vocab = set() new_merges = set() data, base_tokenizer = base_tokenizer vocab = data["model"]["vocab"] merges = data["model"]["merges"] vocab_size = base_tokenizer.basic_count(with_added_tokens=True) for word in word_list: encoding = base_tokenizer.encode(word) if len(encoding.ids) == 1: continue if len(encoding.ids) >= 4: print("[ERROR]: encoding不能超过4", word, encoding) tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] # print("merging", word, json.dumps(tokens)) for i in range(1, len(encoding.ids)): new_vocab.add("".join(tokens[:i+1])) new_merges.add("".join(tokens[:i]) + " " + tokens[i]) # append to the end of vocab # print("new_vocab size", len(new_vocab)) # print("new_merges size", len(new_merges)) if unused_ids == None: for token in new_vocab: vocab[token] = vocab_size vocab_size += 1 merges += new_merges else: for iddx, token in enumerate(new_vocab): # print(unused_ids.qsize()) unused_token_id, unused_token_str, unused_merges = unused_ids.get() if unused_token_id == 39468: print("catch") if unused_token_str in reserved_token: print("skip unused token", unused_token_id, unused_token_str) unused_token_id, unused_token_str, unused_merges = unused_ids.get() print("[%d]merging %s to unused %s %s" % (unused_ids.qsize(), json.dumps(token), unused_token_id, json.dumps(unused_token_str)) ) vocab[token] = unused_token_id if unused_token_id != vocab.pop(unused_token_str): print("ERROR") # assert unused_token_id == vocab.pop(unused_token_str) merges.remove(unused_merges) # print(new_merges) merges += new_merges # print("共merge %d 个 token" % (len(new_vocab))) # print(json.dumps(list(new_vocab))) with open(output_tokenizer_path, "w", encoding="utf-8") as f_out: json.dump(data, f_out, indent=2) return data, base_tokenizer # data, base_tokenizer = load_base_tokenizer(output_tokenizer_path) # encoding = base_tokenizer.encode(word) # print(encoding.ids) def load_unused_id(): unused_ids = Queue(maxsize=0) for line in open("word_count.corpus.remove.jsonl", "r", encoding="utf-8"): line_data = json.loads(line) token_id = line_data["id"] token_str = line_data["token"] merges = line_data["merges"] unused_ids.put((token_id, token_str, merges)) # for i in range(2000): # unused_ids.get() return unused_ids def check_tokenize(base_tokenizer, word): data, base_tokenizer = base_tokenizer encodings = base_tokenizer.encode(word) assert len(encodings.ids) == 1 assert base_tokenizer.decode(encodings.ids) == word def add_tokens(): unused_ids = load_unused_id() add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")] add_chars = [char for token in add_tokens for char in token] add_chars = list(set(add_chars)) add_words = [token for token in add_tokens if len(token) > 1] tokenizer_path = "../20B_tokenizer_chinese.json" # tokenizer_path = "../../gpt_nexo_20b/20B_tokenizer.json" base_tokenizer = load_base_tokenizer(tokenizer_path) reserved_token.update(load_reserve_tokens(add_chars, base_tokenizer)) ## add chars append_token(add_chars, base_tokenizer, "20B_tokenizer.1.json", unused_ids=unused_ids) print(unused_ids.qsize()) # 22320 new_tokenizer = load_base_tokenizer("20B_tokenizer.1.json") append_token(add_words, new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids) new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") # # ## add words # while unused_ids.qsize() != 22320: # unused_ids.get() # assert unused_ids.qsize() == 22320 # # shutil.copyfile("20B_tokenizer.1.json", "20B_tokenizer.2.json") # while len(add_words) > 0: # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") # append_token([add_words.pop()], # new_tokenizer, "20B_tokenizer.2.json", unused_ids=unused_ids) # # new_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") def check_all_tokens(): add_tokens = [line.strip() for line in open("oov.add.txt", "r", encoding="utf-8")] add_chars = [char for token in add_tokens for char in token] add_chars = list(set(add_chars)) add_words = [token for token in add_tokens if len(token) > 1] # add_chars = ['吳'] base_tokenizer = load_base_tokenizer("20B_tokenizer.2.json") for k in add_chars: check_tokenize(base_tokenizer, k) for word in add_words: # print(word) check_tokenize(base_tokenizer, word) add_tokens() check_all_tokens()