Spaces:
Runtime error
Runtime error
""" | |
## 比gpt_neox的词典 对中文支持更好。 | |
1. 汉字: | |
编码长度统计: Counter({1: 3861, 2: 1909}) | |
平均编码长度: 1.330 | |
2. 中文标点 | |
编码长度统计: Counter({1: 47, 2: 34}) | |
平均编码长度: 1.4197530864197532 | |
""" | |
from collections import Counter | |
from transformers import AutoTokenizer, BloomTokenizerFast | |
from data_sample.oov_base import jd_vocab_tokens | |
from utils.text_util import is_chinese | |
from zhon.hanzi import punctuation as zh_punc | |
# tokenizer = AutoTokenizer.from_pretrained("tokenizer") | |
tokenizer = BloomTokenizerFast.from_pretrained("tokenizer") | |
def test_coding_length(vocab, filter=None): | |
all_length = [] | |
for word in vocab: | |
if len(word) > 1: | |
continue | |
if filter is not None and filter(word): | |
continue | |
tokens = tokenizer.encode(word) | |
all_length.append(len(tokens)) | |
if len(tokens) > 1: | |
print(word, tokens, ) | |
print("编码长度统计:", Counter(all_length)) | |
print("平均编码长度:", sum(all_length)/len(all_length)) | |
if __name__ == "__main__": | |
# test_coding_length(jd_vocab_tokens, filter=lambda k: not is_chinese(k)) | |
test_coding_length(zh_punc) |