bohmian commited on
Commit
dfe22da
·
verified ·
1 Parent(s): e653334

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +391 -0
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+
4
+ import os
5
+ from langchain.llms import HuggingFaceHub # for calling HuggingFace Inference API (free for our use case)
6
+ from langchain.embeddings import HuggingFaceEmbeddings # to let program know what embeddings the vector store was embedded in earlier
7
+
8
+ # to set up the agent and tools which will be used to answer questions later
9
+ from langchain.agents import initialize_agent
10
+ from langchain.agents import tool # decorator so each function will be recognized as a tool
11
+ from langchain.chains.retrieval_qa.base import RetrievalQA # to answer questions from vector store retriever
12
+ # from langchain.chains.question_answering import load_qa_chain # to further customize qa chain if needed
13
+ from langchain.vectorstores import Chroma # vector store for retriever
14
+ import ast # to parse user string input to list for one of the tools (agent tools do not support 2 inputs)
15
+ #from langchain.memory import ConversationBufferMemory # not used as of now
16
+ import pickle # for loading the bm25 retriever
17
+ from langchain.retrievers import EnsembleRetriever # to use chroma and
18
+
19
+ # for defining a generic LLMChain as a generic chat tool (if needed)
20
+ from langchain.prompts import PromptTemplate
21
+ from langchain.chains import LLMChain
22
+
23
+ import warnings
24
+ warnings.filterwarnings("ignore", category=FutureWarning)
25
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
26
+
27
+ # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'your_api_key' # for using HuggingFace Inference API
28
+
29
+
30
+ from langchain.callbacks.base import BaseCallbackHandler
31
+ class MyCallbackHandler(BaseCallbackHandler):
32
+ def __init__(self):
33
+ self.tokens = []
34
+
35
+ def on_llm_new_token(self, token, **kwargs) -> None: # HuggingFaceHub() cannot stream
36
+ self.tokens.append(token)
37
+ print(token)
38
+
39
+ def on_agent_action(self, action, **kwargs):
40
+ """Run on agent action."""
41
+ print("\n\nnew action", action)
42
+ thought = action.log.replace('\n', ' \n') # so streamlit will recognize as newline
43
+ tool_called = action.tool
44
+ # tool_input = action.tool_input
45
+ calling_tool = f"I am calling the '{tool_called}' tool and waiting for it to give me a result..."
46
+ st.session_state.messages.extend(
47
+ [{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}]
48
+ )
49
+ # Add the response to the chat window
50
+ with st.chat_message("assistant"):
51
+ st.markdown(thought)
52
+ st.markdown(calling_tool)
53
+
54
+ # def on_agent_finish(self, finish, **kwargs):
55
+ # """Run on agent end."""
56
+ # #print("\n\nEnd", finish)
57
+ # finish_string = finish.log.replace('\n', ' \n') # so streamlit will recognize as newline
58
+ # st.session_state.messages.append(
59
+ # {"role": "assistant", "content": finish_string}
60
+ # )
61
+ # with st.chat_message("assistant"):
62
+ # st.markdown(finish_string)
63
+
64
+ # def on_llm_start(self, serialized, prompts, **kwargs):
65
+ # """Run when LLM starts running."""
66
+ # print("LLM Start: ", prompts)
67
+
68
+
69
+ # def on_llm_end(self, response, **kwargs):
70
+ # """Run when LLM ends running."""
71
+ # print(response)
72
+
73
+
74
+ def on_tool_end(self, output, **kwargs):
75
+ """Run when tool ends running."""
76
+ #print("\n\nTool End: ", output)
77
+ tool_output = f"Tool Output: {output} \n \nI am processing the output from the tool..."
78
+ st.session_state.messages.append(
79
+ {"role": "assistant", "content": tool_output}
80
+ )
81
+ with st.chat_message("assistant"):
82
+ st.markdown(tool_output)
83
+
84
+ my_callback_handler = MyCallbackHandler()
85
+
86
+ # # Set the webpage title
87
+ # st.set_page_config(
88
+ # page_title="Your own AI-Chat!",
89
+ # layout="wide"
90
+ # )
91
+
92
+ # llm for HuggingFace Inference API
93
+ # model = "mistralai/Mistral-7B-Instruct-v0.2"
94
+ model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
95
+
96
+ # with st.spinner('Downloading pre-built Chroma and BM25 vector stores'):
97
+ # chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings)
98
+
99
+ # Document config
100
+ if 'chunk_size' not in st.session_state:
101
+ st.session_state['chunk_size'] = 1000 # choose one of [500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000]
102
+
103
+ if 'chunk_overlap' not in st.session_state:
104
+ st.session_state['chunk_overlap'] = 100 # choose one of [50, 100, 150, 200]
105
+
106
+ # scraping results using DuckDuckGo
107
+ if 'top_n_results' not in st.session_state:
108
+ st.session_state['top_n_results'] = 10 # this is for returning top n search results using DuckDuckGo
109
+
110
+ if 'countries_to_scrape' not in st.session_state:
111
+ st.session_state['countries_to_scrape'] = [] # this is for returning top n search results using DuckDuckGo
112
+
113
+ # in main app, add configuration for user to scrape new data from DuckDuckGo
114
+ # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
115
+
116
+
117
+
118
+ # Retriever config
119
+ if 'chroma_n_similar_documents' not in st.session_state:
120
+ st.session_state['chroma_n_similar_documents'] = 5 # number of chunks returned by chroma vector store retriever (semantic)
121
+
122
+ if 'bm25_n_similar_documents' not in st.session_state:
123
+ st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword)
124
+
125
+ if 'retriever_config' not in st.session_state:
126
+ st.session_state['retriever_config'] = 'ensemble' # choose one of ['semantic', 'keyword', 'ensemble']
127
+
128
+ if 'keyword_retriever_weight' not in st.session_state:
129
+ st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble
130
+
131
+ if 'source_documents' not in st.session_state:
132
+ st.session_state['source_documents'] = [] # this is to store all source documents for a particular search
133
+
134
+
135
+ # LLM config
136
+ if 'temperature' not in st.session_state:
137
+ st.session_state['temperature'] = 0.25
138
+
139
+ if 'max_new_tokens' not in st.session_state:
140
+ st.session_state['max_new_tokens'] = 500 # max tokens generated by LLM
141
+
142
+ # This is the list of countries present in the vector store, since the vector store is previously prepared as they take very long to prepare
143
+ # This is for the RetrievalQA tool later to check, because even if the country given to it is not in the vector store,
144
+ # it would still filter the vector store with this country and give an empty result, instead of giving an error.
145
+ # We have to manually return the error to let the agent using the tool know.
146
+ # The countries were reduced to just 6 as the time taken to get the embeddings to build up the chunks is too long.
147
+ # However, having more countries **will not affect** the quality of the answers in comparing between 2 countries in the RAG application
148
+ # as the RAG only picks out document chunks for the 2 countries of interest.
149
+ countries = [
150
+ "Australia",
151
+ "China",
152
+ "Japan",
153
+ "Malaysia",
154
+ "Singapore",
155
+ "Germany",
156
+ ]
157
+
158
+ @st.cache_data # only going to get once
159
+ def get_llm(temp = st.session_state['temperature'], tokens = st.session_state['max_new_tokens']):
160
+ # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
161
+ # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
162
+ llm = HuggingFaceHub(repo_id=model,
163
+ model_kwargs={
164
+ 'temperature':temp,
165
+ "max_new_tokens":tokens
166
+ },
167
+ )
168
+ return llm
169
+
170
+ llm = get_llm(st.session_state['temperature'], tokens = st.session_state['max_new_tokens'])
171
+
172
+ @st.cache_data # only going to get once
173
+ def get_embeddings():
174
+ with st.spinner(f'Getting HuggingFaceEmbeddings'):
175
+ # We use HuggingFaceEmbeddings() as it is open source and free to use.
176
+ # Initialize the default hf model for embedding the tokenized texts into vectors with semantic meanings
177
+ hf_embeddings = HuggingFaceEmbeddings()
178
+ return hf_embeddings
179
+
180
+ hf_embeddings = get_embeddings()
181
+
182
+ # Chromadb vector stores have already been pre-created for the countries above for each of the different chunk sizes and overlaps,
183
+ # to save time when experimenting as the embeddings take a long time to generate.
184
+ # The existing stores will be pulled using !wget above when app starts. When using the existing vector stores,
185
+ # just need to change the name of the persist directory when selecting the different chunk sizes and overlaps.
186
+ # Not in this notebook: Later in the main app if the user choose to scrape new data, or override with their own PDF, a new chromadb would be created.
187
+ persist_directory = f"chromadb/chromadb_esg_countries_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}"
188
+ with st.spinner(f'Setting up pre-built chroma vector store'):
189
+ chroma_db = Chroma(persist_directory=persist_directory,embedding_function=hf_embeddings)
190
+
191
+ # Initialize BM25 Retriever
192
+ # Unlike Chroma (semantic) BM25 is a keyword-based algorithm that performs well on queries containing keywords without capturing the semantic meaning of the query terms,
193
+ # hence there is no need to embed the text with HuggingFaceEmbeddings and it is relatively faster to set up.
194
+ # The retrievers with different chunking sizes and overlaps and countries were created in advanced and saved as pickle files and pulled using !wget.
195
+ # Need to initialize one BM25Retriever for each country so the search results later in the main app can be limited to just a particular country.
196
+ # (Chroma DB gives an option to filter metadata for just a particular country during the retrieval processbut BM25 does not because it makes use of external ranking library.)
197
+ # A separate retriever was saved for each country.
198
+ bm25_retrievers = {} # to store retrievers for different countries
199
+ with st.spinner(f'Setting up pre-built bm25 retrievers'):
200
+ for country in countries:
201
+ bm25_filename = f"bm25/bm25_esg_countries_{country}_chunk_{st.session_state['chunk_size']}_overlap_{st.session_state['chunk_overlap']}.pickle"
202
+ with open(bm25_filename, 'rb') as handle:
203
+ bm25_retriever = pickle.load(handle)
204
+ bm25_retrievers[country] = bm25_retriever
205
+
206
+ # Tools for LLM to Use
207
+ # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
208
+ # e.g. carbon emissions policy of Singapore.
209
+ # By calling this tool multiple times, the agent is able to look at the responses from this tool for both countries and compare them.
210
+ # This is far better than just retrieving relevant chunks for the user's query and throw everything to a single RetrievalQA chain to process
211
+ # Multi input tools are not available, hence we have to prompt the agent to give an input list as a string
212
+ # then use ast.literal_eval to convert it back into a list
213
+ @tool
214
+ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change diff chain type diff version answers, change
215
+ """Gives answer to a query about a single country's public ESG policy.
216
+ The input list should be of the following format:
217
+ [query, country]
218
+ The first element of the list is the user query, surrounded by double quotes.
219
+ The second element is the full name of the country involved, surrounded by double quotes, for example "Singapore".
220
+ The 2 inputs are separated by a comma. Do not write a list comprehension.
221
+ The 2 inputs, together, are surrounded by square brackets as it is a list.
222
+ Do not put multiple countries into the input at once. Instead use this tool multiple times, one time for each country.
223
+ If you have multiple queries to ask about a country, break the query into separate parts and use this tool multiple times, one for each query.
224
+ """
225
+ try:
226
+ query_and_country_list = ast.literal_eval(query_and_country)
227
+ query = query_and_country_list[0]
228
+ country = query_and_country_list[1].capitalize() # in case LLM did not capitalize first letter as filtering for metadata is case sensitive
229
+ if not country in countries:
230
+ return """The country that you input into the tool cannot be found.
231
+ If you did not make a mistake and the country that you input is indeed what the user asked,
232
+ then there is no record for the country and no answer can be obtained."""
233
+
234
+ # different retrievers
235
+ bm = bm25_retrievers[country] # keyword based
236
+ bm.k = st.session_state['bm25_n_similar_documents']
237
+ chroma = chroma_db.as_retriever(search_kwargs={'filter': {'country':country}, 'k': st.session_state['chroma_n_similar_documents']}) # semantic
238
+ # ensemble (below) reranks results from both retrievers
239
+ ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
240
+ retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
241
+
242
+ qa = RetrievalQA.from_chain_type(
243
+ llm=llm,
244
+ chain_type='stuff',
245
+ retriever=retrievers[st.session_state['retriever_config']], # selected retriever based on user config
246
+ return_source_documents=True # returned in result['source_documents']
247
+ )
248
+ result = qa(query)
249
+ st.session_state['source_documents'].append(result['source_documents']) # let user know what source docs are used
250
+ return result['result']
251
+
252
+ except Exception as e:
253
+ return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again.
254
+ Remember the 2 inputs, query and country, must both be surrounded by double quotes.
255
+ The 2 inputs, together, are surrounded by square brackets as it is a list."""
256
+
257
+ # if a user tries to casually chat with the agent chatbot, the LLM will be able to use this tool to reply instead
258
+ # this is optional, better to let user's know the chatbot is not for casual chatting
259
+ @tool
260
+ def generic_chat_llm(query: str) -> str:
261
+ """Use this tool for general queries and casual chat. Forward the user input directly into this tool, do not come up with your own input.
262
+ This tool IS NOT FOR MAKING COMPARISONS of anything.
263
+ This tool IS NOT FOR FINDING ESG POLICY of any country!
264
+ It is only for casual chat! Do not use this tool unnecessarily!
265
+ """
266
+ try:
267
+ # Second Generic Tool
268
+ prompt = PromptTemplate(
269
+ input_variables=["query"],
270
+ template="{query}"
271
+ )
272
+
273
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
274
+ return llm_chain.run(query)
275
+
276
+ except Exception as e:
277
+ return f"""There is an error using this tool: {e}. Check if you have input anything wrongly and try again.
278
+ If you have already tried 2 times, do not try anymore, there is no response for your input.
279
+ Move on to the next step of your plan."""
280
+
281
+ # sometimes the agent will suddenly ask for a 'compare' tool even though it was not given this tool
282
+ # hence I have decided to give it this tool that gives a prompt to remind it to look at past information
283
+ # and decide whether it is time to darw a conclusion
284
+ # tools cannot have no input, hence I let the agent input a 'query' parameter even though it is not used
285
+ # having the query as input let the LLM 'recall' what is being asked
286
+ # instead of it being lost all the way at the start of the ReAct process
287
+ @tool
288
+ def compare(query:str) -> str:
289
+ """Use this tool to give you hints and instructions on how you can compare between policies of countries.
290
+ Use this tool only at one of your final steps, do not use it at the start.
291
+ When putting the query into this tool, look at the entire query that the user has asked at the start,
292
+ do not leave any details in the query out.
293
+ """
294
+ return f"""Look at all your previous observations to answer the user query.
295
+ Use as much relevant information as possible but only from your previous thoughts and observations.
296
+ If you need more details, you can use a tool to find out more. If you have enough information,
297
+ use your reasoning to answer them to the best of your ability. Give as much detail as you want in your answer."""
298
+
299
+ retrieve_answer_for_country.callbacks = [my_callback_handler]
300
+ compare.callbacks = [my_callback_handler]
301
+ generic_chat_llm.callbacks = [my_callback_handler]
302
+
303
+ agent = initialize_agent(
304
+ [retrieve_answer_for_country, compare], # tools
305
+ #[retrieve_answer_for_country, generic_chat_llm, compare],
306
+ llm=llm,
307
+ agent="zero-shot-react-description", # this is good
308
+ verbose=False,
309
+ handle_parsing_errors=True,
310
+ return_intermediate_steps=True,
311
+ callbacks=[my_callback_handler]
312
+ # memory=ConversationBufferMemory(
313
+ # memory_key="chat_history", return_messages=True
314
+ # ),
315
+ # max_iterations=10
316
+ )
317
+
318
+
319
+
320
+ # Create a header element
321
+ st.header("Chat")
322
+
323
+ col1, col2 = st.columns(2)
324
+ # with col1:
325
+
326
+ # Store the conversation in the session state.
327
+ # Used to render the chat conversation.
328
+ # Initialize it with the first message for users to be greeted with
329
+ if "messages" not in st.session_state:
330
+ st.session_state.messages = [
331
+ {"role": "assistant", "content": "How may I help you today?"}
332
+ ]
333
+
334
+ if "current_response" not in st.session_state:
335
+ st.session_state.current_response = ""
336
+
337
+ # Loop through each message in the session state and render it as a chat message.
338
+ for message in st.session_state.messages:
339
+ with st.chat_message(message["role"]):
340
+ st.markdown(message["content"])
341
+
342
+ # We initialize the quantized LLM from a local path.
343
+ # Currently most parameters are fixed but we can make them
344
+ # configurable.
345
+ #llm_chain = create_chain(retriever)
346
+
347
+ # We take questions/instructions from the chat input to pass to the LLM
348
+ if user_query := st.chat_input("Your message here", key="user_input"):
349
+
350
+ # Add our input to the session state
351
+ st.session_state.messages.append(
352
+ {"role": "user", "content": user_query}
353
+ )
354
+
355
+ # Add our input to the chat window
356
+ with st.chat_message("user"):
357
+ st.markdown(user_query)
358
+
359
+ # Let user know agent is planning the actions
360
+ action_plan_message = "Please wait while I plan out a best set of actions to obtain the information and answer your query."
361
+
362
+ # Add the response to the session state
363
+ st.session_state.messages.append(
364
+ {"role": "assistant", "content": action_plan_message}
365
+ )
366
+ # Add the response to the chat window
367
+ with st.chat_message("assistant"):
368
+ st.markdown(action_plan_message)
369
+
370
+ # Pass our input to the llm chain and capture the final responses.
371
+ # It is worth noting that the Stream Handler is already receiving the
372
+ # streaming response as the llm is generating. We get our response
373
+ # here once the llm has finished generating the complete response.
374
+ results = agent(user_query)
375
+ response = f"The answer to your query is: {results['output']}"
376
+
377
+ # Add the response to the session state
378
+ st.session_state.messages.append(
379
+ {"role": "assistant", "content": response}
380
+ )
381
+
382
+ # Add the response to the chat window
383
+ with st.chat_message("assistant"):
384
+ st.markdown(response)
385
+
386
+
387
+ # with col2:
388
+ # st.write("hi")
389
+
390
+
391
+