# coding=utf-8 # author: xusong # time: 2022/8/23 17:08 import gradio as gr from transformers import FillMaskPipeline from transformers import BertTokenizer from modeling_kplug import KplugForMaskedLM model_dir = "models/pretrain/" tokenizer = BertTokenizer.from_pretrained(model_dir) model = KplugForMaskedLM.from_pretrained(model_dir) def correct(text): pass # fill mask def fill_mask(text): fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) outputs = fill_masker(text) return {i["token_str"]: i["score"] for i in outputs} mlm_examples = [ "这款连[MASK]裙真漂亮", "这是杨[MASK]同款包包,精选优质皮料制作", "美颜去痘洁面[MASK]", ] mlm_iface = gr.Interface( fn=fill_mask, inputs=gr.inputs.Textbox( label="输入文本", default="这款连[MASK]裙真漂亮"), outputs=gr.Label( label="填词", show_label=False, ), examples=mlm_examples, title="文本填词", description='
电商领域文本摘要, 基于KPLUG预训练语言模型。
' ) if __name__ == "__main__": # fill_mask("这款连[MASK]裙真漂亮") mlm_iface.launch()