Syed Junaid Iqbal commited on
Commit
5103cb0
β€’
1 Parent(s): d04f7ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -143
app.py CHANGED
@@ -2,47 +2,71 @@ import subprocess
2
 
3
  import streamlit as st
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.vectorstores import FAISS
6
  from langchain.embeddings import FastEmbedEmbeddings # General embeddings from HuggingFace models.
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain.callbacks.manager import CallbackManager
 
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
  from htmlTemplates import css, bot_template, user_template
11
- from langchain.llms import LlamaCpp # For loading transformer models.
12
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader
13
  from langchain.chains import RetrievalQA
14
  from langchain.prompts import PromptTemplate
15
  from langchain import hub
 
16
  import os
17
  import glob
18
  import shutil
 
19
 
20
- # os.environ['FAISS_NO_AVX2'] = '1'
21
- # os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
-
23
-
24
- def load_document_text():
25
- """
26
- input : path to the document
27
- output: list of loaded document
28
  """
 
 
 
29
 
30
- documents = []
 
 
 
 
31
 
32
- for dox in os.listdir(path= "./documents/"):
33
- dir = os.path.join("./documents/", dox)
34
 
35
- if dox.endswith(".pdf"):
36
- documents.extend( PyPDFLoader(dir).load() )
 
37
 
38
- elif dox.endswith(".txt"):
39
- documents.extend( TextLoader(dir).load() )
40
 
41
- elif dox.endswith(".csv"):
42
- documents.extend( CSVLoader(dir).load() )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- return documents
45
 
 
46
  def get_text_chunks(documents):
47
  """
48
  For the compute purpose we will split the document into multiple smaller chunks.
@@ -50,51 +74,58 @@ def get_text_chunks(documents):
50
  IMPORTANT : If the chunks too small we will miss the context and if its too large we will have longer compute time
51
  """
52
  text_splitter = RecursiveCharacterTextSplitter(
53
- chunk_size= 400,
54
- chunk_overlap=50,
55
  )
56
 
57
- st.session_state.text_chunks = text_splitter.split_documents(documents)
 
 
58
 
59
 
60
- def get_vectorstore():
 
61
  """
62
- given the chunks, we will embed them into vector stores
63
  """
 
 
 
64
 
65
- if len(glob.glob("./vectordb/*.faiss")) == 0:
66
- st.session_state.vectorstore = FAISS.from_documents(documents= st.session_state.text_chunks,
67
- embedding= st.session_state.embeddings)
68
- # save the file
69
- st.session_state.vectorstore.save_local("./vectordb")
70
- else:
71
- st.session_state.vectorstore = FAISS.load_local("./vectordb/",
72
- st.session_state.embeddings)
73
-
74
 
75
- def get_conversation_chain():
 
76
  """
77
  This is a langchain model where we will be binding the runner to infer data from LLM
78
  """
79
  model_path = st.session_state.model
80
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
81
 
