|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
https://github.com/gradio-app/gradio/blob/299ba1bd1aed8040b3087c06c10fedf75901f91f/gradio/external.py#L484 |
|
|
|
interface = gr.Interface.load( |
|
"models/bert-base-uncased", api_key=None, alias="fill-mask" |
|
) |
|
## TODO: |
|
|
|
1. json_output |
|
2. 百分数换成小数 |
|
3. |
|
""" |
|
|
|
import gradio as gr |
|
from info import article |
|
from transformers import FillMaskPipeline |
|
from transformers import BertTokenizer |
|
from kplug.modeling_kplug import KplugForMaskedLM |
|
|
|
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") |
|
model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") |
|
|
|
|
|
def fill_mask(text): |
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) |
|
outputs = fill_masker(text) |
|
return {i["token_str"]: i["score"] for i in outputs} |
|
|
|
|
|
mlm_examples = [ |
|
"这款连[MASK]裙真漂亮", |
|
"这是杨[MASK]同款包包,精选优质皮料制作", |
|
"美颜去痘洁面[MASK]", |
|
] |
|
|
|
mlm_iface = gr.Interface( |
|
fn=fill_mask, |
|
inputs=gr.Textbox( |
|
label="输入文本", |
|
value="这款连[MASK]裙真漂亮"), |
|
|
|
outputs=gr.Label( |
|
label="填词", |
|
), |
|
examples=mlm_examples, |
|
title="文本填词(Fill Mask)", |
|
description='基于KPLUG预训练语言模型', |
|
article=article |
|
) |
|
|
|
if __name__ == "__main__": |
|
mlm_iface.launch() |
|
|