Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
from tools import * | |
tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompt-generator-v12") | |
model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompt-generator-v12", from_tf=True) | |
#调用模型自动生成prompt语句 | |
def generate(prompt): | |
""" | |
用于基础的对话功能。,因近期大批量openai账号封禁,直接采用hugging face中的transformers模型本地运行 | |
inputs 是本次问询的输入 | |
""" | |
batch = tokenizer(prompt, return_tensors="pt") | |
generated_ids = model.generate(batch["input_ids"], max_new_tokens=150) | |
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
return output[0] | |
with gr.Blocks() as demo: | |
promptTemplates = gr.State(load_template(get_template_names(plain=True)[0], mode=2)) | |
gr.HTML('<h1>LLM prompt tools (UP)</h1>') | |
with gr.Row(): | |
with gr.Tab(label="Prompt"): | |
# 文本框用于显示prompt语句 | |
systemPromptTxt = gr.Textbox( | |
show_label=True, | |
placeholder=f"在这里输入System Prompt...", | |
label="Result", | |
value="You are a helpful assistant.", | |
lines=10, | |
).style(container=False) | |
# 调用模型自动生成prompt语句 | |
with gr.Column(scale=1): | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="如果模板不能满足,可以在这输入需求,调用模型结果").style(container=False) | |
with gr.Row(): | |
submitBtn = gr.Button("提交", variant="primary") | |
# with gr.Row(): | |
# resetBtn = gr.Button("重置", variant="secondary"); resetBtn.style(size="sm") | |
# stopBtn = gr.Button("停止", variant="secondary"); stopBtn.style(size="sm") | |
with gr.Row(): | |
# with gr.Tab(label="Prompt"): | |
# systemPromptTxt = gr.Textbox( | |
# show_label=True, | |
# placeholder=f"在这里输入System Prompt...", | |
# label="System prompt", | |
# value="You are a helpful assistant.", | |
# lines=10, | |
# ).style(container=False) | |
# 加载现有模板,采用csv便于后期随时修改 | |
with gr.Accordion(label="加载Prompt模板", open=True): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=6): | |
templateFileSelectDropdown = gr.Dropdown( | |
label="选择Prompt模板集合文件", | |
choices=get_template_names(plain=True), | |
multiselect=False, | |
value=get_template_names(plain=True)[0], | |
).style(container=False) | |
with gr.Column(scale=1): | |
templateRefreshBtn = gr.Button("刷新") | |
with gr.Row(): | |
with gr.Column(): | |
templateSelectDropdown = gr.Dropdown( | |
label="从Prompt模板中加载", | |
choices=load_template( | |
get_template_names(plain=True)[0], mode=1 | |
), | |
multiselect=False, | |
value=load_template( | |
get_template_names(plain=True)[0], mode=1 | |
)[0], | |
).style(container=False) | |
# submitBtn.on_click(generate(txt)) | |
# 按提交,调用模板自动生成prompt语句 | |
submitBtn.click(fn=generate, inputs=[txt], outputs=[systemPromptTxt]) | |
# submitBtn.click(generate(txt), None, [systemPromptTxt]) | |
# systemPromptTxt.change( | |
# generate(txt), | |
# show_progress=True, | |
# ) | |
# Templateclickaction | |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown]) | |
templateFileSelectDropdown.change( | |
load_template, | |
[templateFileSelectDropdown], | |
[promptTemplates, templateSelectDropdown], | |
show_progress=True, | |
) | |
templateSelectDropdown.change( | |
get_template_content, | |
[promptTemplates, templateSelectDropdown, systemPromptTxt], | |
[systemPromptTxt], | |
show_progress=True, | |
) | |
if __name__ == '__main__': | |
description = "" | |
demo.launch() |