|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from transformers import FillMaskPipeline |
|
from transformers import BertTokenizer |
|
from modeling_kplug import KplugForMaskedLM |
|
|
|
model_dir = "models/pretrain/" |
|
tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
model = KplugForMaskedLM.from_pretrained(model_dir) |
|
|
|
|
|
def correct(text): |
|
pass |
|
|
|
|
|
def fill_mask(text): |
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) |
|
outputs = fill_masker(text) |
|
return {i["token_str"]: i["score"] for i in outputs} |
|
|
|
|
|
mlm_examples = [ |
|
"这款连[MASK]裙真漂亮", |
|
"这是杨[MASK]同款包包,精选优质皮料制作", |
|
"美颜去痘洁面[MASK]", |
|
] |
|
|
|
mlm_iface = gr.Interface( |
|
fn=fill_mask, |
|
inputs=gr.inputs.Textbox( |
|
label="输入文本", |
|
default="这款连[MASK]裙真漂亮"), |
|
outputs=gr.Label( |
|
label="填词", |
|
show_label=False, |
|
), |
|
examples=mlm_examples, |
|
title="文本填词", |
|
description='<div>电商领域文本摘要, 基于KPLUG预训练语言模型。</div>' |
|
|
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
mlm_iface.launch() |
|
|