Spaces:
Running
Running
import streamlit as st | |
from streamlit_option_menu import option_menu | |
import os | |
# from langchain.llms import HuggingFaceHub # old, for calling HuggingFace Inference API (free for our use case) | |
from langchain_community.llms import HuggingFaceEndpoint # for calling HuggingFace Inference API (free for our use case) | |
from langchain.embeddings import HuggingFaceEmbeddings # to let program know what embeddings the vector store was embedded in earlier | |
from langchain_community.llms import HuggingFaceEndpoint | |
# to set up the agent and tools which will be used to answer questions later | |
from langchain.agents import initialize_agent | |
from langchain.agents import tool # decorator so each function will be recognized as a tool | |
from langchain.chains.retrieval_qa.base import RetrievalQA # to answer questions from vector store retriever | |
# from langchain.chains.question_answering import load_qa_chain # to further customize qa chain if needed | |
from langchain.vectorstores import Chroma # vector store for retriever | |
import ast # to parse user string input to list for one of the tools (agent tools do not support 2 inputs) | |
#from langchain.memory import ConversationBufferMemory # not used as of now | |
import pickle # for loading the bm25 retriever | |
from langchain.retrievers import EnsembleRetriever # to use chroma and | |
# for defining a generic LLMChain as a generic chat tool (if needed) | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import LLMChain | |
# for printing intermediate steps of agent (actions, tool calling etc.) | |
from langchain.callbacks.base import BaseCallbackHandler | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# for web scraping and user to override | |
from web_scrape_and_pdf_loader import ( | |
duckduckgo_scrape, | |
process_links_load_documents, | |
setup_chromadb_vectorstore, | |
setup_bm25_retriever, | |
pdf_loader_local | |
) | |
# look for new retrievers that user created (to override existing ones if user chooses) | |
import glob | |
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API | |
# alternatively set your env variable above | |
################################ Callback ################################ | |
# callback is needed to print intermediate steps of agent reasoning in the chatbot | |
# i.e. when action is taken, when tool is called, when tool call is complete etc. | |
class MyCallbackHandler(BaseCallbackHandler): | |
def __init__(self): | |
self.tokens = [] | |
# def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream unfortunately! | |
# self.tokens.append(token) | |
# print(token) | |
def on_agent_action(self, action, **kwargs): | |
"""Run on agent action.""" | |
print("\n\nnew action", action) | |
thought = action.log.replace('\n', ' \n') # so streamlit will recognize as newline | |
tool_called = action.tool | |
# tool_input = action.tool_input | |
calling_tool = f"I am calling the '{tool_called}' tool and waiting for it to give me a result..." | |
st.session_state.messages.extend( | |
[{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}] | |
) | |
# Add the response to the chat window | |
with st.chat_message("assistant"): | |
st.markdown(thought) | |
st.markdown(calling_tool) | |
# def on_agent_finish(self, finish, **kwargs): | |
# """Run on agent end.""" | |
# #print("\n\nEnd", finish) | |
# finish_string = finish.log.replace('\n', ' \n') # so streamlit will recognize as newline | |
# st.session_state.messages.append( | |
# {"role": "assistant", "content": finish_string} | |
# ) | |
# with st.chat_message("assistant"): | |
# st.markdown(finish_string) | |
# def on_llm_start(self, serialized, prompts, **kwargs): | |
# """Run when LLM starts running.""" | |
# print("LLM Start: ", prompts) | |
# def on_llm_end(self, response, **kwargs): | |
# """Run when LLM ends running.""" | |
# print(response) | |
def on_tool_end(self, output, **kwargs): | |
"""Run when tool ends running.""" | |
#print("\n\nTool End: ", output) | |
tool_output = f":blue[[Tool Output]] {output} \n \nI am processing the output from the tool..." | |
st.session_state.messages.append( | |
{"role": "assistant", "content": tool_output} | |
) | |
with st.chat_message("assistant"): | |
st.markdown(tool_output) | |
my_callback_handler = MyCallbackHandler() | |
################################ Configs ################################ | |
# Set the webpage title | |
st.set_page_config( | |
page_title="ESG Countries Chatbot", | |
# layout="wide" | |
) | |
# Document Config | |
if 'countries_override' not in st.session_state: | |
# countries to override with own documents from uploaded pdf or updated scraped search results | |
# must first scrape or upload own documents to use this | |
st.session_state['countries_override'] = [] | |
if 'chunk_size' not in st.session_state: | |
st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000] | |
if 'chunk_overlap' not in st.session_state: | |
st.session_state['chunk_overlap'] = 100 # choose one of [50, 100, 150, 200] | |
# Retriever Config | |
if 'chroma_n_similar_documents' not in st.session_state: | |
st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic) | |
if 'bm25_n_similar_documents' not in st.session_state: | |
st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword) | |
if 'retriever_config' not in st.session_state: | |
st.session_state['retriever_config'] = 'Ensemble (Both Re-Ranked)' # choose one of ['semantic', 'keyword', 'ensemble'] | |
if 'keyword_retriever_weight' not in st.session_state: | |
st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble | |
if 'source_documents' not in st.session_state: | |
st.session_state['source_documents'] = [] # this is to store all source documents for a particular search | |
# LLM config | |
# LLM from HuggingFace Inference API | |
if 'model' not in st.session_state: | |
st.session_state['model'] = "mistralai/Mixtral-8x7B-Instruct-v0.1" # or "mistralai/Mistral-7B-Instruct-v0.2" | |
if 'temperature' not in st.session_state: | |
st.session_state['temperature'] = 0.25 | |
if 'max_new_tokens' not in st.session_state: | |
st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM | |
# This is the list of countries present in the pre-built vector store, since the vector store is previously prepared as they take very long to prepare | |
# This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store, | |
# it would still filter the vector store with this country and give an empty result, instead of giving an error. | |
# We have to manually return the error to let the agent using the tool know. | |
# The countries were reduced to just 6 as the time taken to get the embeddings to build up the chunks is too long. | |
# However, having more countries **will not affect** the quality of the answers in comparing between 2 countries in the RAG application | |
# as the RAG only picks out document chunks for the 2 countries of interest. | |
countries = [ | |
"Australia", | |
"China", | |
"Japan", | |
"Malaysia", | |
"Singapore", | |
"Germany", | |
] | |
################################ Get LLM and Embeddings ################################ | |
def get_llm(): | |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface | |
# It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM | |
# llm = HuggingFaceHub(repo_id=st.session_state['model'], | |
# model_kwargs={ | |
# 'temperature': st.session_state['temperature'], | |
# "max_new_tokens": st.session_state['max_new_tokens'] | |
# }, | |
# ) | |
llm = HuggingFaceEndpoint( | |
endpoint_url=st.session_state['model'], | |
huggingfacehub_api_token=os.environ['HUGGINGFACEHUB_API_TOKEN'], | |
task="text-generation", | |
temperature = st.session_state['temperature'], | |
max_new_tokens = st.session_state['max_new_tokens'] | |
) | |
return llm | |
# for chromadb vectore store | |
def get_embeddings(): | |
# We use HuggingFaceEmbeddings() as it is open source and free to use. | |
# Initialize the default hf model for embedding the tokenized texts into vectors with semantic meanings | |
hf_embeddings = HuggingFaceEmbeddings() | |
return hf_embeddings | |
# call above functions | |
llm = get_llm() | |
hf_embeddings = get_embeddings() | |
# when LLM config is changed we will call this function | |
def update_llm(): | |
global llm | |
llm = get_llm() | |
################################ Download and Initialize Pre-Built Retrievers ################################ | |
# Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps, and zipped up, | |
# to save time when experimenting as the embeddings take a long time to generate. | |
# The existing stores will be pulled using from google drive above when app starts. When using the existing vector stores, | |
# just need to change the name of the persist directory when selecting the different chunk sizes and overlaps. | |
# Later in the main app if the user choose to scrape new data, or override with their own PDF, a new chromadb would be created. | |
# This step will take some time | |
if not os.path.exists("bm25.zip"): | |
with st.spinner(f'Downloading bm25 retriever for all chunk sizes and overlaps, will take some time'): | |
os.system("gdown https://drive.google.com/uc?id=1q-hNnyyBA8tKyF3vR69nkwCk9kJj7WHi") | |
if not os.path.exists("chromadb.zip"): | |
with st.spinner(f'Downloading chromadb retrievers for all chunk sizes and overlaps, will take some time'): | |
os.system("gdown https://drive.google.com/uc?id=1zad6tgYm2o5M9E2dTLQqmm6GoI8kxNC3") | |
if not os.path.exists("bm25/"): | |
with st.spinner(f'Unzipping bm25 retriever for all chunk sizes and overlaps, will take some time'): | |
os.system("unzip bm25.zip") | |
if not os.path.exists("chromadb/"): | |
with st.spinner(f'Unzipping chromadb retrievers for all chunk sizes and overlaps, will take some time'): | |
os.system("unzip chromadb.zip") | |
# One retriever below is semantic based (chromadb) and the other is keyword based (bm25) | |
# Both retrievers will be used | |
# Then Langchain's EnsembleRetriever will be used to rerank both their results to give final output to RetrievalQA chain below | |
def get_retrievers(): | |
persist_directory = f"chromadb/chromadb_esg_countries_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}" | |
with st.spinner(f'Setting up pre-built chroma vector store'): | |
chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings) | |
# Initialize BM25 Retriever | |
# Unlike Chroma (semantic) BM25 is a keyword-based algorithm that performs well on queries containing keywords without capturing the semantic meaning of the query terms, | |
# hence there is no need to embed the text with HuggingFaceEmbeddings and it is relatively faster to set up. | |
# The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget. | |
# Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country. | |
# (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.) | |
# A separate retriever was hence pre-built for each unique country and each unique chunk size and overlap. | |
bm25_retrievers = {} # to store retrievers for different countries | |
with st.spinner(f'Setting up pre-built bm25 retrievers'): | |
for country in countries: | |
bm25_filename = f"bm25/bm25_esg_countries_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}.pickle" | |
with open(bm25_filename, 'rb') as handle: | |
bm25_retriever = pickle.load(handle) | |
bm25_retrievers[country] = bm25_retriever | |
return chroma_db, bm25_retrievers | |
chroma_db, bm25_retrievers = get_retrievers() | |
# when retriever config is changed we will call this function | |
def update_retrievers(): | |
global chroma_db | |
global bm25_retrievers | |
chroma_db, bm25_retrievers = get_retrievers() | |
chroma_db_new = None | |
bm25_new_retrievers = {} # to store retrievers for different countries | |
# get retrievers for country which we override | |
if len(st.session_state['countries_override']) > 0: | |
for country in st.session_state['countries_override']: | |
chroma_db_new = Chroma(persist_directory=f"chromadb/new_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}_",embedding_function=hf_embeddings) | |
bm25_filename = f"bm25/new_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}_.pickle" | |
with open(bm25_filename, 'rb') as handle: | |
bm25_retriever = pickle.load(handle) | |
bm25_new_retrievers[country] = bm25_retriever | |
# check if there are any new retrievers where user uploaded PDF or scraped new links and return list of countries for them | |
def check_for_new_retrievers(): | |
# see if retrievers/vector stores created by user's own uploaded PDF or newly scraped data is found | |
new_documents_chroma = glob.glob("chromadb/new*") | |
new_documents_bm25 = glob.glob("bm25/new*") | |
new_documents_chroma = [os.path.split(doc)[-1] for doc in new_documents_chroma] | |
new_documents_bm25 = [os.path.split(doc)[-1] for doc in new_documents_bm25] | |
new_countries = [] | |
# loop through new docs in chroma retrievers created by user scraping/pdf (if any) | |
try: | |
for doc in new_documents_chroma: | |
#print(doc) | |
if ((doc + ".pickle") in new_documents_bm25): # check that the doc also exists for bm25 retriever | |
new_doc_country = doc.split('_')[1] | |
new_doc_chunk_size = doc.split('_')[3] | |
new_doc_chunk_overlap = doc.split('_')[5] | |
# check that the retrievers are created for the current selected chunk sizes | |
if ((new_doc_chunk_overlap == str(st.session_state['chunk_overlap'])) & (new_doc_chunk_size == str(st.session_state['chunk_size']))): | |
new_countries.append(new_doc_country) | |
except Exception as e: | |
print(e) | |
if len(new_countries) == 0: | |
info = ' (Own documents are :red[NOT FOUND]. Must first scrape or upload own PDF (in menu above) before you can select any countries to override.)' | |
else: | |
info = ' (⚠️Own documents for the following countries :green[FOUND], select them in the list below to override.)' | |
return new_countries, info | |
################################ Tools for Agent to Use ################################ | |
# The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies, | |
# e.g. carbon emissions policy of Singapore. | |
# By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them. | |
# This is far better than just retrieving relevant chunks for the user's query and throwing everything to a single RetrievalQA chain to process | |
# Multi input tools are not available, hence we have to prompt the agent to give an input list as a string | |
# then use ast.literal_eval to convert it back into a list | |
def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change diff chain type diff version answers, change | |
"""Gives answer to a query about a single country's public ESG policy. | |
The input list should be of the following format: | |
[query, country] | |
The first element of the list is the user query, surrounded by double quotes. | |
The second element is the full name of the country involved, surrounded by double quotes, for example "Singapore". | |
The 2 inputs are separated by a comma. Do not write a list comprehension. | |
The 2 inputs, together, are surrounded by square brackets as it is a list. | |
Do not put multiple countries into the input at once. Instead use this tool multiple times, one time for each country. | |
If you have multiple queries to ask about a country, break the query into separate parts and use this tool multiple times, one for each query. | |
""" | |
try: | |
query_and_country_list = ast.literal_eval(query_and_country) | |
query = query_and_country_list[0] | |
country = query_and_country_list[1].capitalize() # in case LLM did not capitalize first letter as filtering for metadata is case sensitive | |
if not country in (countries + st.session_state['countries_override']): | |
return """The country that you input into the tool cannot be found. | |
If you did not make a mistake and the country that you input is indeed what the user asked, | |
then there is no record for the country and no answer can be obtained.""" | |
# if there are countries we want to override | |
if country in st.session_state['countries_override']: | |
# keyword | |
bm = bm25_new_retrievers [country] | |
bm.k = st.session_state['bm25_n_similar_documents'] | |
# semantic | |
chroma = chroma_db_new.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) | |
else: | |
# keyword | |
bm = bm25_retrievers[country] | |
bm.k = st.session_state['bm25_n_similar_documents'] | |
# semantic | |
chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) | |
# ensemble (below) reranks results from both retrievers above | |
ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']]) | |
# for user to make selection | |
retrievers = {'Ensemble (Both Re-Ranked)': ensemble, 'Semantic (Chroma DB)': chroma, 'Keyword (BM 2.5)': bm} | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type='stuff', | |
retriever=retrievers[st.session_state['retriever_config']], # selected retriever based on user config | |
return_source_documents=True # returned in result['source_documents'] | |
) | |
result = qa(query) | |
# add to source documents session state so it can be loaded later in the other menu | |
# all source documents linked to answer any query (or part of it) are visible | |
st.session_state['source_documents'].append(f"Documents retrieved for agent query '{query}' for country '{country}'.") | |
st.session_state['source_documents'].append(result['source_documents']) | |
return f"'{query.capitalize()}' for '{country}': " + result['result'] | |
except Exception as e: | |
return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again. | |
Remember the 2 inputs, query and country, must both be surrounded by double quotes. | |
The 2 inputs, together, are surrounded by square brackets as it is a list.""" | |
# if a user tries to casually chat with the agent chatbot, the LLM will be able to use this tool to reply instead | |
# this is optional, better to let user's know the chatbot is not for casual chatting | |
def generic_chat_llm(query: str) -> str: | |
"""Use this tool for general queries and casual chat. Forward the user input directly into this tool, do not come up with your own input. | |
This tool IS NOT FOR MAKING COMPARISONS of anything. | |
This tool IS NOT FOR FINDING ESG POLICY of any country! | |
It is only for casual chat! Do not use this tool unnecessarily! | |
""" | |
try: | |
# Second Generic Tool | |
prompt = PromptTemplate( | |
input_variables=["query"], | |
template="{query}" | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
return llm_chain.run(query) | |
except Exception as e: | |
return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again. | |
If you have already tried 2 times, do not try anymore, there is no response for your input. | |
Move on to the next step of your plan.""" | |
# sometimes the agent will suddenly ask for a 'compare' tool even though it was not given this tool | |
# hence I have decided to give it this tool that gives a prompt to remind it to look at past information | |
# and decide whether it is time to darw a conclusion | |
# tools cannot have no input, hence I let the agent input a 'query' parameter even though it is not used | |
# having the query as input let the LLM 'recall' what is being asked | |
# instead of it being lost all the way at the start of the ReAct process | |
def compare(query:str) -> str: | |
"""Use this tool to give you hints and instructions on how you can compare between policies of countries. | |
Use this tool as a final step, only after you have used other tools to obtain all the information you need. | |
When putting the query into this tool, look at the entire query that the user has asked at the start, | |
do not leave any details in the query out. | |
""" | |
return f"""Once again, check through all your previous observations to answer the user query. | |
Make sure every part of the query is addressed by the context, or that you have at least tried to do so. | |
Make sure you have not forgotten to address anything in the query. | |
If you still need more details, you can use another tool to find out more if you have not tried using the same tool with the necessary input earlier. | |
If you have enough information, use your reasoning to answer them to the best of your ability. | |
Give as much elaboration in your answer as possible but they MUST be from the earlier context. | |
Do not give details that cannot be found in the earlier context.""" | |
# equip tools with callbacks | |
retrieve_answer_for_country.callbacks = [my_callback_handler] | |
compare.callbacks = [my_callback_handler] | |
generic_chat_llm.callbacks = [my_callback_handler] | |
# Initialize | |
agent = initialize_agent( | |
[retrieve_answer_for_country, compare], # tools | |
# uncomment below if want to enable general chat option also, if user engages bot with casual talk | |
# however user should be advised not to do this | |
# [generic_chat_llm, retrieve_answer_for_country, compare], | |
llm=llm, | |
agent="zero-shot-react-description", # this is good | |
verbose=False, | |
handle_parsing_errors=True, | |
return_intermediate_steps=True, | |
callbacks=[my_callback_handler] | |
# no memories, limited RAM in HuggingFaceSpaces | |
# in production mode conversation can be stored for separate users/chat sessions in postgresql database | |
# memory=ConversationBufferMemory( | |
# memory_key="chat_history", return_messages=True | |
# ), | |
# max_iterations=10 | |
) | |
################################ Sidebar with Menu ################################ | |
with st.sidebar: | |
st.title("ESG Countries Chatbot") | |
page = option_menu("Menu", | |
[ | |
"Main Chatbot", | |
"View Source Docs for Last Query", | |
"Scrape or Upload Own Docs", | |
], | |
icons=['robot', 'list-task', 'cloud-upload-fill'], | |
default_index=0) | |
with st.expander("Warning", expanded = True): | |
st.write("⚠️ DO NOT navigate between pages or change config when chat is ongoing. Wait for query to complete first.") | |
st.write("") | |
new_countries, info = check_for_new_retrievers() | |
# if new retrievers that pass the above criteria are found, let the user know their countries | |
# the user can select from these countries to override existing retrievers | |
# otherwise prompt user to scrape or upload own PDF to create the new retrievers | |
with st.expander("Document Config", expanded = True): | |
st.multiselect( | |
'Countries to Override with Own Docs:' + info, | |
new_countries, | |
key="countries_override" | |
) | |
st.selectbox( | |
"Chunk Size", | |
options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000], | |
on_change=update_retrievers, | |
key="chunk_size" | |
) | |
st.selectbox( | |
"Chunk Overlap", | |
options=[50, 100, 150, 200], | |
on_change=update_retrievers, | |
key="chunk_overlap" | |
) | |
st.write("") | |
with st.expander("LLM Config", expanded = True): | |
st.selectbox( | |
"HuggingFace Inference Model", | |
options=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"], | |
on_change=update_llm, | |
key="model" | |
) | |
st.slider( | |
"Temperature", | |
0.0, 1.0, 0.05, | |
#value = st.session_state['temperature'], | |
on_change=update_llm, | |
key="temperature" | |
) | |
st.slider( | |
"Max Tokens Generated", | |
200, 1000, | |
on_change=update_llm, | |
key="max_new_tokens" | |
) | |
st.write("") | |
with st.expander("Retriever Config", expanded = True): | |
st.selectbox( | |
"Retriever to Use", | |
options=['Ensemble (Both Re-Ranked)', 'Semantic (Chroma DB)', 'Keyword (BM 2.5)'], | |
key="retriever_config" | |
) | |
st.slider( | |
"Keyword Retriever Weight (If using ensemble retriever, this is the weight of the keyword retriever, semantic retriever would be 1 minus this value)", | |
0.0, 0.05, 1.0, | |
key="keyword_retriever_weight" | |
) | |
st.number_input( | |
"Number of Relevant Documents Returned by Keyword Retriever (BM25)", | |
0, 20, | |
key="bm25_n_similar_documents" | |
) | |
st.number_input( | |
"Number of Relevant Documents Returned by Semantic Retriever (ChromaDB)", | |
0, 20, | |
key="chroma_n_similar_documents" | |
) | |
################################ Main Chatbot Page ################################ | |
if page == "Main Chatbot": | |
st.subheader("Chatbot") | |
# Store the conversation in the session state. | |
# Used to render the chat conversation. | |
# Initialize it with the first message for users to be greeted with | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", | |
"content": f""" | |
Hello, I am a chatbot which specializes in ESG policies of countries. | |
Currently I have data for {(', ').join(countries)}. | |
You can update the data or add data for more countries in the left menu under ""Scrape or Upload Own Docs". | |
You can ask me to compare specific policies between multiple countries too. An example of a question you can ask me is: | |
"What are the differences between carbon emissions policy in Singapore, Malaysia and China?" How may I help you today? | |
"""} | |
] | |
# Loop through each message in the session state and render it as a chat message | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# We take questions/instructions from the chat input to pass to the LLM | |
if user_query := st.chat_input("Your message here", key="user_input"): | |
# reset source documents list during a new query | |
st.session_state['source_documents'] = [f"User query: '{user_query}'"] # reset source documents list | |
# Add our input to the session state | |
formatted_user_query = f":blue[{user_query}]" | |
st.session_state.messages.append( | |
{"role": "user", "content": formatted_user_query} | |
) | |
# Add our input to the chat window | |
with st.chat_message("user"): | |
st.markdown(formatted_user_query) | |
# Let user know agent is planning the actions | |
action_plan_message = "Please wait while I plan out a best set of actions to obtain the necessary information to answer your query." | |
# Add the response to the session state | |
st.session_state.messages.append( | |
{"role": "assistant", "content": action_plan_message} | |
) | |
# Add the response to the chat window | |
with st.chat_message("assistant"): | |
st.markdown(action_plan_message) | |
results = agent(user_query) | |
response = f":blue[The answer to your query is:] {results['output']}" | |
# Add the response to the session state | |
st.session_state.messages.append( | |
{"role": "assistant", "content": response} | |
) | |
# Add the response to the chat window | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
################################ Source Documents Page ################################ | |
if page == "View Source Docs for Last Query": | |
st.subheader("Source Documents for Last Query") | |
try: | |
st.subheader(st.session_state['source_documents'][0]) | |
for doc in st.session_state['source_documents'][1:]: | |
#st.write("Source: " + doc['page_content']) | |
st.write(doc) | |
except: | |
st.write("No source documents retrieved yet. Please run a full user query before coming back to this page.") | |
################################ Scrap or Upload Documents Page ################################ | |
# to scrape new documents from DuckDuckGo | |
# to upload own PDF | |
# to override existing data on new scraped data or new pdf uploaded | |
if page == "Scrape or Upload Own Docs": | |
st.header("Scrape or Upload Own PDF") | |
st.write("Here you can choose to upload your own PDF or scrape more recent data via DuckDuckGo search for a selected country below.") | |
st.write(":blue[NOTE: Certain countries were not present in the original default vector stores, you can scrape data for these countries too so you can ask about them in the chat.]") | |
st.write("You will create new BM2.5 (keyword) and Chroma (semantic) retrievers for it. Note that this can take a very long time.") | |
country_scrape_upload = st.selectbox( | |
"Select Country", | |
options=[ | |
"Australia", "Bangladesh", "Brunei", "Cambodia", "China", "India", "Indonesia", "Japan", "Laos", "Macau", "Malaysia", "Myanmar", | |
"Nepal", "Philippines", "Singapore", "South Korea", "Sri Lanka", "Thailand", "Vietnam", "France", "Germany", "Israel", "Poland", | |
"Sweden", "Turkey", "United Kingdom", "United States" | |
], | |
) | |
# display documents chunk sizes and overlaps | |
col1, col2 = st.columns(2) | |
with col1: | |
with st.container(border = True): | |
st.write("New Documents Chunk Size: (Can change in sidebar)" ) | |
st.text(f"{st.session_state['chunk_size']}" ) | |
with col2: | |
with st.container(border = True): | |
st.write("New Documents Chunk Overlap: (Can change in sidebar)" ) | |
st.text(f"{st.session_state['chunk_overlap']}") | |
# how user wishes to populate documents | |
options = [ | |
"Upload Own PDF", | |
"Automatically Scrape Web Data using DuckDuckGo (may take more than 5 mins)" | |
] | |
option = st.radio( | |
"How Do You Wish To Create New Documents", | |
options=options | |
) | |
submit_upload_pdf = False | |
submit_scrape_web = False | |
submit_scrape_vector_store = False | |
# save new retrievers in local directory | |
def save_new_retrievers(all_documents, chunk_size, chunk_overlap, country_scrape_upload): | |
with st.spinner('Setting up new bm25 retrievers with documents, may take more than 5 mins...'): | |
# vectorstore for this country will be stored in "bm25/new_{country}_chunk_{chunk_size}_overlap_{chunk_overlap}_" | |
# can be used to override existing vectorstore for this country in sidebar document configuration | |
setup_bm25_retriever(all_documents, chunk_size, chunk_overlap, country_scrape_upload) | |
with st.spinner('Setting up new chromadb vector stores with documents, may take more than 5 mins...'): | |
# vectorstore for this country will be stored in "chroma_db/new_{country}_chunk_{chunk_size}_overlap_{chunk_overlap}_" | |
# can be used to override existing vectorstore for this country in sidebar document configuration | |
setup_chromadb_vectorstore(hf_embeddings, all_documents, chunk_size, chunk_overlap, country_scrape_upload) | |
st.toast(":blue[SUCCESS!] New retrievers set up with your new data. To override data for this country, you can :blue[Select the Countries to Override in the 'Document Config'] section of the left sidebar.") | |
st.rerun() | |
# form for user to configure pdf loading options | |
if option == options[0]: | |
with st.form(key='upload_pdf_form'): | |
st.subheader(f"Selected Option: {option}") | |
uploaded_pdf = st.file_uploader("Upload a PDF") | |
if uploaded_pdf: | |
temp_file = uploaded_pdf.name | |
with open(temp_file, "wb") as file: | |
file.write(uploaded_pdf.getvalue()) | |
submit_upload_pdf = st.form_submit_button(label='Upload and Create Vector Store (Scroll down after clicking)') | |
st.markdown(":blue[NOTE:] After you are done creating the vector store, the country will appear under :blue[Countries to Override in the 'Document Config'] section of the left sidebar. Select the country to override it.") | |
if submit_upload_pdf: | |
try: | |
with st.spinner('Generating documents from PDF...may take more than 5 mins...'): | |
all_documents = pdf_loader_local(temp_file, country_scrape_upload) | |
#st.write(all_documents) | |
save_new_retrievers(all_documents, st.session_state['chunk_size'], st.session_state['chunk_overlap'], country_scrape_upload) | |
except Exception as e: | |
st.write(f"Error! Did you remember to upload the PDF file? Error Message: {e}") | |
# form for user to configure web scraping for duckduckgo | |
if option == options[1]: | |
with st.form(key='scrape_web_form'): | |
st.subheader(f"Selected Option: {option}") | |
n_search_results = st.number_input( | |
"How many DuckDuckGo search results would you like to scrape? In the default vector stores, the number is 10 but it will take a very long time!", | |
0, 20, | |
value = 5 | |
) | |
search_term = st.text_input( | |
"Search Term", | |
value = f"{country_scrape_upload} sustainability esg newest updated public policy document government", | |
) | |
submit_scrape_web = st.form_submit_button(label='Scrape Web for Results and Create Vector Store (Scroll down after clicking)') | |
st.markdown(":blue[NOTE:] After you are done creating the vector store, the country will appear under :blue[Countries to Override in the 'Document Config'] section of the left sidebar. Select the country to override it.") | |
if submit_scrape_web: | |
with st.spinner('Scraping web using Duck Duck Go search...'): | |
all_links, df_links = duckduckgo_scrape(country_scrape_upload, search_term, n_search_results) | |
st.write(f"Results from Web Scrape") | |
try: | |
st.write(df_links) | |
except: | |
st.write("Waiting for web scraping results.") | |
with st.spinner('Generating documents from web search results...may take more than 5 mins...'): | |
all_documents = process_links_load_documents(all_links) | |
save_new_retrievers(all_documents, st.session_state['chunk_size'], st.session_state['chunk_overlap'], country_scrape_upload) |