Yew Chong commited on
Commit
e17bf8a
·
1 Parent(s): aa1d498

basic streamlit for chest pain

Browse files
Files changed (1) hide show
  1. streamlit/app8.py +370 -0
streamlit/app8.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+ import datetime
5
+
6
+
7
+ ## Firestore ??
8
+ import os
9
+ import sys
10
+ import inspect
11
+ currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
12
+ parentdir = os.path.dirname(currentdir)
13
+ sys.path.append(parentdir)
14
+ import db_firestore as db
15
+
16
+
17
+ ## ----------------------------------------------------------------
18
+ ## LLM Part
19
+ import openai
20
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
21
+ import tiktoken
22
+ from langchain.prompts.few_shot import FewShotPromptTemplate
23
+ from langchain.prompts.prompt import PromptTemplate
24
+ from operator import itemgetter
25
+ from langchain.schema import StrOutputParser
26
+ from langchain_core.output_parsers import StrOutputParser
27
+ from langchain_core.runnables import RunnablePassthrough
28
+
29
+ import langchain_community.embeddings.huggingface
30
+ # help(langchain_community.embeddings.huggingface)
31
+ from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings
32
+ from langchain_community.vectorstores import FAISS
33
+
34
+ from langchain.chains import LLMChain
35
+ from langchain.chains.conversation.memory import ConversationBufferMemory, ConversationBufferWindowMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
36
+
37
+ import os, dotenv
38
+ from dotenv import load_dotenv
39
+ load_dotenv()
40
+
41
+
42
+
43
+ if "openai_model" not in st.session_state:
44
+ st.session_state["openai_model"] = "gpt-3.5-turbo"
45
+
46
+ if "messages_1" not in st.session_state:
47
+ st.session_state.messages_1 = []
48
+
49
+ if "messages_2" not in st.session_state:
50
+ st.session_state.messages_2 = []
51
+
52
+ if "start_time" not in st.session_state:
53
+ st.session_state.start_time = None
54
+
55
+ if "active_chat" not in st.session_state:
56
+ st.session_state.active_chat = 1
57
+
58
+ model_name = "bge-large-en-v1.5"
59
+ model_kwargs = {"device": "cpu"}
60
+ # model_kwargs = {"device": "cuda"}
61
+ encode_kwargs = {"normalize_embeddings": True}
62
+ if "embeddings" not in st.session_state:
63
+ st.session_state.embeddings = HuggingFaceBgeEmbeddings(
64
+ # model_name=model_name,
65
+ model_kwargs = model_kwargs,
66
+ encode_kwargs = encode_kwargs)
67
+ embeddings = st.session_state.embeddings
68
+ if "llm" not in st.session_state:
69
+ st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
70
+ llm = st.session_state.llm
71
+ if "llm_gpt4" not in st.session_state:
72
+ st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
73
+ llm_gpt4 = st.session_state.llm_gpt4
74
+
75
+ ## ------------------------------------------------------------------------------------------------
76
+ ## Patient part
77
+
78
+ index_name = "indexes/ChestPainQA"
79
+
80
+ if "store" not in st.session_state:
81
+ st.session_state.store = db.get_store(index_name, embeddings=embeddings)
82
+ store = st.session_state.store
83
+
84
+ TEMPLATE = """You are a patient undergoing a medical check-up. You will be given the following:
85
+ 1. A context to answer the doctor, for your possible symptoms.
86
+ 2. A question about your current symptoms.
87
+
88
+ Your task is to answer the doctor's questions as simple as possible, acting like a patient.
89
+ Do not include other symptoms that are not included in the context, which provides your symptoms.
90
+
91
+ Answer the question to the point, without any elaboration if you're not prodded with it.
92
+
93
+ As you are a patient, you do not know any medical jargon or lingo. Do not include specific medical terms in your reply.
94
+ You only know colloquial words for medical terms.
95
+ For example, you should not reply with "dysarthria", but instead with "cannot speak properly".
96
+ For example, you should not reply with "syncope", but instead with "fainting".
97
+
98
+ Here is the context:
99
+ {context}
100
+
101
+ ----------------------------------------------------------------
102
+ You are to reply the doctor's following question, with reference to the above context.
103
+ Question:
104
+ {question}
105
+ ----------------------------------------------------------------
106
+ Remember, answer in a short and sweet manner, don't talk too much.
107
+ Your reply:
108
+ """
109
+
110
+ prompt = PromptTemplate(
111
+ input_variables = ["question", "context"],
112
+ template = TEMPLATE
113
+ )
114
+ if "retriever" not in st.session_state:
115
+ st.session_state.retriever = store.as_retriever(search_type="similarity", search_kwargs={"k":2})
116
+ retriever = st.session_state.retriever
117
+
118
+ def format_docs(docs):
119
+ return "\n--------------------\n".join(doc.page_content for doc in docs)
120
+
121
+
122
+ if "memory" not in st.session_state:
123
+ st.session_state.memory = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
124
+ memory = st.session_state.memory
125
+
126
+
127
+ if "chain" not in st.session_state:
128
+ st.session_state.chain = (
129
+ {
130
+ "context": retriever | format_docs,
131
+ "question": RunnablePassthrough()
132
+ } |
133
+ LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=False)
134
+ )
135
+ chain = st.session_state.chain
136
+
137
+ sp_mapper = {"human":"student","ai":"patient"}
138
+
139
+ ## ------------------------------------------------------------------------------------------------
140
+ ## ------------------------------------------------------------------------------------------------
141
+ ## Grader part
142
+ index_name = "indexes/ChestPainRubrics"
143
+
144
+ # store = FAISS.load_local(index_name, embeddings)
145
+
146
+ if "store2" not in st.session_state:
147
+ st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
148
+ store2 = st.session_state.store2
149
+
150
+ TEMPLATE2 = """You are a teacher for medical students. You are grading a medical student on their OSCE, the Object Structured Clinical Examination.
151
+
152
+ Your task is to provide an overall assessment of a student's diagnosis, based on the rubrics provided.
153
+ You will be provided with the following information:
154
+ 1. The rubrics that the student should be judged based upon.
155
+ 2. The conversation history between the medical student and the patient.
156
+ 3. The final diagnosis that the student will make.
157
+
158
+ =================================================================
159
+
160
+ Your task is as follows:
161
+ 1. Your grading should touch on every part of the rubrics, and grade the student holistically.
162
+ Finally, provide an overall grade for the student.
163
+
164
+ Some additional information that is useful to understand the rubrics:
165
+ - The rubrics are segmented, with each area separated by dashes, such as "----------"
166
+ - There will be multiple segments on History Taking. For each segment, the rubrics and corresponding grades will be provided below the required history taking.
167
+ - For History Taking, you are to grade the student based on the rubrics, by checking the chat history between the patients and the medical student.
168
+ - There is an additional segment on Presentation, differentials, and diagnosis. The
169
+
170
+
171
+ =================================================================
172
+
173
+
174
+ Here are the rubrics for grading the student:
175
+ <rubrics>
176
+
177
+ {context}
178
+
179
+ </rubrics>
180
+
181
+ =================================================================
182
+ You are to give a comprehensive judgement based on the student's diagnosis, with reference to the above rubrics.
183
+
184
+ Here is the chat history between the medical student and the patient:
185
+
186
+ <history>
187
+
188
+ {history}
189
+
190
+ </history>
191
+ =================================================================
192
+
193
+
194
+ Student's final diagnosis:
195
+ <diagnosis>
196
+ {question}
197
+ </diagnosis>
198
+
199
+ =================================================================
200
+
201
+ Your grade:
202
+ """
203
+
204
+ prompt2 = PromptTemplate(
205
+ input_variables = ["question", "context", "history"],
206
+ template = TEMPLATE2
207
+ )
208
+ if "retriever2" not in st.session_state:
209
+ st.session_state.retriever2 = store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
210
+ retriever2 = st.session_state.retriever2
211
+
212
+ def format_docs(docs):
213
+ return "\n--------------------\n".join(doc.page_content for doc in docs)
214
+
215
+
216
+ fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
217
+
218
+ if "memory2" not in st.session_state:
219
+ st.session_state.memory2 = ConversationSummaryBufferMemory(llm=llm, memory_key="chat_history", input_key="question" )
220
+ memory2 = st.session_state.memory2
221
+
222
+ def x(_):
223
+ return fake_history
224
+
225
+ if "chain2" not in st.session_state:
226
+ st.session_state.chain2 = (
227
+ {
228
+ "context": retriever | format_docs,
229
+ "history": x,
230
+ "question": RunnablePassthrough(),
231
+ } |
232
+
233
+ LLMChain(llm=llm, prompt=prompt2, memory=memory, verbose=False)
234
+ )
235
+ chain2 = st.session_state.chain2
236
+
237
+ ## ------------------------------------------------------------------------------------------------
238
+ ## ------------------------------------------------------------------------------------------------
239
+ ## Streamlit now
240
+
241
+ # from dotenv import load_dotenv
242
+ # import os
243
+ # load_dotenv()
244
+ # key = os.environ.get("OPENAI_API_KEY")
245
+ # client = OpenAI(api_key=key)
246
+
247
+ st.title("UAT for PatientLLM and GraderLLM")
248
+ st.title("Chest pain for now")
249
+
250
+ ## Testing HTML
251
+ # html_string = """
252
+ # <canvas></canvas>
253
+
254
+
255
+ # <script>
256
+ # canvas = document.querySelector('canvas');
257
+ # canvas.width = 1024;
258
+ # canvas.height = 576;
259
+ # console.log(canvas);
260
+
261
+ # const c = canvas.getContext('2d');
262
+ # c.fillStyle = "green";
263
+ # c.fillRect(0,0,canvas.width,canvas.height);
264
+
265
+ # const img = new Image();
266
+ # img.src = "./tksfordumtrive.png";
267
+ # c.drawImage(img, 10, 10);
268
+ # </script>
269
+
270
+ # <style>
271
+ # body {
272
+ # margin: 0;
273
+ # }
274
+ # </style>
275
+ # """
276
+ # components.html(html_string,
277
+ # width=1280,
278
+ # height=640)
279
+
280
+
281
+ st.write("Timer has been removed, switch with this button")
282
+
283
+ st.write("Buggy button, please double click")
284
+ if st.button(f"Switch to {'PATIENT' if st.session_state.active_chat==2 else 'GRADER'}"):
285
+ st.session_state.active_chat = 3 - st.session_state.active_chat
286
+
287
+ st.write(st.session_state.active_chat)
288
+
289
+ # Create two columns for the two chat interfaces
290
+ col1, col2 = st.columns(2)
291
+
292
+ # First chat interface
293
+ with col1:
294
+ st.subheader("Student LLM")
295
+ for message in st.session_state.messages_1:
296
+ with st.chat_message(message["role"]):
297
+ st.markdown(message["content"])
298
+
299
+ # Second chat interface
300
+ with col2:
301
+ st.write("pls dun spam this, its tons of tokens cos chat history")
302
+ st.subheader("Grader LLM")
303
+ for message in st.session_state.messages_2:
304
+ with st.chat_message(message["role"]):
305
+ st.markdown(message["content"])
306
+
307
+ # Timer and Input
308
+ # time_left = None
309
+ # if st.session_state.start_time:
310
+ # time_elapsed = datetime.datetime.now() - st.session_state.start_time
311
+ # time_left = datetime.timedelta(minutes=10) - time_elapsed
312
+ # st.write(f"Time left: {time_left}")
313
+
314
+ # if time_left is None or time_left > datetime.timedelta(0):
315
+ # # Chat 1 is active
316
+ # prompt = st.text_input("Enter your message for Chat 1:")
317
+ # active_chat = 1
318
+ # messages = st.session_state.messages_1
319
+ # elif time_left and time_left <= datetime.timedelta(0):
320
+ # # Chat 2 is active
321
+ # prompt = st.text_input("Enter your message for Chat 2:")
322
+ # active_chat = 2
323
+ # messages = st.session_state.messages_2
324
+
325
+ if st.session_state.active_chat==1:
326
+ text_prompt = st.text_input("Enter your message for PATIENT")
327
+ messages = st.session_state.messages_1
328
+ else:
329
+ text_prompt = st.text_input("Enter your message for GRADER")
330
+ messages = st.session_state.messages_2
331
+
332
+
333
+ if text_prompt:
334
+ messages.append({"role": "user", "content": text_prompt})
335
+
336
+ with (col1 if st.session_state.active_chat == 1 else col2):
337
+ with st.chat_message("user"):
338
+ st.markdown(text_prompt)
339
+
340
+ with (col1 if st.session_state.active_chat == 1 else col2):
341
+ with st.chat_message("assistant"):
342
+ message_placeholder = st.empty()
343
+ if st.session_state.active_chat==1:
344
+ full_response = chain.invoke(text_prompt).get("text")
345
+ else:
346
+ full_response = chain2.invoke(text_prompt).get("text")
347
+ message_placeholder.markdown(full_response)
348
+ messages.append({"role": "assistant", "content": full_response})
349
+
350
+
351
+ # import streamlit as st
352
+ # import time
353
+ # def count_down(ts):
354
+ # with st.empty():
355
+ # while ts:
356
+ # mins, secs = divmod(ts, 60)
357
+ # time_now = '{:02d}:{:02d}'.format(mins, secs)
358
+ # st.header(f"{time_now}")
359
+ # time.sleep(1)
360
+ # ts -= 1
361
+ # st.write("Time Up!")
362
+ # def main():
363
+ # st.title("Pomodoro")
364
+ # time_minutes = st.number_input('Enter the time in minutes ', min_value=1, value=25)
365
+ # time_in_seconds = time_minutes * 60
366
+ # if st.button("START"):
367
+ # count_down(int(time_in_seconds))
368
+ # if __name__ == '__main__':
369
+ # main()
370
+