eson's picture
update
751936e
raw
history blame
No virus
8.56 kB
"""
获取超低频token,用于裁剪
"""
import copy
import glob
import json
from collections import defaultdict
def word_count():
from collections import Counter
from megatron.data.indexed_dataset import MMapIndexedDataset
counter = Counter()
for file_name in glob.glob("data/jd/*.bin"):
print(file_name)
file_name = file_name[:-4]
dataset = MMapIndexedDataset(file_name, skip_warmup=True)
for doc in dataset:
counter.update(doc)
f_out = open("word_count.txt", "w", encoding="utf-8")
for token_id, count in counter.most_common():
f_out.write("%d\t%d\n" % (token_id, count))
def get_unused_id():
pass
def print_word_count():
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json")
data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8"))
vocab = data["model"]["vocab"]
merges = data["model"]["merges"]
merge_dict = {}
sorted_parts = []
for merge in merges:
idx = merge.find(" ")
token_str = merge[:idx] + merge[idx + 1:]
merge_dict[token_str] = (merge[:idx], merge[idx + 1:])
sorted_parts += [token_str, merge[:idx], merge[idx + 1:]]
id2vocab = {idx: token for token, idx in vocab.items()}
# 补充 sorted_parts,并排序
all_tokens = [line.strip().split("\t") for line in open("word_count.corpus.txt", "r", encoding="utf-8")]
raw_token_count = {int(token_id): int(count) for token_id, count in all_tokens}
sorted_parts = set(sorted_parts)
for token_id in raw_token_count:
if token_id in [35448, 40519]:
print("ddd")
token_str = id2vocab[token_id]
if token_str not in sorted_parts:
sorted_parts.add(token_str)
# print(token_id, token_str, json.dumps(token_str), raw_token_count[token_id], " not in parts")
sorted_parts = sorted(set(sorted_parts), key=lambda k: len(k), reverse=True)
# 重新计算merge的频率
# token_count = copy.deepcopy(raw_token_count)
token_count = defaultdict(int)
for token_str in sorted_parts: # 从长到短 遍历 (否则要深度遍历,)
token_id = vocab[token_str]
if token_id in [35448, 40519]:
print("ddd")
count = raw_token_count.get(token_id, 0)
token_count[token_id] += count # 原token 的词频
if token_str in merge_dict:
if vocab[merge_dict[token_str][0]] in [35448, 40519] or vocab[merge_dict[token_str][1]] in [35448, 40519]:
print("ddd")
token_count[vocab[merge_dict[token_str][0]]] += token_count[token_id]
token_count[vocab[merge_dict[token_str][1]]] += token_count[token_id]
else:
print(token_id, json.dumps(token_str))
# 重新排序 (按频率升序排列,相同频率按长度降序排列)
sorted_token_count = sorted(token_count.items(), key=lambda kv: (kv[1], -len(id2vocab[kv[0]])))
f_out = open("word_count.corpus.sort_by_count.jsonl", "w", encoding="utf-8")
for token_id, count in sorted_token_count:
# for token_str, count in token_count.items():
token_str = id2vocab[token_id]
# token_id = vocab[token_str]
decode_str = tokenizer.decode([token_id]) # 解码会失真
if token_str in merge_dict:
merges = " ".join(merge_dict[token_str])
else:
merges = "NULL"
f_out.write(json.dumps(
{"id": token_id, "token": token_str, "merges": merges, "raw_count": raw_token_count.get(token_id, 0),
"count": count, "decode_str": decode_str}) + "\n")
def get_remove_words():
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json")
data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8"))
added_tokens = [token["id"] for token in data["added_tokens"]]
vocab = data["model"]["vocab"]
merges = data["model"]["merges"]
id2vocab = {idx: token for token, idx in vocab.items()}
merge_dict = {k.replace(" ", "", 1): k for k in merges}
token_count = {}
for line in open("word_count.corpus.sort_by_count.jsonl", "r", encoding="utf-8"):
line_data = json.loads(line)
token_id = int(line_data["id"])
count = int(line_data["count"])
token_count[token_id] = count
f_out = open("word_count.corpus.remove.jsonl", "w", encoding="utf-8")
remove_vocab_set = set()
# # 1. 去掉错误token
# error_tokens = [54611, 54612, 54613, 54614, 54615, 54616, 54617, 54618, 54619, 54620, 54621, 54622,
# 54623, 54624, 54625, 54626, 54627, 54628, 54629, 54630, 54631, 54632, 54633]
# for token_id in error_tokens:
# token_str = id2vocab[token_id]
# # token_str = tokenizer.id_to_token(token_id) # 失真
# remove_vocab_set.add(token_id)
# f_out.write(json.dumps(
# {"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": 0,
# "type": "error-char"}) + "\n")
# 2. 去掉超长token
# for token_id in range(tokenizer.get_vocab_size()):
# if token_id in added_tokens:
# continue
# token_str = id2vocab[token_id]
# # token_str = tokenizer.id_to_token(token_id) # 也会失真,比如 54611 个token
# decode_str = tokenizer.decode([token_id]) # decode会失真,比如 Ġ 会变成空格
# if len(decode_str) > 8 and len(set(decode_str)) < 3:
# if token_id in remove_vocab_set:
# continue
# remove_vocab_set.add(token_id)
# f_out.write(
# json.dumps({"id": token_id, "token": token_str,
# "merges": merge_dict.get(token_str), "count": token_count.get(token_id, 0),
# "type": "按长度过滤"}, ensure_ascii=False) + "\n")
#
# # 删除依赖,(否则会造成 merges中存在oov的token)
# #
# for merge in merges:
# if token_str in merge:
# # if token_str + " " in merge or " " + token_str in merge:
# parent_token_str = merge.replace(" ", "", 1)
# parent_token_id = vocab[parent_token_str]
# if parent_token_id in remove_vocab_set:
# continue
# remove_vocab_set.add(parent_token_id)
# f_out.write(
# json.dumps({"id": parent_token_id, "token": parent_token_str,
# "merges": merge, "count": token_count.get(parent_token_id, 0),
# "type": "按长度过滤-依赖删除"}, ensure_ascii=False) + "\n")
# 3. 去掉低频token
for token_id, count in list(token_count.items())[:25000]:
# token_id = 6460
if token_id in added_tokens:
continue
if token_id in remove_vocab_set:
continue
token_str = tokenizer.id_to_token(token_id)
# token_str = tokenizer.decode([int(token_id)])
if len(token_str.strip()) > 1:
remove_vocab_set.add(token_id)
f_out.write(json.dumps(
{"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": count,
"type": "remove by frequency"}) + "\n")
######## 已经按频率排序的,就不需要删除依赖了
# # 删除依赖,(否则会造成 merges中存在oov的token)
# for merge in merges:
# # if token_str + " " in merge or " " + token_str in merge:
# if token_str in merge:
# parent_token_str = merge.replace(" ", "", 1)
# parent_token_id = vocab[parent_token_str]
# if parent_token_id in remove_vocab_set:
# continue
# remove_vocab_set.add(parent_token_id)
# f_out.write(
# json.dumps({"id": parent_token_id, "token": parent_token_str,
# "merges": merge, "count": token_count.get(parent_token_id, 0),
# "type": "按频率过滤-依赖删除"}, ensure_ascii=False) + "\n")
# remove 24969 tokens
print("remove %d tokens" % (len(remove_vocab_set)))
def ss():
pass
# word_count()
# print_word_count()
get_remove_words()