Yew Chong commited on
Commit
bae2e43
1 Parent(s): ca422fb

Update streamlit app with new LLM and prompts

Browse files
Files changed (1) hide show
  1. streamlit/app8.py +30 -101
streamlit/app8.py CHANGED
@@ -17,7 +17,7 @@ import db_firestore as db
17
  ## ----------------------------------------------------------------
18
  ## LLM Part
19
  import openai
20
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
21
  import tiktoken
22
  from langchain.prompts.few_shot import FewShotPromptTemplate
23
  from langchain.prompts.prompt import PromptTemplate
@@ -31,7 +31,7 @@ from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings
31
  from langchain_community.vectorstores import FAISS
32
 
33
  from langchain.chains import LLMChain
34
- from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationBufferWindowMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
35
 
36
  import os, dotenv
37
  from dotenv import load_dotenv
@@ -114,8 +114,11 @@ if "embeddings" not in st.session_state:
114
  encode_kwargs = encode_kwargs)
115
  embeddings = st.session_state.embeddings
116
  if "llm" not in st.session_state:
117
- st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
118
  llm = st.session_state.llm
 
 
 
119
  if "llm_gpt4" not in st.session_state:
120
  st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
121
  llm_gpt4 = st.session_state.llm_gpt4
@@ -129,40 +132,13 @@ if "store" not in st.session_state:
129
  st.session_state.store = db.get_store(index_name, embeddings=embeddings)
130
  store = st.session_state.store
131
 
132
- TEMPLATE = """You are a patient undergoing a medical check-up. You will be given the following:
133
- 1. A context to answer the doctor, for your possible symptoms.
134
- 2. A question about your current symptoms.
135
-
136
- Your task is to answer the doctor's questions as simple as possible, acting like a patient.
137
- Do not include other symptoms that are not included in the context, which provides your symptoms.
138
-
139
- Answer the question to the point, without any elaboration if you're not prodded with it.
140
-
141
- As you are a patient, you do not know any medical jargon or lingo. Do not include specific medical terms in your reply.
142
- You only know colloquial words for medical terms.
143
- For example, you should not reply with "dysarthria", but instead with "cannot speak properly".
144
- For example, you should not reply with "syncope", but instead with "fainting".
145
-
146
- Here is the context:
147
- {context}
148
-
149
- ----------------------------------------------------------------
150
- You are to reply the doctor's following question, with reference to the above context.
151
- Question:
152
- {question}
153
- ----------------------------------------------------------------
154
- Remember, answer in a short and sweet manner, don't talk too much.
155
- Your reply:
156
- """
157
-
158
- with open('templates/patient.txt', 'r') as file:
159
- TEMPLATE = file.read()
160
-
161
  if "TEMPLATE" not in st.session_state:
 
 
162
  st.session_state.TEMPLATE = TEMPLATE
163
 
164
  with st.expander("Patient Prompt"):
165
- TEMPLATE = st.text_area("Patient Prompt", value=TEMPLATE)
166
 