82
- llm = LlamaCpp(model_path= model_path,
83
- n_ctx=4000,
84
- max_tokens= 4000,
85
- n_gpu_layers = 40,
86
- callback_manager = callback_manager,
87
- verbose=True)
88
-
89
- memory = ConversationBufferMemory(
90
- memory_key='chat_history', return_messages=False)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  prompt_template = """You are a personal HR Bot assistant for answering any questions about Companies policies
94
  You are given a question and a set of documents.
95
  If the user's question requires you to provide specific information from the documents, give your answer based only on the examples provided below. DON'T generate an answer that is NOT written in the provided examples.
96
  If you don't find the answer to the user's question with the examples provided to you below, answer that you didn't find the answer in the documentation and propose him to rephrase his query with more details.
97
- Use bullet points if you have to make a list, only if necessary.
98
 
99
  QUESTION: {question}
100
 
@@ -107,11 +138,12 @@ def get_conversation_chain():
107
 
108
  rag_prompt_custom = PromptTemplate.from_template(prompt_template)
109
 
 
110
  prompt = hub.pull("rlm/rag-prompt-mistral")
111
 
112
  conversation_chain = RetrievalQA.from_chain_type(
113
  llm,
114
- retriever= st.session_state.vectorstore.as_retriever(),
115
  chain_type_kwargs={"prompt": prompt},
116
  )
117
  conversation_chain.callback_manager = callback_manager
@@ -119,7 +151,7 @@ def get_conversation_chain():
119
 
120
  return conversation_chain
121
 
122
-
123
  def handle_userinput():
124
 
125
  clear = False
@@ -129,77 +161,64 @@ def handle_userinput():
129
  clear = True
130
  st.session_state.messages = []
131
 
 
132
  if "messages" not in st.session_state:
133
- st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
134
 
135
- for msg in st.session_state.messages:
136
- st.chat_message(msg["role"]).write(msg["content"])
 
137
 
138
- if prompt := st.chat_input():
139
- st.session_state.messages.append({"role": "user", "content": prompt})
140
- st.chat_message("user").write(prompt)
141
- if clear:
142
- st.session_state.conversation.clean()
143
 
144
- msg = st.session_state.conversation.run(prompt)
145
- print(msg)
146
- st.session_state.messages.append({"role": "assistant", "content": msg})
147
- st.chat_message("assistant").write(msg)
148
 
149
 
150
-
151
- # Function to apply rounded edges using CSS
152
- def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
153
- st.markdown(
154
- f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
155
- unsafe_allow_html=True,)
156
- st.image(image_path, use_column_width=True, output_format='auto')
157
-
158
-
159
- # Delete directory content
160
- def delete_file(directory_path):
161
-
162
- # Check if the directory exists
163
- if os.path.exists(directory_path) and len(os.listdir(directory_path)) > 0:
164
- # Iterate over all files in the directory and remove them
165
- for filename in os.listdir(directory_path):
166
- file_path = os.path.join(directory_path, filename)
167
- try:
168
- if os.path.isfile(file_path) or os.path.islink(file_path):
169
- os.unlink(file_path)
170
- elif os.path.isdir(file_path):
171
- shutil.rmtree(file_path)
172
- except Exception as e:
173
- print(f"Error deleting {file_path}: {e}")
174
- else:
175
- print(f"The directory {directory_path} does not exist.")
176
 
177
 
178
- def save_uploaded_file(uploaded_file):
179
- save_directory = "./documents/"
180
- file_path = os.path.join(save_directory, uploaded_file.name)
181
- with open(file_path, "wb") as f:
182
- f.write(uploaded_file.getvalue())
183
- return file_path
184
 
 
 
 
 
 
 
 
185
 
186
- def load_dependencies():
187
- # append documents to a list
188
- doc_list = load_document_text()
189
 
190
- # get the text chunks
191
- get_text_chunks(doc_list)
 
 
 
192
 
193
- # create vector store
194
- get_vectorstore()
195
 
196
- # create conversation chain
197
- st.session_state.conversation = get_conversation_chain()
 
 
 
 
198
 
199
 
200
  def main():
201
-
202
- st.set_page_config(page_title="Randstad Chad Bot",
203
  page_icon=":books:")
204
  st.write(css, unsafe_allow_html=True)
205
 
@@ -213,71 +232,65 @@ def main():
213
  st.subheader("πŸš€ A HR powered by Generative AI")
214
 
215
  # default model
216
- st.session_state.model = "./models/mistral-7b-instruct-v0.2.Q5_K_M.gguf"
 
217
 
 
218
 
219
- # Embedding Model
220
- st.session_state.embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-base-en-v1.5",
221
- cache_dir="./embedding_model/")
 
 
222
 
 
223
  with st.sidebar:
224
 
225
  # calling a
226
  add_rounded_edges()
227
 
228
  st.subheader("Select Your Embedding Model Model")
229
- st.session_state.model = st.selectbox( 'Models', tuple( glob.glob('./models/*.gguf') ) )
 
 
230
 
231
 
232
  st.subheader("Your documents")
233
-
234
- # Space to Upload a Document
235
  docs = st.file_uploader(
236
  "Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True)
237
-
238
- # Define a process button
239
  if st.button("Process"):
240
-
241
- # delete the old embeddings
242
- delete_file(directory_path= './vectordb/')
243
-
244
- # delete old documents
245
- delete_file(directory_path="./documents/")
246
-
247
- # then Embedd new documents
248
  with st.spinner("Processing"):
249
-
250
-
251
- # iterate over updated files and save them to the local directory (i.e. "Documents") using a helper function
252
- for file in docs:
253
- save_uploaded_file(file)
254
-
255
 
256
  # using the helper function below lets load our dependencies
 
257
  # Step 1 : Load the documents
258
- # Step 2 : Break them into Chunks
259
- # Step 3 : Create Embeddings and save them to Vector DB
260
- # Step 4 : Get our conversation chain
 
 
 
 
 
 
 
 
261
 
262
- load_dependencies()
263
 
264
- # Load our model
265
- if len(glob.glob("./vectordb/*.faiss")) == 0:
266
- load_dependencies()
267
- get_vectorstore()
268
- else:
269
- get_vectorstore()
270
- st.session_state.conversation = get_conversation_chain()
271
 
272
- handle_userinput()
273
-
274
- # # load dependencies -> chaunks of documents -> Embeddings -> Inference
275
- # load_dependencies()
276
 
 
 
277
 
278
 
279
  if __name__ == '__main__':
280
-
281
  command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir'
282
 
283
  # Run the command using subprocess
@@ -286,6 +299,4 @@ if __name__ == '__main__':
286
  print("Command executed successfully.")
287
  except subprocess.CalledProcessError as e:
288
  print(f"Error: {e}")
289
-
290
- # Run the apps
291
  main()
 
2
 
3
  import streamlit as st
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import Chroma, FAISS
6
  from langchain.embeddings import FastEmbedEmbeddings # General embeddings from HuggingFace models.
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain.callbacks.manager import CallbackManager
9
+ from langchain.callbacks import StreamlitCallbackHandler
10
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
11
  from htmlTemplates import css, bot_template, user_template
12
+ from langchain.llms import LlamaCpp, OpenAI, GooglePalm # For loading transformer models.
13
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader
14
  from langchain.chains import RetrievalQA
15
  from langchain.prompts import PromptTemplate
16
  from langchain import hub
17
+ import tempfile
18
  import os
19
  import glob
20
  import shutil
21
+ import time
22
 
23
+ # TEXT LOADERS
24
+ def get_pdf_text(pdf_docs):
 
 
 
 
 
 
25
  """
