# coding=utf-8 # author: xusong # time: 2022/8/23 17:08 """ https://github.com/gradio-app/gradio/blob/299ba1bd1aed8040b3087c06c10fedf75901f91f/gradio/external.py#L484 interface = gr.Interface.load( "models/bert-base-uncased", api_key=None, alias="fill-mask" ) ## TODO: 1. json_output 2. 百分数换成小数 3. """ import gradio as gr from info import article from transformers import FillMaskPipeline from transformers import BertTokenizer from kplug.modeling_kplug import KplugForMaskedLM tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") # fill mask def fill_mask(text): fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) outputs = fill_masker(text) return {i["token_str"]: i["score"] for i in outputs} mlm_examples = [ "这款连[MASK]裙真漂亮", "这是杨[MASK]同款包包,精选优质皮料制作", "美颜去痘洁面[MASK]", ] mlm_iface = gr.Interface( fn=fill_mask, inputs=gr.Textbox( label="输入文本", value="这款连[MASK]裙真漂亮"), # outputs=gr.Label( outputs=gr.Label( label="填词", ), examples=mlm_examples, title="文本填词(Fill Mask)", description='基于KPLUG预训练语言模型', article=article ) if __name__ == "__main__": mlm_iface.launch()