neoguojing commited on
Commit
494b300
1 Parent(s): ac510cd
Files changed (4) hide show
  1. app.py +152 -1
  2. embedding.py +69 -0
  3. requirements.txt +7 -1
  4. retriever.py +150 -0
app.py CHANGED
@@ -6,7 +6,8 @@ from inference import ModelFactory
6
  from face import FaceAlgo
7
  from sam_everything import SamAnything
8
  from ocr import do_ocr
9
-
 
10
 
11
  components = {}
12
 
@@ -125,10 +126,55 @@ def create_ui():
125
  with gr.Row():
126
  with gr.Group():
127
  components["ocr_json_output"] = gr.JSON(label="推理结果")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  create_event_handlers()
 
130
  return demo
131
 
 
 
 
 
132
 
133
  def create_event_handlers():
134
  params["algo_type"] = gr.State("全景分割")
@@ -172,6 +218,24 @@ def create_event_handlers():
172
  do_ocr,gradio('ocr_type','ocr_input'),gradio("ocr_output","ocr_json_output")
173
  )
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def do_refernce(algo_type,input_image):
176
  # def do_refernce():
177
  print("input image",input_image)
@@ -261,6 +325,93 @@ def point_to_mask(pil_image):
261
  points_array_reshaped = points_array.reshape(-1, 2)
262
  return points_array_reshaped
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if __name__ == "__main__":
265
  demo = create_ui()
266
  # demo.launch(server_name="10.151.124.137")
 
6
  from face import FaceAlgo
7
  from sam_everything import SamAnything
8
  from ocr import do_ocr
9
+ from retriever import knowledgeBase
10
+ import time
11
 
12
  components = {}
13
 
 
126
  with gr.Row():
127
  with gr.Group():
128
  components["ocr_json_output"] = gr.JSON(label="推理结果")
129
+ with gr.Tab("知识库"):
130
+ with gr.Row():
131
+ with gr.Column(scale=1):
132
+ with gr.Group():
133
+ components["db_view"] = gr.Dataframe(
134
+ headers=["列表"],
135
+ datatype=["str"],
136
+ row_count=8,
137
+ col_count=(1, "fixed"),
138
+ interactive=False
139
+ )
140
+ with gr.Column(scale=2):
141
+ with gr.Group():
142
+ components["db_name"] = gr.Textbox(label="名称", info="请输入库名称", lines=1, value="")
143
+ components["file_upload"] = gr.File(elem_id='file_upload',file_count='multiple',label='文档上传', file_types=[".pdf", ".doc", '.docx', '.json', '.csv'])
144
+ components["db_submit_btn"] = gr.Button(value="提交")
145
+ with gr.Row():
146
+ with gr.Column(scale=2):
147
+ components["db_input"] = gr.Textbox(label="关键词", lines=1, value="")
148
+
149
+ with gr.Column(scale=1):
150
+ components["db_test_select"] = gr.Dropdown(
151
+ choices=knowledgeBase.get_bases(),value=None,multiselect=True, label="知识库选择"
152
+ )
153
+ components["dbtest_submit_btn"] = gr.Button(value="检索")
154
+ with gr.Row():
155
+ with gr.Group():
156
+ components["db_search_result"] = gr.JSON(label="检索结果")
157
 
158
+ with gr.Tab("问答"):
159
+ with gr.Row():
160
+ with gr.Column():
161
+ with gr.Group():
162
+ components["chatbot"] = gr.Chatbot(
163
+ [(None,"What can I help you?")],
164
+ elem_id="chatbot",
165
+ bubble_full_width=False,
166
+ height=600
167
+ )
168
+ components["chat_input"] = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
169
+ components["db_select"] = gr.CheckboxGroup(choices=knowledgeBase.get_bases(),value=None,label="知识库", info="可选择1个或多个知识库")
170
  create_event_handlers()
171
+ demo.load(init,None,gradio("db_view"))
172
  return demo
173
 
174
+ def init():
175
+ # db_list = knowledgeBase.get_bases()
176
+ db_df_list = knowledgeBase.get_df_bases()
177
+ return db_df_list
178
 
179
  def create_event_handlers():
180
  params["algo_type"] = gr.State("全景分割")
 
218
  do_ocr,gradio('ocr_type','ocr_input'),gradio("ocr_output","ocr_json_output")
219
  )
220
 
