xusong28
commited on
Commit
•
19fb2f0
1
Parent(s):
c10350f
update
Browse files- app.py +15 -66
- app2.py +0 -22
- demo_corrector.py +8 -7
app.py
CHANGED
@@ -1,74 +1,23 @@
|
|
1 |
# coding=utf-8
|
2 |
# author: xusong <xusong28@jd.com>
|
3 |
-
# time: 2022/8/
|
4 |
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
|
12 |
-
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
|
13 |
-
|
14 |
-
|
15 |
-
def ai_text(text):
|
16 |
-
with torch.no_grad():
|
17 |
-
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
|
18 |
-
|
19 |
-
def to_ner(corrected_sent, errs):
|
20 |
-
output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
|
21 |
-
enumerate(errs)]
|
22 |
-
return {"text": corrected_sent, "entities": output}
|
23 |
-
|
24 |
-
def get_errors(corrected_text, origin_text):
|
25 |
-
sub_details = []
|
26 |
-
for i, ori_char in enumerate(origin_text):
|
27 |
-
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
28 |
-
# add unk word
|
29 |
-
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
|
30 |
-
continue
|
31 |
-
if i >= len(corrected_text):
|
32 |
-
continue
|
33 |
-
if ori_char != corrected_text[i]:
|
34 |
-
if ori_char.lower() == corrected_text[i]:
|
35 |
-
# pass english upper char
|
36 |
-
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
|
37 |
-
continue
|
38 |
-
sub_details.append((ori_char, corrected_text[i], i, i + 1))
|
39 |
-
sub_details = sorted(sub_details, key=operator.itemgetter(2))
|
40 |
-
return corrected_text, sub_details
|
41 |
-
|
42 |
-
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
43 |
-
corrected_text = _text[:len(text)]
|
44 |
-
corrected_text, details = get_errors(corrected_text, text)
|
45 |
-
print(text, ' => ', corrected_text, details)
|
46 |
-
return to_ner(corrected_text, details), details
|
47 |
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
if __name__ == '__main__':
|
50 |
-
print(ai_text('少先队员因该为老人让坐'))
|
51 |
|
52 |
-
|
53 |
-
['真麻烦你了。希望你们好好的跳无'],
|
54 |
-
['少先队员因该为老人让坐'],
|
55 |
-
['机七学习是人工智能领遇最能体现智能的一个分知'],
|
56 |
-
['今天心情很好'],
|
57 |
-
['他法语说的很好,的语也不错'],
|
58 |
-
['他们的吵翻很不错,再说他们做的咖喱鸡也好吃'],
|
59 |
-
]
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
inputs="textbox",
|
64 |
-
outputs=[
|
65 |
-
gr.outputs.HighlightedText(
|
66 |
-
label="Output",
|
67 |
-
show_legend=True,
|
68 |
-
),
|
69 |
-
gr.outputs.JSON()
|
70 |
-
],
|
71 |
-
title="Chinese Spelling Correction Model shibing624/macbert4csc-base-chinese",
|
72 |
-
description="Copy or input error Chinese text. Submit and the machine will correct text.",
|
73 |
-
article="Link to <a href='https://github.com/shibing624/pycorrector' style='color:blue;' target='_blank\'>Github REPO</a>",
|
74 |
-
examples=examples).launch()
|
|
|
1 |
# coding=utf-8
|
2 |
# author: xusong <xusong28@jd.com>
|
3 |
+
# time: 2022/8/23 16:06
|
4 |
|
5 |
+
"""
|
6 |
+
https://gradio.app/docs/#tabbedinterface-header
|
7 |
|
8 |
+
## 更多任务
|
9 |
+
- 抽取式摘要
|
10 |
+
- 检索式对话 、 抽取式问答
|
11 |
+
-
|
12 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
import gradio as gr
|
15 |
+
from demo_sum import sum_iface
|
16 |
+
from demo_mlm import mlm_iface
|
17 |
+
from demo_corrector import corr_iface
|
18 |
|
|
|
|
|
19 |
|
20 |
+
demo = gr.TabbedInterface([sum_iface, mlm_iface, corr_iface], ["生成式摘要", "文本填词", "句子纠错"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
if __name__ == "__main__":
|
23 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app2.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# coding=utf-8
|
2 |
-
# author: xusong <xusong28@jd.com>
|
3 |
-
# time: 2022/8/23 16:06
|
4 |
-
|
5 |
-
"""
|
6 |
-
https://gradio.app/docs/#tabbedinterface-header
|
7 |
-
|
8 |
-
## 更多任务
|
9 |
-
- 抽取式摘要
|
10 |
-
- 检索式对话 、 抽取式问答
|
11 |
-
-
|
12 |
-
"""
|
13 |
-
|
14 |
-
import gradio as gr
|
15 |
-
from demo_sum import sum_iface
|
16 |
-
from demo_mlm import mlm_iface
|
17 |
-
|
18 |
-
|
19 |
-
demo = gr.TabbedInterface([sum_iface, mlm_iface], ["生成式摘要", "文本填词", "句子纠错"])
|
20 |
-
|
21 |
-
if __name__ == "__main__":
|
22 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_corrector.py
CHANGED
@@ -32,7 +32,7 @@ class KplugCorrector(BertCorrector):
|
|
32 |
logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
|
33 |
|
34 |
|
35 |
-
|
36 |
|
37 |
error_sentences = [
|
38 |
'少先队员因该为老人让坐',
|
@@ -49,8 +49,8 @@ def mock_data():
|
|
49 |
|
50 |
def correct(sent):
|
51 |
|
52 |
-
|
53 |
-
corrected_sent, errs = mock_data()
|
54 |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
|
55 |
output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
|
56 |
enumerate(errs)]
|
@@ -69,12 +69,13 @@ corr_iface = gr.Interface(
|
|
69 |
label="输入文本",
|
70 |
default="少先队员因该为老人让坐"),
|
71 |
outputs=[
|
72 |
-
gr.HighlightedText(
|
73 |
-
label="
|
74 |
show_legend=True,
|
75 |
-
# visible=False
|
76 |
),
|
77 |
-
gr.JSON(
|
|
|
|
|
78 |
],
|
79 |
examples=error_sentences,
|
80 |
title="文本纠错(Corrector)",
|
|
|
32 |
logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
|
33 |
|
34 |
|
35 |
+
corrector = KplugCorrector()
|
36 |
|
37 |
error_sentences = [
|
38 |
'少先队员因该为老人让坐',
|
|
|
49 |
|
50 |
def correct(sent):
|
51 |
|
52 |
+
corrected_sent, errs = corrector.bert_correct(sent)
|
53 |
+
# corrected_sent, errs = mock_data()
|
54 |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
|
55 |
output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
|
56 |
enumerate(errs)]
|
|
|
69 |
label="输入文本",
|
70 |
default="少先队员因该为老人让坐"),
|
71 |
outputs=[
|
72 |
+
gr.outputs.HighlightedText(
|
73 |
+
label="Output",
|
74 |
show_legend=True,
|
|
|
75 |
),
|
76 |
+
gr.outputs.JSON(
|
77 |
+
label="JSON Output"
|
78 |
+
)
|
79 |
],
|
80 |
examples=error_sentences,
|
81 |
title="文本纠错(Corrector)",
|