167
  prompt = PromptTemplate(
168
  input_variables = ["question", "context"],
@@ -177,7 +153,9 @@ def format_docs(docs):
177
 
178
 
179
  if "memory" not in st.session_state:
180
- st.session_state.memory = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
 
 
181
  memory = st.session_state.memory
182
 
183
 
@@ -200,74 +178,17 @@ sp_mapper = {"human":"student","ai":"patient"}
200
  ## Grader part
201
  index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
202
 
203
- # store = FAISS.load_local(index_name, embeddings)
204
-
205
  if "store2" not in st.session_state:
206
  st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
207
  store2 = st.session_state.store2
208
 
209
- TEMPLATE2 = """You are a teacher for medical students. You are grading a medical student on their OSCE, the Object Structured Clinical Examination.
210
-
211
- Your task is to provide an overall assessment of a student's diagnosis, based on the rubrics provided.
212
- You will be provided with the following information:
213
- 1. The rubrics that the student should be judged based upon.
214
- 2. The conversation history between the medical student and the patient.
215
- 3. The final diagnosis that the student will make.
216
-
217
- =================================================================
218
-
219
- Your task is as follows:
220
- 1. Your grading should touch on every part of the rubrics, and grade the student holistically.
221
- Finally, provide an overall grade for the student.
222
-
223
- Some additional information that is useful to understand the rubrics:
224
- - The rubrics are segmented, with each area separated by dashes, such as "----------"
225
- - There will be multiple segments on History Taking. For each segment, the rubrics and corresponding grades will be provided below the required history taking.
226
- - For History Taking, you are to grade the student based on the rubrics, by checking the chat history between the patients and the medical student.
227
- - There is an additional segment on Presentation, differentials, and diagnosis. The
228
-
229
-
230
- =================================================================
231
-
232
- e
233
- Here are the rubrics for grading the student:
234
- <rubrics>
235
-
236
- {context}
237
-
238
- </rubrics>
239
-
240
- =================================================================
241
- You are to give a comprehensive judgement based on the student's diagnosis, with reference to the above rubrics.
242
-
243
- Here is the chat history between the medical student and the patient:
244
-
245
- <history>
246
-
247
- {history}
248
-
249
- </history>
250
- =================================================================
251
-
252
-
253
- Student's final diagnosis:
254
- <diagnosis>
255
- {question}
256
- </diagnosis>
257
-
258
- =================================================================
259
-
260
- Your grade:
261
- """
262
-
263
- with open('templates/grader.txt', 'r') as file:
264
- TEMPLATE2 = file.read()
265
-
266
  if "TEMPLATE2" not in st.session_state:
 
 
267
  st.session_state.TEMPLATE2 = TEMPLATE2
268
 
269
  with st.expander("Grader Prompt"):
270
- TEMPLATE2 = st.text_area("Grader Prompt", value=TEMPLATE2)
271
 
272
  prompt2 = PromptTemplate(
273
  input_variables = ["question", "context", "history"],
@@ -283,10 +204,6 @@ def format_docs(docs):
283
 
284
  fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
285
 
286
- if "memory2" not in st.session_state:
287
- st.session_state.memory2 = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
288
- memory2 = st.session_state.memory2
289
-
290
  def x(_):
291
  return fake_history
292
 
@@ -300,7 +217,19 @@ if ("chain2" not in st.session_state
300
  "question": RunnablePassthrough(),
301
  } |
302
 
303
- LLMChain(llm=llm, prompt=prompt2, memory=memory, verbose=False)
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
  chain2 = st.session_state.chain2
306
 
@@ -318,7 +247,7 @@ chain2 = st.session_state.chain2
318
  if st.button("Clear History and Memory", type="primary"):
319
  st.session_state.messages_1 = []
320
  st.session_state.messages_2 = []
321
- st.session_state.memory = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
322
  memory = st.session_state.memory
323
 
324
  ## Testing HTML
@@ -417,7 +346,7 @@ if text_prompt:
417
  if st.session_state.active_chat==1:
418
  full_response = chain.invoke(text_prompt).get("text")
419
  else:
420
- full_response = chain2.invoke(text_prompt).get("text")
421
  message_placeholder.markdown(full_response)
422
  messages.append({"role": "assistant", "content": full_response})
423
 
 
17
  ## ----------------------------------------------------------------
18
  ## LLM Part
19
  import openai
20
+ from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
21
  import tiktoken
22
  from langchain.prompts.few_shot import FewShotPromptTemplate
23
  from langchain.prompts.prompt import PromptTemplate
 
31
  from langchain_community.vectorstores import FAISS
32
 
33
  from langchain.chains import LLMChain
34
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
35
 
36
  import os, dotenv
37
  from dotenv import load_dotenv
 
114
  encode_kwargs = encode_kwargs)
115
  embeddings = st.session_state.embeddings
116
  if "llm" not in st.session_state:
117
+ st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
118
  llm = st.session_state.llm
119
+ if "llm_i" not in st.session_state:
120
+ st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
121
+ llm_i = st.session_state.llm_i
122
  if "llm_gpt4" not in st.session_state:
123
  st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
124
  llm_gpt4 = st.session_state.llm_gpt4
 
132
  st.session_state.store = db.get_store(index_name, embeddings=embeddings)
133
  store = st.session_state.store
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if "TEMPLATE" not in st.session_state:
136
+ with open('templates/patient.txt', 'r') as file:
137
+ TEMPLATE = file.read()
138
  st.session_state.TEMPLATE = TEMPLATE
139
 
140
  with st.expander("Patient Prompt"):
141
+ TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
142
 
143
  prompt = PromptTemplate(
144
  input_variables = ["question", "context"],
 
153
 
154
 
155
  if "memory" not in st.session_state:
156
+ st.session_state.memory = ConversationBufferWindowMemory(
157
+ llm=llm, memory_key="chat_history", input_key="question",
158
+ k=5, human_prefix="student", ai_prefix="patient",)
159
  memory = st.session_state.memory
160
 
161
 
 
178
  ## Grader part
179
  index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
180
 
 
 
181
  if "store2" not in st.session_state:
182
  st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
183
  store2 = st.session_state.store2
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if "TEMPLATE2" not in st.session_state:
186
+ with open('templates/grader.txt', 'r') as file:
187
+ TEMPLATE2 = file.read()
188
  st.session_state.TEMPLATE2 = TEMPLATE2
189
 
190
  with st.expander("Grader Prompt"):
191
+ TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2)
192
 
193
  prompt2 = PromptTemplate(
194
  input_variables = ["question", "context", "history"],
 
204
 
205
  fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
206
 
 
 
 
 
207
  def x(_):
208
  return fake_history
209
 
 
217
  "question": RunnablePassthrough(),
218
  } |
219
 
220
+ # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
221
+ LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
222
+ | {
223
+ "json": itemgetter("text"),
224
+ "text": (
225
+ LLMChain(
226
+ llm=llm,
227
+ prompt=PromptTemplate(
228
+ input_variables=["text"],
229
+ template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
230
+ verbose=False)
231
+ )
232
+ }
233
  )
234
  chain2 = st.session_state.chain2
235
 
 
247
  if st.button("Clear History and Memory", type="primary"):
248
  st.session_state.messages_1 = []
249
  st.session_state.messages_2 = []
250
+ st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" )
251
  memory = st.session_state.memory
252
 
253
  ## Testing HTML
 
346
  if st.session_state.active_chat==1:
347
  full_response = chain.invoke(text_prompt).get("text")
348
  else:
349
+ full_response = chain2.invoke(text_prompt).get("text").get("text")
350
  message_placeholder.markdown(full_response)
351
  messages.append({"role": "assistant", "content": full_response})
352