xusong28 commited on
Commit
a39e93b
1 Parent(s): 2bb0b26
Files changed (4) hide show
  1. app.py +5 -2
  2. demo_corrector.py +46 -0
  3. demo_mlm.py +8 -2
  4. demo_sum.py +11 -7
app.py CHANGED
@@ -5,6 +5,10 @@
5
  """
6
  https://gradio.app/docs/#tabbedinterface-header
7
 
 
 
 
 
8
  """
9
 
10
  import gradio as gr
@@ -12,8 +16,7 @@ from demo_sum import sum_iface
12
  from demo_mlm import mlm_iface
13
 
14
 
15
-
16
- demo = gr.TabbedInterface([sum_iface, mlm_iface], ["商品摘要", "文本填词", "句子纠错"])
17
 
18
  if __name__ == "__main__":
19
  demo.launch()
 
5
  """
6
  https://gradio.app/docs/#tabbedinterface-header
7
 
8
+ ## 更多任务
9
+ - 抽取式摘要
10
+ - 检索式对话 、 抽取式问答
11
+ -
12
  """
13
 
14
  import gradio as gr
 
16
  from demo_mlm import mlm_iface
17
 
18
 
19
+ demo = gr.TabbedInterface([sum_iface, mlm_iface], ["生成式摘要", "文本填词", "句子纠错"])
 
20
 
21
  if __name__ == "__main__":
22
  demo.launch()
demo_corrector.py CHANGED
@@ -1,3 +1,49 @@
1
  # coding=utf-8
2
  # author: xusong <xusong28@jd.com>
3
  # time: 2022/8/23 17:08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
  # author: xusong <xusong28@jd.com>
3
  # time: 2022/8/23 17:08
4
+
5
+
6
+ import gradio as gr
7
+ from transformers import FillMaskPipeline
8
+ from transformers import BertTokenizer
9
+ from modeling_kplug import KplugForMaskedLM
10
+
11
+ model_dir = "models/pretrain/"
12
+ tokenizer = BertTokenizer.from_pretrained(model_dir)
13
+ model = KplugForMaskedLM.from_pretrained(model_dir)
14
+
15
+
16
+ def correct(text):
17
+ pass
18
+
19
+ # fill mask
20
+ def fill_mask(text):
21
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
22
+ outputs = fill_masker(text)
23
+ return {i["token_str"]: i["score"] for i in outputs}
24
+
25
+
26
+ mlm_examples = [
27
+ "这款连[MASK]裙真漂亮",
28
+ "这是杨[MASK]同款包包,精选优质皮料制作",
29
+ "美颜去痘洁面[MASK]",
30
+ ]
31
+
32
+ mlm_iface = gr.Interface(
33
+ fn=fill_mask,
34
+ inputs=gr.inputs.Textbox(
35
+ label="输入文本",
36
+ default="这款连[MASK]裙真漂亮"),
37
+ outputs=gr.Label(
38
+ label="填词",
39
+ show_label=False,
40
+ ),
41
+ examples=mlm_examples,
42
+ title="文本填词",
43
+ description='<div>电商领域文本摘要, 基于KPLUG预训练语言模型。</div>'
44
+
45
+ )
46
+
47
+ if __name__ == "__main__":
48
+ # fill_mask("这款连[MASK]裙真漂亮")
49
+ mlm_iface.launch()
demo_mlm.py CHANGED
@@ -5,10 +5,16 @@
5
  """
6
 
7
 
 
8
 
9
  interface = gr.Interface.load(
10
  "models/bert-base-uncased", api_key=None, alias="fill-mask"
11
  )
 
 
 
 
 
12
  """
13
 
14
  import gradio as gr
@@ -44,8 +50,8 @@ mlm_iface = gr.Interface(
44
  show_label=False,
45
  ),
46
  examples=mlm_examples,
47
- title="文本填词",
48
- description='电商领域文本摘要, 基于KPLUG预训练语言模型,'
49
  '<a href=""> K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding'
50
  ' and Generation in E-Commerce (Findings of EMNLP 2021) </a>。'
51
  )
 
