neoguojing commited on
Commit
4d10a94
1 Parent(s): 494b300

finish rag

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +51 -36
  3. llm.py +89 -0
  4. requirements.txt +2 -1
.gitignore CHANGED
@@ -2,3 +2,6 @@
2
  __pycache__/
3
  *.bin
4
  .vscode/
 
 
 
 
2
  __pycache__/
3
  *.bin
4
  .vscode/
5
+ files/input/ir2023_ashare.pdf
6
+ knowledge_bases/中国移动.faiss
7
+ knowledge_bases/中国移动.pkl
app.py CHANGED
@@ -133,23 +133,23 @@ def create_ui():
133
  components["db_view"] = gr.Dataframe(
134
  headers=["列表"],
135
  datatype=["str"],
136
- row_count=8,
137
  col_count=(1, "fixed"),
138
  interactive=False
139
  )
140
  with gr.Column(scale=2):
141
- with gr.Group():
 
142
  components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
143
- components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
144
  components["db_submit_btn"] = gr.Button(value="提交")
 
145
  with gr.Row():
146
  with gr.Column(scale=2):
147
  components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
148
-
149
  with gr.Column(scale=1):
150
- components["db_test_select"] = gr.Dropdown(
151
- choices=knowledgeBase.get_bases(),value=None,multiselect=True, label="知识库选择"
152
- )
153
  components["dbtest_submit_btn"] = gr.Button(value="检索")
154
  with gr.Row():
155
  with gr.Group():
@@ -157,16 +157,22 @@ def create_ui():
157
 
158
  with gr.Tab("问答"):
159
  with gr.Row():
160
- with gr.Column():
 
 
 
 
 
 
161
  with gr.Group():
162
  components["chatbot"] = gr.Chatbot(
163
- [(None,"What can I help you?")],
164
  elem_id="chatbot",
165
  bubble_full_width=False,
166
  height=600
167
  )
168
  components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
169
- components["db_select"] = gr.CheckboxGroup(choices=knowledgeBase.get_bases(),value=None,label="知识库", info="可选择1个或多个知识库")
170
  create_event_handlers()
171
  demo.load(init,None,gradio("db_view"))
172
  return demo
@@ -236,6 +242,10 @@ def create_event_handlers():
236
  do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
237
  )
238
 
 
 
 
 
239
  def do_refernce(algo_type,input_image):
240
  # def do_refernce():
241
  print("input image",input_image)
@@ -307,9 +317,6 @@ def do_sam_everything(im):
307
 
308
  return images
309
 
310
-
311
-
312
-
313
  def point_to_mask(pil_image):
314
  # 遍历每个像素
315
  width, height = pil_image.size
@@ -337,11 +344,11 @@ def do_llm_request(history, message):
337
  return history, gr.MultimodalTextbox(value=None, interactive=False)
338
 
339
  def do_llm_response(history,selected_dbs):
 
340
  user_input = history[-1][0]
341
  prompt = ""
342
  quote = ""
343
- print("----------",selected_dbs)
344
- if selected_dbs is not None and len(selected_dbs) != 0:
345
  knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
346
  print("do_llm_response context:",knowledge)
347
  prompt = f'''
@@ -349,8 +356,8 @@ def do_llm_response(history,selected_dbs):
349
  背景2:{knowledge[1]["content"]}
350
  背景3:{knowledge[2]["content"]}
351
  基于以上事实回答问题:{user_input}
352
- '''
353
- print("do_llm_response prompt:",prompt)
354
  quote = f'''
355
  > 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
356
  > 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
@@ -358,33 +365,41 @@ def do_llm_response(history,selected_dbs):
358
  '''
359
  else:
360
  prompt = user_input
361
-
362
- response = llm(prompt)
363
  history[-1][1] = ""
364
- response = response.removeprefix(prompt)
365
- response += quote
 
 
 
 
 
 
 
366
  for character in response:
367
  history[-1][1] += character
368
  time.sleep(0.01)
369
  yield history
370
 