221
+ components["db_submit_btn"].click(
222
+ file_handler,gradio('file_upload','db_name'),gradio("db_view",'db_select',"db_test_select")
223
+ )
224
+
225
+ components["chat_input"].submit(
226
+ do_llm_request, gradio("chatbot", "chat_input"), gradio("chatbot", "chat_input")
227
+ ).then(
228
+ do_llm_response, gradio("chatbot","db_select"), gradio("chatbot"), api_name="bot_response"
229
+ ).then(
230
+ lambda: gr.MultimodalTextbox(interactive=True), None, gradio('chat_input')
231
+ )
232
+
233
+ # components["chatbot"].like(print_like_dislike, None, None)
234
+
235
+ components['dbtest_submit_btn'].click(
236
+ do_search, gradio('db_test_select','db_input'), gradio('db_search_result')
237
+ )
238
+
239
  def do_refernce(algo_type,input_image):
240
  # def do_refernce():
241
  print("input image",input_image)
 
325
  points_array_reshaped = points_array.reshape(-1, 2)
326
  return points_array_reshaped
327
 
328
+
329
+ def print_like_dislike(x: gr.LikeData):
330
+ print(x.index, x.value, x.liked)
331
+
332
+ def do_llm_request(history, message):
333
+ for x in message["files"]:
334
+ history.append(((x,), None))
335
+ if message["text"] is not None:
336
+ history.append((message["text"], None))
337
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
338
+
339
+ def do_llm_response(history,selected_dbs):
340
+ user_input = history[-1][0]
341
+ prompt = ""
342
+ quote = ""
343
+ print("----------",selected_dbs)
344
+ if selected_dbs is not None and len(selected_dbs) != 0:
345
+ knowledge = knowledgeBase.retrieve_documents(selected_dbs,user_input)
346
+ print("do_llm_response context:",knowledge)
347
+ prompt = f'''
348
+ 背景1:{knowledge[0]["content"]}
349
+ 背景2:{knowledge[1]["content"]}
350
+ 背景3:{knowledge[2]["content"]}
351
+ 基于以上事实回答问题:{user_input}
352
+ '''
353
+ print("do_llm_response prompt:",prompt)
354
+ quote = f'''
355
+ > 文档:{knowledge[0]["meta"]["source"]},页码:{knowledge[0]["meta"]["page"]}
356
+ > 文档:{knowledge[1]["meta"]["source"]},页码:{knowledge[1]["meta"]["page"]}
357
+ > 文档:{knowledge[2]["meta"]["source"]},页码:{knowledge[2]["meta"]["page"]}
358
+ '''
359
+ else:
360
+ prompt = user_input
361
+
362
+ response = llm(prompt)
363
+ history[-1][1] = ""
364
+ response = response.removeprefix(prompt)
365
+ response += quote
366
+ for character in response:
367
+ history[-1][1] += character
368
+ time.sleep(0.01)
369
+ yield history
370
+
371
+ def llm(input):
372
+ import requests
373
+ API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
374
+ headers = {"Authorization": "Bearer "}
375
+
376
+ def query(payload):
377
+ response = requests.post(API_URL, headers=headers, json=payload)
378
+ return response.json()
379
+
380
+ output = query({
381
+ "inputs": input,
382
+ })
383
+ print(output)
384
+ if len(output) >0:
385
+ return output[0]['generated_text']
386
+ return ""
387
+
388
+
389
+
390
+ def file_handler(file_objs,name):
391
+ import shutil
392
+ import os
393
+
394
+ print("file_obj:",file_objs)
395
+
396
+ os.makedirs(os.path.dirname("./files/input/"), exist_ok=True)
397
+
398
+ for idx, file in enumerate(file_objs):
399
+ print(file)
400
+ file_path = "./files/input/" + os.path.basename(file.name)
401
+ if not os.path.exists(file_path):
402
+ shutil.move(file.name,"./files/input/")
403
+
404
+ knowledgeBase.add_documents_to_kb(name,[file_path])
405
+
406
+ dbs = knowledgeBase.get_bases()
407
+ dfs = knowledgeBase.get_df_bases()
408
+ return dfs,gr.CheckboxGroup(dbs,label="知识库", info="可选择1个或多个知识库"),gr.Dropdown(dbs,multiselect=True, label="知识库选择")
409
+
410
+ def do_search(selected_dbs,user_input):
411
+ print("do_search:",selected_dbs,user_input)
412
+ context = knowledgeBase.retrieve_documents(selected_dbs,user_input)
413
+ return context
414
+
415
  if __name__ == "__main__":
416
  demo = create_ui()
