bohmian commited on
Commit
e638c19
1 Parent(s): 399efe8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -47
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
- #from streamlit_chat import message
3
  from streamlit_option_menu import option_menu
4
 
5
  import os
@@ -21,14 +20,17 @@ from langchain.retrievers import EnsembleRetriever # to use chroma and
21
  from langchain.prompts import PromptTemplate
22
  from langchain.chains import LLMChain
23
 
 
 
 
24
  import warnings
25
  warnings.filterwarnings("ignore", category=FutureWarning)
26
  warnings.filterwarnings("ignore", category=DeprecationWarning)
27
 
28
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
29
 
30
- from langchain.callbacks.base import BaseCallbackHandler
31
 
 
32
  # callback is needed to print intermediate steps of agent reasoning in the chatbot
33
  # i.e. when action is taken, when tool is called, when tool call is complete etc.
34
  class MyCallbackHandler(BaseCallbackHandler):
@@ -86,20 +88,15 @@ class MyCallbackHandler(BaseCallbackHandler):
86
 
87
  my_callback_handler = MyCallbackHandler()
88
 
89
- # # Set the webpage title
90
- # st.set_page_config(
91
- # page_title="Your own AI-Chat!",
92
- # layout="wide"
93
- # )
94
-
95
- # llm for HuggingFace Inference API
96
- # model = "mistralai/Mistral-7B-Instruct-v0.2"
97
- model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
98
 
99
- # with st.spinner('Downloading pre-built Chroma and BM25 vector stores'):
100
- # chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings)
 
 
 
 
101
 
102
- # Document config
103
  if 'chunk_size' not in st.session_state:
104
  st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000]
105
 
@@ -116,8 +113,7 @@ if 'countries_to_scrape' not in st.session_state:
116
  # in main app, add configuration for user to scrape new data from DuckDuckGo
117
  # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
118
 
119
-
120
- # Retriever config
121
  if 'chroma_n_similar_documents' not in st.session_state:
122
  st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
123
 
@@ -135,11 +131,16 @@ if 'source_documents' not in st.session_state:
135
 
136
 
137
  # LLM config
 
 
 
 
138
  if 'temperature' not in st.session_state:
139
  st.session_state['temperature'] = 0.25
140
 
141
  if 'max_new_tokens' not in st.session_state:
142
  st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM
 
143
 
144
  # This is the list of countries present in the vector store, since the vector store is previously prepared as they take very long to prepare
145
  # This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store,
@@ -157,21 +158,22 @@ countries = [
157
  "Germany",
158
  ]
159
 
160
- @st.cache_data # only going to get once
161
- def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
 
 
 
162
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
163
  # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
164
- llm = HuggingFaceHub(repo_id=model,
165
  model_kwargs={
166
- 'temperature':temp,
167
- "max_new_tokens":tokens
168
  },
169
  )
170
  return llm
171
 
172
- llm = get_llm(st.session_state['temperature'], tokens = st.session_state['max_new_tokens'])
173
-
174
- @st.cache_data # only going to get once
175
  def get_embeddings():
176
  with st.spinner(f'Getting HuggingFaceEmbeddings'):
177
  # We use HuggingFaceEmbeddings() as it is open source and free to use.
@@ -179,8 +181,13 @@ def get_embeddings():
179
  hf_embeddings = HuggingFaceEmbeddings()
180
  return hf_embeddings
181
 
 
 
182
  hf_embeddings = get_embeddings()
183
 
 
 
 
184
  # Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up,
185
  # to save time when experimenting as the embeddings take a long time to generate.
186
  # The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores,
@@ -213,7 +220,7 @@ with st.spinner(f'Setting up pre-built chroma vector store'):
213
  # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
214
  # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
215
  # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
216
- # A separate retriever was saved for each country.
217
  bm25_retrievers = {} # to store retrievers for different countries
218
  with st.spinner(f'Setting up pre-built bm25 retrievers'):
219
  for country in countries:
@@ -222,11 +229,16 @@ with st.spinner(f'Setting up pre-built bm25 retrievers'):
222
  bm25_retriever = pickle.load(handle)
