Sbnos commited on
Commit
29eec6c
1 Parent(s): 7cee51c
Files changed (1) hide show
  1. app.py +185 -67
app.py CHANGED
@@ -2,80 +2,69 @@ import streamlit as st
2
  import os
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
- from langchain_together import Together
6
- from langchain.prompts import ChatPromptTemplate, PromptTemplate
7
- from langchain.schema import format_document
8
- from typing import List
9
- from langchain.memory import ConversationBufferMemory
 
 
10
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
 
 
 
 
 
11
  import time
12
 
13
  # Load the embedding function
14
  model_name = "BAAI/bge-base-en"
15
- encode_kwargs = {'normalize_embeddings': True}
16
- embedding_function = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
 
 
 
 
 
17
 
18
- # Load the LLM
19
- llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, max_tokens=19096, top_k=10, together_api_key=os.environ['pilotikval'])
20
 
 
 
 
 
21
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
22
  memory = ConversationBufferMemory(chat_memory=msgs)
23
 
 
 
24
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
25
 
26
- def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
 
 
27
  doc_strings = [format_document(doc, document_prompt) for doc in docs]
28
  return document_separator.join(doc_strings)
29
 
 
 
30
  chistory = []
31
 
32
- def store_chat_history(role: str, content: str):
33
  chistory.append({"role": role, "content": content})
34
 
35
- def render_message_with_copy_button(role: str, content: str, key: str):
36
- html_code = f"""
37
- <div class="message" style="position: relative; padding-right: 40px;">
38
- <div class="message-content">{content}</div>
39
- <button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0; background-color: transparent; border: none; cursor: pointer;">
40
- <img src="https://img.icons8.com/material-outlined/24/grey/copy.png" alt="Copy">
41
- </button>
42
- </div>
43
- <textarea id="{key}" style="display:none;">{content}</textarea>
44
- <script>
45
- function copyToClipboard(key) {{
46
- var copyText = document.getElementById(key);
47
- copyText.style.display = "block";
48
- copyText.select();
49
- document.execCommand("copy");
50
- copyText.style.display = "none";
51
- alert("Copied to clipboard");
52
- }}
53
- </script>
54
- """
55
- st.write(html_code, unsafe_allow_html=True)
56
-
57
- def get_streaming_response(user_query, chat_history):
58
- template = """
59
- You are a knowledgeable assistant. Provide a detailed and thorough answer to the question based on the following context:
60
 
61
- Chat history: {chat_history}
 
62
 
63
- User question: {user_question}
64
- """
65
- prompt = ChatPromptTemplate.from_template(template)
66
-
67
- inputs = {
68
- "chat_history": chat_history,
69
- "user_question": user_query
70
- }
71
 
72
- chain = prompt | llm
73
- return chain.stream(inputs)
74
 
75
- def app():
76
  with st.sidebar:
 
77
  st.title("dochatter")
78
- option = st.selectbox('Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'))
 
 
 
79
  if option == 'RespiratoryFishman':
80
  persist_directory = "./respfishmandbcud/"
81
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
@@ -83,47 +72,176 @@ def app():
83
  elif option == 'RespiratoryMurray':
84
  persist_directory = "./respmurray/"
85
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes")
 
 
86
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
87
  elif option == 'MedMRCP2':
88
  persist_directory = "./medmrcp2store/"
89
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
 
 
90
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
91
  elif option == 'General Medicine':
92
  persist_directory = "./oxfordmedbookdir/"
93
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
 
 
94
  retriever = vectordb.as_retriever(search_kwargs={"k": 7})
 
 
95
  else:
96
  persist_directory = "./mrcpchromadb/"
97
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
98
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if "messages" not in st.session_state.keys():
101
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
102
 
103
- st.header("Ask Away!")
104
- for i, message in enumerate(st.session_state.messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  with st.chat_message(message["role"]):
106
- render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
107
  store_chat_history(message["role"], message["content"])
108
 
109
- user_query = st.chat_input("Say something")
110
- if user_query:
111
- st.session_state.messages.append({"role": "user", "content": user_query})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  with st.chat_message("user"):
113
- st.write(user_query)
 
 
114
 
 
115
  with st.chat_message("assistant"):
116
  with st.spinner("Thinking..."):
117
- chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory])
118
- try:
119
- response_generator = get_streaming_response(user_query, chat_history)
120
- response_text = ""
121
- for response_part in response_generator:
122
- response_text += response_part
123
- st.write(response_text)
124
- st.session_state.messages.append({"role": "assistant", "content": response_text})
125
- except Exception as e:
126
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  if __name__ == '__main__':
129
- app()
 
