# coding=utf-8 # author: xusong # time: 2022/8/23 17:08 """ interface = gr.Interface.load( "models/bert-base-uncased", api_key=None, alias="fill-mask" ) """ import gradio as gr from transformers import FillMaskPipeline from transformers import BertTokenizer from modeling_kplug import KplugForMaskedLM model_dir = "models/pretrain/" tokenizer = BertTokenizer.from_pretrained(model_dir) model = KplugForMaskedLM.from_pretrained(model_dir) # fill mask def fill_mask(text): fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) outputs = fill_masker(text) return {i["token_str"]: i["score"] for i in outputs} mlm_examples = [ "这款连[MASK]裙真漂亮", "这是杨[MASK]同款包包,精选优质皮料制作", "美颜去痘洁面[MASK]", ] mlm_iface = gr.Interface( fn=fill_mask, inputs=gr.inputs.Textbox( label="输入文本", default="这款连[MASK]裙真漂亮"), outputs=gr.Label( label="填词", show_label=False, ), examples=mlm_examples, title="文本填词", description='电商领域文本摘要, 基于KPLUG预训练语言模型,' ' K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding' ' and Generation in E-Commerce (Findings of EMNLP 2021) 。' ) if __name__ == "__main__": mlm_iface.launch()