mainchainge cgpt 2
Browse files
app.py
CHANGED
@@ -5,24 +5,20 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|
5 |
from langchain_together import Together
|
6 |
from langchain import hub
|
7 |
from operator import itemgetter
|
8 |
-
from langchain.schema import
|
9 |
-
from
|
10 |
-
from langchain.
|
11 |
-
from
|
12 |
-
from langchain.memory import StreamlitChatMessageHistory, ConversationBufferMemory, ConversationSummaryMemory
|
13 |
-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
|
14 |
-
from langchain.schema import RunnableLambda, RunnablePassthrough
|
15 |
|
16 |
# Load the embedding function
|
17 |
model_name = "BAAI/bge-base-en"
|
18 |
encode_kwargs = {'normalize_embeddings': True}
|
19 |
-
|
20 |
embedding_function = HuggingFaceBgeEmbeddings(
|
21 |
model_name=model_name,
|
22 |
encode_kwargs=encode_kwargs
|
23 |
)
|
24 |
|
25 |
-
#
|
26 |
llm = Together(
|
27 |
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
|
28 |
temperature=0.2,
|
@@ -31,7 +27,6 @@ llm = Together(
|
|
31 |
together_api_key=os.environ['pilotikval']
|
32 |
)
|
33 |
|
34 |
-
# Load the summarizeLLM
|
35 |
llmc = Together(
|
36 |
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
|
37 |
temperature=0.2,
|
@@ -40,19 +35,40 @@ llmc = Together(
|
|
40 |
together_api_key=os.environ['pilotikval']
|
41 |
)
|
42 |
|
|
|
43 |
msgs = StreamlitChatMessageHistory(key="langchain_messages")
|
44 |
memory = ConversationBufferMemory(chat_memory=msgs)
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
|
|
49 |
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
50 |
return document_separator.join(doc_strings)
|
51 |
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
# Define the Streamlit app
|
58 |
def app():
|
@@ -63,67 +79,35 @@ def app():
|
|
63 |
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
|
64 |
)
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
|
79 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
80 |
-
|
81 |
-
elif option == 'General Medicine':
|
82 |
-
persist_directory="./oxfordmedbookdir/"
|
83 |
-
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
|
84 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 7})
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
90 |
-
|
91 |
if "messages" not in st.session_state:
|
92 |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
|
93 |
-
|
94 |
-
_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question which contains the themes of the conversation. Do not write the question. Do not write the answer.
|
95 |
-
Chat History:
|
96 |
-
{chat_history}
|
97 |
-
Follow Up Input: {question}
|
98 |
-
Standalone question:"""
|
99 |
-
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
100 |
-
|
101 |
-
template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
|
102 |
-
{context}
|
103 |
-
Question: {question}
|
104 |
-
"""
|
105 |
-
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
|
106 |
-
|
107 |
-
_inputs = RunnableParallel(
|
108 |
-
standalone_question=RunnablePassthrough.assign(
|
109 |
-
chat_history=lambda x: chistory
|
110 |
-
)
|
111 |
-
| CONDENSE_QUESTION_PROMPT
|
112 |
-
| llmc
|
113 |
-
| StrOutputParser(),
|
114 |
-
)
|
115 |
-
_context = {
|
116 |
-
"context": itemgetter("standalone_question") | retriever | _combine_documents,
|
117 |
-
"question": lambda x: x["standalone_question"],
|
118 |
-
}
|
119 |
-
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
|
120 |
-
|
121 |
st.header("Ask Away!")
|
122 |
for message in st.session_state.messages:
|
123 |
with st.chat_message(message["role"]):
|
124 |
st.write(message["content"])
|
125 |
store_chat_history(message["role"], message["content"])
|
126 |
-
|
127 |
prompts2 = st.chat_input("Say something")
|
128 |
|
129 |
if prompts2:
|
|
|
5 |
from langchain_together import Together
|
6 |
from langchain import hub
|
7 |
from operator import itemgetter
|
8 |
+
from langchain.schema import format_document
|
9 |
+
from langchain.prompts import ChatPromptTemplate, PromptTemplate
|
10 |
+
from langchain.memory import StreamlitChatMessageHistory, ConversationBufferMemory
|
11 |
+
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
|
|
|
|
|
|
|
12 |
|
13 |
# Load the embedding function
|
14 |
model_name = "BAAI/bge-base-en"
|
15 |
encode_kwargs = {'normalize_embeddings': True}
|
|
|
16 |
embedding_function = HuggingFaceBgeEmbeddings(
|
17 |
model_name=model_name,
|
18 |
encode_kwargs=encode_kwargs
|
19 |
)
|
20 |
|
21 |
+
# Initialize the LLMs
|
22 |
llm = Together(
|
23 |
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
|
24 |
temperature=0.2,
|
|
|
27 |
together_api_key=os.environ['pilotikval']
|
28 |
)
|
29 |
|
|
|
30 |
llmc = Together(
|
31 |
model="mistralai/Mixtral-8x22B-Instruct-v0.1",
|
32 |
temperature=0.2,
|
|
|
35 |
together_api_key=os.environ['pilotikval']
|
36 |
)
|
37 |
|
38 |
+
# Memory setup
|
39 |
msgs = StreamlitChatMessageHistory(key="langchain_messages")
|
40 |
memory = ConversationBufferMemory(chat_memory=msgs)
|
41 |
|
42 |
+
# Define the prompt templates
|
43 |
+
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(
|
44 |
+
"""Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question.
|
45 |
+
Chat History:
|
46 |
+
{chat_history}
|
47 |
+
Follow Up Input: {question}
|
48 |
+
Standalone question:"""
|
49 |
+
)
|
50 |
+
|
51 |
+
ANSWER_PROMPT = ChatPromptTemplate.from_template(
|
52 |
+
"""You are helping a doctor. Answer based on the provided context:
|
53 |
+
{context}
|
54 |
+
Question: {question}"""
|
55 |
+
)
|
56 |
|
57 |
+
# Function to combine documents
|
58 |
+
def _combine_documents(docs, document_prompt=PromptTemplate.from_template("{page_content}"), document_separator="\n\n"):
|
59 |
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
60 |
return document_separator.join(doc_strings)
|
61 |
|
62 |
+
# Define the chain using LCEL
|
63 |
+
condense_question_chain = RunnableLambda(lambda x: {"chat_history": chistory, "question": x}) | CONDENSE_QUESTION_PROMPT | llmc
|
64 |
+
retriever_chain = RunnableLambda(lambda x: {"standalone_question": x}) | retriever | _combine_documents
|
65 |
+
answer_chain = ANSWER_PROMPT | llm
|
66 |
|
67 |
+
conversational_qa_chain = RunnableParallel(
|
68 |
+
condense_question=condense_question_chain,
|
69 |
+
retrieve=retriever_chain,
|
70 |
+
generate_answer=answer_chain
|
71 |
+
)
|
72 |
|
73 |
# Define the Streamlit app
|
74 |
def app():
|
|
|
79 |
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
|
80 |
)
|
81 |
|
82 |
+
# Define retrievers based on option
|
83 |
+
persist_directory = {
|
84 |
+
'General Medicine': "./oxfordmedbookdir/",
|
85 |
+
'RespiratoryFishman': "./respfishmandbcud/",
|
86 |
+
'RespiratoryMurray': "./respmurray/",
|
87 |
+
'MedMRCP2': "./medmrcp2store/",
|
88 |
+
'OldMedicine': "./mrcpchromadb/"
|
89 |
+
}.get(option, "./mrcpchromadb/")
|
90 |
|
91 |
+
collection_name = {
|
92 |
+
'General Medicine': "oxfordmed",
|
93 |
+
'RespiratoryFishman': "fishmannotescud",
|
94 |
+
'RespiratoryMurray': "respmurraynotes",
|
95 |
+
'MedMRCP2': "medmrcp2notes",
|
96 |
+
'OldMedicine': "mrcppassmednotes"
|
97 |
+
}.get(option, "mrcppassmednotes")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
|
100 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
101 |
+
|
|
|
|
|
102 |
if "messages" not in st.session_state:
|
103 |
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
|
104 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
st.header("Ask Away!")
|
106 |
for message in st.session_state.messages:
|
107 |
with st.chat_message(message["role"]):
|
108 |
st.write(message["content"])
|
109 |
store_chat_history(message["role"], message["content"])
|
110 |
+
|
111 |
prompts2 = st.chat_input("Say something")
|
112 |
|
113 |
if prompts2:
|