medchat2 / app.py
Sbnos's picture
adsas
64aa216 verified
raw
history blame
5.39 kB
import streamlit as st
import os
import asyncio
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.vectorstores import Chroma
from langchain_together import Together
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# Initialize the LLMs
llm = Together(
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
temperature=0.2,
top_k=12,
max_tokens=22048,
together_api_key=os.environ['pilotikval']
)
# Function to store chat history
store = {}
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
embedding_function = HuggingFaceBgeEmbeddings(
model_name=model_name,
encode_kwargs=encode_kwargs
)
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = StreamlitChatMessageHistory(key=session_id)
return store[session_id]
# Define the Streamlit app
def app():
with st.sidebar:
st.title("dochatter")
option = st.selectbox(
'Which retriever would you like to use?',
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
)
# Define retrievers based on option
persist_directory = {
'General Medicine': "./oxfordmedbookdir/",
'Respiratory1': "./respfishmandbcud/",
'Respiratory2': "./respmurray/",
'Med2.2': "./medmrcp2store/",
'Med2.1': "./mrcpchromadb/"
}.get(option, "./mrcpchromadb/")
collection_name = {
'General Medicine': "oxfordmed",
'Respiratory1': "fishmannotescud",
'Respiratory2': "respmurraynotes",
'Med2.2': "medmrcp2notes",
'Med2.1': "mrcppassmednotes"
}.get(option, "mrcppassmednotes")
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
# Define the prompt templates
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
system_prompt = (
"You are helping a doctor. Be as detailed and thorough as possible "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know."
"\n\n"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
# Statefully manage chat history
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# Session State
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
st.header("Hello Doc!")
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
prompts2 = st.chat_input("Say something")
if prompts2:
st.session_state.messages.append({"role": "user", "content": prompts2})
with st.chat_message("user"):
st.write(prompts2)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
final_response = conversational_rag_chain.invoke(
{
"input": prompts2,
},
config={"configurable": {"session_id": "current_session"}}
)
st.write(final_response['answer'])
st.session_state.messages.append({"role": "assistant", "content": final_response['answer']})
if __name__ == '__main__':
app()