Spaces:
Running
Running
""" | |
获取超低频token,用于裁剪 | |
""" | |
import copy | |
import glob | |
import json | |
from collections import defaultdict | |
def word_count(): | |
from collections import Counter | |
from megatron.data.indexed_dataset import MMapIndexedDataset | |
counter = Counter() | |
for file_name in glob.glob("data/jd/*.bin"): | |
print(file_name) | |
file_name = file_name[:-4] | |
dataset = MMapIndexedDataset(file_name, skip_warmup=True) | |
for doc in dataset: | |
counter.update(doc) | |
f_out = open("word_count.txt", "w", encoding="utf-8") | |
for token_id, count in counter.most_common(): | |
f_out.write("%d\t%d\n" % (token_id, count)) | |
def get_unused_id(): | |
pass | |
def print_word_count(): | |
from tokenizers import Tokenizer | |
tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json") | |
data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8")) | |
vocab = data["model"]["vocab"] | |
merges = data["model"]["merges"] | |
merge_dict = {} | |
sorted_parts = [] | |
for merge in merges: | |
idx = merge.find(" ") | |
token_str = merge[:idx] + merge[idx + 1:] | |
merge_dict[token_str] = (merge[:idx], merge[idx + 1:]) | |
sorted_parts += [token_str, merge[:idx], merge[idx + 1:]] | |
id2vocab = {idx: token for token, idx in vocab.items()} | |
# 补充 sorted_parts,并排序 | |
all_tokens = [line.strip().split("\t") for line in open("word_count.corpus.txt", "r", encoding="utf-8")] | |
raw_token_count = {int(token_id): int(count) for token_id, count in all_tokens} | |
sorted_parts = set(sorted_parts) | |
for token_id in raw_token_count: | |
if token_id in [35448, 40519]: | |
print("ddd") | |
token_str = id2vocab[token_id] | |
if token_str not in sorted_parts: | |
sorted_parts.add(token_str) | |
# print(token_id, token_str, json.dumps(token_str), raw_token_count[token_id], " not in parts") | |
sorted_parts = sorted(set(sorted_parts), key=lambda k: len(k), reverse=True) | |
# 重新计算merge的频率 | |
# token_count = copy.deepcopy(raw_token_count) | |
token_count = defaultdict(int) | |
for token_str in sorted_parts: # 从长到短 遍历 (否则要深度遍历,) | |
token_id = vocab[token_str] | |
if token_id in [35448, 40519]: | |
print("ddd") | |
count = raw_token_count.get(token_id, 0) | |
token_count[token_id] += count # 原token 的词频 | |
if token_str in merge_dict: | |
if vocab[merge_dict[token_str][0]] in [35448, 40519] or vocab[merge_dict[token_str][1]] in [35448, 40519]: | |
print("ddd") | |
token_count[vocab[merge_dict[token_str][0]]] += token_count[token_id] | |
token_count[vocab[merge_dict[token_str][1]]] += token_count[token_id] | |
else: | |
print(token_id, json.dumps(token_str)) | |
# 重新排序 (按频率升序排列,相同频率按长度降序排列) | |
sorted_token_count = sorted(token_count.items(), key=lambda kv: (kv[1], -len(id2vocab[kv[0]]))) | |
f_out = open("word_count.corpus.sort_by_count.jsonl", "w", encoding="utf-8") | |
for token_id, count in sorted_token_count: | |
# for token_str, count in token_count.items(): | |
token_str = id2vocab[token_id] | |
# token_id = vocab[token_str] | |
decode_str = tokenizer.decode([token_id]) # 解码会失真 | |
if token_str in merge_dict: | |
merges = " ".join(merge_dict[token_str]) | |
else: | |
merges = "NULL" | |
f_out.write(json.dumps( | |
{"id": token_id, "token": token_str, "merges": merges, "raw_count": raw_token_count.get(token_id, 0), | |
"count": count, "decode_str": decode_str}) + "\n") | |
def get_remove_words(): | |
from tokenizers import Tokenizer | |
tokenizer = Tokenizer.from_file("../20B_tokenizer_chinese.json") | |
data = json.load(open("../20B_tokenizer_chinese.json", "r", encoding="utf-8")) | |
added_tokens = [token["id"] for token in data["added_tokens"]] | |
vocab = data["model"]["vocab"] | |
merges = data["model"]["merges"] | |
id2vocab = {idx: token for token, idx in vocab.items()} | |
merge_dict = {k.replace(" ", "", 1): k for k in merges} | |
token_count = {} | |
for line in open("word_count.corpus.sort_by_count.jsonl", "r", encoding="utf-8"): | |
line_data = json.loads(line) | |
token_id = int(line_data["id"]) | |
count = int(line_data["count"]) | |
token_count[token_id] = count | |
f_out = open("word_count.corpus.remove.jsonl", "w", encoding="utf-8") | |
remove_vocab_set = set() | |
# # 1. 去掉错误token | |
# error_tokens = [54611, 54612, 54613, 54614, 54615, 54616, 54617, 54618, 54619, 54620, 54621, 54622, | |
# 54623, 54624, 54625, 54626, 54627, 54628, 54629, 54630, 54631, 54632, 54633] | |
# for token_id in error_tokens: | |
# token_str = id2vocab[token_id] | |
# # token_str = tokenizer.id_to_token(token_id) # 失真 | |
# remove_vocab_set.add(token_id) | |
# f_out.write(json.dumps( | |
# {"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": 0, | |
# "type": "error-char"}) + "\n") | |
# 2. 去掉超长token | |
# for token_id in range(tokenizer.get_vocab_size()): | |
# if token_id in added_tokens: | |
# continue | |
# token_str = id2vocab[token_id] | |
# # token_str = tokenizer.id_to_token(token_id) # 也会失真,比如 54611 个token | |
# decode_str = tokenizer.decode([token_id]) # decode会失真,比如 Ġ 会变成空格 | |
# if len(decode_str) > 8 and len(set(decode_str)) < 3: | |
# if token_id in remove_vocab_set: | |
# continue | |
# remove_vocab_set.add(token_id) | |
# f_out.write( | |
# json.dumps({"id": token_id, "token": token_str, | |
# "merges": merge_dict.get(token_str), "count": token_count.get(token_id, 0), | |
# "type": "按长度过滤"}, ensure_ascii=False) + "\n") | |
# | |
# # 删除依赖,(否则会造成 merges中存在oov的token) | |
# # | |
# for merge in merges: | |
# if token_str in merge: | |
# # if token_str + " " in merge or " " + token_str in merge: | |
# parent_token_str = merge.replace(" ", "", 1) | |
# parent_token_id = vocab[parent_token_str] | |
# if parent_token_id in remove_vocab_set: | |
# continue | |
# remove_vocab_set.add(parent_token_id) | |
# f_out.write( | |
# json.dumps({"id": parent_token_id, "token": parent_token_str, | |
# "merges": merge, "count": token_count.get(parent_token_id, 0), | |
# "type": "按长度过滤-依赖删除"}, ensure_ascii=False) + "\n") | |
# 3. 去掉低频token | |
for token_id, count in list(token_count.items())[:25000]: | |
# token_id = 6460 | |
if token_id in added_tokens: | |
continue | |
if token_id in remove_vocab_set: | |
continue | |
token_str = tokenizer.id_to_token(token_id) | |
# token_str = tokenizer.decode([int(token_id)]) | |
if len(token_str.strip()) > 1: | |
remove_vocab_set.add(token_id) | |
f_out.write(json.dumps( | |
{"id": token_id, "token": token_str, "merges": merge_dict.get(token_str), "count": count, | |
"type": "remove by frequency"}) + "\n") | |
######## 已经按频率排序的,就不需要删除依赖了 | |
# # 删除依赖,(否则会造成 merges中存在oov的token) | |
# for merge in merges: | |
# # if token_str + " " in merge or " " + token_str in merge: | |
# if token_str in merge: | |
# parent_token_str = merge.replace(" ", "", 1) | |
# parent_token_id = vocab[parent_token_str] | |
# if parent_token_id in remove_vocab_set: | |
# continue | |
# remove_vocab_set.add(parent_token_id) | |
# f_out.write( | |
# json.dumps({"id": parent_token_id, "token": parent_token_str, | |
# "merges": merge, "count": token_count.get(parent_token_id, 0), | |
# "type": "按频率过滤-依赖删除"}, ensure_ascii=False) + "\n") | |
# remove 24969 tokens | |
print("remove %d tokens" % (len(remove_vocab_set))) | |
def ss(): | |
pass | |
# word_count() | |
# print_word_count() | |
get_remove_words() | |