tokenizer-arena / utils /compress_rate_util.py
eson's picture
update compress rate
367a536
raw history blame
No virus
5.08 kB
"""
中文数据:clue superclue
英文数据:glue cnn_dailymail gigaword
代码数据:
数字:
"""
import json
import os
import sys
import pandas as pd
from datasets import load_dataset
from utils.log_util import logger
from vocab import load_tokener
from typing import List
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
common_units = ["g_bytes/b_tokens", "b_tokens/g_bytes", "t_bytes/t_tokens", "t_tokens/t_bytes", "n_chars/n_tokens", ]
common_corpuses = ["cc100-en", "cc100-zh-Hans", "cc100-es"]
# code: https://huggingface.co/datasets/codeparrot/github-code-clean python java c sql html
# math:
def get_n_bytes_of_string(string_text):
n_bytes = len(string_text.encode("utf-8"))
return n_bytes
def unit_convertor(stat, unit):
n_tokens = stat["n_tokens"]
n_chars = stat["n_chars"]
n_bytes = stat["n_bytes"]
n_tokens_in_billion = n_tokens / (1000 * 1000 * 1000)
n_tokens_in_trillion = n_tokens / (1000 * 1000 * 1000 * 1000)
n_bytes_in_mb = n_bytes / (1024 * 1024)
n_bytes_in_gb = n_bytes_in_mb / 1024
n_bytes_in_tb = n_bytes_in_gb / 1024
# n_chars_in_billion = n_chars / (1000 * 1000 * 1000)
if unit == "n_tokens/n_bytes":
value = n_tokens / n_bytes
elif unit == "n_chars/n_tokens": # 重要:平均一个token包含多少个字符。
value = n_chars / n_tokens
elif unit == "n_tokens/n_chars": # 一个中文汉字需要几个token?
value = n_tokens / n_chars
elif unit == "g_bytes/b_tokens":
value = n_bytes_in_gb / n_tokens_in_billion
elif unit == "b_tokens/g_bytes":
value = n_tokens_in_billion / n_bytes_in_gb
elif unit == "t_bytes/t_tokens": # 重要:
value = n_bytes_in_tb / n_tokens_in_trillion
elif unit == "t_tokens/t_bytes":
value = n_tokens_in_trillion / n_bytes_in_tb
else:
raise "measure not support"
return round(value, 2)
def pprint(stats):
table = []
for tokenizer_name, stat in stats.items():
columns = {"tokenizer": tokenizer_name, "vocab_size": stat["vocab_size"]}
for unit in common_units:
if unit not in stat:
columns[unit] = unit_convertor(stat, unit)
else:
logger.error(f"unit {unit} not support")
table.append(columns)
df = pd.DataFrame(table)
# print(df.to_markdown(index=False, tablefmt='fancy_grid'))
logger.info(f"\n{df.to_markdown(index=False)}")
cache = {}
def tokenize_corpus(tokenizer, corpuses, cache_dir="stats/compress_rate"):
"""
这个要独立的cache,因为速度慢。
:param tokenizer:
:param corpuses:
:param cache_dir:
:return:
"""
def _tokenize(tokenizer, datasets):
n_tokens = 0
n_chars = 0
n_bytes = 0
for dataset in datasets:
for item in dataset:
text = item["text"]
n_bytes += get_n_bytes_of_string(text)
n_chars += len(text)
encodings = tokenizer.encode(text)
n_tokens += len(encodings)
stat = {
"vocab_size": tokenizer.vocab_size,
"n_bytes": n_bytes,
"n_tokens": n_tokens,
"n_chars": n_chars,
}
return stat
tokenizer_name = tokenizer.alias
cache_id = f"{tokenizer_name}.{'.'.join(corpuses)}"
# L1: in-memory cache
if cache_id in cache:
logger.info(f"loading {cache_id} from in-memory cache")
return cache[cache_id]
# L2: file cache
cache_dir = os.path.join(CURRENT_DIR, f"../{cache_dir}")
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, f"{cache_id}.json")
if os.path.exists(cache_path):
logger.info(f"loading {cache_id} from file cache")
stat = json.load(open(cache_path, "r", encoding="utf-8"))
cache[cache_id] = stat
return stat
# tokenize corpus
datasets = [load_dataset("eson/cc100-samples", corpus.replace("cc100-", ""), split="train") for corpus in corpuses]
stat = _tokenize(tokenizer, datasets)
logger.info(f"saving {cache_id} to {cache_path}")
json.dump(stat, open(cache_path, "w", encoding="utf-8"))
logger.info(f"saving {cache_id} to in-memory cache")
cache[cache_id] = stat
return stat
def test():
tokenizer_name = "gpt_4"
tokenizer = load_tokener(tokenizer_name)
stats = {tokenizer_name: tokenize_corpus(tokenizer, ["cc100-en", "cc100-zh-Hans"])}
pprint(stats)
def main():
from vocab import all_tokenizers
if len(sys.argv) == 3:
tokenizers = [sys.argv[1]]
corpuses = [sys.argv[2]]
else:
tokenizers = all_tokenizers
corpuses = common_corpuses
stats = {}
for lang in corpuses:
print("###" * 10 + lang)
for tokenizer_name in tokenizers:
tokenizer = load_tokener(tokenizer_name)
stat = tokenize_corpus(tokenizer, [lang])
stats[tokenizer_name] = stat
pprint(stats)
if __name__ == "__main__":
main()
# test()