223
  bm25_retrievers[country] = bm25_retriever
224
 
225
- # Tools for LLM to Use
 
 
 
 
 
226
  # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
227
  # e.g. carbon emissions policy of Singapore.
228
  # By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them.
229
- # This is far better than just retrieving relevant chunks for the user's query and throw everything to a single RetrievalQA chain to process
230
  # Multi input tools are not available, hence we have to prompt the agent to give an input list as a string
231
  # then use ast.literal_eval to convert it back into a list
232
  @tool
@@ -251,11 +263,14 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
251
  then there is no record for the country and no answer can be obtained."""
252
 
253
  # different retrievers
254
- bm = bm25_retrievers[country] # keyword based
 
255
  bm.k = st.session_state['bm25_n_similar_documents']
256
- chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) # semantic
257
- # ensemble (below) reranks results from both retrievers
 
258
  ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
 
259
  retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
260
 
261
  qa = RetrievalQA.from_chain_type(
@@ -265,8 +280,10 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
265
  return_source_documents=True # returned in result['source_documents']
266
  )
267
  result = qa(query)
 
 
268
  st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
269
- st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
270
  return f"{query.capitalize()} for {country}: " + result['result']
271
 
272
  except Exception as e:
@@ -319,10 +336,12 @@ def compare(query:str) -> str:
319
  Give as much elaboration in your answer as possible but they MUST be from the earlier context.
320
  Do not give details that cannot be found in the earlier context."""
321
 
 
322
  retrieve_answer_for_country.callbacks = [my_callback_handler]
323
  compare.callbacks = [my_callback_handler]
324
  generic_chat_llm.callbacks = [my_callback_handler]
325
 
 
326
  agent = initialize_agent(
327
  [retrieve_answer_for_country, compare], # tools
328
  #[retrieve_answer_for_country, generic_chat_llm, compare],
@@ -347,7 +366,7 @@ if "menu" not in st.session_state:
347
  "Source Documents for Last Query",
348
  ]
349
 
350
- # sidebar with menu navigation
351
  with st.sidebar:
352
  st.subheader("DO NOT NAVIGATE between pages when agent is still generating messages in the chat. Wait for query to complete first.")
353
  st.write("")
@@ -356,6 +375,7 @@ with st.sidebar:
356
  st.spinner("test")
357
 
358
 
 
359
  if page == "Chatbot":
360
  st.header("Chat")
361
 
