""" ## 比gpt_neox的词典 对中文支持更好。 1. 汉字: 编码长度统计: Counter({1: 3861, 2: 1909}) 平均编码长度: 1.330 2. 中文标点 编码长度统计: Counter({1: 47, 2: 34}) 平均编码长度: 1.4197530864197532 """ from collections import Counter from transformers import AutoTokenizer, BloomTokenizerFast from data_sample.oov_base import jd_vocab_tokens from utils.text_util import is_chinese from zhon.hanzi import punctuation as zh_punc # tokenizer = AutoTokenizer.from_pretrained("tokenizer") tokenizer = BloomTokenizerFast.from_pretrained("tokenizer") def test_coding_length(vocab, filter=None): all_length = [] for word in vocab: if len(word) > 1: continue if filter is not None and filter(word): continue tokens = tokenizer.encode(word) all_length.append(len(tokens)) if len(tokens) > 1: print(word, tokens, ) print("编码长度统计:", Counter(all_length)) print("平均编码长度:", sum(all_length)/len(all_length)) if __name__ == "__main__": # test_coding_length(jd_vocab_tokens, filter=lambda k: not is_chinese(k)) test_coding_length(zh_punc)