417
  # demo.launch(server_name="10.151.124.137")
embedding.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ from sklearn.preprocessing import normalize
3
+ from typing import Any, List, Mapping, Optional,Union
4
+ from langchain.callbacks.manager import (
5
+ CallbackManagerForLLMRun
6
+ )
7
+ from langchain_core.embeddings import Embeddings
8
+ import torch
9
+
10
+ class Embedding(Embeddings):
11
+
12
+ def __init__(self,**kwargs):
13
+ self.model=AutoModel.from_pretrained('BAAI/bge-small-zh-v1.5')
14
+ self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-zh-v1.5')
15
+ self.model.eval()
16
+
17
+ @property
18
+ def _llm_type(self) -> str:
19
+ return "BAAI/bge-small-zh-v1.5"
20
+
21
+ @property
22
+ def model_name(self) -> str:
23
+ return "embedding"
24
+
25
+ def _call(
26
+ self,
27
+ prompt: List[str],
28
+ stop: Optional[List[str]] = None,
29
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
30
+ **kwargs: Any,
31
+ ) -> str:
32
+ encoded_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors='pt')
33
+
34
+ with torch.no_grad():
35
+ model_output = self.model(**encoded_input)
36
+ # Perform pooling. In this case, cls pooling.
37
+ sentence_embeddings = model_output[0][:, 0]
38
+ print(sentence_embeddings.shape)
39
+ # normalize embeddings
40
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
41
+ return sentence_embeddings.numpy()
42
+
43
+ @property
44
+ def _identifying_params(self) -> Mapping[str, Any]:
45
+ """Get the identifying parameters."""
46
+ return {"model_path": self.model_path}
47
+
48
+ def embed_documents(self, texts) -> List[List[float]]:
49
+ # Embed a list of documents
50
+ embeddings = []
51
+ print("embed_documents:",len(texts),type(texts))
52
+ embedding = self._call(texts)
53
+ for row in embedding:
54
+ embeddings.append(row)
55
+ # print("embed_documents: shape",embeddings.shape)
56
+ return embeddings
57
+
58
+ def embed_query(self, text) -> List[float]:
59
+ # Embed a single query
60
+ embedding = self._call([text])
61
+ return embedding[0]
62
+
63
+
64
+ # if __name__ == '__main__':
65
+ # sd = Embedding()
66
+ # v1 = sd.embed_query("他是一个人")
67
+ # v2 = sd.embed_query("他是一个好人")
68
+ # v3 = sd.embed_documents(["她是一条狗","他是一个人"])
69
+ # print(v1 @ v2.T)
requirements.txt CHANGED
@@ -14,4 +14,10 @@ cloudpickle==2.2.1
14
  segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
15
  paddlepaddle==2.6.1
16
  paddleocr==2.7.3
17
- easyocr==1.7.1
 
 
 
 
 
 
 
14
  segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
15
  paddlepaddle==2.6.1
16
  paddleocr==2.7.3
17
+ easyocr==1.7.1
18
+ scikit-learn==1.5.0
19
+ faiss-cpu==1.8.0
20
+ pypdf==4.2.0
21
+ langchain==0.2.5
22
+ langchain-community==0.2.5
23
+ transformers==4.32.1
retriever.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.docstore.in_memory import InMemoryDocstore
5
+ import faiss
6
+ import os
7
+ import glob
8
+ from typing import Any,List,Dict
9
+ from embedding import Embedding
10
+
11
+
12
+ class KnowledgeBaseManager:
13
+ def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
14
+ self.base_path = base_path
15
+ self.embedding_dim = embedding_dim
16
+ self.batch_size = batch_size
17
+ self.embeddings = Embedding()
18
+ self.knowledge_bases: Dict[str, FAISS] = {}
19
+ os.makedirs(self.base_path, exist_ok=True)
20
+
21
+ faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
22
+ # 获取不带后缀的名称
23
+ file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
24
+ for name in file_names_without_extension:
25
+ self.load_knowledge_base(name)
26
+
27
+
28
+ def create_knowledge_base(self, name: str):
29
+ index = faiss.IndexFlatL2(self.embedding_dim)
30
+ kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
31
+ if name in self.knowledge_bases:
32
+ print(f"Knowledge base '{name}' already exists.")
33
+ return
34
+
35
+ self.knowledge_bases[name] = kb
36
+ self.save_knowledge_base(name)
37
+ print(f"Knowledge base '{name}' created.")
38
+
39
+ def delete_knowledge_base(self, name: str):
40
+ if name in self.knowledge_bases:
41
+ del self.knowledge_bases[name]
42
+ os.remove(os.path.join(self.base_path, f"{name}.faiss"))
43
+ print(f"Knowledge base '{name}' deleted.")
44
+ else:
45
+ print(f"Knowledge base '{name}' does not exist.")
46
+
47
+ def load_knowledge_base(self, name: str):
48
+ kb_path = os.path.join(self.base_path, f"{name}.faiss")
49
+ if os.path.exists(kb_path):
50
+ self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
51
+ print(f"Knowledge base '{name}' loaded.")
52
+ else:
53
+ print(f"Knowledge base '{name}' does not exist.")
54
+
55
+ def save_knowledge_base(self, name: str):
56
+ if name in self.knowledge_bases:
57
+ self.knowledge_bases[name].save_local(self.base_path, name)
58
+ print(f"Knowledge base '{name}' saved.")
59
+ else:
60
+ print(f"Knowledge base '{name}' does not exist.")
61
+
62
+ # Document(page_content = '渠道版', metadata = {
63
+ # 'source': './files/input/PS004.pdf',
64
+ # 'page': 0
65
+ # }), Document(page_content = '2/20.', metadata = {
66
+ # 'source': './files/input/PS004.pdf',
67
+ # 'page': 1
68
+ # })
69
+ def add_documents_to_kb(self, name: str, file_paths: List[str]):
70
+ if name not in self.knowledge_bases:
71
+ print(f"Knowledge base '{name}' does not exist.")
72
+ self.create_knowledge_base(name)
73
+
74
+ kb = self.knowledge_bases[name]
75
+ documents = self.load_documents(file_paths)
76
+ print(f"Loaded {len(documents)} documents.")
77
+ print(documents)
78
+ pages = self.split_documents(documents)
79
+ print(f"Split documents into {len(pages)} pages.")
80
+ # print(pages)
81
+
82
+ doc_ids = []
83
+ for i in range(0, len(pages), self.batch_size):
84
+ batch = pages[i:i+self.batch_size]
85
+ doc_ids.extend(kb.add_documents(batch))
86
+
87
+ self.save_knowledge_base(name)
88
+ return doc_ids
89
+
90
+ def load_documents(self, file_paths: List[str]):
91
+ documents = []
92
+ for file_path in file_paths:
93
+ loader = self.get_loader(file_path)
94
+ documents.extend(loader.load())
95
+ return documents
96
+
97
+ def get_loader(self, file_path: str):
98
+ if file_path.endswith('.txt'):
99
+ return TextLoader(file_path)
100
+ elif file_path.endswith('.json'):
101
+ return JSONLoader(file_path)
102
+ elif file_path.endswith('.pdf'):
103
+ return PyPDFLoader(file_path)
104
+ else:
105
+ raise ValueError("Unsupported file format")
106
+
107
+ def split_documents(self, documents):
108
+ text_splitter = RecursiveCharacterTextSplitter(separators=[
109
+ "\n\n",
110
+ "\n",
111
+ " ",
112
+ ".",
113
+ ",",
114
+ "\u200b", # Zero-width space
115
+ "\uff0c", # Fullwidth comma
116
+ "\u3001", # Ideographic comma
117
+ "\uff0e", # Fullwidth full stop
118
+ "\u3002", # Ideographic full stop
119
+ "",
120
+ ],
121
+ chunk_size=512, chunk_overlap=0)
122
+ return text_splitter.split_documents(documents)
123
+
124
+ def retrieve_documents(self, names: List[str], query: str):
125
+ results = []
126
+ for name in names:
127
+ if name not in self.knowledge_bases:
128
+ print(f"Knowledge base '{name}' does not exist.")
129
+ continue
130
+
131
+ retriever = self.knowledge_bases[name].as_retriever(
132
+ search_type="mmr",
133
+ search_kwargs={"score_threshold": 0.5, "k": 3}
134
+ )
135
+ docs = retriever.get_relevant_documents(query)
136
+ results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
137
+
138
+
139
+ return results
140
+
141
+ def get_bases(self):
142
+ data = self.knowledge_bases.keys()
143
+ return list(data)
144
+
145
+ def get_df_bases(self):
146
+ import pandas as pd
147
+ data = self.knowledge_bases.keys()
148
+ return pd.DataFrame(list(data), columns=['列表'])
149
+
150
+ knowledgeBase = KnowledgeBaseManager()