@@ -373,27 +393,19 @@ if page == "Chatbot":
373
  """}
374
  ]
375
 
376
- if "current_response" not in st.session_state:
377
- st.session_state.current_response = ""
378
-
379
- # Loop through each message in the session state and render it as a chat message.
380
  for message in st.session_state.messages:
381
  with st.chat_message(message["role"]):
382
  st.markdown(message["content"])
383
 
384
- # We initialize the quantized LLM from a local path.
385
- # Currently most parameters are fixed but we can make them
386
- # configurable.
387
- #llm_chain = create_chain(retriever)
388
-
389
  # We take questions/instructions from the chat input to pass to the LLM
390
  if user_query := st.chat_input("Your message here", key="user_input"):
391
- # remove source documents option from menu while query is running
392
 
 
393
  st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
394
 
395
- formatted_user_query = f":blue[{user_query}]"
396
  # Add our input to the session state
 
397
  st.session_state.messages.append(
398
  {"role": "user", "content": formatted_user_query}
399
  )
@@ -413,10 +425,6 @@ if page == "Chatbot":
413
  with st.chat_message("assistant"):
414
  st.markdown(action_plan_message)
415
 
416
- # Pass our input to the llm chain and capture the final responses.
417
- # It is worth noting that the Stream Handler is already receiving the
418
- # streaming response as the llm is generating. We get our response
419
- # here once the llm has finished generating the complete response.
420
  results = agent(user_query)
421
  response = f":blue[The answer to your query is:] {results['output']}"
422
 
@@ -430,14 +438,22 @@ if page == "Chatbot":
430
  st.markdown(response)
431
 
432
 
 
 
433
  if page == "Chat Config":
434
  st.header(page)
435
 
436
 
 
 
 
 
 
437
  if page == "Document, Retriever, Web Scraping Config":
438
  st.header(page)
439
 
440
 
 
441
  if page == "Source Documents for Last Query":
442
  st.header(page)
443
  try:
 
1
  import streamlit as st
 
2
  from streamlit_option_menu import option_menu
3
 
4
  import os
 
20
  from langchain.prompts import PromptTemplate
21
  from langchain.chains import LLMChain
22
 
23
+ # for printing intermediate steps of agent (actions, tool calling etc.)
24
+ from langchain.callbacks.base import BaseCallbackHandler
25
+
26
  import warnings
27
  warnings.filterwarnings("ignore", category=FutureWarning)
28
  warnings.filterwarnings("ignore", category=DeprecationWarning)
29
 
30
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
31
 
 
32
 
33
+ ################################ Callback ################################
34
  # callback is needed to print intermediate steps of agent reasoning in the chatbot
35
  # i.e. when action is taken, when tool is called, when tool call is complete etc.
36
  class MyCallbackHandler(BaseCallbackHandler):
 
88
 
89
  my_callback_handler = MyCallbackHandler()
90
 
 
 
 
 
 
 
 
 
 
91
 
92
+ ################################ Configs ################################
93
+ # Set the webpage title
94
+ st.set_page_config(
95
+ page_title="ESG Countries Chatbot",
96
+ # layout="wide"
97
+ )
98
 
99
+ # Document Config
100
  if 'chunk_size' not in st.session_state:
101
  st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000]
102
 
 
113
  # in main app, add configuration for user to scrape new data from DuckDuckGo
114
  # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
115
 
116
+ # Retriever Config
 
117
  if 'chroma_n_similar_documents' not in st.session_state:
118
  st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
119
 
 
131
 
132
 
133
  # LLM config
134
+ # LLM from HuggingFace Inference API
135
+ if 'model' not in st.session_state:
136
+ st.session_state['model'] = "mistralai/Mixtral-8x7B-Instruct-v0.1" # or "mistralai/Mistral-7B-Instruct-v0.2"
137
+
138
  if 'temperature' not in st.session_state:
139
  st.session_state['temperature'] = 0.25
140
 
141
  if 'max_new_tokens' not in st.session_state:
142
  st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM
143
+
144
 
145
  # This is the list of countries present in the vector store, since the vector store is previously prepared as they take very long to prepare
146
  # This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store,
 
158
  "Germany",
159
  ]
160
 
161
+
162
+ ################################ Get LLM and Embeddings ################################
163
+ @st.cache_data # only going to get this once instead of all the time when page refreshers
164
+ # unless LLM config change then we will call the function again
165
+ def get_llm():
166
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
167
  # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
168
+ llm = HuggingFaceHub(repo_id=st.session_state['model'],
169
  model_kwargs={
170
+ 'temperature': st.session_state['temperature'],
171
+ "max_new_tokens": st.session_state['max_new_tokens']
172
  },
173
  )
174
  return llm
175
 
176
+ @st.cache_data # only going to get this once instead of all the time when page refreshers
 
 
177
  def get_embeddings():
178
  with st.spinner(f'Getting HuggingFaceEmbeddings'):
179
  # We use HuggingFaceEmbeddings() as it is open source and free to use.
 
181
  hf_embeddings = HuggingFaceEmbeddings()
182
  return hf_embeddings
183
 
184
+ # call above functions
185
+ llm = get_llm()
186
  hf_embeddings = get_embeddings()
187
 
188
+
189
+ ################################ Download and Initialize Pre-Built Retrievers ################################
190
+
191
  # Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up,
192
  # to save time when experimenting as the embeddings take a long time to generate.
193
  # The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores,
 
220
  # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
221
  # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
222
  # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
223
+ # A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap.
224
  bm25_retrievers = {} # to store retrievers for different countries
225
  with st.spinner(f'Setting up pre-built bm25 retrievers'):
226
  for country in countries:
 
229
  bm25_retriever = pickle.load(handle)
230
  bm25_retrievers[country] = bm25_retriever
231
 
232
+ # One retriever above is semantic based and the other is keyword based
233
+ # Both retrievers will be used
234
+ # Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below
235
+
236
+ ################################ Tools for Agent to Use ################################
237
+
238
  # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
239
  # e.g. carbon emissions policy of Singapore.
240
  # By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them.
241
+ # This is far better than just retrieving relevant chunks for the user's query and throwing everything to a single RetrievalQA chain to process
242
  # Multi input tools are not available, hence we have to prompt the agent to give an input list as a string
243
  # then use ast.literal_eval to convert it back into a list
244
  @tool
 
263
  then there is no record for the country and no answer can be obtained."""