2
  import os
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
+ from langchain_community.llms import Together
6
+ from langchain import hub
7
+ from operator import itemgetter
8
+ from langchain.schema.runnable import RunnableParallel
9
+ from langchain.chains import LLMChain
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.schema.output_parser import StrOutputParser
12
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
13
+ from langchain.memory import ConversationBufferMemory
14
+ from langchain.chains import ConversationalRetrievalChain
15
+ from langchain.memory import ConversationSummaryMemory
16
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
17
+ from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
18
  import time
19
 
20
  # Load the embedding function
21
  model_name = "BAAI/bge-base-en"
22
+ encode_kwargs=encode_kwargs
23
+ )
24
+
25
+
26
+
27
+
28
+
29
 
 
 
30
 
31
+
32
+ # Load the LLM
33
+ llm = Together(
34
+ model="mistralai/Mixtral-8x22B-Instruct-v0.1",
35
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
36
  memory = ConversationBufferMemory(chat_memory=msgs)
37
 
38
+
39
+
40
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
41
 
42
+ def _combine_documents(
43
+ docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
44
+ ):
45
  doc_strings = [format_document(doc, document_prompt) for doc in docs]
46
  return document_separator.join(doc_strings)
47
 
48
+
49
+
50
  chistory = []
51
 
52
+ # Append the new message to the chat history
53
  chistory.append({"role": role, "content": content})
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Define the Streamlit app
57
+ def app():
58
 
 
 
 
 
 
 
 
 
59
 
 
 
60
 
 
61
  with st.sidebar:
62
+
63
  st.title("dochatter")
64
+ # Create a dropdown selection box
65
+ option = st.selectbox(
66
+ )
67
+ # Depending on the selected option, choose the appropriate retriever
68
  if option == 'RespiratoryFishman':
69
  persist_directory = "./respfishmandbcud/"
70
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
 
72
  elif option == 'RespiratoryMurray':
73
  persist_directory = "./respmurray/"
74
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="respmurraynotes")
75
+
76
+
77
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
78
  elif option == 'MedMRCP2':
79
  persist_directory = "./medmrcp2store/"
80
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="medmrcp2notes")
81
+
82
+
83
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
84
  elif option == 'General Medicine':
85
  persist_directory = "./oxfordmedbookdir/"
86
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="oxfordmed")
87
+
88
+
89
  retriever = vectordb.as_retriever(search_kwargs={"k": 7})
90
+
91
+
92
  else:
93
  persist_directory = "./mrcpchromadb/"
94
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
95
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
96
 
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+ # Session State
112
+
113
  if "messages" not in st.session_state.keys():
114
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
115
 
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+ _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question which contains the themes of the conversation. Do not write the question. Do not write the answer.
128
+
129
+ Chat History:
130
+ {chat_history}
131
+ Follow Up Input: {question}
132
+
133
+ template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
134
+ {context}
135
+
136
+ Question: {question}
137
+ """
138
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
139
+
140
+
141
+ _inputs = RunnableParallel(
142
+ standalone_question=RunnablePassthrough.assign(
143
+ chat_history=lambda x: chistory
144
+ ) | CONDENSE_QUESTION_PROMPT | llmc | StrOutputParser(),
145
+
146
+
147
+
148
+ )
149
+ _context = {
150
+ "context": itemgetter("standalone_question") | retriever | _combine_documents,
151
+ }
152
+ conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
153
+
154
+ st.header("Hello Doctor!")
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+ for message in st.session_state.messages:
167
  with st.chat_message(message["role"]):
168
+ st.write(message["content"])
169
  store_chat_history(message["role"], message["content"])
170
 
171
+
172
+
173
+
174
+
175
+
176
+ prompts2 = st.chat_input("Say something")
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+ if prompts2:
194
+ st.session_state.messages.append({"role": "user", "content": prompts2})
195
  with st.chat_message("user"):
196
+ st.write(prompts2)
197
+
198
+
199
 
200
+ if st.session_state.messages[-1]["role"] != "assistant":
201
  with st.chat_message("assistant"):
202
  with st.spinner("Thinking..."):
203
+ for _ in range(3): # Retry up to 3 times
204
+ try:
205
+ response = conversational_qa_chain.invoke(
206
+ {
207
+ "question": prompts2,
208
+ "chat_history": chistory,
209
+ }
210
+ )
211
+ st.write(response)
212
+ message = {"role": "assistant", "content": response}
213
+ st.session_state.messages.append(message)
214
+ break
215
+ except Exception as e:
216
+ st.error(f"An error occurred: {e}")
217
+ time.sleep(2) # Wait 2 seconds before retrying
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+
235
+
236
+
237
+
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
 
246
  if __name__ == '__main__':
247
+ app()