5
  """
6
 
7
 
8
+ https://github.com/gradio-app/gradio/blob/299ba1bd1aed8040b3087c06c10fedf75901f91f/gradio/external.py#L484
9
 
10
  interface = gr.Interface.load(
11
  "models/bert-base-uncased", api_key=None, alias="fill-mask"
12
  )
13
+ ## TODO:
14
+
15
+ 1. json_output
16
+ 2. 百分数换成小数
17
+ 3.
18
  """
19
 
20
  import gradio as gr
 
50
  show_label=False,
51
  ),
52
  examples=mlm_examples,
53
+ title="文本填词(Fill Mask)",
54
+ description='基于KPLUG预训练语言模型,'
55
  '<a href=""> K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding'
56
  ' and Generation in E-Commerce (Findings of EMNLP 2021) </a>。'
57
  )
demo_sum.py CHANGED
@@ -20,6 +20,7 @@ model_dir = "models/ft_cepsum_jiadian/"
20
  model = BartForConditionalGeneration.from_pretrained(model_dir) # cnn指的是cnn daily mail
21
  tokenizer = BertTokenizer.from_pretrained(model_dir)
22
 
 
23
  def summarize(text):
24
  inputs = tokenizer([text], max_length=512, return_tensors="pt")
25
  summary_ids = model.generate(inputs["input_ids"][:, 1:], num_beams=4, min_length=20, max_length=100)
@@ -27,7 +28,7 @@ def summarize(text):
27
  return summary[0]
28
 
29
 
30
- #TODO:
31
  # 1. 下拉框,选择类目。 gr.Radio(['服饰','箱包', '鞋靴']
32
  # 2. 支持NER、LM、Corrector
33
  # beam seach参数
@@ -57,13 +58,16 @@ sum_iface = gr.Interface(
57
  "建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
58
  "门俯视图,全开门俯视图,预留参考图"),
59
  outputs=gr.Textbox(
60
- label="文本摘要(Summarization)",
61
- ),
 
62
  examples=sum_examples,
63
- title="电商文本摘要(Abstractive Summarization)",
64
- description='电商领域文本摘要, 基于KPLUG预训练语言模型,'
65
- '<a href=""> K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding'
66
- ' and Generation in E-Commerce (Findings of EMNLP 2021) </a>。更多样例,见 https://github.com/xu-song/k-plug/tree/master/data_sample/sum/cepsum/jiadian/raw '
 
 
67
  )
68
 
69
  if __name__ == "__main__":
 
20
  model = BartForConditionalGeneration.from_pretrained(model_dir) # cnn指的是cnn daily mail
21
  tokenizer = BertTokenizer.from_pretrained(model_dir)
22
 
23
+
24
  def summarize(text):
25
  inputs = tokenizer([text], max_length=512, return_tensors="pt")
26
  summary_ids = model.generate(inputs["input_ids"][:, 1:], num_beams=4, min_length=20, max_length=100)
 
28
  return summary[0]
29
 
30
 
31
+ # TODO:
32
  # 1. 下拉框,选择类目。 gr.Radio(['服饰','箱包', '鞋靴']
33
  # 2. 支持NER、LM、Corrector
34
  # beam seach参数
 
58
  "建议左右两侧、顶部和背部需要预留10C,电源线和调平脚等。冰箱放置时为保证,菜谱推荐,半开"
59
  "门俯视图,全开门俯视图,预留参考图"),
60
  outputs=gr.Textbox(
61
+ label="文本摘要(Summarization)",
62
+ lines=4,
63
+ ),
64
  examples=sum_examples,
65
+ title="生成式摘要(Abstractive Summarization)",
66
+ description='<div>这是一个生成式摘要的demo,用于电商领域的商品营销文案写作。'
67
+ '该demo基于KPLUG预训练语言模型,输入商品信息,输出商品的营销文案。</div>'
68
+ '<div> Paper: <a href="https://aclanthology.org/2021.findings-emnlp.1/"> K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding'
69
+ ' and Generation in E-Commerce (Findings of EMNLP 2021) </a> </div>'
70
+ '<div>Github: <a href="https://github.com/xu-song/k-plug">https://github.com/xu-song/k-plug </a> </div>'
71
  )
72
 
73
  if __name__ == "__main__":