Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -52,7 +52,7 @@ class MyCallbackHandler(BaseCallbackHandler):
|
|
52 |
[{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}]
|
53 |
)
|
54 |
# Add the response to the chat window
|
55 |
-
with
|
56 |
st.markdown(thought)
|
57 |
st.markdown(calling_tool)
|
58 |
|
@@ -83,7 +83,7 @@ class MyCallbackHandler(BaseCallbackHandler):
|
|
83 |
st.session_state.messages.append(
|
84 |
{"role": "assistant", "content": tool_output}
|
85 |
)
|
86 |
-
with
|
87 |
st.markdown(tool_output)
|
88 |
|
89 |
my_callback_handler = MyCallbackHandler()
|
@@ -121,7 +121,7 @@ if 'bm25_n_similar_documents' not in st.session_state:
|
|
121 |
st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword)
|
122 |
|
123 |
if 'retriever_config' not in st.session_state:
|
124 |
-
st.session_state['retriever_config'] = '
|
125 |
|
126 |
if 'keyword_retriever_weight' not in st.session_state:
|
127 |
st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble
|
@@ -160,7 +160,6 @@ countries = [
|
|
160 |
|
161 |
|
162 |
################################ Get LLM and Embeddings ################################
|
163 |
-
# when LLM config change we will call the function again
|
164 |
def get_llm():
|
165 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
166 |
# It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
|
@@ -185,6 +184,11 @@ def get_embeddings():
|
|
185 |
llm = get_llm()
|
186 |
hf_embeddings = get_embeddings()
|
187 |
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
################################ Download and Initialize Pre-Built Retrievers ################################
|
190 |
|
@@ -238,6 +242,12 @@ def get_retrievers():
|
|
238 |
|
239 |
chroma_db, bm25_retrievers = get_retrievers()
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
################################ Tools for Agent to Use ################################
|
242 |
|
243 |
# The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
|
@@ -276,7 +286,7 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
|
|
276 |
# ensemble (below) reranks results from both retrievers above
|
277 |
ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
|
278 |
# for user to make selection
|
279 |
-
retrievers = {'
|
280 |
|
281 |
qa = RetrievalQA.from_chain_type(
|
282 |
llm=llm,
|
@@ -362,30 +372,91 @@ agent = initialize_agent(
|
|
362 |
# max_iterations=10
|
363 |
)
|
364 |
|
365 |
-
|
366 |
-
if "menu" not in st.session_state:
|
367 |
-
st.session_state["menu"] = [
|
368 |
-
"Chatbot",
|
369 |
-
"Chat Config",
|
370 |
-
"Document, Retriever, Web Scraping Config",
|
371 |
-
"Source Documents for Last Query",
|
372 |
-
]
|
373 |
-
|
374 |
################################ Sidebar with Menu ################################
|
375 |
with st.sidebar:
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
-
tab1, tab2, tab3 = st.tabs(["Cat", "Dog", "Owl"])
|
383 |
|
384 |
################################ Main Chatbot Page ################################
|
385 |
-
|
386 |
-
|
387 |
-
#st.header("Chat")
|
388 |
-
messages = st.container()
|
389 |
|
390 |
# Store the conversation in the session state.
|
391 |
# Used to render the chat conversation.
|
@@ -403,7 +474,7 @@ with tab1:
|
|
403 |
|
404 |
# Loop through each message in the session state and render it as a chat message
|
405 |
for message in st.session_state.messages:
|
406 |
-
with
|
407 |
st.markdown(message["content"])
|
408 |
|
409 |
# We take questions/instructions from the chat input to pass to the LLM
|
@@ -419,7 +490,7 @@ with tab1:
|
|
419 |
)
|
420 |
|
421 |
# Add our input to the chat window
|
422 |
-
with
|
423 |
st.markdown(formatted_user_query)
|
424 |
|
425 |
# Let user know agent is planning the actions
|
@@ -430,7 +501,7 @@ with tab1:
|
|
430 |
{"role": "assistant", "content": action_plan_message}
|
431 |
)
|
432 |
# Add the response to the chat window
|
433 |
-
with
|
434 |
st.markdown(action_plan_message)
|
435 |
|
436 |
results = agent(user_query)
|
@@ -442,63 +513,32 @@ with tab1:
|
|
442 |
)
|
443 |
|
444 |
# Add the response to the chat window
|
445 |
-
with
|
446 |
st.markdown(response)
|
447 |
|
448 |
|
449 |
-
################################ Chat Config Page ################################
|
450 |
-
# for changing config like temperature etc.
|
451 |
-
with tab2:
|
452 |
-
# if page == "Chat Config":
|
453 |
-
# st.header(page)
|
454 |
-
|
455 |
-
st.selectbox(
|
456 |
-
"HuggingFace Inference Model",
|
457 |
-
options=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"],
|
458 |
-
on_change=get_llm,
|
459 |
-
key="model"
|
460 |
-
)
|
461 |
-
|
462 |
-
st.slider(
|
463 |
-
"Temperature",
|
464 |
-
0.0, 1.0, 0.05,
|
465 |
-
#value = st.session_state['temperature'],
|
466 |
-
on_change=get_llm,
|
467 |
-
key="temperature"
|
468 |
-
)
|
469 |
-
|
470 |
|
471 |
################################ Document Page ################################
|
472 |
# to scrape new documents from DuckDuckGo
|
473 |
# to chnange paramters like chunk size
|
474 |
# to upload own PDF
|
475 |
# to override existing data on new scraped data or new pdf uploaded
|
476 |
-
with tab3:
|
477 |
-
# if page == "Document, Retriever, Web Scraping Config":
|
478 |
-
# st.header(page)
|
479 |
-
|
480 |
-
st.selectbox(
|
481 |
-
"Chunk Size",
|
482 |
-
options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000],
|
483 |
-
on_change=get_retrievers,
|
484 |
-
key="chunk_size"
|
485 |
-
)
|
486 |
-
|
487 |
-
st.selectbox(
|
488 |
-
"Chunk Overlap",
|
489 |
-
options=[50, 100, 150, 200],
|
490 |
-
on_change=get_retrievers,
|
491 |
-
key="chunk_overlap"
|
492 |
-
)
|
493 |
|
494 |
|
495 |
-
|
496 |
-
|
497 |
-
|
|
|
498 |
try:
|
499 |
st.subheader(st.session_state['source_documents'][0])
|
500 |
for doc in st.session_state['source_documents'][1:]:
|
501 |
-
st.write("Source: " + doc.metadata['source'])
|
502 |
st.write(doc)
|
503 |
except:
|
504 |
st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
[{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}]
|
53 |
)
|
54 |
# Add the response to the chat window
|
55 |
+
with st.chat_message("assistant"):
|
56 |
st.markdown(thought)
|
57 |
st.markdown(calling_tool)
|
58 |
|
|
|
83 |
st.session_state.messages.append(
|
84 |
{"role": "assistant", "content": tool_output}
|
85 |
)
|
86 |
+
with st.chat_message("assistant"):
|
87 |
st.markdown(tool_output)
|
88 |
|
89 |
my_callback_handler = MyCallbackHandler()
|
|
|
121 |
st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword)
|
122 |
|
123 |
if 'retriever_config' not in st.session_state:
|
124 |
+
st.session_state['retriever_config'] = 'Ensemble (Both Re-Ranked)' # choose one of ['semantic', 'keyword', 'ensemble']
|
125 |
|
126 |
if 'keyword_retriever_weight' not in st.session_state:
|
127 |
st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble
|
|
|
160 |
|
161 |
|
162 |
################################ Get LLM and Embeddings ################################
|
|
|
163 |
def get_llm():
|
164 |
# This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
|
165 |
# It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
|
|
|
184 |
llm = get_llm()
|
185 |
hf_embeddings = get_embeddings()
|
186 |
|
187 |
+
# when LLM config is changed we will call this function
|
188 |
+
def update_llm():
|
189 |
+
global llm
|
190 |
+
llm = get_llm()
|
191 |
+
|
192 |
|
193 |
################################ Download and Initialize Pre-Built Retrievers ################################
|
194 |
|
|
|
242 |
|
243 |
chroma_db, bm25_retrievers = get_retrievers()
|
244 |
|
245 |
+
# when retriever config is changed we will call this function
|
246 |
+
def update_retrievers():
|
247 |
+
global chroma_db
|
248 |
+
global bm25_retrievers
|
249 |
+
chroma_db, bm25_retrievers = get_retrievers()
|
250 |
+
|
251 |
################################ Tools for Agent to Use ################################
|
252 |
|
253 |
# The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
|
|
|
286 |
# ensemble (below) reranks results from both retrievers above
|
287 |
ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
|
288 |
# for user to make selection
|
289 |
+
retrievers = {'Ensemble (Both Re-Ranked)': ensemble, 'Semantic (Chroma DB)': chroma, 'Keyword (BM 2.5)': bm}
|
290 |
|
291 |
qa = RetrievalQA.from_chain_type(
|
292 |
llm=llm,
|
|
|
372 |
# max_iterations=10
|
373 |
)
|
374 |
|
375 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
################################ Sidebar with Menu ################################
|
377 |
with st.sidebar:
|
378 |
+
page = option_menu("Chatbot",
|
379 |
+
[
|
380 |
+
"Main Chatbot",
|
381 |
+
"View Source Docs for Last Query",
|
382 |
+
"Scrape or Upload Docs",
|
383 |
+
],
|
384 |
+
icons=['house', 'gear', 'gear', 'gear'],
|
385 |
+
menu_icon="", default_index=0)
|
386 |
+
|
387 |
+
with st.container(border = True):
|
388 |
+
st.write("DO NOT NAVIGATE between pages or change when agent is still generating messages in the chat. Wait for query to complete first.")
|
389 |
+
st.write("")
|
390 |
+
|
391 |
+
with st.expander("LLM Config", expanded = True):
|
392 |
+
|
393 |
+
st.selectbox(
|
394 |
+
"HuggingFace Inference Model",
|
395 |
+
options=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"],
|
396 |
+
on_change=update_llm,
|
397 |
+
key="model"
|
398 |
+
)
|
399 |
+
|
400 |
+
st.slider(
|
401 |
+
"Temperature",
|
402 |
+
0.0, 1.0, 0.05,
|
403 |
+
#value = st.session_state['temperature'],
|
404 |
+
on_change=update_llm,
|
405 |
+
key="temperature"
|
406 |
+
)
|
407 |
+
|
408 |
+
st.slider(
|
409 |
+
"Max Tokens Generated",
|
410 |
+
200, 1000,
|
411 |
+
on_change=update_llm,
|
412 |
+
key="max_new_tokens"
|
413 |
+
)
|
414 |
+
|
415 |
+
with st.expander("Document Config", expanded = True):
|
416 |
+
st.selectbox(
|
417 |
+
"Chunk Size",
|
418 |
+
options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000],
|
419 |
+
on_change=update_retrievers,
|
420 |
+
key="chunk_size"
|
421 |
+
)
|
422 |
+
|
423 |
+
st.selectbox(
|
424 |
+
"Chunk Overlap",
|
425 |
+
options=[50, 100, 150, 200],
|
426 |
+
on_change=update_retrievers,
|
427 |
+
key="chunk_overlap"
|
428 |
+
)
|
429 |
+
|
430 |
+
with st.expander("Retriever Config", expanded = True):
|
431 |
+
|
432 |
+
st.selectbox(
|
433 |
+
"Retriever to Use",
|
434 |
+
options=['Ensemble (Both Re-Ranked)', 'Semantic (Chroma DB)', 'Keyword (BM 2.5)'],
|
435 |
+
key="retriever_config"
|
436 |
+
)
|
437 |
+
|
438 |
+
st.slider(
|
439 |
+
"Keyword Retriever Weight (If using ensemble retriever, this is the weight of the keyword retriever, semantic retriever would be 1 minus this value)",
|
440 |
+
0.0, 0.05, 1.0,
|
441 |
+
key="keyword_retriever_weight"
|
442 |
+
)
|
443 |
+
|
444 |
+
st.slider(
|
445 |
+
"Number of Relevant Documents Returned by Keyword Retriever",
|
446 |
+
0, 1, 20,
|
447 |
+
key="bm25_n_similar_documents"
|
448 |
+
)
|
449 |
+
|
450 |
+
st.slider(
|
451 |
+
"Number of Relevant Documents Returned by Semantic Retriever",
|
452 |
+
0, 1, 20,
|
453 |
+
key="chroma_n_similar_documents"
|
454 |
+
)
|
455 |
|
|
|
456 |
|
457 |
################################ Main Chatbot Page ################################
|
458 |
+
if page == "Main Chatbot":
|
459 |
+
st.subheader("Chatbot")
|
|
|
|
|
460 |
|
461 |
# Store the conversation in the session state.
|
462 |
# Used to render the chat conversation.
|
|
|
474 |
|
475 |
# Loop through each message in the session state and render it as a chat message
|
476 |
for message in st.session_state.messages:
|
477 |
+
with st.chat_message(message["role"]):
|
478 |
st.markdown(message["content"])
|
479 |
|
480 |
# We take questions/instructions from the chat input to pass to the LLM
|
|
|
490 |
)
|
491 |
|
492 |
# Add our input to the chat window
|
493 |
+
with st.chat_message("user"):
|
494 |
st.markdown(formatted_user_query)
|
495 |
|
496 |
# Let user know agent is planning the actions
|
|
|
501 |
{"role": "assistant", "content": action_plan_message}
|
502 |
)
|
503 |
# Add the response to the chat window
|
504 |
+
with st.chat_message("assistant"):
|
505 |
st.markdown(action_plan_message)
|
506 |
|
507 |
results = agent(user_query)
|
|
|
513 |
)
|
514 |
|
515 |
# Add the response to the chat window
|
516 |
+
with st.chat_message("assistant"):
|
517 |
st.markdown(response)
|
518 |
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
################################ Document Page ################################
|
522 |
# to scrape new documents from DuckDuckGo
|
523 |
# to chnange paramters like chunk size
|
524 |
# to upload own PDF
|
525 |
# to override existing data on new scraped data or new pdf uploaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
|
528 |
+
|
529 |
+
################################ Source Documents Page ################################
|
530 |
+
if page == "View Source Docs for Last Query":
|
531 |
+
st.header("Source Documents for Last Query")
|
532 |
try:
|
533 |
st.subheader(st.session_state['source_documents'][0])
|
534 |
for doc in st.session_state['source_documents'][1:]:
|
535 |
+
#st.write("Source: " + doc.metadata['source'])
|
536 |
st.write(doc)
|
537 |
except:
|
538 |
st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
+
# in main app, add configuration for user to scrape new data from DuckDuckGo
|
543 |
+
# in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
|
544 |
+
|