File size: 2,988 Bytes
2bb0b26
 
 
a39e93b
c10350f
 
a39e93b
5b47e63
a39e93b
 
c10350f
 
 
 
 
 
 
 
dc5d472
 
 
 
c10350f
a39e93b
c10350f
 
 
 
a39e93b
01920f9
 
a39e93b
c10350f
 
 
 
a39e93b
 
19fb2f0
a39e93b
c10350f
 
 
 
a39e93b
 
c10350f
 
 
 
 
 
 
 
d0547d2
 
 
19fb2f0
 
c10350f
 
 
 
 
 
 
 
 
 
 
 
 
 
7faefc8
a39e93b
7faefc8
c10350f
7faefc8
dc5d472
c10350f
dc5d472
c10350f
7faefc8
d0547d2
19fb2f0
c10350f
 
 
5b47e63
dc5d472
 
a39e93b
 
 
c10350f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# coding=utf-8
# author: xusong <xusong28@jd.com>
# time: 2022/8/23 17:08

import time
import torch
import gradio as gr
from info import article
from transformers import FillMaskPipeline
from transformers import BertTokenizer
from kplug.modeling_kplug import KplugForMaskedLM
from pycorrector.bert.bert_corrector import BertCorrector
from pycorrector import config
from loguru import logger

device_id = 0 if torch.cuda.is_available() else -1


css = """
.category-legend {display: none !important}
"""

class KplugCorrector(BertCorrector):

    def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id):
        super(BertCorrector, self).__init__()
        self.name = 'kplug_corrector'
        t1 = time.time()

        tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder")
        model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder")

        self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device)
        if self.model:
            self.mask = self.model.tokenizer.mask_token
            logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))


corrector = KplugCorrector()

error_sentences = [
    '少先队员因该为老人让坐',
    '机七学习是人工智能领遇最能体现智能的一个分知',
    '今天心情很好',
]


def mock_data():
    corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知'
    errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)]
    return corrected_sent, errs


def correct(sent):
    """
    {"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext
    """
    corrected_sent, errs = corrector.bert_correct(sent)
    # corrected_sent, errs = mock_data()
    print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
    output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
              enumerate(errs)]
    return {"text": corrected_sent, "entities": output}, errs


def test():
    for sent in error_sentences:
        corrected_sent, err = corrector.bert_correct(sent)
        print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err))


corr_iface = gr.Interface(
    fn=correct,
    inputs=gr.Textbox(
        label="输入文本",
        value="少先队员因该为老人让坐"),
    outputs=[
        gr.HighlightedText(
            label="文本纠错",
            show_legend=True,

        ),
        gr.JSON(
            # label="JSON Output"
        )
    ],
    examples=error_sentences,
    title="文本纠错(Corrector)",
    description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议',
    article=article,
    css=css
)

if __name__ == "__main__":
    # test()
    # correct("少先队员因该为老人让坐")
    corr_iface.launch()