Spaces:
Runtime error
Runtime error
Yew Chong
commited on
Commit
•
534b6f7
1
Parent(s):
02eae13
updates on ui
Browse files- app.py +5 -18
- app_final.py +96 -278
app.py
CHANGED
@@ -199,21 +199,16 @@ retriever2 = st.session_state.retriever2
|
|
199 |
def format_docs(docs):
|
200 |
return "\n--------------------\n".join(doc.page_content for doc in docs)
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
fake_history = '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1])
|
205 |
-
st.write(fake_history)
|
206 |
-
|
207 |
-
def y(_):
|
208 |
-
return fake_history
|
209 |
|
210 |
if ("chain2" not in st.session_state
|
211 |
or
|
212 |
st.session_state.TEMPLATE2 != TEMPLATE2):
|
213 |
st.session_state.chain2 = (
|
214 |
{
|
215 |
-
"context":
|
216 |
-
"history":
|
217 |
"question": RunnablePassthrough(),
|
218 |
} |
|
219 |
|
@@ -337,10 +332,6 @@ from langchain.callbacks.manager import tracing_v2_enabled
|
|
337 |
from uuid import uuid4
|
338 |
import os
|
339 |
|
340 |
-
os.environ['LANGCHAIN_TRACING_V2']='true'
|
341 |
-
os.environ['LANGCHAIN_ENDPOINT']='https://api.smith.langchain.com'
|
342 |
-
os.environ['LANGCHAIN_API_KEY']='ls__4ad767c45b844e6a8d790e12f556d3ca'
|
343 |
-
os.environ['LANGCHAIN_PROJECT']='streamlit'
|
344 |
|
345 |
|
346 |
if text_prompt:
|
@@ -353,7 +344,7 @@ if text_prompt:
|
|
353 |
with (col1 if st.session_state.active_chat == 1 else col2):
|
354 |
with st.chat_message("assistant"):
|
355 |
message_placeholder = st.empty()
|
356 |
-
with tracing_v2_enabled(project_name = "streamlit"):
|
357 |
if st.session_state.active_chat==1:
|
358 |
full_response = chain.invoke(text_prompt).get("text")
|
359 |
else:
|
@@ -381,7 +372,3 @@ if text_prompt:
|
|
381 |
# count_down(int(time_in_seconds))
|
382 |
# if __name__ == '__main__':
|
383 |
# main()
|
384 |
-
|
385 |
-
st.write('fake history is:')
|
386 |
-
st.write(y(""))
|
387 |
-
st.write('done')
|
|
|
199 |
def format_docs(docs):
|
200 |
return "\n--------------------\n".join(doc.page_content for doc in docs)
|
201 |
|
202 |
+
def get_history(_):
|
203 |
+
return '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1])
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
if ("chain2" not in st.session_state
|
206 |
or
|
207 |
st.session_state.TEMPLATE2 != TEMPLATE2):
|
208 |
st.session_state.chain2 = (
|
209 |
{
|
210 |
+
"context": retriever2 | format_docs,
|
211 |
+
"history": get_history,
|
212 |
"question": RunnablePassthrough(),
|
213 |
} |
|
214 |
|
|
|
332 |
from uuid import uuid4
|
333 |
import os
|
334 |
|
|
|
|
|
|
|
|
|
335 |
|
336 |
|
337 |
if text_prompt:
|
|
|
344 |
with (col1 if st.session_state.active_chat == 1 else col2):
|
345 |
with st.chat_message("assistant"):
|
346 |
message_placeholder = st.empty()
|
347 |
+
if True: ## with tracing_v2_enabled(project_name = "streamlit"):
|
348 |
if st.session_state.active_chat==1:
|
349 |
full_response = chain.invoke(text_prompt).get("text")
|
350 |
else:
|
|
|
372 |
# count_down(int(time_in_seconds))
|
373 |
# if __name__ == '__main__':
|
374 |
# main()
|
|
|
|
|
|
|
|
app_final.py
CHANGED
@@ -14,11 +14,6 @@ import os
|
|
14 |
# parentdir = os.path.dirname(currentdir)
|
15 |
# sys.path.append(parentdir)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
# ## ----------------------------------------------------------------
|
21 |
-
# ## LLM Part
|
22 |
import openai
|
23 |
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
|
24 |
import tiktoken
|
@@ -40,6 +35,13 @@ import os, dotenv
|
|
40 |
from dotenv import load_dotenv
|
41 |
load_dotenv()
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if not os.path.isdir("./.streamlit"):
|
44 |
os.mkdir("./.streamlit")
|
45 |
print('made streamlit folder')
|
@@ -83,25 +85,6 @@ Vomiting
|
|
83 |
Weakness
|
84 |
Weakness2""".split("\n")
|
85 |
|
86 |
-
# if "selected_index" not in st.session_state:
|
87 |
-
# st.session_state.selected_index = 3
|
88 |
-
|
89 |
-
# if "index_selectbox" not in st.session_state:
|
90 |
-
# st.session_state.index_selectbox = "Headache"
|
91 |
-
|
92 |
-
# index_selectbox = st.selectbox("Select index",indexes, index=int(st.session_state.selected_index))
|
93 |
-
|
94 |
-
# if index_selectbox != indexes[st.session_state.selected_index]:
|
95 |
-
# st.session_state.selected_index = indexes.index(index_selectbox)
|
96 |
-
# st.session_state.index_selectbox = index_selectbox
|
97 |
-
# del st.session_state["store"]
|
98 |
-
# del st.session_state["store2"]
|
99 |
-
# del st.session_state["retriever"]
|
100 |
-
# del st.session_state["retriever2"]
|
101 |
-
# del st.session_state["chain"]
|
102 |
-
# del st.session_state["chain2"]
|
103 |
-
|
104 |
-
|
105 |
model_name = "bge-large-en-v1.5"
|
106 |
model_kwargs = {"device": "cpu"}
|
107 |
encode_kwargs = {"normalize_embeddings": True}
|
@@ -122,73 +105,29 @@ if "llm_gpt4" not in st.session_state:
|
|
122 |
st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
123 |
llm_gpt4 = st.session_state.llm_gpt4
|
124 |
|
125 |
-
# ## ------------------------------------------------------------------------------------------------
|
126 |
-
# ## Patient part
|
127 |
-
|
128 |
-
# index_name = f"indexes/{st.session_state.index_selectbox}/QA"
|
129 |
-
|
130 |
-
# if "store" not in st.session_state:
|
131 |
-
# st.session_state.store = db.get_store(index_name, embeddings=embeddings)
|
132 |
-
# store = st.session_state.store
|
133 |
|
134 |
if "TEMPLATE" not in st.session_state:
|
135 |
with open('templates/patient.txt', 'r') as file:
|
136 |
TEMPLATE = file.read()
|
137 |
st.session_state.TEMPLATE = TEMPLATE
|
138 |
TEMPLATE = st.session_state.TEMPLATE
|
139 |
-
# with st.expander("Patient Prompt"):
|
140 |
-
# TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
|
141 |
|
142 |
prompt = PromptTemplate(
|
143 |
input_variables = ["question", "context"],
|
144 |
template = st.session_state.TEMPLATE
|
145 |
)
|
146 |
|
147 |
-
# if "retriever" not in st.session_state:
|
148 |
-
# st.session_state.retriever = store.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
149 |
-
# retriever = st.session_state.retriever
|
150 |
-
|
151 |
def format_docs(docs):
|
152 |
return "\n--------------------\n".join(doc.page_content for doc in docs)
|
153 |
|
154 |
|
155 |
-
# if "memory" not in st.session_state:
|
156 |
-
# st.session_state.memory = ConversationBufferWindowMemory(
|
157 |
-
# llm=llm, memory_key="chat_history", input_key="question",
|
158 |
-
# k=5, human_prefix="student", ai_prefix="patient",)
|
159 |
-
# memory = st.session_state.memory
|
160 |
-
|
161 |
-
|
162 |
-
# if ("chain" not in st.session_state
|
163 |
-
# or
|
164 |
-
# st.session_state.TEMPLATE != TEMPLATE):
|
165 |
-
# st.session_state.chain = (
|
166 |
-
# {
|
167 |
-
# "context": retriever | format_docs,
|
168 |
-
# "question": RunnablePassthrough()
|
169 |
-
# } |
|
170 |
-
# LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=False)
|
171 |
-
# )
|
172 |
-
# chain = st.session_state.chain
|
173 |
-
|
174 |
sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"}
|
175 |
|
176 |
-
# ## ------------------------------------------------------------------------------------------------
|
177 |
-
# ## ------------------------------------------------------------------------------------------------
|
178 |
-
# ## Grader part
|
179 |
-
# index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
|
180 |
-
|
181 |
-
# if "store2" not in st.session_state:
|
182 |
-
# st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
|
183 |
-
# store2 = st.session_state.store2
|
184 |
-
|
185 |
if "TEMPLATE2" not in st.session_state:
|
186 |
with open('templates/grader.txt', 'r') as file:
|
187 |
TEMPLATE2 = file.read()
|
188 |
st.session_state.TEMPLATE2 = TEMPLATE2
|
189 |
TEMPLATE2 = st.session_state.TEMPLATE2
|
190 |
-
# with st.expander("Grader Prompt"):
|
191 |
-
# TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2)
|
192 |
|
193 |
prompt2 = PromptTemplate(
|
194 |
input_variables = ["question", "context", "history"],
|
@@ -198,178 +137,6 @@ prompt2 = PromptTemplate(
|
|
198 |
def get_patient_chat_history(_):
|
199 |
return st.session_state.get("patient_chat_history")
|
200 |
|
201 |
-
# if "retriever2" not in st.session_state:
|
202 |
-
# st.session_state.retriever2 = store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
203 |
-
# retriever2 = st.session_state.retriever2
|
204 |
-
|
205 |
-
# def format_docs(docs):
|
206 |
-
# return "\n--------------------\n".join(doc.page_content for doc in docs)
|
207 |
-
|
208 |
-
|
209 |
-
# fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
|
210 |
-
# fake_history = '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1])
|
211 |
-
# st.write(fake_history)
|
212 |
-
|
213 |
-
# def y(_):
|
214 |
-
# return fake_history
|
215 |
-
|
216 |
-
# if ("chain2" not in st.session_state
|
217 |
-
# or
|
218 |
-
# st.session_state.TEMPLATE2 != TEMPLATE2):
|
219 |
-
# st.session_state.chain2 = (
|
220 |
-
# {
|
221 |
-
# "context": retriever2 | format_docs,
|
222 |
-
# "history": y,
|
223 |
-
# "question": RunnablePassthrough(),
|
224 |
-
# } |
|
225 |
-
|
226 |
-
# # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
|
227 |
-
# LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #|
|
228 |
-
# | {
|
229 |
-
# "json": itemgetter("text"),
|
230 |
-
# "text": (
|
231 |
-
# LLMChain(
|
232 |
-
# llm=llm,
|
233 |
-
# prompt=PromptTemplate(
|
234 |
-
# input_variables=["text"],
|
235 |
-
# template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
|
236 |
-
# verbose=False)
|
237 |
-
# )
|
238 |
-
# }
|
239 |
-
# )
|
240 |
-
# chain2 = st.session_state.chain2
|
241 |
-
|
242 |
-
# ## ------------------------------------------------------------------------------------------------
|
243 |
-
# ## ------------------------------------------------------------------------------------------------
|
244 |
-
# ## Streamlit now
|
245 |
-
|
246 |
-
# # from dotenv import load_dotenv
|
247 |
-
# # import os
|
248 |
-
# # load_dotenv()
|
249 |
-
# # key = os.environ.get("OPENAI_API_KEY")
|
250 |
-
# # client = OpenAI(api_key=key)
|
251 |
-
|
252 |
-
|
253 |
-
# if st.button("Clear History and Memory", type="primary"):
|
254 |
-
# st.session_state.messages_1 = []
|
255 |
-
# st.session_state.messages_2 = []
|
256 |
-
# st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" )
|
257 |
-
# memory = st.session_state.memory
|
258 |
-
|
259 |
-
# ## Testing HTML
|
260 |
-
# # html_string = """
|
261 |
-
# # <canvas></canvas>
|
262 |
-
|
263 |
-
|
264 |
-
# # <script>
|
265 |
-
# # canvas = document.querySelector('canvas');
|
266 |
-
# # canvas.width = 1024;
|
267 |
-
# # canvas.height = 576;
|
268 |
-
# # console.log(canvas);
|
269 |
-
|
270 |
-
# # const c = canvas.getContext('2d');
|
271 |
-
# # c.fillStyle = "green";
|
272 |
-
# # c.fillRect(0,0,canvas.width,canvas.height);
|
273 |
-
|
274 |
-
# # const img = new Image();
|
275 |
-
# # img.src = "./tksfordumtrive.png";
|
276 |
-
# # c.drawImage(img, 10, 10);
|
277 |
-
# # </script>
|
278 |
-
|
279 |
-
# # <style>
|
280 |
-
# # body {
|
281 |
-
# # margin: 0;
|
282 |
-
# # }
|
283 |
-
# # </style>
|
284 |
-
# # """
|
285 |
-
# # components.html(html_string,
|
286 |
-
# # width=1280,
|
287 |
-
# # height=640)
|
288 |
-
|
289 |
-
|
290 |
-
# st.write("Timer has been removed, switch with this button")
|
291 |
-
|
292 |
-
# if st.button(f"Switch to {'PATIENT' if st.session_state.active_chat==2 else 'GRADER'}"+".... Buggy button, please double click"):
|
293 |
-
# st.session_state.active_chat = 3 - st.session_state.active_chat
|
294 |
-
|
295 |
-
# # st.write("Currently in " + ('PATIENT' if st.session_state.active_chat==2 else 'GRADER'))
|
296 |
-
|
297 |
-
# # Create two columns for the two chat interfaces
|
298 |
-
# col1, col2 = st.columns(2)
|
299 |
-
|
300 |
-
# # First chat interface
|
301 |
-
# with col1:
|
302 |
-
# st.subheader("Student LLM")
|
303 |
-
# for message in st.session_state.messages_1:
|
304 |
-
# with st.chat_message(message["role"]):
|
305 |
-
# st.markdown(message["content"])
|
306 |
-
|
307 |
-
# # Second chat interface
|
308 |
-
# with col2:
|
309 |
-
# # st.write("pls dun spam this, its tons of tokens cos chat history")
|
310 |
-
# st.subheader("Grader LLM")
|
311 |
-
# st.write("grader takes a while to load... please be patient")
|
312 |
-
# for message in st.session_state.messages_2:
|
313 |
-
# with st.chat_message(message["role"]):
|
314 |
-
# st.markdown(message["content"])
|
315 |
-
|
316 |
-
# # Timer and Input
|
317 |
-
# # time_left = None
|
318 |
-
# # if st.session_state.start_time:
|
319 |
-
# # time_elapsed = datetime.datetime.now() - st.session_state.start_time
|
320 |
-
# # time_left = datetime.timedelta(minutes=10) - time_elapsed
|
321 |
-
# # st.write(f"Time left: {time_left}")
|
322 |
-
|
323 |
-
# # if time_left is None or time_left > datetime.timedelta(0):
|
324 |
-
# # # Chat 1 is active
|
325 |
-
# # prompt = st.text_input("Enter your message for Chat 1:")
|
326 |
-
# # active_chat = 1
|
327 |
-
# # messages = st.session_state.messages_1
|
328 |
-
# # elif time_left and time_left <= datetime.timedelta(0):
|
329 |
-
# # # Chat 2 is active
|
330 |
-
# # prompt = st.text_input("Enter your message for Chat 2:")
|
331 |
-
# # active_chat = 2
|
332 |
-
# # messages = st.session_state.messages_2
|
333 |
-
|
334 |
-
# if st.session_state.active_chat==1:
|
335 |
-
# text_prompt = st.text_input("Enter your message for PATIENT")
|
336 |
-
# messages = st.session_state.messages_1
|
337 |
-
# else:
|
338 |
-
# text_prompt = st.text_input("Enter your message for GRADER")
|
339 |
-
# messages = st.session_state.messages_2
|
340 |
-
|
341 |
-
|
342 |
-
# from langchain.callbacks.manager import tracing_v2_enabled
|
343 |
-
# from uuid import uuid4
|
344 |
-
# import os
|
345 |
-
|
346 |
-
# if text_prompt:
|
347 |
-
# messages.append({"role": "user", "content": text_prompt})
|
348 |
-
|
349 |
-
# with (col1 if st.session_state.active_chat == 1 else col2):
|
350 |
-
# with st.chat_message("user"):
|
351 |
-
# st.markdown(text_prompt)
|
352 |
-
|
353 |
-
# with (col1 if st.session_state.active_chat == 1 else col2):
|
354 |
-
# with st.chat_message("assistant"):
|
355 |
-
# message_placeholder = st.empty()
|
356 |
-
# if True: ## with tracing_v2_enabled(project_name = "streamlit"):
|
357 |
-
# if st.session_state.active_chat==1:
|
358 |
-
# full_response = chain.invoke(text_prompt).get("text")
|
359 |
-
# else:
|
360 |
-
# full_response = chain2.invoke(text_prompt).get("text").get("text")
|
361 |
-
# message_placeholder.markdown(full_response)
|
362 |
-
# messages.append({"role": "assistant", "content": full_response})
|
363 |
-
|
364 |
-
|
365 |
-
# st.write('fake history is:')
|
366 |
-
# st.write(y(""))
|
367 |
-
# st.write('done')
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
## ====================
|
373 |
|
374 |
if not st.session_state.get("scenario_list", None):
|
375 |
st.session_state.scenario_list = indexes
|
@@ -497,15 +264,15 @@ else:
|
|
497 |
st.session_state.scenario_tab_index=x
|
498 |
return None
|
499 |
|
500 |
-
def select_scenario_and_change_tab(_):
|
501 |
-
|
502 |
|
503 |
def go_to_patient_llm():
|
504 |
selected_scenario = st.session_state.get('selected_scenario')
|
505 |
if selected_scenario is None or selected_scenario < 0:
|
506 |
st.warning("Please select a scenario!")
|
507 |
else:
|
508 |
-
|
509 |
states = ["store", "store2","retriever","retriever2","chain","chain2"]
|
510 |
for state_to_del in states:
|
511 |
if state_to_del in st.session_state:
|
@@ -549,14 +316,14 @@ else:
|
|
549 |
col1, col2, col3 = st.columns([1,3,1])
|
550 |
with col1:
|
551 |
back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
552 |
-
with col3:
|
553 |
-
|
554 |
|
555 |
with col2:
|
556 |
TIME_LIMIT = 60*10 ## to change to 10 minutes
|
557 |
time.sleep(1)
|
558 |
-
if start_timer_button:
|
559 |
-
|
560 |
# st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time')
|
561 |
st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time
|
562 |
|
@@ -638,9 +405,9 @@ else:
|
|
638 |
""",
|
639 |
)
|
640 |
|
641 |
-
with open("./public/
|
642 |
contents = f.read()
|
643 |
-
|
644 |
|
645 |
with open("./public/chars/Male_talk.gif", "rb") as f:
|
646 |
contents = f.read()
|
@@ -660,26 +427,83 @@ else:
|
|
660 |
st.session_state.patient_response = response
|
661 |
with interactive_container:
|
662 |
html(f"""
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
elif st.session_state.scenario_tab_index == ScenarioTabIndex.GRADER_LLM:
|
685 |
st.session_state.grader_output = "" if not st.session_state.get("grader_output") else st.session_state.grader_output
|
@@ -710,7 +534,7 @@ else:
|
|
710 |
st.session_state.differential_3 = st.text_input("Differential 3")
|
711 |
with st.columns(6)[5]:
|
712 |
send_for_grading = st.button("Get grades!", on_click=get_grades)
|
713 |
-
with st.expander("Your
|
714 |
if st.session_state.grader_output:
|
715 |
st.write(st.session_state.grader_output.get("text").get("text"))
|
716 |
|
@@ -718,14 +542,8 @@ else:
|
|
718 |
back_btn = st.button("New Scenario?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
719 |
|
720 |
with dashboard_tab:
|
721 |
-
|
722 |
-
|
723 |
-
from firebase_admin import credentials, storage, firestore
|
724 |
-
import plotly.express as px
|
725 |
-
import plotly.graph_objects as go
|
726 |
-
import pandas as pd
|
727 |
-
|
728 |
-
cred = credentials.Certificate(json.loads(os.environ.get("FIREBASE_CREDENTIAL")))
|
729 |
|
730 |
# Initialize Firebase (if not already initialized)
|
731 |
if not firebase_admin._apps:
|
|
|
14 |
# parentdir = os.path.dirname(currentdir)
|
15 |
# sys.path.append(parentdir)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
import openai
|
18 |
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
|
19 |
import tiktoken
|
|
|
35 |
from dotenv import load_dotenv
|
36 |
load_dotenv()
|
37 |
|
38 |
+
import firebase_admin, json
|
39 |
+
from firebase_admin import credentials, storage, firestore
|
40 |
+
import plotly.express as px
|
41 |
+
import plotly.graph_objects as go
|
42 |
+
import pandas as pd
|
43 |
+
|
44 |
+
|
45 |
if not os.path.isdir("./.streamlit"):
|
46 |
os.mkdir("./.streamlit")
|
47 |
print('made streamlit folder')
|
|
|
85 |
Weakness
|
86 |
Weakness2""".split("\n")
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
model_name = "bge-large-en-v1.5"
|
89 |
model_kwargs = {"device": "cpu"}
|
90 |
encode_kwargs = {"normalize_embeddings": True}
|
|
|
105 |
st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
106 |
llm_gpt4 = st.session_state.llm_gpt4
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
if "TEMPLATE" not in st.session_state:
|
110 |
with open('templates/patient.txt', 'r') as file:
|
111 |
TEMPLATE = file.read()
|
112 |
st.session_state.TEMPLATE = TEMPLATE
|
113 |
TEMPLATE = st.session_state.TEMPLATE
|
|
|
|
|
114 |
|
115 |
prompt = PromptTemplate(
|
116 |
input_variables = ["question", "context"],
|
117 |
template = st.session_state.TEMPLATE
|
118 |
)
|
119 |
|
|
|
|
|
|
|
|
|
120 |
def format_docs(docs):
|
121 |
return "\n--------------------\n".join(doc.page_content for doc in docs)
|
122 |
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"}
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
if "TEMPLATE2" not in st.session_state:
|
127 |
with open('templates/grader.txt', 'r') as file:
|
128 |
TEMPLATE2 = file.read()
|
129 |
st.session_state.TEMPLATE2 = TEMPLATE2
|
130 |
TEMPLATE2 = st.session_state.TEMPLATE2
|
|
|
|
|
131 |
|
132 |
prompt2 = PromptTemplate(
|
133 |
input_variables = ["question", "context", "history"],
|
|
|
137 |
def get_patient_chat_history(_):
|
138 |
return st.session_state.get("patient_chat_history")
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if not st.session_state.get("scenario_list", None):
|
142 |
st.session_state.scenario_list = indexes
|
|
|
264 |
st.session_state.scenario_tab_index=x
|
265 |
return None
|
266 |
|
267 |
+
# def select_scenario_and_change_tab(_):
|
268 |
+
# set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM)
|
269 |
|
270 |
def go_to_patient_llm():
|
271 |
selected_scenario = st.session_state.get('selected_scenario')
|
272 |
if selected_scenario is None or selected_scenario < 0:
|
273 |
st.warning("Please select a scenario!")
|
274 |
else:
|
275 |
+
st.session_state.start_time = datetime.datetime.now()
|
276 |
states = ["store", "store2","retriever","retriever2","chain","chain2"]
|
277 |
for state_to_del in states:
|
278 |
if state_to_del in st.session_state:
|
|
|
316 |
col1, col2, col3 = st.columns([1,3,1])
|
317 |
with col1:
|
318 |
back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
319 |
+
# with col3:
|
320 |
+
# start_timer_button = st.button("START")
|
321 |
|
322 |
with col2:
|
323 |
TIME_LIMIT = 60*10 ## to change to 10 minutes
|
324 |
time.sleep(1)
|
325 |
+
# if start_timer_button:
|
326 |
+
# st.session_state.start_time = datetime.datetime.now()
|
327 |
# st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time')
|
328 |
st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time
|
329 |
|
|
|
405 |
""",
|
406 |
)
|
407 |
|
408 |
+
with open("./public/chars/Female_talk.gif", "rb") as f:
|
409 |
contents = f.read()
|
410 |
+
student_url = base64.b64encode(contents).decode("utf-8")
|
411 |
|
412 |
with open("./public/chars/Male_talk.gif", "rb") as f:
|
413 |
contents = f.read()
|
|
|
427 |
st.session_state.patient_response = response
|
428 |
with interactive_container:
|
429 |
html(f"""
|
430 |
+
<style>
|
431 |
+
.conversation-container {{
|
432 |
+
display: grid;
|
433 |
+
grid-template-columns: 1fr 1fr;
|
434 |
+
grid-template-rows: 1fr 1fr;
|
435 |
+
gap: 10px;
|
436 |
+
width: 100%;
|
437 |
+
height: 100%;
|
438 |
+
background-color: #add8e6; /* Soothing blue background */
|
439 |
+
}}
|
440 |
+
|
441 |
+
.doctor-image {{
|
442 |
+
grid-column: 1;
|
443 |
+
grid-row: 2;
|
444 |
+
display: flex;
|
445 |
+
justify-content: center;
|
446 |
+
align-items: center;
|
447 |
+
}}
|
448 |
+
|
449 |
+
.patient-image {{
|
450 |
+
grid-column: 2;
|
451 |
+
grid-row: 1;
|
452 |
+
display: flex;
|
453 |
+
justify-content: center;
|
454 |
+
align-items: center;
|
455 |
+
}}
|
456 |
+
|
457 |
+
.doctor-input {{
|
458 |
+
grid-column: 2;
|
459 |
+
grid-row: 2;
|
460 |
+
display: flex;
|
461 |
+
justify-content: center;
|
462 |
+
align-items: center;
|
463 |
+
}}
|
464 |
+
|
465 |
+
.patient-input {{
|
466 |
+
grid-column: 1;
|
467 |
+
grid-row: 1;
|
468 |
+
display: flex;
|
469 |
+
justify-content: center;
|
470 |
+
align-items: center;
|
471 |
+
}}
|
472 |
+
|
473 |
+
img {{
|
474 |
+
max-width: 100%;
|
475 |
+
height: auto;
|
476 |
+
border-radius: 8px; /* Rounded corners for the images */
|
477 |
+
}}
|
478 |
+
|
479 |
+
input[type="text"] {{
|
480 |
+
width: 90%;
|
481 |
+
padding: 10px;
|
482 |
+
margin: 10px;
|
483 |
+
border: none;
|
484 |
+
border-radius: 5px;
|
485 |
+
}}
|
486 |
+
</style>
|
487 |
+
</head>
|
488 |
+
<body>
|
489 |
+
<div class="conversation-container">
|
490 |
+
<div class="doctor-image">
|
491 |
+
<img src="data:image/png;base64,{student_url}" alt="Doctor" />
|
492 |
+
</div>
|
493 |
+
<div class="patient-image">
|
494 |
+
<img src="data:image/gif;base64,{patient_url}" alt="Patient" />
|
495 |
+
</div>
|
496 |
+
<div class="doctor-input">
|
497 |
+
<span id="doctor_message">You: {st.session_state.get('user_inputs') or ''}</span>
|
498 |
+
</div>
|
499 |
+
<div class="patient-input">
|
500 |
+
<span id="patient_message">{'Patient: '+st.session_state.get('patient_response') if st.session_state.get('patient_response') else '...'}</span>
|
501 |
+
</div>
|
502 |
+
</div>
|
503 |
+
</body>
|
504 |
+
</html>
|
505 |
+
|
506 |
+
""", height=500)
|
507 |
|
508 |
elif st.session_state.scenario_tab_index == ScenarioTabIndex.GRADER_LLM:
|
509 |
st.session_state.grader_output = "" if not st.session_state.get("grader_output") else st.session_state.grader_output
|
|
|
534 |
st.session_state.differential_3 = st.text_input("Differential 3")
|
535 |
with st.columns(6)[5]:
|
536 |
send_for_grading = st.button("Get grades!", on_click=get_grades)
|
537 |
+
with st.expander("Your grade", expanded=st.session_state.has_llm_output):
|
538 |
if st.session_state.grader_output:
|
539 |
st.write(st.session_state.grader_output.get("text").get("text"))
|
540 |
|
|
|
542 |
back_btn = st.button("New Scenario?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
543 |
|
544 |
with dashboard_tab:
|
545 |
+
cred = db.cred
|
546 |
+
# cred = credentials.Certificate(json.loads(os.environ.get("FIREBASE_CREDENTIAL")))
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
|
548 |
# Initialize Firebase (if not already initialized)
|
549 |
if not firebase_admin._apps:
|