|
|
|
|
|
|
|
|
|
""" |
|
TODO: 还要能判断是否需要回复。 |
|
""" |
|
|
|
import torch |
|
import gradio as gr |
|
from info import article |
|
from kplug import modeling_kplug_s2s_patch |
|
from transformers import BertTokenizer, BartForConditionalGeneration |
|
|
|
model = BartForConditionalGeneration.from_pretrained("eson/kplug-base-jddc") |
|
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-jddc") |
|
|
|
|
|
def predict(input, history=[]): |
|
""" |
|
拼接方案:直接拼接history作为输入,不区分角色。虽然简单粗糙,但是encoder-decoder架构不会混淆输入和输出(如果是gpt架构就需要区分角色了)。 |
|
""" |
|
|
|
history = history + [input] |
|
|
|
|
|
bot_input_ids = tokenizer.encode("".join(history)[-500:], return_tensors='pt') |
|
|
|
|
|
|
|
|
|
response = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() |
|
|
|
|
|
response = "".join(tokenizer.decode(response[0], skip_special_tokens=True).split()) |
|
history = history + [response] |
|
response = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] |
|
return response, history |
|
|
|
|
|
jddc_examples = [ |
|
|
|
"昨天刚买的怎么就降了几十块,应该补给我差价吧", |
|
|
|
"请问这个猕猴桃是有货的吗?", |
|
|
|
"我下的这个单怎么还没到", |
|
|
|
"发什么快递", |
|
"能发邮政吗", |
|
] |
|
|
|
jddc_iface = gr.Interface( |
|
fn=predict, |
|
|
|
inputs=[ |
|
gr.Textbox( |
|
label="输入文本", |
|
value="发什么快递"), |
|
"state" |
|
], |
|
outputs=["chatbot", "state"], |
|
examples=jddc_examples, |
|
title="电商客服-生成式对话(Response Generation)", |
|
article=article, |
|
) |
|
|
|
if __name__ == "__main__": |
|
jddc_iface.launch() |
|
|