""" 获取超低频token,用于裁剪 """ import copy import glob import json from collections import defaultdict def word_count(): from collections import Counter from megatron.data.indexed_dataset import MMapIndexedDataset counter = Counter() for file_name in glob.glob("data/jd/*.bin"): print(file_name) file_name = file_name[:-4] dataset = MMapIndexedDataset(file_name, skip_warmup=True) for doc in dataset: counter.update(doc) f_out = open("word_count.txt", "w", encoding="utf-8") for token_id, count in counter.most_common(): f_out.write("%d\t%d\n" % (token_id, count)) def get_unused_id(): pass def print_word_count(): from tokenizers import Tokenizer tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json") data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8")) vocab = data["model"]["vocab"] merges = data["model"]["merges"] merge_dict = {} sorted_parts = [] for merge in merges: idx = merge.find(" ") token_str = merge[:idx] + merge[idx + 1:] merge_dict[token_str] = (merge[:idx], merge[idx + 1:]) sorted_parts += [token_str, merge[:idx], merge[idx + 1:]] id2vocab = {idx: token for token, idx in vocab.items()} # 补充 sorted_parts,并排序 all_tokens = [line.strip().split("\t") for line in open("word_count.corpus.txt", "r", encoding="utf-8")] raw_token_count = {int(token_id): int(count) for token_id, count in all_tokens} sorted_parts = set(sorted_parts) for token_id in raw_token_count: if token_id in [35448, 40519]: print("ddd") token_str = id2vocab[token_id] if token_str not in sorted_parts: sorted_parts.add(token_str) # print(token_id, token_str, json.dumps(token_str), raw_token_count[token_id], " not in parts") sorted_parts = sorted(set(sorted_parts), key=lambda k: len(k), reverse=True) # 重新计算merge的频率 # token_count = copy.deepcopy(raw_token_count) token_count = defaultdict(int) for token_str in sorted_parts: # 从长到短 遍历 (否则要深度遍历,) token_id = vocab[token_str] if token_id in [35448, 40519]: print("ddd") count = raw_token_count.get(token_id, 0) token_count[token_id] += count # 原token 的词频 if token_str in merge_dict: if vocab[merge_dict[token_str][0]] in [35448, 40519] or vocab[merge_dict[token_str][1]] in [35448, 40519]: print("ddd") token_count[vocab[merge_dict[token_str][0]]] += token_count[token_id] token_count[vocab[merge_dict[token_str][1]]] += token_count[token_id] else: print(token_id, json.dumps(token_str)) # 重新排序 (按频率升序排列,相同频率按长度降序排列) sorted_token_count = sorted(token_count.items(), key=lambda kv: (kv[1], -len(id2vocab[kv[0]]))) f_out = open("word_count.corpus.sort_by_count.jsonl", "w", encoding="utf-8") for token_id, count in sorted_token_count: # for token_str, count in token_count.items(): token_str = id2vocab[token_id] # token_id = vocab[token_str] decode_str = tokenizer.decode([token_id]) # 解码会失真 if token_str in merge_dict: merges = " ".join(merge_dict[token_str]) else: merges = "NULL" f_out.write(json.dumps( {"id": token_id, "token": token_str, "merges": merges, "raw_count": raw_token_count.get(token_id, 0), "count": count, "decode_str": decode_str}) + "\n") def get_remove_words(): from tokenizers import Tokenizer tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json") data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8")) added_tokens = [token["id"] for token in data["added_tokens"]] vocab = data["model"]["vocab"] merges = data["model"]["merges"] id2vocab = {idx: token for token, idx in vocab.items()} merge_dict = {k.replace(" ", "", 1): k for k in merges} token_count = {} for line in open("word_count.corpus.sort_by_count.jsonl", "r", encoding="utf-8"): line_data = json.loads(line) token_id = int(line_data["id"]) count = int(line_data["count"]) token_count[token_id] = count f_out = open("word_count.corpus.remove.jsonl", "w", encoding="utf-8") remove_vocab_set = set() # # 1. 去掉错误token # error_tokens = [54611, 54612, 54613, 54614, 54615, 54616, 54617, 54618, 54619, 54620, 54621, 54622, # 54623, 54624, 54625, 54626, 54627, 54628, 54629, 54630, 54631, 54632, 54633] # for token_id in error_tokens: # token_str = id2vocab[token_id] # # token_str = tokenizer.id_to_token(token_id) # 失真 # remove_vocab_set.add(token_id) # f_out.write(json.dumps( # {"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": 0, # "type": "error-char"}) + "\n") # 2. 去掉超长token # for token_id in range(tokenizer.get_vocab_size()): # if token_id in added_tokens: # continue # token_str = id2vocab[token_id] # # token_str = tokenizer.id_to_token(token_id) # 也会失真,比如 54611 个token # decode_str = tokenizer.decode([token_id]) # decode会失真,比如 Ġ 会变成空格 # if len(decode_str) > 8 and len(set(decode_str)) < 3: # if token_id in remove_vocab_set: # continue # remove_vocab_set.add(token_id) # f_out.write( # json.dumps({"id": token_id, "token": token_str, # "merges": merge_dict.get(token_str), "count": token_count.get(token_id, 0), # "type": "按长度过滤"}, ensure_ascii=False) + "\n") # # # 删除依赖,(否则会造成 merges中存在oov的token) # # # for merge in merges: # if token_str in merge: # # if token_str + " " in merge or " " + token_str in merge: # parent_token_str = merge.replace(" ", "", 1) # parent_token_id = vocab[parent_token_str] # if parent_token_id in remove_vocab_set: # continue # remove_vocab_set.add(parent_token_id) # f_out.write( # json.dumps({"id": parent_token_id, "token": parent_token_str, # "merges": merge, "count": token_count.get(parent_token_id, 0), # "type": "按长度过滤-依赖删除"}, ensure_ascii=False) + "\n") # 3. 去掉低频token for token_id, count in list(token_count.items())[:25000]: # token_id = 6460 if token_id in added_tokens: continue if token_id in remove_vocab_set: continue token_str = tokenizer.id_to_token(token_id) # token_str = tokenizer.decode([int(token_id)]) if len(token_str.strip()) > 1: remove_vocab_set.add(token_id) f_out.write(json.dumps( {"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": count, "type": "remove by frequency"}) + "\n") ######## 已经按频率排序的,就不需要删除依赖了 # # 删除依赖,(否则会造成 merges中存在oov的token) # for merge in merges: # # if token_str + " " in merge or " " + token_str in merge: # if token_str in merge: # parent_token_str = merge.replace(" ", "", 1) # parent_token_id = vocab[parent_token_str] # if parent_token_id in remove_vocab_set: # continue # remove_vocab_set.add(parent_token_id) # f_out.write( # json.dumps({"id": parent_token_id, "token": parent_token_str, # "merges": merge, "count": token_count.get(parent_token_id, 0), # "type": "按频率过滤-依赖删除"}, ensure_ascii=False) + "\n") # remove 24969 tokens print("remove %d tokens" % (len(remove_vocab_set))) def ss(): pass # word_count() # print_word_count() get_remove_words()