""" merge 是干嘛的? ## 结果 共merge 4357 个 token """ import json from tokenizers import Tokenizer from data_sample.oov_base import jd_vocab_tokens from zhon.hanzi import punctuation as zh_punc def load_base_tokenizer(vocab_path): data = json.load(open(vocab_path, "r", encoding="utf-8")) tokenizer = Tokenizer.from_file(vocab_path) print("vocab_size with added_tokens:", ) return data, tokenizer data, base_tokenizer = load_base_tokenizer("../gpt_nexo_20b/20B_tokenizer.json") vocab = data["model"]["vocab"] merges = data["model"]["merges"] vocab_size = base_tokenizer.get_vocab_size(with_added_tokens=True) """ 方式一:原有的added_tokens保持id不变。方式二:原有的added_tokens进行id移位。 以下采用方式一。 """ new_added_tokens = {} for word in jd_vocab_tokens + list(zh_punc): if len(word) > 1 or word in new_added_tokens: continue encoding = base_tokenizer.encode(word) # if len(encoding.ids) > 1: if len(encoding.ids) == 2: # 3个的,怎么处理? tokens = [base_tokenizer.id_to_token(token_id) for token_id in encoding.ids] # print("merging", vocab_size, word, json.dumps(tokens)) vocab["".join(tokens)] = vocab_size new_added_tokens[word] = vocab_size vocab_size += 1 merges.append(" ".join(tokens)) print("共merge %d 个 token" % (len(new_added_tokens))) with open("20B_tokenizer_chinese.json", "w", encoding="utf-8") as f_out: json.dump(data, f_out, indent=2) ## check tokenizer = Tokenizer.from_file("20B_tokenizer_chinese.json") all_error_ids = [] for word, idx in new_added_tokens.items(): decode_str = tokenizer.decode([idx]) if word != decode_str: all_error_ids.append(idx) print(idx, word, decode_str) print(all_error_ids)