Spaces:
Runtime error
Runtime error
import gradio as gr | |
import subprocess | |
def generate_text(length, prefix, temperature,topk,topp,rep): | |
# 构建命令行参数 | |
my_prefix = "--prefix=" + prefix + "," | |
args = ["python", "generate.py", f"--length={int(length)}", f"--nsamples=1", f"--prefix={prefix}", f"--temperature={temperature}",f"--batch_size=1",f"--topk={int(topk)}",f"--topp={topp}",f"--repetition_penalty={rep}","--fast_pattern","--tokenizer_path=./vocab.txt","--model_config=./config.json"] | |
# 执行命令并获取输出 | |
process = subprocess.Popen(args, stdout=subprocess.PIPE) | |
output, error = process.communicate() | |
output = output.decode("utf-8") | |
return output | |
input_length = gr.Slider(label="生成文本长度", minimum=10, maximum=500, value=500,step=10) | |
input_prefix = gr.Textbox(label="起始文本") | |
input_temperature = gr.Slider(label="生成温度", minimum=0, maximum=2, value=1, step=0.01) | |
#input_batchsize = gr.Slider(label="生成的batch size", minimum=1, maximum=1, value=1,step=1) | |
input_topk = gr.Slider(label="最高几选一", minimum=1, maximum=48, value=32, step=1) | |
input_topp = gr.Slider(label="最高积累概率", minimum=0, maximum=1, value=0,step=0.01) | |
input_repeat_penality = gr.Slider(label="重复罚值", minimum=0, maximum=15, value=10,step=0.01) | |
output_text = gr.Textbox(label="生成的文本") | |
title = "GPT2中文文本生成器" | |
description = "cpu推理约1s/字,温度太低基本是无意义字符" | |
gr.Interface(fn=generate_text, inputs=[input_length, input_prefix, input_temperature,input_topk,input_topp,input_repeat_penality], outputs=output_text, title=title, description=description).launch() | |