26
+ Purpose: A hypothetical loader for PDF files in Python.
27
+ Usage: Used to extract text or other information from PDF documents.
28
+ Load Function: A load_pdf function might be used to read and extract data from a PDF file.
29
 
30
+ input : pdf document path
31
+ returns : extracted text
32
+ """
33
+ temp_dir = tempfile.TemporaryDirectory()
34
+ temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
35
 
36
+ with open(temp_filepath, "wb") as f:
37
+ f.write(pdf_docs.getvalue())
38
 
39
+ pdf_loader = PyPDFLoader(temp_filepath)
40
+ pdf_doc = pdf_loader.load()
41
+ return pdf_doc
42
 
 
 
43
 
44
+ def get_text_file(text_docs):
45
+ """
46
+ """
47
+ temp_dir = tempfile.TemporaryDirectory()
48
+ temp_filepath = os.path.join(temp_dir.name, text_docs.name)
49
+
50
+ with open(temp_filepath, "wb") as f:
51
+ f.write(text_docs.getvalue())
52
+
53
+ text_loader = TextLoader(temp_filepath)
54
+ text_doc = text_loader.load()
55
+ return text_doc
56
+
57
+ def get_csv_file(csv_docs):
58
+ temp_dir = tempfile.TemporaryDirectory()
59
+ temp_filepath = os.path.join(temp_dir.name, csv_docs.name)
60
+
61
+ with open(temp_filepath, "wb") as f:
62
+ f.write(csv_docs.getvalue())
63
+
64
+ csv_loader = CSVLoader(temp_filepath)
65
+ csv_doc = csv_loader.load()
66
+ return csv_doc
67
 
 
68
 
69
+ # Break the documents into chunks
70
  def get_text_chunks(documents):
71
  """
