|
""" |
|
|
|
## 比gpt_neox的词典 对中文支持更好。 |
|
|
|
1. 汉字: |
|
编码长度统计: Counter({1: 3861, 2: 1909}) |
|
平均编码长度: 1.330 |
|
|
|
2. 中文标点 |
|
编码长度统计: Counter({1: 47, 2: 34}) |
|
平均编码长度: 1.4197530864197532 |
|
|
|
|
|
""" |
|
|
|
from collections import Counter |
|
from transformers import AutoTokenizer, BloomTokenizerFast |
|
from data_sample.oov_base import jd_vocab_tokens |
|
from utils.text_util import is_chinese |
|
from zhon.hanzi import punctuation as zh_punc |
|
|
|
|
|
tokenizer = BloomTokenizerFast.from_pretrained("tokenizer") |
|
|
|
|
|
|
|
def test_coding_length(vocab, filter=None): |
|
all_length = [] |
|
for word in vocab: |
|
if len(word) > 1: |
|
continue |
|
if filter is not None and filter(word): |
|
continue |
|
tokens = tokenizer.encode(word) |
|
all_length.append(len(tokens)) |
|
if len(tokens) > 1: |
|
print(word, tokens, ) |
|
|
|
print("编码长度统计:", Counter(all_length)) |
|
print("平均编码长度:", sum(all_length)/len(all_length)) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_coding_length(zh_punc) |