Files changed (1) hide show
  1. app.py +179 -76
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import clueai
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
6
  tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
7
  model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
8
  # 使用
@@ -11,76 +12,123 @@ model.to(device)
11
  model.half()
12
 
13
  base_info = ""
 
 
14
  def preprocess(text):
15
- text = f"{base_info}{text}"
16
- text = text.replace("\n", "\\n").replace("\t", "\\t")
17
- return text
 
18
 
19
  def postprocess(text):
20
- return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ')#.replace(" ", " ")
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- generate_config = {'do_sample': True, 'top_p': 0.9, 'top_k': 50, 'temperature': 0.7,
25
- 'num_beams': 1, 'max_length': 1024, 'min_length': 3, 'no_repeat_ngram_size': 5,
26
- 'length_penalty': 0.6, 'return_dict_in_generate': True, 'output_scores': True}
27
- def answer(text, sample=True, top_p=0.9, temperature=0.7):
28
- '''sample:是否抽样。生成任务,可以设置为True;
 
 
29
  top_p:0-1之间,生成的内容越多样'''
30
- text = preprocess(text)
31
- encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=1024, return_tensors="pt").to(device)
32
- if not sample:
33
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, num_beams=1, length_penalty=0.6)
34
- else:
35
- out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=1024, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=12)
36
- #out=model.generate(**encoding, **generate_config)
37
- out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
38
- return postprocess(out_text[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def clear_session():
41
  return '', None
42
 
43
- def chatyuan_bot(input, history):
 
44
  history = history or []
45
- if len(history) > 5:
46
- history = history[-5:]
47
 
48
- context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
 
 
 
49
  #print(context)
50
 
51
  input_text = context + "\n用户:" + input + "\n小元:"
52
  input_text = input_text.strip()
53
- output_text = answer(input_text)
54
  print("open_model".center(20, "="))
55
  print(f"{input_text}\n{output_text}")
56
  #print("="*20)
57
  history.append((input, output_text))
58
  #print(history)
59
- return history, history
60
- def chatyuan_bot_regenerate(input, history):
61
-
 
 
62
  history = history or []
63
-
64
  if history:
65
- input=history[-1][0]
66
- history=history[:-1]
67
-
68
-
69
- if len(history) > 5:
70
- history = history[-5:]
71
 
72
- context = "\n".join([f"用户���{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
 
 
 
73
  #print(context)
74
 
75
  input_text = context + "\n用户:" + input + "\n小元:"
76
  input_text = input_text.strip()
77
- output_text = answer(input_text)
78
  print("open_model".center(20, "="))
79
  print(f"{input_text}\n{output_text}")
80
  history.append((input, output_text))
81
  #print(history)
82
- return history, history
83
-
 
84
  block = gr.Blocks()
85
 
86
  with block as demo:
@@ -88,27 +136,58 @@ with block as demo:
88
  <font size=4>回答来自ChatYuan, 是模型生成的结果, 请谨慎辨别和参考, 不代表任何人观点 | Answer generated by ChatYuan model</font>
89
  <font size=4>注意:gradio对markdown代码格式展示有限</font>
90
  """)
91
- chatbot = gr.Chatbot(label='ChatYuan')
92
- message = gr.Textbox()
93
- state = gr.State()
94
- message.submit(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
95
  with gr.Row():
96
- clear_history = gr.Button("👋 清除历史对话 | Clear History")
97
- clear = gr.Button('🧹 清除发送框 | Clear Input')
98
- send = gr.Button("🚀 发送 | Send")
99
- regenerate = gr.Button("🚀 重新生成本次结果 | regenerate")
100
 
 
101
 
102
- regenerate.click(chatyuan_bot_regenerate, inputs=[message, state], outputs=[chatbot, state])
103
- send.click(chatyuan_bot, inputs=[message, state], outputs=[chatbot, state])
104
- clear.click(lambda: None, None, message, queue=False)
105
- clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
106
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- def ChatYuan(api_key, text_prompt):
 
 
 
109
 
110
- cl = clueai.Client(api_key,
111
- check_api_key=True)
 
 
 
 
 
 
 
 
 
112
  # generate a prediction for a prompt
113
  # 需要返回得分的话,指定return_likelihoods="GENERATION"
114
  prediction = cl.generate(model_name='ChatYuan-large', prompt=text_prompt)
@@ -119,26 +198,28 @@ def ChatYuan(api_key, text_prompt):
119
  response = "很抱歉,我无法回答这个问题"
120
 
121
  return response
122
-
123
- def chatyuan_bot_api(api_key, input, history):
 
124
  history = history or []
125
 
126
- if len(history) > 5:
127
- history = history[-5:]
128
 
129
- context = "\n".join([f"用户:{input_text}\n小元:{answer_text}" for input_text, answer_text in history])
130
- #print(context)
 
 
131
 
132
  input_text = context + "\n用户:" + input + "\n小元:"
133
  input_text = input_text.strip()
134
- output_text = ChatYuan(api_key, input_text)
135
  print("api".center(20, "="))
136
  print(f"api_key:{api_key}\n{input_text}\n{output_text}")
137
- #print("="*20)
138
  history.append((input, output_text))
139
- #print(history)
140
- return history, history
141
 
 
142
 
143
 
144
  block = gr.Blocks()
@@ -149,19 +230,40 @@ with block as demo_1:
149
  <font size=4>注意:gradio对markdown代码格式展示有限</font>
150
  <font size=4>在使用此功能前,你需要有个API key. API key 可以通过这个<a href='https://www.clueai.cn/' target="_blank">平台</a>获取</font>
151
  """)
152
- api_key = gr.inputs.Textbox(label="请输入你的api-key(必填)", default="", type='password')
153
- chatbot = gr.Chatbot(label='ChatYuan')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  message = gr.Textbox()
155
  state = gr.State()
156
- message.submit(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state])
157
- with gr.Row():
158
- clear_history = gr.Button("👋 清除历史对话 | Clear Context")
159
- clear = gr.Button('🧹 清除发送框 | Clear Input')
160
- send = gr.Button("🚀 发送 | Send")
161
 
162
- send.click(chatyuan_bot_api, inputs=[api_key,message, state], outputs=[chatbot, state])
163
- clear.click(lambda: None, None, message, queue=False)
164
- clear_history.click(fn=clear_session , inputs=[], outputs=[chatbot, state], queue=False)
 
 
 
 
165
 
166
  block = gr.Blocks()
167
  with block as introduction:
@@ -202,6 +304,7 @@ Based on the original functions of Chatyuan-large-v1, we optimized the model as
202
  <center><a href="https://clustrmaps.com/site/1bts0" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=ycVCe17noTYFDs30w7AmkFaE-TwabMBukDP1802_Lts&cl=ffffff" /></a></center>
203
  """)
204
 
205
-
206
- gui = gr.TabbedInterface(interface_list=[introduction,demo, demo_1], tab_names=["相关介绍Introduction","开源模型Online Demo", "API调用"])
207
- gui.launch(quiet=True,show_api=False, share = False)
 
 
3
  import clueai
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
+
7
  tokenizer = T5Tokenizer.from_pretrained("ClueAI/ChatYuan-large-v2")
8
  model = T5ForConditionalGeneration.from_pretrained("ClueAI/ChatYuan-large-v2")
9
  # 使用
 
12
  model.half()
13
 
14
  base_info = ""
15
+
16
+
17
  def preprocess(text):
18
+ text = f"{base_info}{text}"
19
+ text = text.replace("\n", "\\n").replace("\t", "\\t")
20
+ return text
21
+
22
 
23
  def postprocess(text):
24
+ return text.replace("\\n", "\n").replace("\\t", "\t").replace(
25
+ '%20', ' ') #.replace(" ", "&nbsp;")
26
+
27
 
28
+ generate_config = {
29
+ 'do_sample': True,
30
+ 'top_p': 0.9,
31
+ 'top_k': 50,
32
+ 'temperature': 0.7,
33
+ 'num_beams': 1,
34
+ 'max_length': 1024,
35
+ 'min_length': 3,
36
+ 'no_repeat_ngram_size': 5,
37
+ 'length_penalty': 0.6,
38
+ 'return_dict_in_generate': True,
39
+ 'output_scores': True
40
+ }
41
 
42
 
43
+ def answer(
44
+ text,
45
+ top_p,
46
+ temperature,
47
+ sample=True,
48
+ ):
49
+ '''sample:是否抽样。生成任务,可以设置为True;
50
  top_p:0-1之间,生成的内容越多样'''
51
+ text = preprocess(text)
52
+ encoding = tokenizer(text=[text],
53
+ truncation=True,
54
+ padding=True,
55
+ max_length=1024,
56
+ return_tensors="pt").to(device)
57
+ if not sample:
58
+ out = model.generate(**encoding,
59
+ return_dict_in_generate=True,
60
+ output_scores=False,
61
+ max_new_tokens=1024,
62
+ num_beams=1,
63
+ length_penalty=0.6)
64
+ else:
65
+ out = model.generate(**encoding,
66
+ return_dict_in_generate=True,
67
+ output_scores=False,
68
+ max_new_tokens=1024,
69
+ do_sample=True,
70
+ top_p=top_p,
71
+ temperature=temperature,
72
+ no_repeat_ngram_size=12)
73
+ #out=model.generate(**encoding, **generate_config)
74
+ out_text = tokenizer.batch_decode(out["sequences"],
75
+ skip_special_tokens=True)
76
+ return postprocess(out_text[0])
77
+
78
 
79
  def clear_session():
80
  return '', None
81
 
82
+
83
+ def chatyuan_bot(input, history, top_p, temperature, num):
84
  history = history or []
85
+ if len(history) > num:
86
+ history = history[-num:]
87
 
88
+ context = "\n".join([
89
+ f"用户:{input_text}\n小元:{answer_text}"
90
+ for input_text, answer_text in history
91
+ ])
92
  #print(context)
93
 
94
  input_text = context + "\n用户:" + input + "\n小元:"
95
  input_text = input_text.strip()
96
+ output_text = answer(input_text, top_p, temperature)
97
  print("open_model".center(20, "="))
98
  print(f"{input_text}\n{output_text}")
99
  #print("="*20)
100
  history.append((input, output_text))
101
  #print(history)
102
+ return '', history, history
103
+
104
+
105
+ def chatyuan_bot_regenerate(input, history, top_p, temperature, num):
106
+
107
  history = history or []
108
+
109
  if history:
110
+ input = history[-1][0]
111
+ history = history[:-1]
112
+
113
+ if len(history) > num:
114
+ history = history[-num:]
 
115
 
116
+ context = "\n".join([
117
+ f"用户:{input_text}\n小元:{answer_text}"
118
+ for input_text, answer_text in history
119
+ ])
120
  #print(context)
121
 
122
  input_text = context + "\n用户:" + input + "\n小元:"
123
  input_text = input_text.strip()
124
+ output_text = answer(input_text, top_p, temperature)
125
  print("open_model".center(20, "="))
126
  print(f"{input_text}\n{output_text}")
127
  history.append((input, output_text))
128
  #print(history)
129
+ return '', history, history
130
+
131
+
132
  block = gr.Blocks()
133
 
134
  with block as demo:
 
136
  <font size=4>回答来自ChatYuan, 是模型生成的结果, 请谨慎辨别和参考, 不代表任何人观点 | Answer generated by ChatYuan model</font>
137
  <font size=4>注意:gradio对markdown代码格式展示有限</font>
138
  """)
 
 
 
 
139
  with gr.Row():
140
+ with gr.Column(scale=3):
141
+ chatbot = gr.Chatbot(label='ChatYuan').style(height=400)
 
 
142
 
143
+ with gr.Column(scale=1):
144
 
145
+ num = gr.Slider(minimum=4,
146
+ maximum=10,
147
+ label="最大的对话轮数",
148
+ value=5,
149
+ step=1)
150
+ top_p = gr.Slider(minimum=0,
151
+ maximum=1,
152
+ label="top_p",
153
+ value=1,
154
+ step=0.1)
155
+ temperature = gr.Slider(minimum=0,
156
+ maximum=1,
157
+ label="temperature",
158
+ value=0.7,
159
+ step=0.1)
160
+ clear_history = gr.Button("👋 清除历史对话 | Clear History")
161
+ send = gr.Button("🚀 发送 | Send")
162
+ regenerate = gr.Button("🚀 重新生成本次结果 | regenerate")
163
+ message = gr.Textbox()
164
+ state = gr.State()
165
+ message.submit(chatyuan_bot,
166
+ inputs=[message, state, top_p, temperature, num],
167
+ outputs=[message, chatbot, state])
168
+ regenerate.click(chatyuan_bot_regenerate,
169
+ inputs=[message, state, top_p, temperature, num],
170
+ outputs=[message, chatbot, state])
171
+ send.click(chatyuan_bot,
172
+ inputs=[message, state, top_p, temperature, num],
173
+ outputs=[message, chatbot, state])
174
 
175
+ clear_history.click(fn=clear_session,
176
+ inputs=[],
177
+ outputs=[chatbot, state],
178
+ queue=False)
179
 
180
+
181
+ def ChatYuan(api_key, text_prompt, top_p):
182
+ generate_config = {
183
+ "do_sample": True,
184
+ "top_p": top_p,
185
+ "max_length": 128,
186
+ "min_length": 10,
187
+ "length_penalty": 1.0,
188
+ "num_beams": 1
189
+ }
190
+ cl = clueai.Client(api_key, check_api_key=True)
191
  # generate a prediction for a prompt
192
  # 需要返回得分的话,指定return_likelihoods="GENERATION"
193
  prediction = cl.generate(model_name='ChatYuan-large', prompt=text_prompt)
 
198
  response = "很抱歉,我无法回答这个问题"
199
 
200
  return response
201
+
202
+
203
+ def chatyuan_bot_api(api_key, input, history, top_p, num):
204
  history = history or []
205
 
206
+ if len(history) > num:
207
+ history = history[-num:]
208
 
209
+ context = "\n".join([
210
+ f"用户:{input_text}\n小元:{answer_text}"
211
+ for input_text, answer_text in history
212
+ ])
213
 
214
  input_text = context + "\n用户:" + input + "\n小元:"
215
  input_text = input_text.strip()
216
+ output_text = ChatYuan(api_key, input_text, top_p)
217
  print("api".center(20, "="))
218
  print(f"api_key:{api_key}\n{input_text}\n{output_text}")
219
+
220
  history.append((input, output_text))
 
 
221
 
222
+ return '', history, history
223
 
224
 
225
  block = gr.Blocks()
 
230
  <font size=4>注意:gradio对markdown代码格式展示有限</font>
231
  <font size=4>在使用此功能前,你需要有个API key. API key 可以通过这个<a href='https://www.clueai.cn/' target="_blank">平台</a>获取</font>
232
  """)
233
+ with gr.Row():
234
+ with gr.Column(scale=3):
235
+ chatbot = gr.Chatbot(label='ChatYuan').style(height=400)
236
+
237
+ with gr.Column(scale=1):
238
+ api_key = gr.inputs.Textbox(label="请输入你的api-key(必填)",
239
+ default="",
240
+ type='password')
241
+ num = gr.Slider(minimum=4,
242
+ maximum=10,
243
+ label="最大的对话轮数",
244
+ value=5,
245
+ step=1)
246
+ top_p = gr.Slider(minimum=0,
247
+ maximum=1,
248
+ label="top_p",
249
+ value=1,
250
+ step=0.1)
251
+ clear_history = gr.Button("👋 清除历史对话 | Clear History")
252
+ send = gr.Button("🚀 发送 | Send")
253
+
254
  message = gr.Textbox()
255
  state = gr.State()
256
+ message.submit(chatyuan_bot_api,
257
+ inputs=[api_key, message, state, top_p, num],
258
+ outputs=[message, chatbot, state])
 
 
259
 
260
+ send.click(chatyuan_bot_api,
261
+ inputs=[api_key, message, state, top_p, num],
262
+ outputs=[message, chatbot, state])
263
+ clear_history.click(fn=clear_session,
264
+ inputs=[],
265
+ outputs=[chatbot, state],
266
+ queue=False)
267
 
268
  block = gr.Blocks()
269
  with block as introduction:
 
304
  <center><a href="https://clustrmaps.com/site/1bts0" title="Visit tracker"><img src="//www.clustrmaps.com/map_v2.png?d=ycVCe17noTYFDs30w7AmkFaE-TwabMBukDP1802_Lts&cl=ffffff" /></a></center>
305
  """)
306
 
307
+ gui = gr.TabbedInterface(
308
+ interface_list=[introduction, demo, demo_1],
309
+ tab_names=["相关介绍 | Introduction", "开源模型 | Online Demo", "API调用"])
310
+ gui.launch(quiet=True, show_api=False, share=False)