captain-awesome commited on
Commit
f18103b
1 Parent(s): d9b4100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -1
app.py CHANGED
@@ -106,8 +106,17 @@ def create_vector_database(loaded_documents):
106
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len)
107
  chunked_documents = text_splitter.split_documents(loaded_documents)
108
 
 
 
 
 
 
 
 
109
  embeddings = HuggingFaceBgeEmbeddings(
110
- model_name = "BAAI/bge-large-en"
 
 
111
  )
112
 
113
  persist_directory = 'db'
@@ -122,3 +131,176 @@ def create_vector_database(loaded_documents):
122
  # db = Chroma(persist_directory=persist_directory,
123
  # embedding_function=embedding)
124
  return db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=30, length_function = len)
107
  chunked_documents = text_splitter.split_documents(loaded_documents)
108
 
109
+ # embeddings = HuggingFaceBgeEmbeddings(
110
+ # model_name = "BAAI/bge-large-en"
111
+ # )
112
+
113
+ model_name = "BAAI/bge-large-en"
114
+ model_kwargs = {'device': 'cpu'}
115
+ encode_kwargs = {'normalize_embeddings': False}
116
  embeddings = HuggingFaceBgeEmbeddings(
117
+ model_name=model_name,
118
+ model_kwargs=model_kwargs,
119
+ encode_kwargs=encode_kwargs
120
  )
121
 
122
  persist_directory = 'db'
 
131
  # db = Chroma(persist_directory=persist_directory,
132
  # embedding_function=embedding)
133
  return db
134
+
135
+ def set_custom_prompt():
136
+ """
137
+ Prompt template for retrieval for each vectorstore
138
+ """
139
+ prompt_template = """Use the following pieces of information to answer the user's question.
140
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
141
+ Context: {context}
142
+ Question: {question}
143
+ Only return the helpful answer below and nothing else.
144
+ Helpful answer:
145
+ """
146
+
147
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
148
+ return prompt
149
+
150
+ def create_chain(llm, prompt, db):
151
+ """
152
+ Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database.
153
+ This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations,
154
+ and returns this chain. The retriever is set up to return the top 3 results (k=3).
155
+ Args:
156
+ llm (any): The language model to be used in the RetrievalQA.
157
+ prompt (str): The prompt to be used in the chain type.
158
+ db (any): The database to be used as the
159
+ retriever.
160
+ Returns:
161
+ ConversationalRetrievalChain: The initialized conversational chain.
162
+ """
163
+ memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', output_key='answer')
164
+ # chain = ConversationalRetrievalChain.from_llm(
165
+ # llm=llm,
166
+ # chain_type="stuff",
167
+ # retriever=db.as_retriever(search_kwargs={"k": 3}),
168
+ # return_source_documents=True,
169
+ # max_tokens_limit=256,
170
+ # combine_docs_chain_kwargs={"prompt": prompt},
171
+ # condense_question_prompt=CONDENSE_QUESTION_PROMPT,
172
+ # memory=memory,
173
+ # )
174
+ chain = RetrievalQA.from_chain_type(llm=llm,
175
+ chain_type='stuff',
176
+ retriever=db.as_retriever(search_kwargs={'k': 3}),
177
+ return_source_documents=True,
178
+ chain_type_kwargs={'prompt': prompt}
179
+ )
180
+ return chain
181
+
182
+ def create_retrieval_qa_bot(loaded_documents):
183
+ # if not os.path.exists(persist_dir):
184
+ # raise FileNotFoundError(f"No directory found at {persist_dir}")
185
+
186
+ try:
187
+ llm = load_model() # Assuming this function exists and works as expected
188
+ except Exception as e:
189
+ raise Exception(f"Failed to load model: {str(e)}")
190
+
191
+ try:
192
+ prompt = set_custom_prompt() # Assuming this function exists and works as expected
193
+ except Exception as e:
194
+ raise Exception(f"Failed to get prompt: {str(e)}")
195
+
196
+ # try:
197
+ # CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense() # Assuming this function exists and works as expected
198
+ # except Exception as e:
199
+ # raise Exception(f"Failed to get condense prompt: {str(e)}")
200
+
201
+ try:
202
+ db = create_vector_database(loaded_documents) # Assuming this function exists and works as expected
203
+ except Exception as e:
204
+ raise Exception(f"Failed to get database: {str(e)}")
205
+
206
+ try:
207
+ # qa = create_chain(
208
+ # llm=llm, prompt=prompt,CONDENSE_QUESTION_PROMPT=CONDENSE_QUESTION_PROMPT, db=db
209
+ # ) # Assuming this function exists and works as expected
210
+ qa = create_chain(
211
+ llm=llm, prompt=prompt, db=db
212
+ ) # Assuming this function exists and works as expected
213
+ except Exception as e:
214
+ raise Exception(f"Failed to create retrieval QA chain: {str(e)}")
215
+
216
+ return qa
217
+
218
+ def retrieve_bot_answer(query, loaded_documents):
219
+ """
220
+ Retrieves the answer to a given query using a QA bot.
221
+ This function creates an instance of a QA bot, passes the query to it,
222
+ and returns the bot's response.
223
+ Args:
224
+ query (str): The question to be answered by the QA bot.
225
+ Returns:
226
+ dict: The QA bot's response, typically a dictionary with response details.
227
+ """
228
+ qa_bot_instance = create_retrieval_qa_bot(loaded_documents)
229
+ # bot_response = qa_bot_instance({"question": query})
230
+ bot_response = qa_bot_instance({"query": query})
231
+ # Check if the 'answer' key exists in the bot_response dictionary
232
+ # if 'answer' in bot_response:
233
+ # # answer = bot_response['answer']
234
+ # return bot_response
235
+ # else:
236
+ # raise KeyError("Expected 'answer' key in bot_response, but it was not found.")
237
+ # result = bot_response['answer']
238
+ result = bot_response['result']
239
+ sources = []
240
+ for source in bot_response["source_documents"]:
241
+ sources.append(source.metadata['source'])
242
+ return result, sources
243
+
244
+ def main():
245
+
246
+ st.title("Docuverse")
247
+
248
+ # Upload files
249
+ uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True)
250
+ loaded_documents = []
251
+
252
+ if uploaded_files:
253
+ # Create a temporary directory
254
+ with tempfile.TemporaryDirectory() as td:
255
+ # Move the uploaded files to the temporary directory and process them
256
+ for uploaded_file in uploaded_files:
257
+ st.write(f"Uploaded: {uploaded_file.name}")
258
+ ext = os.path.splitext(uploaded_file.name)[-1][1:].lower()
259
+ st.write(f"Uploaded: {ext}")
260
+
261
+ # Check if the extension is in FILE_LOADER_MAPPING
262
+ if ext in FILE_LOADER_MAPPING:
263
+ loader_class, loader_args = FILE_LOADER_MAPPING[ext]
264
+ # st.write(f"loader_class: {loader_class}")
265
+
266
+ # Save the uploaded file to the temporary directory
267
+ file_path = os.path.join(td, uploaded_file.name)
268
+ with open(file_path, 'wb') as temp_file:
269
+ temp_file.write(uploaded_file.read())
270
+
271
+ # Use Langchain loader to process the file
272
+ loader = loader_class(file_path, **loader_args)
273
+ loaded_documents.extend(loader.load())
274
+ else:
275
+ st.warning(f"Unsupported file extension: {ext}")
276
+
277
+ # st.write(f"loaded_documents: {loaded_documents}")
278
+ st.write("Chat with the Document:")
279
+ query = st.text_input("Ask a question:")
280
+
281
+ if st.button("Get Answer"):
282
+ if query:
283
+ # Load model, set prompts, create vector database, and retrieve answer
284
+ try:
285
+ start = timeit.default_timer()
286
+ llm = load_model()
287
+ prompt = set_custom_prompt()
288
+ CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense()
289
+ db = create_vector_database(loaded_documents)
290
+ # st.write(f"db: {db}")
291
+ result, sources = retrieve_bot_answer(query,loaded_documents)
292
+ end = timeit.default_timer()
293
+ st.write("Elapsed time:")
294
+ st.write(end - start)
295
+ # st.write(f"response: {response}")
296
+ # Display bot response
297
+ st.write("Bot Response:")
298
+ st.write(result)
299
+ st.write(sources)
300
+ except Exception as e:
301
+ st.error(f"An error occurred: {str(e)}")
302
+ else:
303
+ st.warning("Please enter a question.")
304
+
305
+ if __name__ == "__main__":
306
+ main()