fsal's picture
first commit
c8ebe28
raw history blame
No virus
22.5 kB
from datetime import datetime
from typing import Tuple, List, Dict, Any, Union, Optional
import anthropic
import langsmith.utils
import openai
import streamlit as st
from langchain.agents import load_tools
from langchain.agents.tools import tool
from langchain.callbacks import StreamlitCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.document import Document
from langchain.schema.retriever import BaseRetriever
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
from langsmith.client import Client
from streamlit_feedback import streamlit_feedback
from defaults import default_values
from llm_resources import (
get_agent,
get_doc_agent,
get_llm,
get_runnable,
get_texts_and_multiretriever,
)
from python_coder import get_agent as get_python_agent
from research_assistant.chain import get_chain as get_research_assistant_chain
__version__ = "2.1.4"
# --- Initialization ---
st.set_page_config(
page_title=f"langchain-streamlit-demo v{__version__}",
page_icon="🦜",
)
def st_init_null(*variable_names) -> None:
for variable_name in variable_names:
if variable_name not in st.session_state:
st.session_state[variable_name] = None
st_init_null(
"chain",
"client",
"doc_chain",
"document_chat_chain_type",
"llm",
"ls_tracer",
"provider",
"retriever",
"run",
"run_id",
"trace_link",
"LANGSMITH_API_KEY",
"LANGSMITH_PROJECT",
"AZURE_OPENAI_BASE_URL",
"AZURE_OPENAI_API_VERSION",
"AZURE_OPENAI_DEPLOYMENT_NAME",
"AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_MODEL_VERSION",
"AZURE_AVAILABLE",
)
# --- LLM globals ---
STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
MEMORY = ConversationBufferMemory(
chat_memory=STMEMORY,
return_messages=True,
memory_key="chat_history",
)
RUN_COLLECTOR = RunCollectorCallbackHandler()
st.session_state.LANGSMITH_API_KEY = (
st.session_state.LANGSMITH_API_KEY
or default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
)
st.session_state.LANGSMITH_PROJECT = st.session_state.LANGSMITH_PROJECT or (
default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
)
def azure_state_or_default(*args):
st.session_state.update(
{
arg: st.session_state.get(arg) or default_values.AZURE_DICT.get(arg)
for arg in args
},
)
azure_state_or_default(
"AZURE_OPENAI_BASE_URL",
"AZURE_OPENAI_API_VERSION",
"AZURE_OPENAI_DEPLOYMENT_NAME",
"AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_MODEL_VERSION",
)
st.session_state.AZURE_AVAILABLE = all(
[
st.session_state.AZURE_OPENAI_BASE_URL,
st.session_state.AZURE_OPENAI_API_VERSION,
st.session_state.AZURE_OPENAI_DEPLOYMENT_NAME,
st.session_state.AZURE_OPENAI_API_KEY,
st.session_state.AZURE_OPENAI_MODEL_VERSION,
],
)
st.session_state.AZURE_EMB_AVAILABLE = (
st.session_state.AZURE_AVAILABLE
and st.session_state.AZURE_OPENAI_EMB_DEPLOYMENT_NAME
)
AZURE_KWARGS = (
None
if not st.session_state.AZURE_EMB_AVAILABLE
else {
"openai_api_base": st.session_state.AZURE_OPENAI_BASE_URL,
"openai_api_version": st.session_state.AZURE_OPENAI_API_VERSION,
"deployment": st.session_state.AZURE_OPENAI_EMB_DEPLOYMENT_NAME,
"openai_api_key": st.session_state.AZURE_OPENAI_API_KEY,
"openai_api_type": "azure",
}
)
@st.cache_resource
def get_texts_and_retriever_cacheable_wrapper(
uploaded_file_bytes: bytes,
openai_api_key: str,
chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
k: int = default_values.DEFAULT_RETRIEVER_K,
azure_kwargs: Optional[Dict[str, str]] = None,
use_azure: bool = False,
) -> Tuple[List[Document], BaseRetriever]:
return get_texts_and_multiretriever(
uploaded_file_bytes=uploaded_file_bytes,
openai_api_key=openai_api_key,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
k=k,
azure_kwargs=azure_kwargs,
use_azure=use_azure,
)
# --- Sidebar ---
sidebar = st.sidebar
with sidebar:
st.markdown("# Menu")
model = st.selectbox(
label="Chat Model",
options=default_values.SUPPORTED_MODELS,
index=default_values.SUPPORTED_MODELS.index(default_values.DEFAULT_MODEL),
)
st.session_state.provider = default_values.MODEL_DICT[model]
provider_api_key = (
default_values.PROVIDER_KEY_DICT.get(
st.session_state.provider,
)
or st.text_input(
f"{st.session_state.provider} API key",
type="password",
)
if st.session_state.provider != "Azure OpenAI"
else ""
)
if st.button("Clear message history"):
STMEMORY.clear()
st.session_state.trace_link = None
st.session_state.run_id = None
# --- Document Chat Options ---
with st.expander("Document Chat", expanded=False):
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
openai_api_key = (
provider_api_key
if st.session_state.provider == "OpenAI"
else default_values.OPENAI_API_KEY
or st.sidebar.text_input("OpenAI API Key: ", type="password")
)
document_chat = st.checkbox(
"Document Chat",
value=True if uploaded_file else False,
help="Uploaded document will provide context for the chat.",
)
k = st.slider(
label="Number of Chunks",
help="How many document chunks will be used for context?",
value=default_values.DEFAULT_RETRIEVER_K,
min_value=1,
max_value=10,
)
chunk_size = st.slider(
label="Number of Tokens per Chunk",
help="Size of each chunk of text",
min_value=default_values.MIN_CHUNK_SIZE,
max_value=default_values.MAX_CHUNK_SIZE,
value=default_values.DEFAULT_CHUNK_SIZE,
)
chunk_overlap = st.slider(
label="Chunk Overlap",
help="Number of characters to overlap between chunks",
min_value=default_values.MIN_CHUNK_OVERLAP,
max_value=default_values.MAX_CHUNK_OVERLAP,
value=default_values.DEFAULT_CHUNK_OVERLAP,
)
chain_type_help_root = (
"https://python.langchain.com/docs/modules/chains/document/"
)
chain_type_help = "\n".join(
f"- [{chain_type_name}]({chain_type_help_root}/{chain_type_name})"
for chain_type_name in (
"stuff",
"refine",
"map_reduce",
"map_rerank",
)
)
document_chat_chain_type = st.selectbox(
label="Document Chat Chain Type",
options=[
"stuff",
"refine",
"map_reduce",
"map_rerank",
"Q&A Generation",
"Summarization",
],
index=0,
help=chain_type_help,
)
use_azure = st.toggle(
label="Use Azure OpenAI",
value=st.session_state.AZURE_EMB_AVAILABLE,
help="Use Azure for embeddings instead of using OpenAI directly.",
)
if uploaded_file:
if st.session_state.AZURE_EMB_AVAILABLE or openai_api_key:
(
st.session_state.texts,
st.session_state.retriever,
) = get_texts_and_retriever_cacheable_wrapper(
uploaded_file_bytes=uploaded_file.getvalue(),
openai_api_key=openai_api_key,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
k=k,
azure_kwargs=AZURE_KWARGS,
use_azure=use_azure,
)
else:
st.error("Please enter a valid OpenAI API key.", icon="❌")
# --- Advanced Settings ---
with st.expander("Advanced Settings", expanded=False):
st.markdown("## Feedback Scale")
use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
feedback_option = "faces" if use_faces else "thumbs"
system_prompt = (
st.text_area(
"Custom Instructions",
default_values.DEFAULT_SYSTEM_PROMPT,
help="Custom instructions to provide the language model to determine style, personality, etc.",
)
.strip()
.replace("{", "{{")
.replace("}", "}}")
)
temperature = st.slider(
"Temperature",
min_value=default_values.MIN_TEMP,
max_value=default_values.MAX_TEMP,
value=default_values.DEFAULT_TEMP,
help="Higher values give more random results.",
)
max_tokens = st.slider(
"Max Tokens",
min_value=default_values.MIN_MAX_TOKENS,
max_value=default_values.MAX_MAX_TOKENS,
value=default_values.DEFAULT_MAX_TOKENS,
help="Higher values give longer results.",
)
# --- LangSmith Options ---
if default_values.SHOW_LANGSMITH_OPTIONS:
with st.expander("LangSmith Options", expanded=False):
st.session_state.LANGSMITH_API_KEY = st.text_input(
"LangSmith API Key (optional)",
value=st.session_state.LANGSMITH_API_KEY,
type="password",
)
st.session_state.LANGSMITH_PROJECT = st.text_input(
"LangSmith Project Name",
value=st.session_state.LANGSMITH_PROJECT,
)
if st.session_state.client is None and st.session_state.LANGSMITH_API_KEY:
st.session_state.client = Client(
api_url="https://api.smith.langchain.com",
api_key=st.session_state.LANGSMITH_API_KEY,
)
st.session_state.ls_tracer = LangChainTracer(
project_name=st.session_state.LANGSMITH_PROJECT,
client=st.session_state.client,
)
# --- Azure Options ---
if default_values.SHOW_AZURE_OPTIONS:
with st.expander("Azure Options", expanded=False):
st.session_state.AZURE_OPENAI_BASE_URL = st.text_input(
"AZURE_OPENAI_BASE_URL",
value=st.session_state.AZURE_OPENAI_BASE_URL,
)
st.session_state.AZURE_OPENAI_API_VERSION = st.text_input(
"AZURE_OPENAI_API_VERSION",
value=st.session_state.AZURE_OPENAI_API_VERSION,
)
st.session_state.AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
"AZURE_OPENAI_DEPLOYMENT_NAME",
value=st.session_state.AZURE_OPENAI_DEPLOYMENT_NAME,
)
st.session_state.AZURE_OPENAI_EMB_DEPLOYMENT_NAME = st.text_input(
"AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
value=st.session_state.AZURE_OPENAI_EMB_DEPLOYMENT_NAME,
)
st.session_state.AZURE_OPENAI_API_KEY = st.text_input(
"AZURE_OPENAI_API_KEY",
value=st.session_state.AZURE_OPENAI_API_KEY,
type="password",
)
st.session_state.AZURE_OPENAI_MODEL_VERSION = st.text_input(
"AZURE_OPENAI_MODEL_VERSION",
value=st.session_state.AZURE_OPENAI_MODEL_VERSION,
)
# --- LLM Instantiation ---
get_llm_args = dict(
provider=st.session_state.provider,
model=model,
provider_api_key=provider_api_key,
temperature=temperature,
max_tokens=max_tokens,
azure_available=st.session_state.AZURE_AVAILABLE,
azure_dict={
"AZURE_OPENAI_BASE_URL": st.session_state.AZURE_OPENAI_BASE_URL,
"AZURE_OPENAI_API_VERSION": st.session_state.AZURE_OPENAI_API_VERSION,
"AZURE_OPENAI_DEPLOYMENT_NAME": st.session_state.AZURE_OPENAI_DEPLOYMENT_NAME,
"AZURE_OPENAI_API_KEY": st.session_state.AZURE_OPENAI_API_KEY,
"AZURE_OPENAI_MODEL_VERSION": st.session_state.AZURE_OPENAI_MODEL_VERSION,
},
)
get_llm_args_temp_zero = get_llm_args | {"temperature": 0.0}
st.session_state.llm = get_llm(**get_llm_args)
# --- Chat History ---
for msg in STMEMORY.messages:
if msg.content and msg.type in ("ai", "assistant", "human", "user"):
st.chat_message(
msg.type,
avatar="🦜" if msg.type in ("ai", "assistant") else None,
).write(msg.content)
# --- Current Chat ---
if st.session_state.llm:
# --- Regular Chat ---
chat_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt + "\nIt's currently {time}.",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{query}"),
],
).partial(time=lambda: str(datetime.now()))
# --- Chat Input ---
prompt = st.chat_input(placeholder="Ask me a question!")
if prompt:
st.chat_message("user").write(prompt)
feedback_update = None
feedback = None
# --- Chat Output ---
with st.chat_message("assistant", avatar="🦜"):
callbacks = [RUN_COLLECTOR]
if st.session_state.ls_tracer:
callbacks.append(st.session_state.ls_tracer)
def get_config(callbacks: list[BaseCallbackHandler]) -> dict[str, Any]:
config: Dict[str, Any] = dict(
callbacks=callbacks,
tags=["Streamlit Chat"],
verbose=True,
return_intermediate_steps=False,
)
if st.session_state.provider == "Anthropic":
config["max_concurrency"] = 5
return config
use_document_chat = all(
[
document_chat,
st.session_state.retriever,
],
)
full_response: Union[str, None] = None
# stream_handler = StreamHandler(message_placeholder)
# callbacks.append(stream_handler)
message_placeholder = st.empty()
default_tools = [
DuckDuckGoSearchRun(),
WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
]
default_tools += load_tools(["requests_get"])
default_tools += load_tools(["llm-math"], llm=st.session_state.llm)
if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
research_assistant_chain = get_research_assistant_chain(
search_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
writer_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
)
st_callback = StreamlitCallbackHandler(st.container())
callbacks.append(st_callback)
@tool("web-research-assistant")
def research_assistant_tool(question: str, callbacks: Callbacks = None):
"""This assistant returns a comprehensive report based on web research.
It's slow and relatively expensive, so use it sparingly.
Consider using a different tool for quick facts or web queries.
"""
return research_assistant_chain.invoke(
dict(question=question),
config=get_config(callbacks),
)
python_coder_agent = get_python_agent(st.session_state.llm)
@tool("python-coder-assistant")
def python_coder_tool(input_str: str, callbacks: Callbacks = None):
"""This assistant writes PYTHON code.
Give it clear instructions and requirements.
Do not use it for tasks other than Python.
"""
return python_coder_agent.invoke(
dict(input=input_str),
config=get_config(callbacks),
)
TOOLS = [research_assistant_tool, python_coder_tool] + default_tools
if use_document_chat:
st.session_state.doc_chain = get_runnable(
use_document_chat,
document_chat_chain_type,
st.session_state.llm,
st.session_state.retriever,
MEMORY,
chat_prompt,
prompt,
)
@tool("user-document-chat")
def doc_chain_tool(input_str: str, callbacks: Callbacks = None):
"""Always use this tool at least once. Input should be a question."""
return st.session_state.doc_chain.invoke(
input_str,
config=get_config(callbacks),
)
doc_chain_agent = get_doc_agent(
[doc_chain_tool],
)
@tool("document-question-tool")
def doc_question_tool(input_str: str, callbacks: Callbacks = None):
"""This tool is an AI assistant with access to the user's uploaded document.
Input should be one or more questions, requests, instructions, etc.
If the user's meaning is unclear, perhaps the answer is here.
Generally speaking, try this tool before conducting web research.
"""
return doc_chain_agent.invoke(
input_str,
config=get_config(callbacks),
)
TOOLS = [doc_question_tool] + TOOLS
st.session_state.chain = get_agent(
TOOLS,
STMEMORY,
st.session_state.llm,
callbacks,
)
else:
st.session_state.chain = get_runnable(
use_document_chat,
document_chat_chain_type,
st.session_state.llm,
st.session_state.retriever,
MEMORY,
chat_prompt,
prompt,
)
# --- LLM call ---
try:
full_response = st.session_state.chain.invoke(
prompt,
config=get_config(callbacks),
)
except (openai.AuthenticationError, anthropic.AuthenticationError):
st.error(
f"Please enter a valid {st.session_state.provider} API key.",
icon="❌",
)
# --- Display output ---
if full_response is not None:
message_placeholder.markdown(full_response)
# --- Tracing ---
if st.session_state.client:
st.session_state.run = RUN_COLLECTOR.traced_runs[0]
st.session_state.run_id = st.session_state.run.id
RUN_COLLECTOR.traced_runs = []
wait_for_all_tracers()
try:
st.session_state.trace_link = st.session_state.client.read_run(
st.session_state.run_id,
).url
except (
langsmith.utils.LangSmithError,
langsmith.utils.LangSmithNotFoundError,
):
st.session_state.trace_link = None
# --- LangSmith Trace Link ---
if st.session_state.trace_link:
with sidebar:
st.markdown(
f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
unsafe_allow_html=True,
)
# --- Feedback ---
if st.session_state.client and st.session_state.run_id:
feedback = streamlit_feedback(
feedback_type=feedback_option,
optional_text_label="[Optional] Please provide an explanation",
key=f"feedback_{st.session_state.run_id}",
)
# Define score mappings for both "thumbs" and "faces" feedback systems
score_mappings: dict[str, dict[str, Union[int, float]]] = {
"thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
"faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
}
# Get the score mapping based on the selected feedback option
scores = score_mappings[feedback_option]
if feedback:
# Get the score from the selected feedback option's score mapping
score = scores.get(
feedback["score"],
)
if score is not None:
# Formulate feedback type string incorporating the feedback option
# and score value
feedback_type_str = f"{feedback_option} {feedback['score']}"
# Record the feedback with the formulated feedback type string
# and optional comment
feedback_record = st.session_state.client.create_feedback(
st.session_state.run_id,
feedback_type_str,
score=score,
comment=feedback.get("text"),
)
st.toast("Feedback recorded!", icon="πŸ“")
else:
st.warning("Invalid feedback score.")
else:
st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")