eson's picture
update
428b731
raw
history blame
No virus
4.65 kB
"""
## 简介
- bert和clue词典比较 https://github.com/CLUEbenchmark/CLUECorpus2020#%E8%AF%8D%E8%A1%A8%E4%BB%8B%E7%BB%8D
- 相关issue: https://github.com/google-research/bert/issues/396
- bert中文词典大小21128(2万)
- 英文字母都小写了(有没有不小写的?)
-
args:
-
-
output:
-
python bpe_oov.py \
--vocab-bpe vocab.google.txt \
--inputs ../raw/discovery_all \
--workers 60
# stderr打印在屏幕,stdout放在oov_lines
python bpe_oov.py \
--vocab-bpe vocab.clue_plus.txt \
--inputs ../raw/discovery_all \
--workers 60 > oov_lines
python bpe_oov.py \
--vocab-bpe vocab.clue_plus.txt \
--inputs ../raw/small/jd.train.raw \
--workers 60 > oov_lines
## 整词
"""
import argparse
from transformers import BertTokenizer
import contextlib
import sys
from collections import defaultdict
from multiprocessing import Pool
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--vocab-bpe",
type=str,
help='path to vocab.bpe',
)
parser.add_argument(
"--inputs",
nargs="+",
default=['-'],
help="input files to filter/encode",
)
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
for input in args.inputs
]
encoder = MultiprocessingEncoder(args.vocab_bpe)
pool = Pool(args.workers, initializer=encoder.initializer)
oov_lines = pool.imap(encoder.get_oov_lines, zip(*inputs), 100)
oov_count = defaultdict(int)
for i, oov_line in enumerate(oov_lines, start=1): # 主要的计算模块
for oov in oov_line:
oov_count[oov] += 1
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
sorted_oov = sorted(oov_count.items(), key=lambda kv:kv[1], reverse=True)
with open('oov', 'w', encoding='utf-8') as f_out:
f_out.write('\n'.join(['%s %d' % (k,v) for k, v in sorted_oov]))
class MultiprocessingEncoder(object):
def __init__(self, vocab_bpe):
self.vocab_bpe = vocab_bpe
def initializer(self): # 为啥不放到 __init__ ?
global bpe # 为什么用global,设置成成员变量不行吗?
bpe = BertTokenizer(self.vocab_bpe)
def get_oov(self, line):
global bpe
oov_tokens = []
for token in bpe.basic_tokenizer.tokenize(line, never_split=bpe.all_special_tokens):
for sub_token in bpe.wordpiece_tokenizer.tokenize(token):
if sub_token == '[UNK]':
oov_tokens.append(token)
if len(oov_tokens) > 0: # 不用在这里打印,因为有些明显需要添加的token
print(','.join(oov_tokens) + '\t' + line)
return oov_tokens
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def get_oov_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
all_oov = []
for line in lines:
line = line.strip()
oov_tokens = self.get_oov(line)
all_oov += oov_tokens
return all_oov
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0 and not self.args.keep_empty:
return ["EMPTY", None]
tokens = self.encode(line)
enc_lines.append(" ".join(tokens))
return ["PASS", enc_lines]
def test():
encoder = MultiprocessingEncoder('vocab.clue_plus.txt')
encoder.initializer()
line = '蔲驰的,africa❸ 11111111111165000mg❗2⃣piqueddasdasddasdasda,明天25℃,面积120㎡,大约2~3米' \
'3200×1800分辨率,TAS海关密码锁,PC镜片,采用A+节能能,胶包裏,包裹,薄至6㎜,鬼塚虎,' \
'多种矿物元素,特别是锶,靚眼,门闩和便携把手,箜篌刺绣,5㎝,锐蝮蛇竞技版鼠标,滑屛式,T桖,sub+dvi,' \
'呵护牙齦,Baumatic™ ,'
en = encoder.encode(line)
print(line)
print(en)
print(encoder.decode(en))
if __name__ == "__main__":
#main()
test()