bohmian commited on
Commit
eabbef9
·
verified ·
1 Parent(s): 610ed65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -70
app.py CHANGED
@@ -52,7 +52,7 @@ class MyCallbackHandler(BaseCallbackHandler):
52
  [{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}]
53
  )
54
  # Add the response to the chat window
55
- with messages.chat_message("assistant"):
56
  st.markdown(thought)
57
  st.markdown(calling_tool)
58
 
@@ -83,7 +83,7 @@ class MyCallbackHandler(BaseCallbackHandler):
83
  st.session_state.messages.append(
84
  {"role": "assistant", "content": tool_output}
85
  )
86
- with messages.chat_message("assistant"):
87
  st.markdown(tool_output)
88
 
89
  my_callback_handler = MyCallbackHandler()
@@ -121,7 +121,7 @@ if 'bm25_n_similar_documents' not in st.session_state:
121
  st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword)
122
 
123
  if 'retriever_config' not in st.session_state:
124
- st.session_state['retriever_config'] = 'ensemble' # choose one of ['semantic', 'keyword', 'ensemble']
125
 
126
  if 'keyword_retriever_weight' not in st.session_state:
127
  st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble
@@ -160,7 +160,6 @@ countries = [
160
 
161
 
162
  ################################ Get LLM and Embeddings ################################
163
- # when LLM config change we will call the function again
164
  def get_llm():
165
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
166
  # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
@@ -185,6 +184,11 @@ def get_embeddings():
185
  llm = get_llm()
186
  hf_embeddings = get_embeddings()
187
 
 
 
 
 
 
188
 
189
  ################################ Download and Initialize Pre-Built Retrievers ################################
190
 
@@ -238,6 +242,12 @@ def get_retrievers():
238
 
239
  chroma_db, bm25_retrievers = get_retrievers()
240
 
 
 
 
 
 
 
241
  ################################ Tools for Agent to Use ################################
242
 
243
  # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
@@ -276,7 +286,7 @@ def retrieve_answer_for_country(query_and_country: str) -> str: # TODO, change d
276
  # ensemble (below) reranks results from both retrievers above
277
  ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
278
  # for user to make selection
279
- retrievers = {'ensemble': ensemble, 'semantic': chroma, 'keyword': bm}
280
 
281
  qa = RetrievalQA.from_chain_type(
282
  llm=llm,
@@ -362,30 +372,91 @@ agent = initialize_agent(
362
  # max_iterations=10
363
  )
364
 
365
- # menu options
366
- if "menu" not in st.session_state:
367
- st.session_state["menu"] = [
368
- "Chatbot",
369
- "Chat Config",
370
- "Document, Retriever, Web Scraping Config",
371
- "Source Documents for Last Query",
372
- ]
373
-
374
  ################################ Sidebar with Menu ################################
375
  with st.sidebar:
376
- st.subheader("DO NOT NAVIGATE between pages when agent is still generating messages in the chat. Wait for query to complete first.")
377
- st.write("")
378
- page = option_menu("Main Menu", st.session_state["menu"],
379
- icons=['house', 'gear', 'gear', 'gear'], menu_icon="cast", default_index=0)
380
- st.write(st.session_state['chunk_size'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- tab1, tab2, tab3 = st.tabs(["Cat", "Dog", "Owl"])
383
 
384
  ################################ Main Chatbot Page ################################
385
- with tab1:
386
- #if page == "Chatbot":
387
- #st.header("Chat")
388
- messages = st.container()
389
 
390
  # Store the conversation in the session state.
391
  # Used to render the chat conversation.
@@ -403,7 +474,7 @@ with tab1:
403
 
404
  # Loop through each message in the session state and render it as a chat message
405
  for message in st.session_state.messages:
406
- with messages.chat_message(message["role"]):
407
  st.markdown(message["content"])
408
 
409
  # We take questions/instructions from the chat input to pass to the LLM
@@ -419,7 +490,7 @@ with tab1:
419
  )
420
 
421
  # Add our input to the chat window
422
- with messages.chat_message("user"):
423
  st.markdown(formatted_user_query)
424
 
425
  # Let user know agent is planning the actions
@@ -430,7 +501,7 @@ with tab1:
430
  {"role": "assistant", "content": action_plan_message}
431
  )
432
  # Add the response to the chat window
433
- with messages.chat_message("assistant"):
434
  st.markdown(action_plan_message)
435
 
436
  results = agent(user_query)
@@ -442,63 +513,32 @@ with tab1:
442
  )
443
 
444
  # Add the response to the chat window
445
- with messages.chat_message("assistant"):
446
  st.markdown(response)
447
 
448
 
449
- ################################ Chat Config Page ################################
450
- # for changing config like temperature etc.
451
- with tab2:
452
- # if page == "Chat Config":
453
- # st.header(page)
454
-
455
- st.selectbox(
456
- "HuggingFace Inference Model",
457
- options=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"],
458
- on_change=get_llm,
459
- key="model"
460
- )
461
-
462
- st.slider(
463
- "Temperature",
464
- 0.0, 1.0, 0.05,
465
- #value = st.session_state['temperature'],
466
- on_change=get_llm,
467
- key="temperature"
468
- )
469
-
470
 
471
  ################################ Document Page ################################
472
  # to scrape new documents from DuckDuckGo
473
  # to chnange paramters like chunk size
474
  # to upload own PDF
475
  # to override existing data on new scraped data or new pdf uploaded
476
- with tab3:
477
- # if page == "Document, Retriever, Web Scraping Config":
478
- # st.header(page)
479
-
480
- st.selectbox(
481
- "Chunk Size",
482
- options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000],
483
- on_change=get_retrievers,
484
- key="chunk_size"
485
- )
486
-
487
- st.selectbox(
488
- "Chunk Overlap",
489
- options=[50, 100, 150, 200],
490
- on_change=get_retrievers,
491
- key="chunk_overlap"
492
- )
493
 
