File size: 2,770 Bytes
2bb0b26 a39e93b c10350f a39e93b 5b47e63 a39e93b c10350f a39e93b c10350f a39e93b 01920f9 a39e93b c10350f a39e93b 19fb2f0 a39e93b c10350f a39e93b c10350f 19fb2f0 c10350f 7faefc8 a39e93b 7faefc8 c10350f 7faefc8 19fb2f0 c10350f 7faefc8 19fb2f0 c10350f 5b47e63 a39e93b c10350f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 17:08
import time
import torch
import gradio as gr
from info import article
from transformers import FillMaskPipeline
from transformers import BertTokenizer
from kplug.modeling_kplug import KplugForMaskedLM
from pycorrector.bert.bert_corrector import BertCorrector
from pycorrector import config
from loguru import logger
device_id = 0 if torch.cuda.is_available() else -1
class KplugCorrector(BertCorrector):
def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id):
super(BertCorrector, self).__init__()
self.name = 'kplug_corrector'
t1 = time.time()
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder")
model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder")
self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device)
if self.model:
self.mask = self.model.tokenizer.mask_token
logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
corrector = KplugCorrector()
error_sentences = [
'少先队员因该为老人让坐',
'机七学习是人工智能领遇最能体现智能的一个分知',
'今天心情很好',
]
def mock_data():
corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知'
errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)]
return corrected_sent, errs
def correct(sent):
corrected_sent, errs = corrector.bert_correct(sent)
# corrected_sent, errs = mock_data()
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
enumerate(errs)]
return {"text": corrected_sent, "entities": output}, errs
def test():
for sent in error_sentences:
corrected_sent, err = corrector.bert_correct(sent)
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err))
corr_iface = gr.Interface(
fn=correct,
inputs=gr.Textbox(
label="输入文本",
value="少先队员因该为老人让坐"),
outputs=[
gr.HighlightedText(
label="Output",
show_legend=True,
),
gr.JSON(
label="JSON Output"
)
],
examples=error_sentences,
title="文本纠错(Corrector)",
description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议',
article=article
)
if __name__ == "__main__":
# test()
# correct("少先队员因该为老人让坐")
corr_iface.launch()
|