xusong28 commited on
Commit
19fb2f0
1 Parent(s): c10350f
Files changed (3) hide show
  1. app.py +15 -66
  2. app2.py +0 -22
  3. demo_corrector.py +8 -7
app.py CHANGED
@@ -1,74 +1,23 @@
1
  # coding=utf-8
2
  # author: xusong <xusong28@jd.com>
3
- # time: 2022/8/26 13:14
4
 
 
 
5
 
6
- import gradio as gr
7
- import operator
8
- import torch
9
- from transformers import BertTokenizer, BertForMaskedLM
10
-
11
- tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
12
- model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
13
-
14
-
15
- def ai_text(text):
16
- with torch.no_grad():
17
- outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
18
-
19
- def to_ner(corrected_sent, errs):
20
- output = [{"entity": "纠错", "word": err[1], "start": err[2], "end": err[3]} for i, err in
21
- enumerate(errs)]
22
- return {"text": corrected_sent, "entities": output}
23
-
24
- def get_errors(corrected_text, origin_text):
25
- sub_details = []
26
- for i, ori_char in enumerate(origin_text):
27
- if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
28
- # add unk word
29
- corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
30
- continue
31
- if i >= len(corrected_text):
32
- continue
33
- if ori_char != corrected_text[i]:
34
- if ori_char.lower() == corrected_text[i]:
35
- # pass english upper char
36
- corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
37
- continue
38
- sub_details.append((ori_char, corrected_text[i], i, i + 1))
39
- sub_details = sorted(sub_details, key=operator.itemgetter(2))
40
- return corrected_text, sub_details
41
-
42
- _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
43
- corrected_text = _text[:len(text)]
44
- corrected_text, details = get_errors(corrected_text, text)
45
- print(text, ' => ', corrected_text, details)
46
- return to_ner(corrected_text, details), details
47
 
 
 
 
 
48
 
49
- if __name__ == '__main__':
50
- print(ai_text('少先队员因该为老人让坐'))
51
 
52
- examples = [
53
- ['真麻烦你了。希望你们好好的跳无'],
54
- ['少先队员因该为老人让坐'],
55
- ['机七学习是人工智能领遇最能体现智能的一个分知'],
56
- ['今天心情很好'],
57
- ['他法语说的很好,的语也不错'],
58
- ['他们的吵翻很不错,再说他们做的咖喱鸡也好吃'],
59
- ]
60
 
61
- gr.Interface(
62
- ai_text,
63
- inputs="textbox",
64
- outputs=[
65
- gr.outputs.HighlightedText(
66
- label="Output",
67
- show_legend=True,
68
- ),
69
- gr.outputs.JSON()
70
- ],
71
- title="Chinese Spelling Correction Model shibing624/macbert4csc-base-chinese",
72
- description="Copy or input error Chinese text. Submit and the machine will correct text.",
73
- article="Link to <a href='https://github.com/shibing624/pycorrector' style='color:blue;' target='_blank\'>Github REPO</a>",
74
- examples=examples).launch()
 
1
  # coding=utf-8
2
  # author: xusong <xusong28@jd.com>
3
+ # time: 2022/8/23 16:06
4
 
5
+ """
6
+ https://gradio.app/docs/#tabbedinterface-header
7
 
8
+ ## 更多任务
9
+ - 抽取式摘要
10
+ - 检索式对话 、 抽取式问答
11
+ -
12
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ import gradio as gr
15
+ from demo_sum import sum_iface
16
+ from demo_mlm import mlm_iface
17
+ from demo_corrector import corr_iface
18
 
 
 
19
 
20
+ demo = gr.TabbedInterface([sum_iface, mlm_iface, corr_iface], ["生成式摘要", "文本填词", "句子纠错"])
 
 
 
 
 
 
 
21
 
22
+ if __name__ == "__main__":
23
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
app2.py DELETED
@@ -1,22 +0,0 @@
1
- # coding=utf-8
2
- # author: xusong <xusong28@jd.com>
3
- # time: 2022/8/23 16:06
4
-
5
- """
6
- https://gradio.app/docs/#tabbedinterface-header
7
-
8
- ## 更多任务
9
- - 抽取式摘要
10
- - 检索式对话 、 抽取式问答
11
- -
12
- """
13
-
14
- import gradio as gr
15
- from demo_sum import sum_iface
16
- from demo_mlm import mlm_iface
17
-
18
-
19
- demo = gr.TabbedInterface([sum_iface, mlm_iface], ["生成式摘要", "文本填词", "句子纠错"])
20
-
21
- if __name__ == "__main__":
22
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_corrector.py CHANGED
@@ -32,7 +32,7 @@ class KplugCorrector(BertCorrector):
32
  logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
33
 
34
 
35
- # corrector = KplugCorrector()
36
 
37
  error_sentences = [
38
  '少先队员因该为老人让坐',
@@ -49,8 +49,8 @@ def mock_data():
49
 
50
  def correct(sent):
51
 
52
- # corrected_sent, errs = corrector.bert_correct(sent)
53
- corrected_sent, errs = mock_data()
54
  print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
55
  output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
56
  enumerate(errs)]
@@ -69,12 +69,13 @@ corr_iface = gr.Interface(
69
  label="输入文本",
70
  default="少先队员因该为老人让坐"),
71
  outputs=[
72
- gr.HighlightedText(
73
- label="纠错",
74
  show_legend=True,
75
- # visible=False
76
  ),
77
- gr.JSON()
 
 
78
  ],
79
  examples=error_sentences,
80
  title="文本纠错(Corrector)",
 
32
  logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
33
 
34
 
35
+ corrector = KplugCorrector()
36
 
37
  error_sentences = [
38
  '少先队员因该为老人让坐',
 
49
 
50
  def correct(sent):
51
 
52
+ corrected_sent, errs = corrector.bert_correct(sent)
53
+ # corrected_sent, errs = mock_data()
54
  print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
55
  output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in
56
  enumerate(errs)]
 
69
  label="输入文本",
70
  default="少先队员因该为老人让坐"),
71
  outputs=[
72
+ gr.outputs.HighlightedText(
73
+ label="Output",
74
  show_legend=True,
 
75
  ),
76
+ gr.outputs.JSON(
77
+ label="JSON Output"
78
+ )
79
  ],
80
  examples=error_sentences,
81
  title="文本纠错(Corrector)",