494
 
495
- ################################ Main Chatbot Page ################################
496
- if page == "Source Documents for Last Query":
497
- st.header(page)
 
498
  try:
499
  st.subheader(st.session_state['source_documents'][0])
500
  for doc in st.session_state['source_documents'][1:]:
501
- st.write("Source: " + doc.metadata['source'])
502
  st.write(doc)
503
  except:
504
  st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")
 
 
 
 
 
 
 
52
  [{"role": "assistant", "content": thought}, {"role": "assistant", "content": calling_tool}]
53
  )
54
  # Add the response to the chat window
55
+ with st.chat_message("assistant"):
56
  st.markdown(thought)
57
  st.markdown(calling_tool)
58
 
 
83
  st.session_state.messages.append(
84
  {"role": "assistant", "content": tool_output}
85
  )
86
+ with st.chat_message("assistant"):
87
  st.markdown(tool_output)
88
 
89
  my_callback_handler = MyCallbackHandler()
 
121
  st.session_state['bm25_n_similar_documents'] = 5 # number of chunks returned by bm25 retriever (keyword)
122
 
123
  if 'retriever_config' not in st.session_state:
124
+ st.session_state['retriever_config'] = 'Ensemble (Both Re-Ranked)' # choose one of ['semantic', 'keyword', 'ensemble']
125
 
126
  if 'keyword_retriever_weight' not in st.session_state:
127
  st.session_state['keyword_retriever_weight'] = 0.3 # choose between 0 and 1, only when using ensemble
 
160
 
161
 
162
  ################################ Get LLM and Embeddings ################################
 
163
  def get_llm():
164
  # This is an inference endpoint API from huggingface, the model is not run locally, it is run on huggingface
165
  # It is a free API that is very good for deploying online for quick testing without users having to deploy a local LLM
 
184
  llm = get_llm()
185
  hf_embeddings = get_embeddings()
186
 
187
+ # when LLM config is changed we will call this function
188
+ def update_llm():
189
+ global llm
190
+ llm = get_llm()
191
+
192
 
193
  ################################ Download and Initialize Pre-Built Retrievers ################################
194
 
 
242
 
243
  chroma_db, bm25_retrievers = get_retrievers()
244
 
245
+ # when retriever config is changed we will call this function
246
+ def update_retrievers():
247
+ global chroma_db
248
+ global bm25_retrievers
249
+ chroma_db, bm25_retrievers = get_retrievers()
250
+
251
  ################################ Tools for Agent to Use ################################
252
 
253
  # The most important tool is the first one, which uses a RetrievalQA chain to answer a question about a specific country's ESG policies,
 
286
  # ensemble (below) reranks results from both retrievers above
287
  ensemble = EnsembleRetriever(retrievers=[bm, chroma], weights=[st.session_state['keyword_retriever_weight'], 1 - st.session_state['keyword_retriever_weight']])
288
  # for user to make selection
289
+ retrievers = {'Ensemble (Both Re-Ranked)': ensemble, 'Semantic (Chroma DB)': chroma, 'Keyword (BM 2.5)': bm}
290
 
291
  qa = RetrievalQA.from_chain_type(
292
  llm=llm,
 
372
  # max_iterations=10
373
  )
374
 
375
+
 
 
 
 
 
 
 
 
376
  ################################ Sidebar with Menu ################################
377
  with st.sidebar:
378
+ page = option_menu("Chatbot",
379
+ [
380
+ "Main Chatbot",
381
+ "View Source Docs for Last Query",
382
+ "Scrape or Upload Docs",
383
+ ],
384
+ icons=['house', 'gear', 'gear', 'gear'],
385
+ menu_icon="", default_index=0)
386
+
387
+ with st.container(border = True):
388
+ st.write("DO NOT NAVIGATE between pages or change when agent is still generating messages in the chat. Wait for query to complete first.")
389
+ st.write("")
390
+
391
+ with st.expander("LLM Config", expanded = True):
392
+
393
+ st.selectbox(
394
+ "HuggingFace Inference Model",
395
+ options=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.2"],
396
+ on_change=update_llm,
397
+ key="model"
398
+ )
399
+
400
+ st.slider(
401
+ "Temperature",
402
+ 0.0, 1.0, 0.05,
403
+ #value = st.session_state['temperature'],
404
+ on_change=update_llm,
405
+ key="temperature"
406
+ )
407
+
408
+ st.slider(
409
+ "Max Tokens Generated",
410
+ 200, 1000,
411
+ on_change=update_llm,
412
+ key="max_new_tokens"
413
+ )
414
+
415
+ with st.expander("Document Config", expanded = True):
416
+ st.selectbox(
417
+ "Chunk Size",
418
+ options=[500, 600, 700, 800, 900, 1000, 1250, 1500, 1750, 2000, 2250, 2500, 2750, 3000],
419
+ on_change=update_retrievers,
420
+ key="chunk_size"
421
+ )
422
+
423
+ st.selectbox(
424
+ "Chunk Overlap",
425
+ options=[50, 100, 150, 200],
426
+ on_change=update_retrievers,
427
+ key="chunk_overlap"
428
+ )
429
+
430
+ with st.expander("Retriever Config", expanded = True):
431
+
432
+ st.selectbox(
433
+ "Retriever to Use",
434
+ options=['Ensemble (Both Re-Ranked)', 'Semantic (Chroma DB)', 'Keyword (BM 2.5)'],
435
+ key="retriever_config"
436
+ )
437
+
438
+ st.slider(
439
+ "Keyword Retriever Weight (If using ensemble retriever, this is the weight of the keyword retriever, semantic retriever would be 1 minus this value)",
440
+ 0.0, 0.05, 1.0,
441
+ key="keyword_retriever_weight"
442
+ )
443
+
444
+ st.slider(
445
+ "Number of Relevant Documents Returned by Keyword Retriever",
446
+ 0, 1, 20,
447
+ key="bm25_n_similar_documents"
448
+ )
449
+
450
+ st.slider(
451
+ "Number of Relevant Documents Returned by Semantic Retriever",
452
+ 0, 1, 20,
453
+ key="chroma_n_similar_documents"
454
+ )
455
 
 
456
 
457
  ################################ Main Chatbot Page ################################
458
+ if page == "Main Chatbot":
459
+ st.subheader("Chatbot")
 
 
460
 
461
  # Store the conversation in the session state.
462
  # Used to render the chat conversation.
 
474
 
475
  # Loop through each message in the session state and render it as a chat message
476
  for message in st.session_state.messages:
477
+ with st.chat_message(message["role"]):
478
  st.markdown(message["content"])
479
 
480
  # We take questions/instructions from the chat input to pass to the LLM
 
490
  )
491
 
492
  # Add our input to the chat window
493
+ with st.chat_message("user"):
494
  st.markdown(formatted_user_query)
495
 
496
  # Let user know agent is planning the actions
 
501
  {"role": "assistant", "content": action_plan_message}
502
  )
503
  # Add the response to the chat window
504
+ with st.chat_message("assistant"):
505
  st.markdown(action_plan_message)
506
 
507
  results = agent(user_query)
 
513
  )
514
 
515
  # Add the response to the chat window
516
+ with st.chat_message("assistant"):
517
  st.markdown(response)
518
 
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  ################################ Document Page ################################
522
  # to scrape new documents from DuckDuckGo
523
  # to chnange paramters like chunk size
524
  # to upload own PDF
525
  # to override existing data on new scraped data or new pdf uploaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
 
528
+
529
+ ################################ Source Documents Page ################################
530
+ if page == "View Source Docs for Last Query":
531
+ st.header("Source Documents for Last Query")
532
  try:
533
  st.subheader(st.session_state['source_documents'][0])
534
  for doc in st.session_state['source_documents'][1:]:
535
+ #st.write("Source: " + doc.metadata['source'])
536
  st.write(doc)
537
  except:
538
  st.write("No source documents retrieved yet. Please run a user query before coming back to this page.")
539
+
540
+
541
+
542
+ # in main app, add configuration for user to scrape new data from DuckDuckGo
543
+ # in main app, add configuration for user to upload PDF to override country's existing policies in vectorstore
544
+