fsal commited on
Commit
b85d32f
1 Parent(s): ba0d57d

add sources in main chat

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +52 -27
langchain-streamlit-demo/app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from datetime import datetime
2
  from typing import Any, Dict, List, Optional, Tuple, Union
3
 
@@ -9,16 +10,16 @@ from defaults import default_values
9
  from langchain.agents.tools import tool
10
  from langchain.callbacks.base import BaseCallbackHandler
11
  from langchain.callbacks.manager import Callbacks
 
 
12
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
13
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
14
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
15
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
16
  from langchain.schema.document import Document
17
  from langchain.schema.retriever import BaseRetriever
18
- from langchain_community.callbacks import StreamlitCallbackHandler
19
  from langsmith.client import Client
20
  from llm_resources import (
21
- get_agent,
22
  get_llm,
23
  get_runnable,
24
  get_texts_and_multiretriever,
@@ -44,6 +45,8 @@ def st_init_null(*variable_names) -> None:
44
 
45
 
46
  st_init_null(
 
 
47
  "chain",
48
  "client",
49
  "doc_chain",
@@ -430,7 +433,7 @@ if st.session_state.llm:
430
  if question and question != "--" and not prompt:
431
  prompt = question
432
  if not uploaded_file:
433
- st.error("Please upload a PDF to use the document chat feature.")
434
  elif prompt:
435
  feedback_update = None
436
  feedback = None
@@ -463,6 +466,7 @@ if st.session_state.llm:
463
  full_response: Union[str, None] = None
464
 
465
  message_placeholder = st.empty()
 
466
  default_tools = [
467
  # DuckDuckGoSearchRun(),
468
  # WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
@@ -474,10 +478,10 @@ if st.session_state.llm:
474
  # search_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
475
  # writer_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
476
  # )
477
- st_callback = StreamlitCallbackHandler(
478
- st.container(), expand_new_thoughts=False
479
- )
480
- callbacks.append(st_callback)
481
 
482
  # @tool("web-research-assistant")
483
  # def research_assistant_tool(question: str, callbacks: Callbacks = None):
@@ -526,10 +530,16 @@ if st.session_state.llm:
526
  input_str,
527
  config=get_config(callbacks),
528
  )
 
529
  with st.sidebar.expander("Sources"):
530
- for source in response["source_documents"][:3]:
531
  st.markdown("-" * 50)
532
  st.markdown(source.page_content)
 
 
 
 
 
533
  return response["output_text"]
534
 
535
  # doc_chain_agent = get_doc_agent(
@@ -555,29 +565,31 @@ if st.session_state.llm:
555
 
556
  TOOLS = TOOLS + [doc_chain_tool]
557
 
558
- st.session_state.chain = get_agent(
559
- TOOLS,
560
- STMEMORY,
561
- st.session_state.llm,
562
- callbacks,
563
- )
564
- # else:
565
- # st.session_state.chain = get_runnable(
566
- # True, # use_document_chat,
567
- # document_chat_chain_type,
568
  # st.session_state.llm,
569
- # st.session_state.retriever,
570
- # MEMORY,
571
- # chat_prompt,
572
- # prompt,
573
  # )
 
 
 
 
 
 
 
 
 
 
574
 
575
  # --- LLM call ---
576
  try:
577
- full_response = st.session_state.chain.invoke(
578
- prompt,
579
- config=get_config(callbacks),
580
- )
 
 
581
 
582
  except (openai.AuthenticationError, anthropic.AuthenticationError):
583
  st.error(
@@ -587,7 +599,20 @@ if st.session_state.llm:
587
 
588
  # --- Display output ---
589
  if full_response is not None:
590
- message_placeholder.markdown(full_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
  # --- Tracing ---
593
  if st.session_state.client:
 
1
+ import time
2
  from datetime import datetime
3
  from typing import Any, Dict, List, Optional, Tuple, Union
4
 
 
10
  from langchain.agents.tools import tool
11
  from langchain.callbacks.base import BaseCallbackHandler
12
  from langchain.callbacks.manager import Callbacks
13
+
14
+ # from langchain.callbacks.streamlit import StreamlitCallbackHandler
15
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
16
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
17
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
18
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
19
  from langchain.schema.document import Document
20
  from langchain.schema.retriever import BaseRetriever
 
21
  from langsmith.client import Client
22
  from llm_resources import (
 
23
  get_llm,
24
  get_runnable,
25
  get_texts_and_multiretriever,
 
45
 
46
 
47
  st_init_null(
48
+ "start_time",
49
+ "sources",
50
  "chain",
51
  "client",
52
  "doc_chain",
 
433
  if question and question != "--" and not prompt:
434
  prompt = question
435
  if not uploaded_file:
436
+ st.warning("Please upload a PDF to use the document chat feature.")
437
  elif prompt:
438
  feedback_update = None
439
  feedback = None
 
466
  full_response: Union[str, None] = None
467
 
468
  message_placeholder = st.empty()
469
+ sources_placeholder = st.empty()
470
  default_tools = [
471
  # DuckDuckGoSearchRun(),
472
  # WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
 
478
  # search_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
479
  # writer_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
480
  # )
481
+ # st_callback = StreamlitCallbackHandler(
482
+ # st.container(), expand_new_thoughts=False
483
+ # )
484
+ # callbacks.append(st_callback)
485
 
486
  # @tool("web-research-assistant")
487
  # def research_assistant_tool(question: str, callbacks: Callbacks = None):
 
530
  input_str,
531
  config=get_config(callbacks),
532
  )
533
+ st.session_state.sources = response["source_documents"][:3]
534
  with st.sidebar.expander("Sources"):
535
+ for source in st.session_state.sources:
536
  st.markdown("-" * 50)
537
  st.markdown(source.page_content)
538
+ # with sources_placeholder:
539
+ # with st.expander("Sources", expanded=False):
540
+ # for source in response["source_documents"][:3]:
541
+ # st.markdown("-" * 50)
542
+ # st.markdown(source.page_content)
543
  return response["output_text"]
544
 
545
  # doc_chain_agent = get_doc_agent(
 
565
 
566
  TOOLS = TOOLS + [doc_chain_tool]
567
 
568
+ # st.session_state.chain = get_agent(
569
+ # TOOLS,
570
+ # STMEMORY,
 
 
 
 
 
 
 
571
  # st.session_state.llm,
572
+ # callbacks,
 
 
 
573
  # )
574
+ # else:
575
+ st.session_state.chain = get_runnable(
576
+ True, # use_document_chat,
577
+ document_chat_chain_type,
578
+ st.session_state.llm,
579
+ st.session_state.retriever,
580
+ MEMORY,
581
+ chat_prompt,
582
+ prompt,
583
+ )
584
 
585
  # --- LLM call ---
586
  try:
587
+ with st.spinner("Thinking..."):
588
+ st.session_state.start_time = time.time()
589
+ full_response = st.session_state.chain.invoke(
590
+ prompt,
591
+ config=get_config(callbacks),
592
+ )
593
 
594
  except (openai.AuthenticationError, anthropic.AuthenticationError):
595
  st.error(
 
599
 
600
  # --- Display output ---
601
  if full_response is not None:
602
+ finish_time = time.time() - st.session_state.start_time
603
+ text_response, sources = (
604
+ full_response["output_text"],
605
+ full_response["source_documents"],
606
+ )
607
+ st.session_state.sources = sources[:3]
608
+ message_placeholder.markdown(
609
+ text_response + "\n\n" + f"⏱️ {finish_time:.2f}s"
610
+ )
611
+ with sources_placeholder:
612
+ with st.expander("Sources", expanded=False):
613
+ for source in st.session_state.sources:
614
+ st.markdown("-" * 50)
615
+ st.markdown(source.page_content)
616
 
617
  # --- Tracing ---
618
  if st.session_state.client: