Devie commited on
Commit
ef48816
1 Parent(s): 285e249

Add application file

Browse files
Files changed (1) hide show
  1. app.py +318 -0
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import gradio as gr
2
+ #import cv2
3
+ #def to_black(image):
4
+ # output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
5
+ # return output
6
+ #interface = gr.Interface(fn=to_black, inputs="image", outputs="image")
7
+ #print('here')
8
+ #interface.launch()
9
+
10
+ #print(share_url)
11
+ #print(local_url)
12
+ #print(app)
13
+ #interface.launch(inbrowser =True, share=True, port=8888)
14
+ #url = interface.share()
15
+ #print(url)
16
+ from langchain.chains import RetrievalQA
17
+ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader
18
+ from langchain.document_loaders import CSVLoader
19
+ from langchain.document_loaders import TextLoader
20
+ from langchain.vectorstores import DocArrayInMemorySearch
21
+ from langchain.indexes import VectorstoreIndexCreator
22
+ from langchain.prompts import PromptTemplate
23
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
24
+ from langchain import HuggingFacePipeline
25
+ import torch
26
+ from langchain.vectorstores import FAISS
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain.chains.base import Chain
29
+ from langchain.chains import ConversationalRetrievalChain
30
+ from langchain.chains.summarize import load_summarize_chain
31
+ import gradio as gr
32
+ from typing import List
33
+ from tqdm import tqdm
34
+ import logging
35
+ import argparse
36
+ import os
37
+ import string
38
+
39
+ CHUNK_SIZE=600
40
+ CHUNK_OVERLAP = 100
41
+ SEARCH_TOP_K = 5
42
+ logger = logging.getLogger("bio_LLM_logger")
43
+
44
+ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
45
+ """返回两个列表,第一个列表为 filepath 下全部文件的完整路径, 第二个为对应的文件名"""
46
+ if ignore_dir_names is None:
47
+ ignore_dir_names = []
48
+ if ignore_file_names is None:
49
+ ignore_file_names = []
50
+ ret_list = []
51
+ if isinstance(filepath, str):
52
+ if not os.path.exists(filepath):
53
+ print("路径不存在")
54
+ return None, None
55
+ elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names:
56
+ return [filepath], [os.path.basename(filepath)]
57
+ elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names:
58
+ for file in os.listdir(filepath):
59
+ fullfilepath = os.path.join(filepath, file)
60
+ if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names:
61
+ ret_list.append(fullfilepath)
62
+ if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names:
63
+ ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0])
64
+ return ret_list, [os.path.basename(p) for p in ret_list]
65
+
66
+
67
+ def load_file(file_path, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP):
68
+ if file_path.lower().endswith(".pdf"):
69
+ loader = UnstructuredFileLoader(file_path, mode="elements")
70
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
71
+ docs = loader.load_and_split(text_splitter=text_splitter)
72
+ elif file_path.lower().endswith(".txt"):
73
+ loader = TextLoader(file_path, autodetect_encoding=True)
74
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
75
+ docs = loader.load_and_split(text_splitter=text_splitter)
76
+ elif file_path.lower().endswith(".csv"):
77
+ loader = CSVLoader(file_path)
78
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap= chunk_overlap)
79
+ docs = loader.load_and_split(text_splitter=text_splitter)
80
+ else:
81
+ print("unsupported the file format")
82
+
83
+ return docs
84
+
85
+ #class summary_chain:
86
+ # def init_cfg(self,
87
+ # llm_model: Chain,
88
+
89
+ def summary(model, chain_type, PROMPT, REFINE_PROMPT,docs):
90
+ if chain_type == "stuff":
91
+ chain = load_summarize_chain(model, chain_type="stuff", prompt=PROMPT)
92
+ elif chain_type == "refine":
93
+ chain = load_summarize_chain(model, chain_type="refine", question_prompt=PROMPT, refine_prompt=REFINE_PROMPT)
94
+ print(chain.run(docs))
95
+
96
+
97
+ class QA_Localdb:
98
+ llm_model_chain: Chain = None
99
+ embeddings: object = None
100
+ top_k: int = SEARCH_TOP_K
101
+ chunk_size: int = CHUNK_SIZE
102
+
103
+ def init_cfg(self,
104
+ llm_model: Chain,
105
+ embedding_model: str,
106
+ #embedding_device: str,
107
+ top_k = SEARCH_TOP_K,
108
+ ):
109
+ self.llm_model_chain = llm_model
110
+ self.embeddings = HuggingFaceEmbeddings(model_name = embedding_model)
111
+ self.top_k = top_k
112
+
113
+ def init_knowledge_vector_store(self,
114
+ file_path: str or List[str],
115
+ vectorstore_path: str or os.PathLike = None,
116
+ ):
117
+ loaded_files = []
118
+ failed_files = []
119
+ if isinstance(file_path, str):
120
+ if not os.path.exists(file_path):
121
+ print("unknown path")
122
+ return None
123
+ elif os.path.isfile(file_path):
124
+ file = os.path.split(file_path)[-1]
125
+ try:
126
+ docs = load_file(file_path)
127
+ logger.info(f"{file} sucessful loaded")
128
+ loaded_files.append(file_path)
129
+ except Exception as e:
130
+ logger.error(e)
131
+ logger.info(f"{file} unsucessful loaded")
132
+ return None
133
+
134
+ elif os.path.isdir(file_path):
135
+ docs=[]
136
+ for fullfilepath, file in tqdm(zip(*tree(file_path, ignore_dir_names=['tmp_files'])), desc="load file"):
137
+ try:
138
+ docs += load_file(fullfilepath)
139
+ loaded_files.append(fullfilepath)
140
+ except Exception as e:
141
+ logger.error(e)
142
+ failed_files.append(file)
143
+
144
+ if len(failed_files) > 0:
145
+ logger.info('unloaded files are as follows')
146
+ for file in failed_files:
147
+ logger.info(f"{file}\n")
148
+ else:
149
+ docs = []
150
+ for file in file_path:
151
+ try:
152
+ docs += load_file(file)
153
+ logger.info(f"{file} sucessful loaded")
154
+ loaded_files.append(file)
155
+ except Exception as e:
156
+ logger.error(e)
157
+ logger.info(f"{file} unsucessful loaded")
158
+ if len(docs) > 0:
159
+ logger.info("sucessful loaded, generating vector store")
160
+ if vectorstore_path and os.path.isdir(vectorstore_path) and "index.faiss" in os.listdir(vectorstore_path):
161
+ print("temp")
162
+
163
+ # vector_store = load_vector_store(vectorstore_path, self.embeddings)
164
+ # vector_store.add_documents(docs)
165
+ # torch_gc()
166
+ else:
167
+ if not vectorstore_path:
168
+ vectorstore_path = ""
169
+ vector_store = FAISS.from_documents(docs, self.embeddings)
170
+ #vector_store.save_local(vectorstore_path)
171
+ return vector_store, loaded_files
172
+ else:
173
+ logger.info("file load failed")
174
+
175
+
176
+
177
+ '''
178
+ def delete_file_from_vector_store(self,
179
+ filepath: str or List[str],
180
+ vs_path):
181
+ vector_store = load_vector_store(vs_path, self.embeddings)
182
+ status = vector_store.delete_doc(filepath)
183
+ return status
184
+
185
+ def update_file_from_vector_store(self,
186
+ filepath: str or List[str],
187
+ vs_path,
188
+ docs: List[Document], ):
189
+ vector_store = load_vector_store(vs_path, self.embeddings)
190
+ status = vector_store.update_doc(filepath, docs)
191
+ return status
192
+
193
+ def list_file_from_vector_store(self,
194
+ vs_path,
195
+ fullpath=False):
196
+ vector_store = load_vector_store(vs_path, self.embeddings)
197
+ docs = vector_store.list_docs()
198
+ if fullpath:
199
+ return docs
200
+ else:
201
+ return [os.path.split(doc)[-1] for doc in docs]
202
+ '''
203
+ def QA_model():
204
+ # file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/OPUS-DSD.pdf"
205
+ file_path = "OPUS-BioLLM-v1/data/test/Interageting-Prior-into-DA.pdf"
206
+ # file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/Interageting-Prior-into-DA.pdf"
207
+ # file_path = "/mnt/petrelfs/lvying/LLM/BoMA/data/test/"
208
+
209
+ model_path = "/mnt/petrelfs/lvying/LLM/BoMA/models/LLM/Llama-2-13b-chat-hf"
210
+ embedding_path = "/mnt/petrelfs/lvying/LLM/BoMA/text2vec/instructor-xl/"
211
+
212
+ model = HuggingFacePipeline.from_model_id(model_id="daryl149/llama-2-7b-chat-hf",
213
+ task="text-generation",
214
+ model_kwargs={
215
+ "torch_dtype" : torch.float16,
216
+ "low_cpu_mem_usage" : True,
217
+ "temperature": 0.2,
218
+ "max_length": 2048,
219
+ #"device_map": "auto",
220
+ "repetition_penalty":1.1}
221
+ )
222
+ print(model.model_id)
223
+ QA = QA_Localdb()
224
+ QA.init_cfg(llm_model=model, embedding_model = "sentence-transformers/paraphrase-MiniLM-L6-v2")
225
+
226
+ vector_store, _ =QA.init_knowledge_vector_store(file_path)
227
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
228
+
229
+ print("loading LLM...")
230
+ prompt_template = ("Below is an instruction that describes a task. "
231
+ "Write a response that appropriately completes the request.\n\n"
232
+ "### Instruction:\n{context}\n{question}\n\n### Response: ")
233
+
234
+ PROMPT = PromptTemplate(
235
+ template=prompt_template, input_variables=["context", "question"]
236
+ )
237
+
238
+ chain_type_kwargs = {"prompt": PROMPT}
239
+
240
+ #print(chain_type_kwargs)
241
+ '''
242
+ qa_stuff = RetrievalQA.from_chain_type(
243
+ llm = model,
244
+ chain_type="stuff",
245
+ retriever = retriever,
246
+ chain_type_kwargs = chain_type_kwargs,
247
+ # verbose = True
248
+ )
249
+ while True:
250
+ print("Input Qusetion:")
251
+ query = input()
252
+ if len(query.strip())==0:
253
+ break
254
+ print(qa_stuff.run(query))
255
+
256
+ '''
257
+ '''
258
+ qa = ConversationalRetrievalChain.from_llm(
259
+ llm = QA.llm_model_chain,
260
+ chain_type="stuff",
261
+ retriever = retriever,
262
+ combine_docs_chain_kwargs = chain_type_kwargs,
263
+ # verbose = True
264
+ )
265
+ '''
266
+ qa = RetrievalQA.from_chain_type(
267
+ llm = QA.llm_model_chain,
268
+ chain_type="stuff",
269
+ retriever = retriever,
270
+ chain_type_kwargs = chain_type_kwargs,
271
+ # verbose = True
272
+ )
273
+ return qa
274
+ qa_temp = QA_model()
275
+
276
+ def temp(query):
277
+ return qa_temp.run(query)
278
+
279
+
280
+
281
+ def answer_question(query):
282
+ print(query)
283
+ chat_history = []
284
+ threshold_history = 10 # Remembered historical conversations
285
+ i = 0
286
+ if i>threshold_history:
287
+ chat_history = []
288
+ print("Send a Message:")
289
+ #query = context
290
+ #if len(query.strip())==0:
291
+ # break
292
+ result = qa_temp({"question":query, "chat_history": chat_history})
293
+ print(type(result["answer"]))
294
+ chat_history.append((query, result["answer"]))
295
+ i = i + 1
296
+ resp = result["answer"]
297
+ return str(resp)
298
+
299
+
300
+
301
+
302
+ iface = gr.Interface(
303
+ fn = temp,
304
+ inputs="text",
305
+ outputs="text",)
306
+ #title="问答界面",
307
+ #description="输入问题和相关文本,得到问题的答案。",
308
+ #article="这里是相关的文本。可以输入一些段落或者问题的背景。",
309
+ #examples=[
310
+ # ["Gradio是什么?", "Gradio是一个用于构建和部署机器学习模型的开源库。"],
311
+ # ["Python的创始人是谁?", "Python的创始人是Guido van Rossum。"]
312
+ #])
313
+ #print(iface.launch(share=True))
314
+
315
+ #print("======Finish======")
316
+ #share_url = iface.share()
317
+ #print(share_url)
318
+ iface.launch()