# coding=utf-8 # author: xusong # time: 2022/8/23 17:08 import time import torch import gradio as gr from info import article from transformers import FillMaskPipeline from transformers import BertTokenizer from kplug.modeling_kplug import KplugForMaskedLM from pycorrector.bert.bert_corrector import BertCorrector from pycorrector import config from loguru import logger device_id = 0 if torch.cuda.is_available() else -1 css = """ .category-legend {display: none !important} """ class KplugCorrector(BertCorrector): def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id): super(BertCorrector, self).__init__() self.name = 'kplug_corrector' t1 = time.time() tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device) if self.model: self.mask = self.model.tokenizer.mask_token logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1)) corrector = KplugCorrector() error_sentences = [ '少先队员因该为老人让坐', '机七学习是人工智能领遇最能体现智能的一个分知', '今天心情很好', ] def mock_data(): corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知' errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)] return corrected_sent, errs def correct(sent): """ {"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext """ corrected_sent, errs = corrector.bert_correct(sent) # corrected_sent, errs = mock_data() print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs)) output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in enumerate(errs)] return {"text": corrected_sent, "entities": output}, errs def test(): for sent in error_sentences: corrected_sent, err = corrector.bert_correct(sent) print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err)) corr_iface = gr.Interface( fn=correct, inputs=gr.Textbox( label="输入文本", value="少先队员因该为老人让坐"), outputs=[ gr.HighlightedText( label="文本纠错", show_legend=True, ), gr.JSON( # label="JSON Output" ) ], examples=error_sentences, title="文本纠错(Corrector)", description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议', article=article, css=css ) if __name__ == "__main__": # test() # correct("少先队员因该为老人让坐") corr_iface.launch()