72
  For the compute purpose we will split the document into multiple smaller chunks.
 
74
  IMPORTANT : If the chunks too small we will miss the context and if its too large we will have longer compute time
75
  """
76
  text_splitter = RecursiveCharacterTextSplitter(
77
+ chunk_size= 1000,
78
+ chunk_overlap=200,
79
  )
80
 
81
+ text_chunks = text_splitter.split_documents(documents)
82
+
83
+ return text_chunks
84
 
85
 
86
+ # Save chunks to vector store
87
+ def get_vectorstore(text_chunks):
88
  """
89
+ Load our vectors into chroma DB, Googles Vector Store
90
  """
91
+ vectorstore = Chroma.from_documents(documents= text_chunks,
92
+ embedding= st.session_state.embeddings,
93
+ persist_directory= "./vectordb/")
94
 
95
+ return vectorstore
96
+
 
 
 
 
 
 
 
97
 
98
+ # Bind the Vector DB, Large Language models and Embedding models all into one container
99
+ def get_conversation_chain(vectorstore):
100
  """
101
  This is a langchain model where we will be binding the runner to infer data from LLM
102
  """
103
  model_path = st.session_state.model
104
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
105
 
106
+ if st.session_state.model == "Google_PaLm" :
107
+ llm = GooglePalm(google_api_key = "AIzaSyAIo7rw6iJPWpkpOXSJk6BnuOKNSaB5muM",
108
+ max_output_tokens = 4000,
109
+ callback_manager=callback_manager)
 
 
 
 
 
110
 
111
+ elif st.session_state.model == "Open_AIGPT-3.5-Turbo":
112
+ llm = OpenAI(api_key = "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",
113
+ callback_manager = callback_manager,
114
+ max_tokens= 4000 )
115
+
116
+ else:
117
+ llm = LlamaCpp(model_path= model_path,
118
+ n_ctx= 4000,
119
+ max_tokens= 4000,
120
+ f16_kv = True,
121
+ callback_manager = callback_manager,
122
+ verbose=True)
123
 
124
  prompt_template = """You are a personal HR Bot assistant for answering any questions about Companies policies
125
  You are given a question and a set of documents.
126
  If the user's question requires you to provide specific information from the documents, give your answer based only on the examples provided below. DON'T generate an answer that is NOT written in the provided examples.
127
  If you don't find the answer to the user's question with the examples provided to you below, answer that you didn't find the answer in the documentation and propose him to rephrase his query with more details.
128
+ Use bullet points if you have to make a list, only if necessary. Use 'DOCUMENTS' as a reference point, to understand and give a consciese output in 3 or 5 sentences.
129
 
130
  QUESTION: {question}
131
 
 
138
 
139
  rag_prompt_custom = PromptTemplate.from_template(prompt_template)
140
 
141
+ # prompt = hub.pull("rlm/rag-prompt")
142
  prompt = hub.pull("rlm/rag-prompt-mistral")
143
 
144
  conversation_chain = RetrievalQA.from_chain_type(
145
  llm,
146
+ retriever= vectorstore.as_retriever(),
147
  chain_type_kwargs={"prompt": prompt},
148
  )
149
  conversation_chain.callback_manager = callback_manager
 
151
 
152
  return conversation_chain
153
 
154
+ # an stream lit interface to handle and save our chats
155
  def handle_userinput():
156
 
157
  clear = False
 
161
  clear = True
162
  st.session_state.messages = []
163
 
164
+ # initialise our stream lit chat interface
165
  if "messages" not in st.session_state:
166
+ st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
167
 
168
+ for message in st.session_state.messages:
169
+ with st.chat_message(message["role"]):
170
+ st.markdown(message["content"])
171
 
172
+ # Clear the cash memory
173
+ if clear:
174
+ st.session_state.conversation.memory.clear()
175
+ clear = False
 
176
 
177
+ if prompt := st.chat_input():
 
 
 
178
 
179
 
