hanyullai commited on
Commit
07f3d5b
1 Parent(s): 8d8d984

add app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+
4
+ import requests
5
+ import json
6
+ import os
7
+
8
+ APIKEY = os.environ.get("APIKEY")
9
+ APISECRET = os.environ.get("APISECRET")
10
+
11
+ def predict(text, seed, out_seq_length, min_gen_length, sampling_strategy,
12
+ num_beams, length_penalty, no_repeat_ngram_size,
13
+ temperature, topk, topp):
14
+ global APIKEY
15
+ global APISECRET
16
+ url = 'https://wudao.aminer.cn/os/api/api/v2/completions_130B'
17
+
18
+ payload = json.dumps({
19
+ "apikey": APIKEY,
20
+ "apisecret": APISECRET,
21
+ "language": "zh-CN",
22
+ "prompt": text,
23
+ "length_penalty": length_penalty,
24
+ "temperature": temperature,
25
+ "top_k": topk,
26
+ "top_p": topp,
27
+ "min_gen_length": min_gen_length,
28
+ "sampling_strategy": sampling_strategy,
29
+ "num_beams": num_beams,
30
+ "max_tokens": out_seq_length
31
+ })
32
+
33
+ headers = {
34
+ 'Content-Type': 'application/json'
35
+ }
36
+
37
+ response = requests.request("POST", url, headers=headers, data=payload)
38
+
39
+ print(response.text)
40
+ return ret.text
41
+
42
+
43
+ if __name__ == "__main__":
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown(
46
+ """
47
+ # GLM-130B
48
+ An Open Bilingual Pre-Trained Model
49
+ """)
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ model_input = gr.Textbox(lines=7, placeholder='Input something in English or Chinese', label='Input')
54
+ with gr.Row():
55
+ gen = gr.Button("Generate")
56
+ clr = gr.Button("Clear")
57
+ outputs = gr.Textbox(lines=7, label='Output')
58
+
59
+
60
+ seed = gr.Slider(maximum=100000, value=1234, label='Seed')
61
+ out_seq_length = gr.Slider(maximum=256, value=128, minimum=8, label='Output Sequence Length')
62
+ min_gen_length = gr.Slider(maximum=64, value=0, label='Min Generate Length')
63
+ sampling_strategy = gr.Radio(choices=['BeamSearchStrategy', 'BaseStrategy'], value='BeamSearchStrategy', label='Search Strategy')
64
+
65
+ with gr.Tabs():
66
+ with gr.TabItem("Beam Search Parameter"):
67
+ # beam search
68
+ num_beams = gr.Slider(maximum=4, value=1, minimum=1, step=1, label='Number of Beams')
69
+ length_penalty = gr.Slider(maximum=1, value=0.8, minimum=0, label='Length Penalty')
70
+ no_repeat_ngram_size = gr.Slider(maximum=5, value=3, minimum=1, step=1, label='No Repeat Ngram Size')
71
+ with gr.TabItem("Base Search Parameter"):
72
+ # base search
73
+ temperature = gr.Slider(maximum=1, value=1, minimum=0, label='Temperature')
74
+ topk = gr.Slider(maximum=8, value=1, minimum=1, step=1, label='Top K')
75
+ topp = gr.Slider(maximum=8, value=0, minimum=0, step=1, label='Top P')
76
+
77
+ inputs = [model_input, seed, out_seq_length, min_gen_length, sampling_strategy, num_beams, length_penalty, no_repeat_ngram_size, temperature, topk, topp]
78
+ gen.click(fn=predict, inputs=inputs, outputs=outputs)
79
+ clr.click(fn=lambda value: gr.update(value=""), inputs=clr, outputs=model_input)
80
+
81
+ demo.launch()