Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import json | |
import os | |
APIKEY = os.environ.get("APIKEY") | |
APISECRET = os.environ.get("APISECRET") | |
def predict(text, seed, out_seq_length, min_gen_length, sampling_strategy, | |
num_beams, length_penalty, no_repeat_ngram_size, | |
temperature, topk, topp): | |
global APIKEY | |
global APISECRET | |
if text == '': | |
return 'Input should not be empty!' | |
url = 'https://wudao.aminer.cn/os/api/api/v2/completions_130B' | |
payload = json.dumps({ | |
"apikey": APIKEY, | |
"apisecret": APISECRET , | |
"language": "zh-CN", | |
"prompt": text, | |
"length_penalty": length_penalty, | |
"temperature": temperature, | |
"top_k": topk, | |
"top_p": topp, | |
"min_gen_length": min_gen_length, | |
"sampling_strategy": sampling_strategy, | |
"num_beams": num_beams, | |
"max_tokens": out_seq_length, | |
"no_repeat_ngram": no_repeat_ngram_size, | |
"seed": seed | |
}) | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.request("POST", url, headers=headers, data=payload).json() | |
if response['status'] == 1: | |
return 'Please give smaller text than max_tokens or give larger max_tokens.' | |
answer = response['result']['output']['raw'] | |
if isinstance(answer, list): | |
answer = answer[0] | |
answer = answer.replace('[</s>]', '') | |
return answer | |
if __name__ == "__main__": | |
en_fil = ['The Starry Night is an oil-on-canvas painting by [MASK] in June 1889.'] | |
en_gen = ['The largest animal in the world is '] | |
ch_fil = ['凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。'] | |
ch_gen = ['五岳是指哪五座山?回答:'] | |
examples = [en_fil, en_gen, ch_fil, ch_gen] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# GLM-130B | |
An Open Bilingual Pre-Trained Model | |
[Visit our github repo](https://github.com/THUDM/GLM-130B) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input') | |
with gr.Row(): | |
gen = gr.Button("Generate") | |
clr = gr.Button("Clear") | |
outputs = gr.Textbox(lines=7, label='Output') | |
gr.Markdown( | |
""" | |
Generation Parameter | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
seed = gr.Slider(maximum=100000, value=1234, step=1, label='Seed') | |
out_seq_length = gr.Slider(maximum=256, value=128, minimum=32, step=1, label='Output Sequence Length') | |
with gr.Column(): | |
min_gen_length = gr.Slider(maximum=64, value=0, step=1, label='Min Generate Length') | |
sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BeamSearchStrategy', label='Search Strategy') | |
with gr.Row(): | |
with gr.Column(): | |
# beam search | |
gr.Markdown( | |
""" | |
BeamSearchStrategy | |
""") | |
num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label='Number of Beams') | |
length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label='Length Penalty') | |
no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size') | |
with gr.Column(): | |
# base search | |
gr.Markdown( | |
""" | |
BaseStrategy | |
""") | |
temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label='Temperature') | |
topk = gr.Slider(maximum=40, value=1, minimum=0, step=1, label='Top K') | |
topp = gr.Slider(maximum=1, value=0, minimum=0, label='Top P') | |
inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp] | |
gen.click(fn=predict, inputs=inputs, outputs=outputs) | |
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input) | |
gr.Markdown("Try this!") | |
gr_examples = gr.Examples(examples=examples, inputs=model_input) | |
demo.launch() |