371
- def llm(input):
372
- import requests
373
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
374
- headers = {"Authorization": "Bearer "}
375
 
376
- def query(payload):
377
- response = requests.post(API_URL, headers=headers, json=payload)
378
- return response.json()
379
-
380
- output = query({
381
- "inputs": input,
382
- })
383
- print(output)
384
- if len(output) >0:
385
- return output[0]['generated_text']
386
- return ""
 
 
 
 
 
387
 
 
388
 
389
 
390
  def file_handler(file_objs,name):
 
133
  components["db_view"] = gr.Dataframe(
134
  headers=["列表"],
135
  datatype=["str"],
136
+ row_count=2,
137
  col_count=(1, "fixed"),
138
  interactive=False
139
  )
140
  with gr.Column(scale=2):
141
+ with gr.Row():
142
+ with gr.Column(scale=2):
143
  components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
144
+ with gr.Column(scale=2):
145
  components["db_submit_btn"] = gr.Button(value="提交")
146
+ components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
147
  with gr.Row():
148
  with gr.Column(scale=2):
149
  components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
 
150
  with gr.Column(scale=1):
151
+ components["db_test_select"] = gr.Dropdown(knowledgeBase.get_bases(),multiselect=True, label="知识库选择")
152
+ with gr.Column(scale=1):
 
153
  components["dbtest_submit_btn"] = gr.Button(value="检索")
154
  with gr.Row():
155
  with gr.Group():
 
157
 
158
  with gr.Tab("问答"):
159
  with gr.Row():
160
+ with gr.Column(scale=1):
161
+ with gr.Group():
162
+ components["ak"] = gr.Textbox(label="appid")
163
+ components["sk"] = gr.Textbox(label="secret")
164
+ components["llm_client"] =gr.Radio(["Wenxin", "Tongyi","Huggingface"],value="Wenxin", label="llm")
165
+ components["llm_setting_btn"] = gr.Button(value="设置")
166
+ with gr.Column(scale=2):
167
  with gr.Group():
168
  components["chatbot"] = gr.Chatbot(
169
+ [(None,"你好,有什么需要帮助的?")],
170
  elem_id="chatbot",
171
  bubble_full_width=False,
172
  height=600
173
  )
174
  components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
175
+ components["db_select"] = gr.CheckboxGroup(knowledgeBase.get_bases(),label="知识库", info="可选择1个或多个知识库")
176
  create_event_handlers()
177
  demo.load(init,None,gradio("db_view"))
178
  return demo
 
242
  do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
243
  )
244
 
245
+ components['llm_setting_btn'].click(
246
+ llm, gradio('ak','sk','llm_client'), None
247
+ )
248
+
249
  def do_refernce(algo_type,input_image):
250
  # def do_refernce():
251
  print("input image",input_image)
 
317
 
318
  return images
319
 
 
 
 
320
  def point_to_mask(pil_image):
321
  # 遍历每个像素
322
  width, height = pil_image.size
 
344
  return history, gr.MultimodalTextbox(value=None, interactive=False)
345
 
346
  def do_llm_response(history,selected_dbs):
347
+ print("do_llm_response:",history,selected_dbs)
348
  user_input = history[-1][0]
349
  prompt = ""
350
  quote = ""
351
+ if len(selected_dbs) > 0:
 
352
  knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
353
  print("do_llm_response context:",knowledge)
354
  prompt = f'''
 
356
  背景2:{knowledge[1]["content"]}
357
  背景3:{knowledge[2]["content"]}
358
  基于以上事实回答问题:{user_input}
359
+ '''
360
+
361
  quote = f'''
362
  > 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
363
  > 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
 
365
  '''
366
  else:
367
  prompt = user_input
368
+
 
369
  history[-1][1] = ""
370
+ if llm_client is None:
371
+ gr.Warning("请先设置大模型")
372
+ response = "模型参数未设置"
373
+ else:
374
+ print("do_llm_response prompt:",prompt)
375
+ response = llm_client(prompt)
376
+ response = response.removeprefix(prompt)
377
+ response += quote
378
+
379
  for character in response:
