markqiu commited on
Commit
9d0007e
1 Parent(s): 769b229

Create ERNIE-Bot-SDK.py

Browse files
Files changed (1) hide show
  1. ERNIE-Bot-SDK.py +737 -0
ERNIE-Bot-SDK.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import argparse
18
+ import math
19
+ import os
20
+ import time
21
+ from collections.abc import Iterator
22
+ from typing import List
23
+
24
+ import faiss
25
+ import gradio as gr
26
+ import numpy as np
27
+ import requests
28
+ from tqdm import tqdm
29
+
30
+ import erniebot as eb
31
+
32
+
33
+ def parse_setup_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--port", type=int, default=8073)
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def create_ui_and_launch(args):
41
+ with gr.Blocks(title="ERNIE Bot SDK Demos", theme=gr.themes.Soft()) as blocks:
42
+ gr.Markdown("# ERNIE Bot SDK基础功能演示")
43
+ create_chat_completion_tab()
44
+ create_embedding_tab()
45
+ create_image_tab()
46
+ create_rag_tab()
47
+
48
+ blocks.launch(server_name="0.0.0.0", server_port=args.port)
49
+
50
+
51
+ def create_chat_completion_tab():
52
+ def _infer(
53
+ ernie_model, content, state, top_p, temperature, api_type, access_key, secret_key, access_token
54
+ ):
55
+ access_key = access_key.strip()
56
+ secret_key = secret_key.strip()
57
+ access_token = access_token.strip()
58
+
59
+ if (access_key == "" or secret_key == "") and access_token == "":
60
+ raise gr.Error("需要填写正确的AK/SK或access token,不能为空")
61
+ if content.strip() == "":
62
+ raise gr.Error("输入不能为空,请在清空后重试")
63
+
64
+ auth_config = {
65
+ "api_type": api_type,
66
+ }
67
+ if access_key:
68
+ auth_config["ak"] = access_key
69
+ if secret_key:
70
+ auth_config["sk"] = secret_key
71
+ if access_token:
72
+ auth_config["access_token"] = access_token
73
+
74
+ content = content.strip().replace("<br>", "")
75
+ context = state.setdefault("context", [])
76
+ context.append({"role": "user", "content": content})
77
+ data = {
78
+ "messages": context,
79
+ "top_p": top_p,
80
+ "temperature": temperature,
81
+ }
82
+
83
+ if ernie_model == "chat_file":
84
+ response = eb.ChatFile.create(_config_=auth_config, **data, stream=False)
85
+ else:
86
+ response = eb.ChatCompletion.create(
87
+ _config_=auth_config, model=ernie_model, **data, stream=False
88
+ )
89
+
90
+ bot_response = response.result
91
+ context.append({"role": "assistant", "content": bot_response})
92
+ history = _get_history(context)
93
+ return None, history, context, state
94
+
95
+ def _regen_response(
96
+ ernie_model, state, top_p, temperature, api_type, access_key, secret_key, access_token
97
+ ):
98
+ """Regenerate response."""
99
+ context = state.setdefault("context", [])
100
+ if len(context) < 2:
101
+ raise gr.Error("请至少进行一轮对话")
102
+ context.pop()
103
+ user_message = context.pop()
104
+ return _infer(
105
+ ernie_model,
106
+ user_message["content"],
107
+ state,
108
+ top_p,
109
+ temperature,
110
+ api_type,
111
+ access_key,
112
+ secret_key,
113
+ access_token,
114
+ )
115
+
116
+ def _rollback(state):
117
+ """Roll back context."""
118
+ context = state.setdefault("context", [])
119
+ content = context[-2]["content"]
120
+ context = context[:-2]
121
+ state["context"] = context
122
+ history = _get_history(context)
123
+ return content, history, context, state
124
+
125
+ def _get_history(context):
126
+ history = []
127
+ for turn_idx in range(0, len(context), 2):
128
+ history.append([context[turn_idx]["content"], context[turn_idx + 1]["content"]])
129
+
130
+ return history
131
+
132
+ with gr.Tab("对话补全(Chat Completion)") as chat_completion_tab:
133
+ with gr.Row():
134
+ with gr.Column(scale=1):
135
+ api_type = gr.Dropdown(
136
+ label="API Type", info="提供对话能力的后端平台", value="qianfan", choices=["qianfan", "aistudio"]
137
+ )
138
+ access_key = gr.Textbox(
139
+ label="AK", info="用于访问后端平台的AK,如果设置了access token则无需设置此参数", type="password"
140
+ )
141
+ secret_key = gr.Textbox(
142
+ label="SK", info="用于访问后端平台的SK,如果设置了access token则无需设置此参数", type="password"
143
+ )
144
+ access_token = gr.Textbox(
145
+ label="Access Token", info="用于��问后端平台的access token,如果设置了AK、SK则无需设置此参数", type="password"
146
+ )
147
+ ernie_model = gr.Dropdown(
148
+ label="Model", info="模型类型", value="ernie-bot", choices=["ernie-bot", "ernie-bot-turbo"]
149
+ )
150
+ top_p = gr.Slider(
151
+ label="Top-p", info="控制采样范围,该参数越小生成结果越稳定", value=0.7, minimum=0, maximum=1, step=0.05
152
+ )
153
+ temperature = gr.Slider(
154
+ label="Temperature",
155
+ info="控制采样随机性,该参数越小生成结果越稳定",
156
+ value=0.95,
157
+ minimum=0.05,
158
+ maximum=1,
159
+ step=0.05,
160
+ )
161
+ with gr.Column(scale=4):
162
+ state = gr.State({})
163
+ context_chatbot = gr.Chatbot(label="对话历史")
164
+ input_text = gr.Textbox(label="消息内容", placeholder="请输入...")
165
+ with gr.Row():
166
+ clear_btn = gr.Button("清空")
167
+ rollback_btn = gr.Button("撤回")
168
+ regen_btn = gr.Button("重新生成")
169
+ send_btn = gr.Button("发送")
170
+ raw_context_json = gr.JSON(label="原始对话上下文信息")
171
+
172
+ api_type.change(
173
+ lambda api_type: {
174
+ "qianfan": (gr.update(visible=True), gr.update(visible=True)),
175
+ "aistudio": (gr.update(visible=False), gr.update(visible=False)),
176
+ }[api_type],
177
+ inputs=api_type,
178
+ outputs=[
179
+ access_key,
180
+ secret_key,
181
+ ],
182
+ )
183
+ chat_completion_tab.select(
184
+ lambda: (None, None, None, {}),
185
+ outputs=[
186
+ input_text,
187
+ context_chatbot,
188
+ raw_context_json,
189
+ state,
190
+ ],
191
+ )
192
+ input_text.submit(
193
+ _infer,
194
+ inputs=[
195
+ ernie_model,
196
+ input_text,
197
+ state,
198
+ top_p,
199
+ temperature,
200
+ api_type,
201
+ access_key,
202
+ secret_key,
203
+ access_token,
204
+ ],
205
+ outputs=[
206
+ input_text,
207
+ context_chatbot,
208
+ raw_context_json,
209
+ state,
210
+ ],
211
+ )
212
+ clear_btn.click(
213
+ lambda _: (None, None, None, {}),
214
+ inputs=clear_btn,
215
+ outputs=[
216
+ input_text,
217
+ context_chatbot,
218
+ raw_context_json,
219
+ state,
220
+ ],
221
+ show_progress=False,
222
+ )
223
+ rollback_btn.click(
224
+ _rollback,
225
+ inputs=[state],
226
+ outputs=[
227
+ input_text,
228
+ context_chatbot,
229
+ raw_context_json,
230
+ state,
231
+ ],
232
+ show_progress=False,
233
+ )
234
+ regen_btn.click(
235
+ _regen_response,
236
+ inputs=[
237
+ ernie_model,
238
+ state,
239
+ top_p,
240
+ temperature,
241
+ api_type,
242
+ access_key,
243
+ secret_key,
244
+ access_token,
245
+ ],
246
+ outputs=[
247
+ input_text,
248
+ context_chatbot,
249
+ raw_context_json,
250
+ state,
251
+ ],
252
+ )
253
+ send_btn.click(
254
+ _infer,
255
+ inputs=[
256
+ ernie_model,
257
+ input_text,
258
+ state,
259
+ top_p,
260
+ temperature,
261
+ api_type,
262
+ access_key,
263
+ secret_key,
264
+ access_token,
265
+ ],
266
+ outputs=[
267
+ input_text,
268
+ context_chatbot,
269
+ raw_context_json,
270
+ state,
271
+ ],
272
+ )
273
+
274
+
275
+ def create_embedding_tab():
276
+ def _get_embeddings(text1, text2, api_type, access_key, secret_key, access_token):
277
+ access_key = access_key.strip()
278
+ secret_key = secret_key.strip()
279
+ access_token = access_token.strip()
280
+
281
+ if (access_key == "" or secret_key == "") and access_token == "":
282
+ raise gr.Error("需要填写正确的AK/SK或access token,不能为空")
283
+
284
+ auth_config = {
285
+ "api_type": api_type,
286
+ }
287
+ if access_key:
288
+ auth_config["ak"] = access_key
289
+ if secret_key:
290
+ auth_config["sk"] = secret_key
291
+ if access_token:
292
+ auth_config["access_token"] = access_token
293
+
294
+ if text1.strip() == "" or text2.strip() == "":
295
+ raise gr.Error("两个输入均不能为空")
296
+ embeddings = eb.Embedding.create(
297
+ _config_=auth_config,
298
+ model="ernie-text-embedding",
299
+ input=[text1.strip(), text2.strip()],
300
+ )
301
+ emb_0 = embeddings.rbody["data"][0]["embedding"]
302
+ emb_1 = embeddings.rbody["data"][1]["embedding"]
303
+ cos_sim = _calc_cosine_similarity(emb_0, emb_1)
304
+ cos_sim_text = f"## 两段文本余弦相似度: {cos_sim}"
305
+ return str(emb_0), str(emb_1), cos_sim_text
306
+
307
+ def _calc_cosine_similarity(vec_0, vec_1):
308
+ dot_result = float(np.dot(vec_0, vec_1))
309
+ denom = np.linalg.norm(vec_0) * np.linalg.norm(vec_1)
310
+ return 0.5 + 0.5 * (dot_result / denom) if denom != 0 else 0
311
+
312
+ with gr.Tab("语义向量(Embedding)"):
313
+ gr.Markdown("输入两段文本,分别获取两段文本的向量表示,并计算向量间的余弦相似度")
314
+ with gr.Row():
315
+ with gr.Column(scale=1):
316
+ api_type = gr.Dropdown(
317
+ label="API Type", info="提供语义向量能力的后端平台", value="qianfan", choices=["qianfan", "aistudio"]
318
+ )
319
+ access_key = gr.Textbox(
320
+ label="AK", info="用于访问后端平台的AK,如果设置了access token则无需设置此参数", type="password"
321
+ )
322
+ secret_key = gr.Textbox(
323
+ label="SK", info="用于访问后端平台的SK,如果设置了access token则无需设置此参数", type="password"
324
+ )
325
+ access_token = gr.Textbox(
326
+ label="Access Token", info="用于访问后端平台的access token,如果设置了AK、SK则无需设置此参数", type="password"
327
+ )
328
+ with gr.Column(scale=4):
329
+ with gr.Row():
330
+ text1 = gr.Textbox(label="第一段文本", placeholder="输入第一段文本")
331
+ text2 = gr.Textbox(label="第二段文本", placeholder="输入第二段文本")
332
+ cal_emb = gr.Button("生成向量")
333
+ cos_sim = gr.Markdown("## 余弦相似度: -")
334
+ with gr.Row():
335
+ embedding1 = gr.Textbox(label="文本1向量结果")
336
+ embedding2 = gr.Textbox(label="文本2向量结果")
337
+
338
+ api_type.change(
339
+ lambda api_type: {
340
+ "qianfan": (gr.update(visible=True), gr.update(visible=True)),
341
+ "aistudio": (gr.update(visible=False), gr.update(visible=False)),
342
+ }[api_type],
343
+ inputs=api_type,
344
+ outputs=[
345
+ access_key,
346
+ secret_key,
347
+ ],
348
+ )
349
+ cal_emb.click(
350
+ _get_embeddings,
351
+ inputs=[
352
+ text1,
353
+ text2,
354
+ api_type,
355
+ access_key,
356
+ secret_key,
357
+ access_token,
358
+ ],
359
+ outputs=[
360
+ embedding1,
361
+ embedding2,
362
+ cos_sim,
363
+ ],
364
+ )
365
+
366
+
367
+ def create_image_tab():
368
+ def _gen_image(prompt, w_and_h, api_type, access_key, secret_key, access_token):
369
+ access_key = access_key.strip()
370
+ secret_key = secret_key.strip()
371
+ access_token = access_token.strip()
372
+
373
+ if (access_key == "" or secret_key == "") and access_token == "":
374
+ raise gr.Error("需要填写正确的AK/SK或access token,不能为空")
375
+ if prompt.strip() == "":
376
+ raise gr.Error("输入不能为空")
377
+
378
+ auth_config = {
379
+ "api_type": api_type,
380
+ }
381
+ if access_key:
382
+ auth_config["ak"] = access_key
383
+ if secret_key:
384
+ auth_config["sk"] = secret_key
385
+ if access_token:
386
+ auth_config["access_token"] = access_token
387
+
388
+ timestamp = int(time.time())
389
+ w, h = [int(x) for x in w_and_h.strip().split("x")]
390
+
391
+ response = eb.Image.create(
392
+ _config_=auth_config,
393
+ model="ernie-vilg-v2",
394
+ prompt=prompt,
395
+ width=w,
396
+ height=h,
397
+ version="v2",
398
+ image_num=1,
399
+ )
400
+ img_url = response.data["sub_task_result_list"][0]["final_image_list"][0]["img_url"]
401
+ res = requests.get(img_url)
402
+ with open(f"{timestamp}.jpg", "wb") as f:
403
+ f.write(res.content)
404
+ return f"{timestamp}.jpg"
405
+
406
+ with gr.Tab("文生图(Image Generation)"):
407
+ with gr.Row():
408
+ with gr.Column(scale=1):
409
+ api_type = gr.Dropdown(
410
+ label="API Type", info="提供文生图能力的后端平台", value="yinian", choices=["yinian"]
411
+ )
412
+ access_key = gr.Textbox(
413
+ label="AK", info="用于访问后端平台的AK,如果设置了access token则无需设置此参数", type="password"
414
+ )
415
+ secret_key = gr.Textbox(
416
+ label="SK", info="用于访问后端平台的SK,如果设置了access token则无需设置此参数", type="password"
417
+ )
418
+ access_token = gr.Textbox(
419
+ label="Access Token", info="用于访问后端平台的access token,如果设置了AK、SK则无需设置此参数", type="password"
420
+ )
421
+ with gr.Column(scale=4):
422
+ with gr.Row():
423
+ prompt = gr.Textbox(label="Prompt", placeholder="输入用于生成图片的prompt,例如: 生成一朵玫瑰花")
424
+ w_and_h = gr.Dropdown(
425
+ label="分辨率",
426
+ value="512x512",
427
+ choices=[
428
+ "512x512",
429
+ "640x360",
430
+ "360x640",
431
+ "1024x1024",
432
+ "1280x720",
433
+ "720x1280",
434
+ "2048x2048",
435
+ "2560x1440",
436
+ "1440x2560",
437
+ ],
438
+ )
439
+ submit_btn = gr.Button("生成图片")
440
+ image_show_zone = gr.Image(label="图片生成结果", type="filepath", show_download_button=True)
441
+
442
+ submit_btn.click(
443
+ _gen_image,
444
+ inputs=[
445
+ prompt,
446
+ w_and_h,
447
+ api_type,
448
+ access_key,
449
+ secret_key,
450
+ access_token,
451
+ ],
452
+ outputs=image_show_zone,
453
+ )
454
+
455
+
456
+ def create_rag_tab():
457
+ REF_HTML = """
458
+
459
+ <details style="border: 1px solid #ccc; padding: 10px; border-radius: 4px; margin-bottom: 4px">
460
+ <summary style="display: flex; align-items: center; font-weight: bold;">
461
+ <span style="margin-right: 10px;">[{index}] {title}</span>
462
+ <a style="text-decoration: none; background: none !important;" target="_blank">
463
+ <!--[Here should be a link icon]-->
464
+ <i style="border: solid #000; border-width: 0 2px 2px 0; display: inline-block; padding: 3px;
465
+ transform:rotate(-45deg);-webkit-transform(-45deg)">
466
+ </i>
467
+ </a>
468
+ </summary>
469
+ <p style="margin-top: 10px;">{text}</p>
470
+ </details>
471
+
472
+ """
473
+
474
+ PROMPT_TEMPLATE = """基于以下已知信息,请简洁并专业地回答用户的问题。
475
+ 如果无法从中得到答案,请说 '根据已知信息无法回答该问题' 或 '没有提供足够的相关信息'。不允许在答案中添加编造成分。
476
+ 你可以参考以下文章:
477
+ {DOCS}
478
+ 问题:{QUERY}
479
+ 回答:"""
480
+
481
+ _CONFIG = {
482
+ "ernie_model": "",
483
+ "api_type": "",
484
+ "AK": "",
485
+ "SK": "",
486
+ "access_token": "",
487
+ "top_p": 0.7,
488
+ "temperature": 0.95,
489
+ }
490
+
491
+ def split_by_len(texts: List[str], split_token: int = 384) -> List[str]:
492
+ """
493
+ Split the knowledge base docs into chunks by length.
494
+
495
+ Args:
496
+ texts (List[str]): Knowledge Base Texts.
497
+ split_token (int, optional): The max length supported by ernie-text-embedding. Default to 384.
498
+
499
+ Returns:
500
+ List[str]: Doc Chunks.
501
+ """
502
+ chunk = []
503
+ for text in texts:
504
+ idx = 0
505
+ while idx + split_token < len(text):
506
+ temp_text = text[idx : idx + split_token]
507
+ next_idx = temp_text.rfind("。") + 1
508
+ if next_idx != 0: # If this slice doesn't have a period, add the whole sentence.
509
+ chunk.append(temp_text[:next_idx])
510
+ idx = idx + next_idx
511
+ else:
512
+ chunk.append(temp_text)
513
+ idx = idx + split_token
514
+
515
+ chunk.append(text[idx:])
516
+ return chunk
517
+
518
+ def _get_embedding_doc(word: List[str]) -> List[float]:
519
+ """
520
+ Get the embedding of a list of words.
521
+
522
+ Args:
523
+ word (List[str]): Words to get embedding.
524
+
525
+ Returns:
526
+ List[float]: Embedding List of the words.
527
+ """
528
+ if (_CONFIG["AK"] == "" or _CONFIG["SK"] == "") and _CONFIG["access_token"] == "":
529
+ raise gr.Error("需要填写正确的AK/SK或access token,不能为空")
530
+
531
+ embedding: List[float]
532
+ if len(word) <= 16:
533
+ resp = eb.Embedding.create(model="ernie-text-embedding", input=word)
534
+ assert not isinstance(resp, Iterator)
535
+ embedding = resp.get_result()
536
+ else:
537
+ size = len(word)
538
+ embedding = []
539
+ for i in tqdm(range(math.ceil(size / 16))):
540
+ temp_result = eb.Embedding.create(
541
+ model="ernie-text-embedding", input=word[i * 16 : (i + 1) * 16]
542
+ )
543
+ assert not isinstance(temp_result, Iterator)
544
+ embedding.extend(temp_result.get_result())
545
+ time.sleep(1)
546
+ return embedding
547
+
548
+ def l2_normalization(embedding: np.ndarray) -> np.ndarray:
549
+ "Vector Normalization by l2 norm"
550
+ if embedding.ndim == 1:
551
+ return embedding / np.linalg.norm(embedding).reshape(-1, 1)
552
+ else:
553
+ return embedding / np.linalg.norm(embedding, axis=1).reshape(-1, 1)
554
+
555
+ def find_related_doc(
556
+ query: str, origin_chunk: List[str], index_ip: faiss.swigfaiss.IndexFlatIP, top_k: int = 5
557
+ ) -> tuple[str, List[int]]:
558
+ """
559
+ Fin top_k similar documents.
560
+
561
+ Args:
562
+ query (str): user query.
563
+ origin_chunk (List[str]): Knowledge Base Doc.
564
+ index_ip (faiss.swigfaiss.IndexFlatIP): Vector DB index。
565
+ top_k (int, optional): Return top_k most similar documents. Default to 5.
566
+
567
+ Returns:
568
+ str, List[int]: The most similar documents and their index.
569
+ """
570
+
571
+ D, Idx = index_ip.search(np.array(_get_embedding_doc([query])), top_k)
572
+ top_k_similar = Idx.tolist()[0]
573
+
574
+ res = ""
575
+ ref_lis = []
576
+ for i in range(top_k):
577
+ res += f"[参考文章{i+1}]:{origin_chunk[top_k_similar[i]]}" + "\n\n"
578
+ ref_lis.append(origin_chunk[top_k_similar[i]])
579
+ return res, ref_lis
580
+
581
+ def process_uploaded_file(files: List[str], *args: object) -> str:
582
+ """
583
+ Args:
584
+ files: Files path
585
+ _CONFIG: Config
586
+ """
587
+ _update_config(*args)
588
+
589
+ content = []
590
+ for file in files:
591
+ with open(file, "r") as f:
592
+ content.append(f.read())
593
+
594
+ doc_chunk = split_by_len(content)
595
+
596
+ doc_embedding = _get_embedding_doc(doc_chunk)
597
+ assert len(doc_embedding) == len(doc_chunk), "shape mismatch"
598
+ doc_embedding_arr = l2_normalization(np.array(doc_embedding))
599
+
600
+ index_ip = faiss.IndexFlatIP(doc_embedding_arr.shape[1])
601
+ index_ip.add(doc_embedding_arr)
602
+
603
+ temp_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
604
+ if not os.path.exists(temp_path):
605
+ os.makedirs(temp_path)
606
+
607
+ faiss.write_index(index_ip, os.path.join(temp_path, "knowledge_embedding.index"))
608
+ with open(os.path.join(temp_path, "knowledge.txt"), "w") as f:
609
+ for chunk in doc_chunk:
610
+ f.write(repr(chunk) + "\n")
611
+
612
+ return "已完成向量知识库搭建"
613
+
614
+ def get_ans(query: str, *args: object) -> tuple[str, str]:
615
+ _update_config(*args)
616
+
617
+ if (_CONFIG["AK"] == "" or _CONFIG["SK"] == "") and _CONFIG["access_token"] == "":
618
+ raise gr.Error("需要填写正确的AK/SK或access token,不能为空")
619
+ temp_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
620
+ doc_chunk = []
621
+ with open(os.path.join(temp_path, "knowledge.txt"), "r") as f:
622
+ for line in f:
623
+ doc_chunk.append(eval(line))
624
+ index_ip = faiss.read_index(os.path.join(temp_path, "knowledge_embedding.index"))
625
+ related_doc, references = find_related_doc(query, doc_chunk, index_ip)
626
+
627
+ refs = []
628
+ for i in range(len(references)):
629
+ temp_dict = {
630
+ "title": f"Reference{i+1}",
631
+ "text": references[i],
632
+ }
633
+ refs.append(temp_dict)
634
+
635
+ resp = eb.ChatCompletion.create(
636
+ model=_CONFIG["ernie_model"],
637
+ messages=[{"role": "user", "content": PROMPT_TEMPLATE.format(DOCS=related_doc, QUERY=query)}],
638
+ top_p=_CONFIG["top_p"],
639
+ temperature=_CONFIG["temperature"],
640
+ )
641
+ assert not isinstance(resp, Iterator)
642
+ answer = resp.get_result()
643
+
644
+ return answer, "<h3>References (Click to Expand)</h3>" + "\n".join(
645
+ [REF_HTML.format(**item, index=idx + 1) for idx, item in enumerate(refs)]
646
+ )
647
+
648
+ def _update_config(*args: object):
649
+ eb.api_type = args[1]
650
+ eb.access_token = args[2]
651
+ eb.AK = args[3]
652
+ eb.SK = args[4]
653
+
654
+ _CONFIG.update(
655
+ {
656
+ "ernie_model": args[0],
657
+ "api_type": args[1],
658
+ "access_token": args[2],
659
+ "AK": args[3],
660
+ "SK": args[4],
661
+ "top_p": args[5],
662
+ "temperature": args[6],
663
+ }
664
+ )
665
+ # print(_CONFIG)
666
+
667
+ with gr.Tab("知识库问答(Retrieval Augmented QA)"):
668
+ # gr.Markdown("# 文心大模型RAG问答DEMO")
669
+ with gr.Tabs():
670
+ with gr.TabItem("设置栏"):
671
+ with gr.Row():
672
+ with gr.Column():
673
+ file_upload = gr.Files(file_types=["txt"], label="目前仅支持txt格式文件")
674
+ chat_box = gr.Textbox(show_label=False)
675
+ with gr.Column():
676
+ ernie_model = gr.Dropdown(
677
+ label="Model",
678
+ info="模型类型",
679
+ value="ernie-bot-4",
680
+ choices=["ernie-bot-4", "ernie-bot-turbo", "ernie-bot"],
681
+ )
682
+ api_type = gr.Dropdown(
683
+ label="API Type",
684
+ info="提供���话能力的后端平台",
685
+ value="aistudio",
686
+ choices=["aistudio", "qianfan"],
687
+ )
688
+ access_token = gr.Textbox(
689
+ label="Access Token",
690
+ info="用于访问后端平台的access token,如果选择aistudio,则需设置此参数",
691
+ type="password",
692
+ )
693
+ access_key = gr.Textbox(
694
+ label="AK", info="用于访问千帆平台的AK,如果选择qianfan,则需设置此参数", type="password"
695
+ )
696
+ secret_key = gr.Textbox(
697
+ label="SK", info="用于访问千帆平台的SK,如果选择qianfan,则需设置此参数", type="password"
698
+ )
699
+ top_p = gr.Slider(
700
+ label="Top-p",
701
+ info="控制采样范围,该参数越小生成结果越稳定",
702
+ value=0.7,
703
+ step=0.05,
704
+ minimum=0,
705
+ maximum=1,
706
+ )
707
+ temperature = gr.Slider(
708
+ label="temperature",
709
+ info="控制采样随机性,该参数越小生成结果越稳定",
710
+ value=0.95,
711
+ step=0.05,
712
+ maximum=1,
713
+ minimum=0,
714
+ )
715
+
716
+ with gr.TabItem("问答栏"):
717
+ with gr.Row():
718
+ query_box = gr.Textbox(show_label=False, placeholder="Enter question and press ENTER")
719
+
720
+ answer_box = gr.Textbox(show_label=False, value="", lines=5)
721
+ ref_boxes = gr.HTML(label="References")
722
+
723
+ query_box.submit(
724
+ get_ans,
725
+ [query_box, ernie_model, api_type, access_token, access_key, secret_key, top_p, temperature],
726
+ [answer_box, ref_boxes],
727
+ )
728
+ file_upload.upload(
729
+ process_uploaded_file,
730
+ [file_upload, ernie_model, api_type, access_token, access_key, secret_key, top_p, temperature],
731
+ chat_box,
732
+ )
733
+
734
+
735
+ if __name__ == "__main__":
736
+ args = parse_setup_args()
737
+ create_ui_and_launch(args)