GLM-130B / app.py
hanyullai's picture
Update app.py
253b387
raw
history blame
4.51 kB
import gradio as gr
import requests
import json
import os
APIKEY = os.environ.get("APIKEY")
APISECRET = os.environ.get("APISECRET")
def predict(text, seed, out_seq_length, min_gen_length, sampling_strategy,
num_beams, length_penalty, no_repeat_ngram_size,
temperature, topk, topp):
global APIKEY
global APISECRET
if text == '':
return 'Input should not be empty!'
url = 'https://wudao.aminer.cn/os/api/api/v2/completions_130B'
payload = json.dumps({
"apikey": APIKEY,
"apisecret": APISECRET ,
"language": "zh-CN",
"prompt": text,
"length_penalty": length_penalty,
"temperature": temperature,
"top_k": topk,
"top_p": topp,
"min_gen_length": min_gen_length,
"sampling_strategy": sampling_strategy,
"num_beams": num_beams,
"max_tokens": out_seq_length,
"no_repeat_ngram": no_repeat_ngram_size,
"seed": seed
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload).json()
if response['status'] == 1:
return 'Please give smaller text than max_tokens or give larger max_tokens.'
answer = response['result']['output']['raw']
if isinstance(answer, list):
answer = answer[0]
answer = answer.replace('[</s>]', '')
return answer
if __name__ == "__main__":
en_fil = ['The Starry Night is an oil-on-canvas painting by [MASK] in June 1889.']
en_gen = ['The largest animal in the world is ']
ch_fil = ['凯旋门位于意大利米兰市古城堡旁。1807年为纪念[MASK]而建,门高25米,顶上矗立两武士青铜古兵车铸像。']
ch_gen = ['五岳是指哪五座山?回答:']
examples = [en_fil, en_gen, ch_fil, ch_gen]
with gr.Blocks() as demo:
gr.Markdown(
"""
# GLM-130B
An Open Bilingual Pre-Trained Model
[Visit our github repo](https://github.com/THUDM/GLM-130B)
""")
with gr.Row():
with gr.Column():
model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input')
with gr.Row():
gen = gr.Button("Generate")
clr = gr.Button("Clear")
outputs = gr.Textbox(lines=7, label='Output')
gr.Markdown(
"""
Generation Parameter
""")
with gr.Row():
with gr.Column():
seed = gr.Slider(maximum=100000, value=1234, step=1, label='Seed')
out_seq_length = gr.Slider(maximum=256, value=128, minimum=32, step=1, label='Output Sequence Length')
with gr.Column():
min_gen_length = gr.Slider(maximum=64, value=0, step=1, label='Min Generate Length')
sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BeamSearchStrategy', label='Search Strategy')
with gr.Row():
with gr.Column():
# beam search
gr.Markdown(
"""
BeamSearchStrategy
""")
num_beams = gr.Slider(maximum=4, value=2, minimum=1, step=1, label='Number of Beams')
length_penalty = gr.Slider(maximum=1, value=1, minimum=0, label='Length Penalty')
no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size')
with gr.Column():
# base search
gr.Markdown(
"""
BaseStrategy
""")
temperature = gr.Slider(maximum=1, value=0.7, minimum=0, label='Temperature')
topk = gr.Slider(maximum=40, value=1, minimum=0, step=1, label='Top K')
topp = gr.Slider(maximum=1, value=0, minimum=0, label='Top P')
inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp]
gen.click(fn=predict, inputs=inputs, outputs=outputs)
clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
gr.Markdown("Try this!")
gr_examples = gr.Examples(examples=examples, inputs=model_input)
demo.launch()