380
  history[-1][1] += character
381
  time.sleep(0.01)
382
  yield history
383
 
 
 
 
 
384
 
385
+ llm_client = None
386
+ def llm(ak,sk,client):
387
+ global llm_client
388
+ import llm
389
+ llm.init_param(ak,sk)
390
+ if client == "Wenxin":
391
+ llm_client = llm.baidu_client
392
+ elif client == "Tongyi":
393
+ llm_client = llm.qwen_agent_app
394
+ elif client == "Huggingface":
395
+ llm_client = llm.hg_client
396
+
397
+ if ak == "" and sk == "":
398
+ gr.Info("重置成功")
399
+ else:
400
+ gr.Info("设置成功")
401
 
402
+ return llm_client
403
 
404
 
405
  def file_handler(file_objs,name):
llm.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ from http import HTTPStatus
4
+ from dashscope import Application
5
+
6
+ ak = ""
7
+ sk = ""
8
+
9
+ def init_param(access_key,secret_key):
10
+ global ak, sk
11
+ ak = access_key
12
+ sk = secret_key
13
+
14
+
15
+ def baidu_client(input):
16
+ global ak, sk
17
+ if ak == "" or sk == "":
18
+ return ""
19
+
20
+ url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k?access_token=" + get_access_token()
21
+
22
+ payload = json.dumps({
23
+ "temperature": 0.95,
24
+ "top_p": 0.7,
25
+ "penalty_score": 1,
26
+ "messages": [
27
+ {
28
+ "role": "user",
29
+ "content": input
30
+ }
31
+ ],
32
+ "system": ""
33
+ })
34
+ headers = {
35
+ 'Content-Type': 'application/json'
36
+ }
37
+
38
+ response = requests.request("POST", url, headers=headers, data=payload)
39
+
40
+ print("baidu_client",response.text)
41
+ return response.json()["result"]
42
+
43
+
44
+ def get_access_token():
45
+ """
46
+ 使用 AK,SK 生成鉴权签名(Access Token)
47
+ :return: access_token,或是None(如果错误)
48
+ """
49
+ url = "https://aip.baidubce.com/oauth/2.0/token"
50
+ params = {"grant_type": "client_credentials", "client_id": ak, "client_secret": sk}
51
+ return str(requests.post(url, params=params).json().get("access_token"))
52
+
53
+
54
+ def qwen_agent_app(input):
55
+ global ak, sk
56
+ if ak == "" or sk == "":
57
+ return ""
58
+ response = Application.call(app_id=ak,
59
+ prompt=input,
60
+ api_key=sk,
61
+ )
62
+
63
+ if response.status_code != HTTPStatus.OK:
64
+ print('request_id=%s, code=%s, message=%s\n' % (response.request_id, response.status_code, response.message))
65
+ return ""
66
+ else:
67
+ print('request_id=%s\n output=%s\n usage=%s\n' % (response.request_id, response.output, response.usage))
68
+ return response.output["text"]
69
+
70
+
71
+ def hg_client(input):
72
+ global ak, sk
73
+ if sk == "":
74
+ return ""
75
+ import requests
76
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
77
+ headers = {"Authorization": f"Bearer {sk}"}
78
+
79
+ def query(payload):
80
+ response = requests.post(API_URL, headers=headers, json=payload)
81
+ return response.json()
82
+
83
+ output = query({
84
+ "inputs": input,
85
+ })
86
+ print(output)
87
+ if len(output) >0:
88
+ return output[0]['generated_text']
89
+ return ""
requirements.txt CHANGED
@@ -20,4 +20,5 @@ faiss-cpu==1.8.0
20
  pypdf==4.2.0
21
  langchain==0.2.5
22
  langchain-community==0.2.5
23
- transformers==4.32.1
 
 
20
  pypdf==4.2.0
21
  langchain==0.2.5
22
  langchain-community==0.2.5
23
+ transformers==4.32.1
24
+ dashscope==1.20.0