180
+ with st.chat_message("user"):
181
+ st.markdown(prompt)
182
+
183
+ # add user question to chat history
184
+ st.session_state.messages.append( {"role": "user", "content": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
+ with st.chat_message("assistant"):
188
+ # set up a call back handler
189
+ st_callback = StreamlitCallbackHandler(st.container())
190
+ message_holder = st.empty()
191
+ full_response = ""
 
192
 
193
+ # streamlit call back manager
194
+ st.session_state.conversation.callback_manager = st_callback
195
+ msg = st.session_state.conversation.run(prompt)
196
+ #st.markdown(msg)
197
+ for chunk in msg.split():
198
+ full_response += chunk + " "s
199
+ time.sleep(0.09)
200
 
201
+ # add a blinking cursor to simulate typing
202
+ message_holder.markdown(full_response + "✏️ ")
 
203
 
204
+ # Display the responce
205
+ message_holder.info(full_response)
206
+
207
+ # add responce to session state
208
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
209
 
 
 
210
 
211
+ # Function to apply rounded edges using CSS
212
+ def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
213
+ st.markdown(
214
+ f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',s
215
+ unsafe_allow_html=True,)
216
+ st.image(image_path, use_column_width=True, output_format='auto')
217
 
218
 
219
  def main():
220
+
221
+ st.set_page_config(page_title="RANDSTAD",
222
  page_icon=":books:")
223
  st.write(css, unsafe_allow_html=True)
224
 
 
232
  st.subheader("πŸš€ A HR powered by Generative AI")
233
 
234
  # default model
235
+ st.session_state.model = "Google_PaLm"
236
+ # user_question = st.text_input("Ask a question about your documents:")
237
 
238
+ st.session_state.embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-base-en-v1.5", cache_dir="./embedding_model/")
239
 
240
+ if len(glob.glob("./vectordb/*.sqlite3")) > 0 :
241
+
242
+ vectorstore = Chroma(persist_directory="./vectordb/", embedding_function=st.session_state.embeddings)
243
+ st.session_state.conversation = get_conversation_chain(vectorstore)
244
+ handle_userinput()
245
 
246
+ # side bar information
247
  with st.sidebar:
248
 
249
  # calling a
250
  add_rounded_edges()
251
 
252
  st.subheader("Select Your Embedding Model Model")
253
+ LLM = list( glob.glob('./models/*.gguf') )
254
+ LLM.extend(["Open_AIGPT-3.5-Turbo", "Google_PaLm"])
255
+ st.session_state.model = st.selectbox( 'Models', LLM )
256
 
257
 
258
  st.subheader("Your documents")
 
 
259
  docs = st.file_uploader(
260
  "Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True)
261
+
 
262
  if st.button("Process"):
 
 
 
 
 
 
 
 
263
  with st.spinner("Processing"):
264
+ # get pdf text
265
+ doc_list = []
 
 
 
 
266
 
267
  # using the helper function below lets load our dependencies
268
+
269
  # Step 1 : Load the documents
270
+ for file in docs:
271
+ print('file - type : ', file.type)
272
+ if file.type == 'text/plain':
273
+ # file is .txt
274
+ doc_list.extend(get_text_file(file))
275
+ elif file.type in ['application/octet-stream', 'application/pdf']:
276
+ # file is .pdf
277
+ doc_list.extend(get_pdf_text(file))
278
+ elif file.type == 'text/csv':
279
+ # file is .csv
280
+ doc_list.extend(get_csv_file(file))
281
 
 
282
 
283
+ # Step 2 : Break them into Chunks
284
+ text_chunks = get_text_chunks(doc_list)
 
 
 
 
 
285
 
286
+ # Step 3 : Create Embeddings and save them to Vector DB
287
+ vectorstore = get_vectorstore(text_chunks)
 
 
288
 
289
+ # Step 4 : Get our conversation chain
290
+ st.session_state.conversation = get_conversation_chain(vectorstore)
291
 
292
 
293
  if __name__ == '__main__':
 
294
  command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir'
295
 
296
  # Run the command using subprocess
 
299
  print("Command executed successfully.")
300
  except subprocess.CalledProcessError as e:
301
  print(f"Error: {e}")
 
 
302
  main()