Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import streamlit as st
|
2 |
-
#from streamlit_chat import message
|
3 |
from streamlit_option_menu import option_menu
|
4 |
|
5 |
import os
|
@@ -21,14 +20,17 @@ from langchain.retrievers import EnsembleRetriever # to use chroma and
|
|
21 |
from langchain.prompts import PromptTemplate
|
22 |
from langchain.chains import LLMChain
|
23 |
|
|
|
|
|
|
|
24 |
import warnings
|
25 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
26 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
27 |
|
28 |
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
|
29 |
|
30 |
-
from langchain.callbacks.base import BaseCallbackHandler
|
31 |
|
|
|
32 |
# callback is needed to print intermediate steps of agent reasoning in the chatbot
|
33 |
# i.e. when action is taken, when tool is called, when tool call is complete etc.
|
34 |
class MyCallbackHandler(BaseCallbackHandler):
|
@@ -86,20 +88,15 @@ class MyCallbackHandler(BaseCallbackHandler):
|
|
86 |
|
87 |
my_callback_handler = MyCallbackHandler()
|
88 |
|
89 |
-
# # Set the webpage title
|
90 |
-
# st.set_page_config(
|
91 |
-
# page_title="Your own AI-Chat!",
|
92 |
-
# layout="wide"
|
93 |
-
# )
|
94 |
-
|
95 |
-
# llm for HuggingFace Inference API
|
96 |
-
# model = "mistralai/Mistral-7B-Instruct-v0.2"
|
97 |
-
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
98 |
|
99 |
-
|
100 |
-
#
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
# Document
|
103 |
if 'chunk_size' not in st.session_state:
|
104 |
st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000]
|
105 |
|
@@ -116,8 +113,7 @@ if 'countries_to_scrape' not in st.session_state:
|
|
116 |
# in main app, add configuration for user to scrape new data from DuckDuckGo
|
117 |
# in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
|
118 |
|
119 |
-
|
120 |
-
# Retriever config
|
121 |
if 'chroma_n_similar_documents' not in st.session_state:
|
122 |
st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
|
123 |
|
@@ -135,11 +131,16 @@ if 'source_documents' not in st.session_state:
|
|
135 |
|
136 |
|
137 |
# LLM config
|
|
|
|
|
|
|
|
|
138 |
if 'temperature' not in st.session_state:
|
139 |
st.session_state['temperature'] = 0.25
|
140 |
|
141 |
if 'max_new_tokens' not in st.session_state:
|
142 |
st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM
|
|
|
143 |
|
144 |
# This is the list of countries present in the vector store, since the vector store is previously prepared as they take very long to prepare
|
145 |
# This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store,
|
@@ -157,21 +158,22 @@ countries = [
|
|
157 |
"Germany",
|
158 |
]
|
159 |
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
162 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
163 |
# It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
|
164 |
-
llm = HuggingFaceHub(repo_id=model,
|
165 |
model_kwargs={
|
166 |
-
'temperature':
|
167 |
-
"max_new_tokens":
|
168 |
},
|
169 |
)
|
170 |
return llm
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
@st.cache_data # only going to get once
|
175 |
def get_embeddings():
|
176 |
with st.spinner(f'Getting HuggingFaceEmbeddings'):
|
177 |
# We use HuggingFaceEmbeddings() as it is open source and free to use.
|
@@ -179,8 +181,13 @@ def get_embeddings():
|
|
179 |
hf_embeddings = HuggingFaceEmbeddings()
|
180 |
return hf_embeddings
|
181 |
|
|
|
|
|
182 |
hf_embeddings = get_embeddings()
|
183 |
|
|
|
|
|
|
|
184 |
# Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up,
|
185 |
# to save time when experimenting as the embeddings take a long time to generate.
|
186 |
# The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores,
|
@@ -213,7 +220,7 @@ with st.spinner(f'Setting up pre-built chroma vector store'):
|
|
213 |
# The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
|
214 |
# Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
|
215 |
# (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
|
216 |
-
# A separate retriever was
|
217 |
bm25_retrievers = {} # to store retrievers for different countries
|
218 |
with st.spinner(f'Setting up pre-built bm25 retrievers'):
|
219 |
for country in countries:
|
@@ -222,11 +229,16 @@ with st.spinner(f'Setting up pre-built bm25 retrievers'):
|
|
222 |
bm25_retriever = pickle.load(handle)
|
223 |
bm25_retrievers[country] = bm25_retriever
|
224 |
|
225 |
-
#
|
|
|
|
|
|
|
|
|
|
|
226 |
# The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
|
227 |
# e.g. carbon emissions policy of Singapore.
|
228 |
# By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them.
|
229 |
-
# This is far better than just retrieving relevant chunks for the user's query and
|
230 |
# Multi input tools are not available, hence we have to prompt the agent to give an input list as a string
|
231 |
# then use ast.literal_eval to convert it back into a list
|
232 |
@tool
|
@@ -251,11 +263,14 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
|
|
251 |
then there is no record for the country and no answer can be obtained."""
|
252 |
|
253 |
# different retrievers
|
254 |
-
|
|
|
255 |
bm.k = st.session_state['bm25_n_similar_documents']
|
256 |
-
|
257 |
-
|
|
|
258 |
ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
|
|
|
259 |
retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
|
260 |
|
261 |
qa = RetrievalQA.from_chain_type(
|
@@ -265,8 +280,10 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
|
|
265 |
return_source_documents=True # returned in result['source_documents']
|
266 |
)
|
267 |
result = qa(query)
|
|
|
|
|
268 |
st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
|
269 |
-
st.session_state['source_documents'].append(result['source_documents'])
|
270 |
return f"{query.capitalize()} for {country}: " + result['result']
|
271 |
|
272 |
except Exception as e:
|
@@ -319,10 +336,12 @@ def compare(query:str) -> str:
|
|
319 |
Give as much elaboration in your answer as possible but they MUST be from the earlier context.
|
320 |
Do not give details that cannot be found in the earlier context."""
|
321 |
|
|
|
322 |
retrieve_answer_for_country.callbacks = [my_callback_handler]
|
323 |
compare.callbacks = [my_callback_handler]
|
324 |
generic_chat_llm.callbacks = [my_callback_handler]
|
325 |
|
|
|
326 |
agent = initialize_agent(
|
327 |
[retrieve_answer_for_country, compare], # tools
|
328 |
#[retrieve_answer_for_country, generic_chat_llm, compare],
|
@@ -347,7 +366,7 @@ if "menu" not in st.session_state:
|
|
347 |
"Source Documents for Last Query",
|
348 |
]
|
349 |
|
350 |
-
|
351 |
with st.sidebar:
|
352 |
st.subheader("DO NOT NAVIGATE between pages when agent is still generating messages in the chat. Wait for query to complete first.")
|
353 |
st.write("")
|
@@ -356,6 +375,7 @@ with st.sidebar:
|
|
356 |
st.spinner("test")
|
357 |
|
358 |
|
|
|
359 |
if page == "Chatbot":
|
360 |
st.header("Chat")
|
361 |
|
@@ -373,27 +393,19 @@ if page == "Chatbot":
|
|
373 |
"""}
|
374 |
]
|
375 |
|
376 |
-
|
377 |
-
st.session_state.current_response = ""
|
378 |
-
|
379 |
-
# Loop through each message in the session state and render it as a chat message.
|
380 |
for message in st.session_state.messages:
|
381 |
with st.chat_message(message["role"]):
|
382 |
st.markdown(message["content"])
|
383 |
|
384 |
-
# We initialize the quantized LLM from a local path.
|
385 |
-
# Currently most parameters are fixed but we can make them
|
386 |
-
# configurable.
|
387 |
-
#llm_chain = create_chain(retriever)
|
388 |
-
|
389 |
# We take questions/instructions from the chat input to pass to the LLM
|
390 |
if user_query := st.chat_input("Your message here", key="user_input"):
|
391 |
-
# remove source documents option from menu while query is running
|
392 |
|
|
|
393 |
st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
|
394 |
|
395 |
-
formatted_user_query = f":blue[{user_query}]"
|
396 |
# Add our input to the session state
|
|
|
397 |
st.session_state.messages.append(
|
398 |
{"role": "user", "content": formatted_user_query}
|
399 |
)
|
@@ -413,10 +425,6 @@ if page == "Chatbot":
|
|
413 |
with st.chat_message("assistant"):
|
414 |
st.markdown(action_plan_message)
|
415 |
|
416 |
-
# Pass our input to the llm chain and capture the final responses.
|
417 |
-
# It is worth noting that the Stream Handler is already receiving the
|
418 |
-
# streaming response as the llm is generating. We get our response
|
419 |
-
# here once the llm has finished generating the complete response.
|
420 |
results = agent(user_query)
|
421 |
response = f":blue[The answer to your query is:] {results['output']}"
|
422 |
|
@@ -430,14 +438,22 @@ if page == "Chatbot":
|
|
430 |
st.markdown(response)
|
431 |
|
432 |
|
|
|
|
|
433 |
if page == "Chat Config":
|
434 |
st.header(page)
|
435 |
|
436 |
|
|
|
|
|
|
|
|
|
|
|
437 |
if page == "Document, Retriever, Web Scraping Config":
|
438 |
st.header(page)
|
439 |
|
440 |
|
|
|
441 |
if page == "Source Documents for Last Query":
|
442 |
st.header(page)
|
443 |
try:
|
|
|
1 |
import streamlit as st
|
|
|
2 |
from streamlit_option_menu import option_menu
|
3 |
|
4 |
import os
|
|
|
20 |
from langchain.prompts import PromptTemplate
|
21 |
from langchain.chains import LLMChain
|
22 |
|
23 |
+
# for printing intermediate steps of agent (actions, tool calling etc.)
|
24 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
25 |
+
|
26 |
import warnings
|
27 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
28 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
29 |
|
30 |
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
|
31 |
|
|
|
32 |
|
33 |
+
################################ Callback ################################
|
34 |
# callback is needed to print intermediate steps of agent reasoning in the chatbot
|
35 |
# i.e. when action is taken, when tool is called, when tool call is complete etc.
|
36 |
class MyCallbackHandler(BaseCallbackHandler):
|
|
|
88 |
|
89 |
my_callback_handler = MyCallbackHandler()
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
################################ Configs ################################
|
93 |
+
# Set the webpage title
|
94 |
+
st.set_page_config(
|
95 |
+
page_title="ESG Countries Chatbot",
|
96 |
+
# layout="wide"
|
97 |
+
)
|
98 |
|
99 |
+
# Document Config
|
100 |
if 'chunk_size' not in st.session_state:
|
101 |
st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000]
|
102 |
|
|
|
113 |
# in main app, add configuration for user to scrape new data from DuckDuckGo
|
114 |
# in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
|
115 |
|
116 |
+
# Retriever Config
|
|
|
117 |
if 'chroma_n_similar_documents' not in st.session_state:
|
118 |
st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
|
119 |
|
|
|
131 |
|
132 |
|
133 |
# LLM config
|
134 |
+
# LLM from HuggingFace Inference API
|
135 |
+
if 'model' not in st.session_state:
|
136 |
+
st.session_state['model'] = "mistralai/Mixtral-8x7B-Instruct-v0.1" # or "mistralai/Mistral-7B-Instruct-v0.2"
|
137 |
+
|
138 |
if 'temperature' not in st.session_state:
|
139 |
st.session_state['temperature'] = 0.25
|
140 |
|
141 |
if 'max_new_tokens' not in st.session_state:
|
142 |
st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM
|
143 |
+
|
144 |
|
145 |
# This is the list of countries present in the vector store, since the vector store is previously prepared as they take very long to prepare
|
146 |
# This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store,
|
|
|
158 |
"Germany",
|
159 |
]
|
160 |
|
161 |
+
|
162 |
+
################################ Get LLM and Embeddings ################################
|
163 |
+
@st.cache_data # only going to get this once instead of all the time when page refreshers
|
164 |
+
# unless LLM config change then we will call the function again
|
165 |
+
def get_llm():
|
166 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
167 |
# It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
|
168 |
+
llm = HuggingFaceHub(repo_id=st.session_state['model'],
|
169 |
model_kwargs={
|
170 |
+
'temperature': st.session_state['temperature'],
|
171 |
+
"max_new_tokens": st.session_state['max_new_tokens']
|
172 |
},
|
173 |
)
|
174 |
return llm
|
175 |
|
176 |
+
@st.cache_data # only going to get this once instead of all the time when page refreshers
|
|
|
|
|
177 |
def get_embeddings():
|
178 |
with st.spinner(f'Getting HuggingFaceEmbeddings'):
|
179 |
# We use HuggingFaceEmbeddings() as it is open source and free to use.
|
|
|
181 |
hf_embeddings = HuggingFaceEmbeddings()
|
182 |
return hf_embeddings
|
183 |
|
184 |
+
# call above functions
|
185 |
+
llm = get_llm()
|
186 |
hf_embeddings = get_embeddings()
|
187 |
|
188 |
+
|
189 |
+
################################ Download and Initialize Pre-Built Retrievers ################################
|
190 |
+
|
191 |
# Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up,
|
192 |
# to save time when experimenting as the embeddings take a long time to generate.
|
193 |
# The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores,
|
|
|
220 |
# The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
|
221 |
# Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
|
222 |
# (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
|
223 |
+
# A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap.
|
224 |
bm25_retrievers = {} # to store retrievers for different countries
|
225 |
with st.spinner(f'Setting up pre-built bm25 retrievers'):
|
226 |
for country in countries:
|
|
|
229 |
bm25_retriever = pickle.load(handle)
|
230 |
bm25_retrievers[country] = bm25_retriever
|
231 |
|
232 |
+
# One retriever above is semantic based and the other is keyword based
|
233 |
+
# Both retrievers will be used
|
234 |
+
# Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below
|
235 |
+
|
236 |
+
################################ Tools for Agent to Use ################################
|
237 |
+
|
238 |
# The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
|
239 |
# e.g. carbon emissions policy of Singapore.
|
240 |
# By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them.
|
241 |
+
# This is far better than just retrieving relevant chunks for the user's query and throwing everything to a single RetrievalQA chain to process
|
242 |
# Multi input tools are not available, hence we have to prompt the agent to give an input list as a string
|
243 |
# then use ast.literal_eval to convert it back into a list
|
244 |
@tool
|
|
|
263 |
then there is no record for the country and no answer can be obtained."""
|
264 |
|
265 |
# different retrievers
|
266 |
+
# keyword
|
267 |
+
bm = bm25_retrievers[country]
|
268 |
bm.k = st.session_state['bm25_n_similar_documents']
|
269 |
+
# semantic
|
270 |
+
chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']})
|
271 |
+
# ensemble (below) reranks results from both retrievers above
|
272 |
ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
|
273 |
+
# for user to make selection
|
274 |
retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
|
275 |
|
276 |
qa = RetrievalQA.from_chain_type(
|
|
|
280 |
return_source_documents=True # returned in result['source_documents']
|
281 |
)
|
282 |
result = qa(query)
|
283 |
+
# add to source documents session state so it can be loaded later in the other menu
|
284 |
+
# all source documents linked to answer any query (or part of it) are visible
|
285 |
st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.")
|
286 |
+
st.session_state['source_documents'].append(result['source_documents'])
|
287 |
return f"{query.capitalize()} for {country}: " + result['result']
|
288 |
|
289 |
except Exception as e:
|
|
|
336 |
Give as much elaboration in your answer as possible but they MUST be from the earlier context.
|
337 |
Do not give details that cannot be found in the earlier context."""
|
338 |
|
339 |
+
# equip tools with callbacks
|
340 |
retrieve_answer_for_country.callbacks = [my_callback_handler]
|
341 |
compare.callbacks = [my_callback_handler]
|
342 |
generic_chat_llm.callbacks = [my_callback_handler]
|
343 |
|
344 |
+
# Initialize
|
345 |
agent = initialize_agent(
|
346 |
[retrieve_answer_for_country, compare], # tools
|
347 |
#[retrieve_answer_for_country, generic_chat_llm, compare],
|
|
|
366 |
"Source Documents for Last Query",
|
367 |
]
|
368 |
|
369 |
+
################################ Sidebar with Menu ################################
|
370 |
with st.sidebar:
|
371 |
st.subheader("DO NOT NAVIGATE between pages when agent is still generating messages in the chat. Wait for query to complete first.")
|
372 |
st.write("")
|
|
|
375 |
st.spinner("test")
|
376 |
|
377 |
|
378 |
+
################################ Main Chatbot Page ################################
|
379 |
if page == "Chatbot":
|
380 |
st.header("Chat")
|
381 |
|
|
|
393 |
"""}
|
394 |
]
|
395 |
|
396 |
+
# Loop through each message in the session state and render it as a chat message
|
|
|
|
|
|
|
397 |
for message in st.session_state.messages:
|
398 |
with st.chat_message(message["role"]):
|
399 |
st.markdown(message["content"])
|
400 |
|
|
|
|
|
|
|
|
|
|
|
401 |
# We take questions/instructions from the chat input to pass to the LLM
|
402 |
if user_query := st.chat_input("Your message here", key="user_input"):
|
|
|
403 |
|
404 |
+
# reset source documents list during a new query
|
405 |
st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list
|
406 |
|
|
|
407 |
# Add our input to the session state
|
408 |
+
formatted_user_query = f":blue[{user_query}]"
|
409 |
st.session_state.messages.append(
|
410 |
{"role": "user", "content": formatted_user_query}
|
411 |
)
|
|
|
425 |
with st.chat_message("assistant"):
|
426 |
st.markdown(action_plan_message)
|
427 |
|
|
|
|
|
|
|
|
|
428 |
results = agent(user_query)
|
429 |
response = f":blue[The answer to your query is:] {results['output']}"
|
430 |
|
|
|
438 |
st.markdown(response)
|
439 |
|
440 |
|
441 |
+
################################ Chat Config Page ################################
|
442 |
+
# for changing config like temperature etc.
|
443 |
if page == "Chat Config":
|
444 |
st.header(page)
|
445 |
|
446 |
|
447 |
+
################################ Document Page ################################
|
448 |
+
# to scrape new documents from DuckDuckGo
|
449 |
+
# to chnange paramters like chunk size
|
450 |
+
# to upload own PDF
|
451 |
+
# to override existing data on new scraped data or new pdf uploaded
|
452 |
if page == "Document, Retriever, Web Scraping Config":
|
453 |
st.header(page)
|
454 |
|
455 |
|
456 |
+
################################ Main Chatbot Page ################################
|
457 |
if page == "Source Documents for Last Query":
|
458 |
st.header(page)
|
459 |
try:
|