264
 
265
  # different retrievers
266
+ # keyword
267
+ bm = bm25_retrievers[country]
268
  bm.k = st.session_state['bm25_n_similar_documents']
269
+ # semantic
270
+ chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']})
271
+ # ensemble (below) reranks results from both retrievers above
272
  ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
273
+ # for user to make selection
274
  retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
275
 
276
  qa = RetrievalQA.from_chain_type(
 
280
  return_source_documents=True # returned in result['source_documents']
281
  )
282
  result = qa(query)
283
+ # add to source documents session state so it can be loaded later in the other menu
284
+ # all source documents linked to answer any query (or part of it) are visible
285
  st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
286
+ st.session_state['source_documents'].append(result['source_documents'])
287
  return f"{query.capitalize()} for {country}: " + result['result']
288
 
289
  except Exception as e:
 
336
  Give as much elaboration in your answer as possible but they MUST be from the earlier context.
337
  Do not give details that cannot be found in the earlier context."""
338
 
339
+ # equip tools with callbacks
340
  retrieve_answer_for_country.callbacks = [my_callback_handler]
341
  compare.callbacks = [my_callback_handler]
342
  generic_chat_llm.callbacks = [my_callback_handler]
343
 
344
+ # Initialize
345
  agent = initialize_agent(
346
  [retrieve_answer_for_country, compare], # tools
347
  #[retrieve_answer_for_country, generic_chat_llm, compare],
 
366
  "Source Documents for Last Query",
367
  ]
368
 
369
+ ################################ Sidebar with Menu ################################
370
  with st.sidebar:
371
  st.subheader("DO NOT NAVIGATE between pages when agent is still generating messages in the chat. Wait for query to complete first.")
372
  st.write("")
 
375
  st.spinner("test")
376
 
377
 
378
+ ################################ Main Chatbot Page ################################
379
  if page == "Chatbot":
380
  st.header("Chat")
381
 
 
393
  """}
394
  ]
395
 
396
+ # Loop through each message in the session state and render it as a chat message
 
 
 
397
  for message in st.session_state.messages:
398
  with st.chat_message(message["role"]):
399
  st.markdown(message["content"])
400
 
 
 
 
 
 
401
  # We take questions/instructions from the chat input to pass to the LLM
402
  if user_query := st.chat_input("Your message here", key="user_input"):
 
403
 
404
+ # reset source documents list during a new query
405
  st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
406
 
 
407
  # Add our input to the session state
408
+ formatted_user_query = f":blue[{user_query}]"
409
  st.session_state.messages.append(
410
  {"role": "user", "content": formatted_user_query}
411
  )
 
425
  with st.chat_message("assistant"):
426
  st.markdown(action_plan_message)
427
 
 
 
 
 
428
  results = agent(user_query)
429
  response = f":blue[The answer to your query is:] {results['output']}"
430
 
 
438
  st.markdown(response)
439
 
440
 
441
+ ################################ Chat Config Page ################################
442
+ # for changing config like temperature etc.
443
  if page == "Chat Config":
444
  st.header(page)
445
 
446
 
447
+ ################################ Document Page ################################
448
+ # to scrape new documents from DuckDuckGo
449
+ # to chnange paramters like chunk size
450
+ # to upload own PDF
451
+ # to override existing data on new scraped data or new pdf uploaded
452
  if page == "Document, Retriever, Web Scraping Config":
453
  st.header(page)
454
 
455
 
456
+ ################################ Main Chatbot Page ################################
457
  if page == "Source Documents for Last Query":
458
  st.header(page)
459
  try: