Yew Chong commited on
Commit
bc4dcba
1 Parent(s): ac9e946

final combined app

Browse files
.gitignore CHANGED
@@ -24,8 +24,5 @@ test*.py
24
  test*.html
25
  test*.ipynb
26
 
27
- ## Images
28
- *.png
29
-
30
  # streamlit
31
  .streamlit/secrets.toml
 
24
  test*.html
25
  test*.ipynb
26
 
 
 
 
27
  # streamlit
28
  .streamlit/secrets.toml
.streamlit/config.toml CHANGED
@@ -1,3 +1,4 @@
1
  [theme]
2
  base = "dark"
3
- primaryColor="#6633F6"
 
 
1
  [theme]
2
  base = "dark"
3
+ primaryColor="#6633F6"
4
+ backgroundColor="#0E1117"
README.md CHANGED
@@ -5,28 +5,21 @@ colorFrom: red
5
  colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.30.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- ## To test the LLM only
13
- Download the relevant LLM python notebooks (e.g. `LLM for Patient.ipynb`)
14
 
15
- Add your own .env file in the same directory as the notebook
16
 
17
- Run :)
18
 
19
- ----------------------------------------------------------------
20
- # How to run locally
21
 
22
- git pull everything down
23
 
24
- `python -m pip install -r requirements.txt`
25
 
26
- Add your own .env file based on the env.example (huggingface, openai, firebase tokens required)
27
-
28
- ???
29
-
30
- profit
31
 
32
  ---------------------------------
 
5
  colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.30.0
8
+ app_file: app_final.py
9
  pinned: false
10
  ---
11
 
12
+ ## How to run locally
 
13
 
14
+ 1. git clone
15
 
16
+ 2. `python -m pip install -r requirements.txt`
17
 
18
+ 3. Add your own .env file based on the env.example (huggingface, openai, firebase tokens required)
 
19
 
20
+ 4. `streamlit run app.py`
21
 
22
+ 5. Open `localhost:8501`
23
 
 
 
 
 
 
24
 
25
  ---------------------------------
app_final.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+ import datetime, time
5
+ from dataclasses import dataclass
6
+ import math
7
+ import base64
8
+
9
+ ## Firestore ??
10
+ import os
11
+ # import sys
12
+ # import inspect
13
+ # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
14
+ # parentdir = os.path.dirname(currentdir)
15
+ # sys.path.append(parentdir)
16
+
17
+
18
+
19
+
20
+ # ## ----------------------------------------------------------------
21
+ # ## LLM Part
22
+ import openai
23
+ from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
24
+ import tiktoken
25
+ from langchain.prompts.few_shot import FewShotPromptTemplate
26
+ from langchain.prompts.prompt import PromptTemplate
27
+ from operator import itemgetter
28
+ from langchain.schema import StrOutputParser
29
+ from langchain_core.output_parsers import StrOutputParser
30
+ from langchain_core.runnables import RunnablePassthrough
31
+
32
+ import langchain_community.embeddings.huggingface
33
+ from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings
34
+ from langchain_community.vectorstores import FAISS
35
+
36
+ from langchain.chains import LLMChain
37
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
38
+
39
+ import os, dotenv
40
+ from dotenv import load_dotenv
41
+ load_dotenv()
42
+
43
+ if not os.path.isdir("./.streamlit"):
44
+ os.mkdir("./.streamlit")
45
+ print('made streamlit folder')
46
+ if not os.path.isfile("./.streamlit/secrets.toml"):
47
+ with open("./.streamlit/secrets.toml", "w") as f:
48
+ f.write(os.environ.get("STREAMLIT_SECRETS"))
49
+ print('made new file')
50
+
51
+
52
+ import db_firestore as db
53
+
54
+ ## Load from streamlit!!
55
+ os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN") or st.secrets["HF_TOKEN"]
56
+ os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or st.secrets["OPENAI_API_KEY"]
57
+ os.environ["FIREBASE_CREDENTIAL"] = os.environ.get("FIREBASE_CREDENTIAL") or st.secrets["FIREBASE_CREDENTIAL"]
58
+
59
+
60
+ if "openai_model" not in st.session_state:
61
+ st.session_state["openai_model"] = "gpt-3.5-turbo-1106"
62
+
63
+ ## Hardcode indexes for now
64
+ ## TODO: Move indexes to firebase
65
+ indexes = """Bleeding
66
+ ChestPain
67
+ Dysphagia
68
+ Headache
69
+ ShortnessOfBreath
70
+ Vomiting
71
+ Weakness
72
+ Weakness2""".split("\n")
73
+
74
+ # if "selected_index" not in st.session_state:
75
+ # st.session_state.selected_index = 3
76
+
77
+ # if "index_selectbox" not in st.session_state:
78
+ # st.session_state.index_selectbox = "Headache"
79
+
80
+ # index_selectbox = st.selectbox("Select index",indexes, index=int(st.session_state.selected_index))
81
+
82
+ # if index_selectbox != indexes[st.session_state.selected_index]:
83
+ # st.session_state.selected_index = indexes.index(index_selectbox)
84
+ # st.session_state.index_selectbox = index_selectbox
85
+ # del st.session_state["store"]
86
+ # del st.session_state["store2"]
87
+ # del st.session_state["retriever"]
88
+ # del st.session_state["retriever2"]
89
+ # del st.session_state["chain"]
90
+ # del st.session_state["chain2"]
91
+
92
+
93
+ model_name = "bge-large-en-v1.5"
94
+ model_kwargs = {"device": "cpu"}
95
+ encode_kwargs = {"normalize_embeddings": True}
96
+ if "embeddings" not in st.session_state:
97
+ st.session_state.embeddings = HuggingFaceBgeEmbeddings(
98
+ # model_name=model_name,
99
+ model_kwargs = model_kwargs,
100
+ encode_kwargs = encode_kwargs)
101
+ embeddings = st.session_state.embeddings
102
+
103
+ if "llm" not in st.session_state:
104
+ st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
105
+ llm = st.session_state.llm
106
+ if "llm_i" not in st.session_state:
107
+ st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
108
+ llm_i = st.session_state.llm_i
109
+ if "llm_gpt4" not in st.session_state:
110
+ st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
111
+ llm_gpt4 = st.session_state.llm_gpt4
112
+
113
+ # ## ------------------------------------------------------------------------------------------------
114
+ # ## Patient part
115
+
116
+ # index_name = f"indexes/{st.session_state.index_selectbox}/QA"
117
+
118
+ # if "store" not in st.session_state:
119
+ # st.session_state.store = db.get_store(index_name, embeddings=embeddings)
120
+ # store = st.session_state.store
121
+
122
+ if "TEMPLATE" not in st.session_state:
123
+ with open('templates/patient.txt', 'r') as file:
124
+ TEMPLATE = file.read()
125
+ st.session_state.TEMPLATE = TEMPLATE
126
+ TEMPLATE = st.session_state.TEMPLATE
127
+ # with st.expander("Patient Prompt"):
128
+ # TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
129
+
130
+ prompt = PromptTemplate(
131
+ input_variables = ["question", "context"],
132
+ template = st.session_state.TEMPLATE
133
+ )
134
+
135
+ # if "retriever" not in st.session_state:
136
+ # st.session_state.retriever = store.as_retriever(search_type="similarity", search_kwargs={"k":2})
137
+ # retriever = st.session_state.retriever
138
+
139
+ def format_docs(docs):
140
+ return "\n--------------------\n".join(doc.page_content for doc in docs)
141
+
142
+
143
+ # if "memory" not in st.session_state:
144
+ # st.session_state.memory = ConversationBufferWindowMemory(
145
+ # llm=llm, memory_key="chat_history", input_key="question",
146
+ # k=5, human_prefix="student", ai_prefix="patient",)
147
+ # memory = st.session_state.memory
148
+
149
+
150
+ # if ("chain" not in st.session_state
151
+ # or
152
+ # st.session_state.TEMPLATE != TEMPLATE):
153
+ # st.session_state.chain = (
154
+ # {
155
+ # "context": retriever | format_docs,
156
+ # "question": RunnablePassthrough()
157
+ # } |
158
+ # LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=False)
159
+ # )
160
+ # chain = st.session_state.chain
161
+
162
+ sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"}
163
+
164
+ # ## ------------------------------------------------------------------------------------------------
165
+ # ## ------------------------------------------------------------------------------------------------
166
+ # ## Grader part
167
+ # index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
168
+
169
+ # if "store2" not in st.session_state:
170
+ # st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
171
+ # store2 = st.session_state.store2
172
+
173
+ if "TEMPLATE2" not in st.session_state:
174
+ with open('templates/grader.txt', 'r') as file:
175
+ TEMPLATE2 = file.read()
176
+ st.session_state.TEMPLATE2 = TEMPLATE2
177
+ TEMPLATE2 = st.session_state.TEMPLATE2
178
+ # with st.expander("Grader Prompt"):
179
+ # TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2)
180
+
181
+ prompt2 = PromptTemplate(
182
+ input_variables = ["question", "context", "history"],
183
+ template = st.session_state.TEMPLATE2
184
+ )
185
+
186
+ def get_patient_chat_history(_):
187
+ return st.session_state.get("patient_chat_history")
188
+
189
+ # if "retriever2" not in st.session_state:
190
+ # st.session_state.retriever2 = store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
191
+ # retriever2 = st.session_state.retriever2
192
+
193
+ # def format_docs(docs):
194
+ # return "\n--------------------\n".join(doc.page_content for doc in docs)
195
+
196
+
197
+ # fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
198
+ # fake_history = '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1])
199
+ # st.write(fake_history)
200
+
201
+ # def y(_):
202
+ # return fake_history
203
+
204
+ # if ("chain2" not in st.session_state
205
+ # or
206
+ # st.session_state.TEMPLATE2 != TEMPLATE2):
207
+ # st.session_state.chain2 = (
208
+ # {
209
+ # "context": retriever2 | format_docs,
210
+ # "history": y,
211
+ # "question": RunnablePassthrough(),
212
+ # } |
213
+
214
+ # # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
215
+ # LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #|
216
+ # | {
217
+ # "json": itemgetter("text"),
218
+ # "text": (
219
+ # LLMChain(
220
+ # llm=llm,
221
+ # prompt=PromptTemplate(
222
+ # input_variables=["text"],
223
+ # template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
224
+ # verbose=False)
225
+ # )
226
+ # }
227
+ # )
228
+ # chain2 = st.session_state.chain2
229
+
230
+ # ## ------------------------------------------------------------------------------------------------
231
+ # ## ------------------------------------------------------------------------------------------------
232
+ # ## Streamlit now
233
+
234
+ # # from dotenv import load_dotenv
235
+ # # import os
236
+ # # load_dotenv()
237
+ # # key = os.environ.get("OPENAI_API_KEY")
238
+ # # client = OpenAI(api_key=key)
239
+
240
+
241
+ # if st.button("Clear History and Memory", type="primary"):
242
+ # st.session_state.messages_1 = []
243
+ # st.session_state.messages_2 = []
244
+ # st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" )
245
+ # memory = st.session_state.memory
246
+
247
+ # ## Testing HTML
248
+ # # html_string = """
249
+ # # <canvas></canvas>
250
+
251
+
252
+ # # <script>
253
+ # # canvas = document.querySelector('canvas');
254
+ # # canvas.width = 1024;
255
+ # # canvas.height = 576;
256
+ # # console.log(canvas);
257
+
258
+ # # const c = canvas.getContext('2d');
259
+ # # c.fillStyle = "green";
260
+ # # c.fillRect(0,0,canvas.width,canvas.height);
261
+
262
+ # # const img = new Image();
263
+ # # img.src = "./tksfordumtrive.png";
264
+ # # c.drawImage(img, 10, 10);
265
+ # # </script>
266
+
267
+ # # <style>
268
+ # # body {
269
+ # # margin: 0;
270
+ # # }
271
+ # # </style>
272
+ # # """
273
+ # # components.html(html_string,
274
+ # # width=1280,
275
+ # # height=640)
276
+
277
+
278
+ # st.write("Timer has been removed, switch with this button")
279
+
280
+ # if st.button(f"Switch to {'PATIENT' if st.session_state.active_chat==2 else 'GRADER'}"+".... Buggy button, please double click"):
281
+ # st.session_state.active_chat = 3 - st.session_state.active_chat
282
+
283
+ # # st.write("Currently in " + ('PATIENT' if st.session_state.active_chat==2 else 'GRADER'))
284
+
285
+ # # Create two columns for the two chat interfaces
286
+ # col1, col2 = st.columns(2)
287
+
288
+ # # First chat interface
289
+ # with col1:
290
+ # st.subheader("Student LLM")
291
+ # for message in st.session_state.messages_1:
292
+ # with st.chat_message(message["role"]):
293
+ # st.markdown(message["content"])
294
+
295
+ # # Second chat interface
296
+ # with col2:
297
+ # # st.write("pls dun spam this, its tons of tokens cos chat history")
298
+ # st.subheader("Grader LLM")
299
+ # st.write("grader takes a while to load... please be patient")
300
+ # for message in st.session_state.messages_2:
301
+ # with st.chat_message(message["role"]):
302
+ # st.markdown(message["content"])
303
+
304
+ # # Timer and Input
305
+ # # time_left = None
306
+ # # if st.session_state.start_time:
307
+ # # time_elapsed = datetime.datetime.now() - st.session_state.start_time
308
+ # # time_left = datetime.timedelta(minutes=10) - time_elapsed
309
+ # # st.write(f"Time left: {time_left}")
310
+
311
+ # # if time_left is None or time_left > datetime.timedelta(0):
312
+ # # # Chat 1 is active
313
+ # # prompt = st.text_input("Enter your message for Chat 1:")
314
+ # # active_chat = 1
315
+ # # messages = st.session_state.messages_1
316
+ # # elif time_left and time_left <= datetime.timedelta(0):
317
+ # # # Chat 2 is active
318
+ # # prompt = st.text_input("Enter your message for Chat 2:")
319
+ # # active_chat = 2
320
+ # # messages = st.session_state.messages_2
321
+
322
+ # if st.session_state.active_chat==1:
323
+ # text_prompt = st.text_input("Enter your message for PATIENT")
324
+ # messages = st.session_state.messages_1
325
+ # else:
326
+ # text_prompt = st.text_input("Enter your message for GRADER")
327
+ # messages = st.session_state.messages_2
328
+
329
+
330
+ # from langchain.callbacks.manager import tracing_v2_enabled
331
+ # from uuid import uuid4
332
+ # import os
333
+
334
+ # if text_prompt:
335
+ # messages.append({"role": "user", "content": text_prompt})
336
+
337
+ # with (col1 if st.session_state.active_chat == 1 else col2):
338
+ # with st.chat_message("user"):
339
+ # st.markdown(text_prompt)
340
+
341
+ # with (col1 if st.session_state.active_chat == 1 else col2):
342
+ # with st.chat_message("assistant"):
343
+ # message_placeholder = st.empty()
344
+ # if True: ## with tracing_v2_enabled(project_name = "streamlit"):
345
+ # if st.session_state.active_chat==1:
346
+ # full_response = chain.invoke(text_prompt).get("text")
347
+ # else:
348
+ # full_response = chain2.invoke(text_prompt).get("text").get("text")
349
+ # message_placeholder.markdown(full_response)
350
+ # messages.append({"role": "assistant", "content": full_response})
351
+
352
+
353
+ # st.write('fake history is:')
354
+ # st.write(y(""))
355
+ # st.write('done')
356
+
357
+
358
+
359
+
360
+ ## ====================
361
+
362
+ if not st.session_state.get("scenario_list", None):
363
+ st.session_state.scenario_list = indexes
364
+
365
+ def init_patient_llm():
366
+ if "messages_1" not in st.session_state:
367
+ st.session_state.messages_1 = []
368
+ ## messages 2?
369
+
370
+ index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/QA"
371
+ if "store" not in st.session_state:
372
+ st.session_state.store = db.get_store(index_name, embeddings=embeddings)
373
+ if "retriever" not in st.session_state:
374
+ st.session_state.retriever = st.session_state.store.as_retriever(search_type="similarity", search_kwargs={"k":2})
375
+ if "memory" not in st.session_state:
376
+ st.session_state.memory = ConversationBufferWindowMemory(
377
+ llm=llm, memory_key="chat_history", input_key="question",
378
+ k=5, human_prefix="student", ai_prefix="patient",)
379
+
380
+ if ("chain" not in st.session_state
381
+ or
382
+ st.session_state.TEMPLATE != TEMPLATE):
383
+ st.session_state.chain = (
384
+ {
385
+ "context": st.session_state.retriever | format_docs,
386
+ "question": RunnablePassthrough()
387
+ } |
388
+ LLMChain(llm=llm, prompt=prompt, memory=st.session_state.memory, verbose=False)
389
+ )
390
+
391
+ def init_grader_llm():
392
+ ## Grader
393
+ index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/Rubric"
394
+
395
+ ## Reset time
396
+ st.session_state.start_time = False
397
+
398
+ if "store2" not in st.session_state:
399
+ st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
400
+ if "retriever2" not in st.session_state:
401
+ st.session_state.retriever2 = st.session_state.store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
402
+
403
+ ## Re-init history
404
+ st.session_state["patient_chat_history"] = "History\n" + '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in st.session_state.memory.chat_memory.messages])
405
+
406
+ if ("chain2" not in st.session_state
407
+ or
408
+ st.session_state.TEMPLATE2 != TEMPLATE2):
409
+ st.session_state.chain2 = (
410
+ {
411
+ "context": st.session_state.retriever2 | format_docs,
412
+ "history": (get_patient_chat_history),
413
+ "question": RunnablePassthrough(),
414
+ } |
415
+
416
+ # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
417
+ LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #|
418
+ | {
419
+ "json": itemgetter("text"),
420
+ "text": (
421
+ LLMChain(
422
+ llm=llm,
423
+ prompt=PromptTemplate(
424
+ input_variables=["text"],
425
+ template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
426
+ verbose=False)
427
+ )
428
+ }
429
+ )
430
+
431
+
432
+ login_info = {
433
+ "bob":"builder",
434
+ "student1": "password",
435
+ "admin":"admin"
436
+ }
437
+
438
+ def set_username(x):
439
+ st.session_state.username = x
440
+
441
+ def validate_username(username, password):
442
+ if login_info.get(username) == password:
443
+ set_username(username)
444
+ else:
445
+ st.warning("Wrong username or password")
446
+ return None
447
+
448
+ if not st.session_state.get("username"):
449
+ ## ask to login
450
+ st.title("Login")
451
+ username = st.text_input("Username:")
452
+ password = st.text_input("Password:", type="password")
453
+ login_button = st.button("Login", on_click=validate_username, args=[username, password])
454
+
455
+ else:
456
+ if True: ## Says hello and logout
457
+ col_1, col_2 = st.columns([1,3])
458
+ col_2.title(f"Hello there, {st.session_state.username}")
459
+ # Display logout button
460
+ if col_1.button('Logout'):
461
+ # Remove username from session state
462
+ del st.session_state.username
463
+ # Rerun the app to go back to the login view
464
+ st.rerun()
465
+
466
+ scenario_tab, dashboard_tab = st.tabs(["Training", "Dashboard"])
467
+ # st.header("head")
468
+ # st.markdown("## markdown")
469
+ # st.caption("caption")
470
+ # st.divider()
471
+ # import pandas as pd
472
+ # import numpy as np
473
+ # map_data = pd.DataFrame(
474
+ # np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],
475
+ # columns=['lat', 'lon'])
476
+
477
+ # st.map(map_data)
478
+
479
+ class ScenarioTabIndex:
480
+ SELECT_SCENARIO = 0
481
+ PATIENT_LLM = 1
482
+ GRADER_LLM = 2
483
+
484
+ def set_scenario_tab_index(x):
485
+ st.session_state.scenario_tab_index=x
486
+ return None
487
+
488
+ def select_scenario_and_change_tab(_):
489
+ set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM)
490
+
491
+ def go_to_patient_llm():
492
+ selected_scenario = st.session_state.get('selected_scenario')
493
+ if selected_scenario is None or selected_scenario < 0:
494
+ st.warning("Please select a scenario!")
495
+ else:
496
+ ## TODO: Clear state for time, LLM, Index, etc
497
+ states = ["store", "store2","retriever","retriever2","chain","chain2"]
498
+ for state_to_del in states:
499
+ if state_to_del in st.session_state:
500
+ del st.session_state[state_to_del]
501
+ init_patient_llm()
502
+ set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM)
503
+ if not st.session_state.get("scenario_tab_index"):
504
+ set_scenario_tab_index(ScenarioTabIndex.SELECT_SCENARIO)
505
+
506
+ with scenario_tab:
507
+ ## Check in select scenario
508
+ if st.session_state.scenario_tab_index == ScenarioTabIndex.SELECT_SCENARIO:
509
+ def change_scenario(scenario_index):
510
+ st.session_state.selected_scenario = scenario_index
511
+ if st.session_state.get("selected_scenario", None) is None:
512
+ st.session_state.selected_scenario = -1
513
+
514
+ total_cols = 3
515
+ rows = list()
516
+ # for _ in range(0, number_of_indexes, total_cols):
517
+ # rows.extend(st.columns(total_cols))
518
+
519
+ st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}")
520
+ for i, scenario in enumerate(st.session_state.scenario_list):
521
+ if i % total_cols == 0:
522
+ rows.extend(st.columns(total_cols))
523
+ curr_col = rows[(-total_cols + i % total_cols)]
524
+ tile = curr_col.container(height=120)
525
+ ## TODO: Implement highlight box if index is selected
526
+ # if st.session_state.selected_scenario == i:
527
+ # tile.markdown("<style>background: pink !important;</style>", unsafe_allow_html=True)
528
+ tile.write(":balloon:")
529
+ tile.button(label=scenario, on_click=change_scenario, args=[i])
530
+
531
+ select_scenario_btn = st.button("Select Scenario", on_click=go_to_patient_llm, args=[])
532
+
533
+ elif st.session_state.scenario_tab_index == ScenarioTabIndex.PATIENT_LLM:
534
+ st.header("Patient info")
535
+ st.write("Pull the info here!!!")
536
+ col1, col2, col3 = st.columns([1,3,1])
537
+ with col1:
538
+ back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
539
+ with col3:
540
+ start_timer_button = st.button("START")
541
+
542
+ with col2:
543
+ TIME_LIMIT = 60*10 ## to change to 10 minutes
544
+ time.sleep(1)
545
+ if start_timer_button:
546
+ st.session_state.start_time = datetime.datetime.now()
547
+ # st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time')
548
+ st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time
549
+
550
+ from streamlit.components.v1 import html
551
+
552
+
553
+ html(f"""
554
+ <style>
555
+ @import url('https://fonts.googleapis.com/css2?family=Pixelify+Sans&display=swap');
556
+ @import url('https://fonts.googleapis.com/css2?family=VT323&display=swap');
557
+ @import url('https://fonts.googleapis.com/css2?family=Monofett&display=swap');
558
+ </style>
559
+
560
+ <style>
561
+ html {{
562
+ font-family: 'Pixelify Sans', monospace, serif;
563
+ font-family: 'VT323', monospace, sans-serif;
564
+ font-family: 'Monofett', monospace, sans-serif;
565
+ font-family: 'Times New Roman', sans-serif;
566
+ background-color: #0E1117 !important;
567
+ color: RGB(250,250,250);
568
+ // border-radius: 25%;
569
+ // border: 1px solid #0E1117;
570
+ }}
571
+ html, body {{
572
+ // background-color: transparent !important;
573
+ // margin: 10px;
574
+ // border: 1px solid pink;
575
+ text-align: center;
576
+ }}
577
+ body {{
578
+ background-color: #0E1117;
579
+ // margin: 10px;
580
+ // border: 1px solid pink;
581
+ }}
582
+
583
+ body #ttime {{
584
+ font-weight: bold;
585
+ font-family: 'VT323', monospace, sans-serif;
586
+ // font-family: 'Pixelify Sans', monospace, serif;
587
+ }}
588
+ </style>
589
+
590
+ <div>
591
+ <h1>Time left</h1>
592
+ <h1 id="ttime"> </h1>
593
+ </div>
594
+
595
+
596
+ <script>
597
+
598
+ var x = setInterval(function() {{
599
+ var start_time_str = "{st.session_state.start_time}";
600
+ var start_date = new Date(start_time_str);
601
+ var curr_date = new Date();
602
+ var time_difference = curr_date - start_date;
603
+ var time_diff_secs = Math.floor(time_difference / 1000);
604
+ var time_left = {TIME_LIMIT} - time_diff_secs;
605
+ var mins = Math.floor(time_left / 60);
606
+ var secs = time_left % 60;
607
+ var fmins = mins.toString().padStart(2, '0');
608
+ var fsecs = secs.toString().padStart(2, '0');
609
+ console.log("run");
610
+
611
+ if (start_time_str == "False") {{
612
+ document.getElementById("ttime").innerHTML = 'Press "Start" to start!';
613
+ clearInterval(x);
614
+ }}
615
+ else if (time_left <= 0) {{
616
+ document.getElementById("ttime").innerHTML = "Time's Up!!!";
617
+ clearInterval(x);
618
+ }}
619
+ else {{
620
+ document.getElementById("ttime").innerHTML = `${{fmins}}:${{fsecs}}`;
621
+ }}
622
+ }}, 999)
623
+
624
+ </script>
625
+ """,
626
+ )
627
+
628
+ with open("./public/char.png", "rb") as f:
629
+ contents = f.read()
630
+ data_url = base64.b64encode(contents).decode("utf-8")
631
+
632
+ with open("./public/chars/Male_talk.gif", "rb") as f:
633
+ contents = f.read()
634
+ patient_url = base64.b64encode(contents).decode("utf-8")
635
+ interactive_container = st.container()
636
+ user_input_col ,r = st.columns([4,1])
637
+ def to_grader_llm():
638
+ init_grader_llm()
639
+ set_scenario_tab_index(ScenarioTabIndex.GRADER_LLM)
640
+
641
+ with r:
642
+ to_grader_btn = st.button("To Grader", on_click=to_grader_llm)
643
+ with user_input_col:
644
+ user_inputs = st.text_input("", placeholder="Chat with the patient here!", key="user_inputs")
645
+ if user_inputs:
646
+ response = st.session_state.chain.invoke(user_inputs).get("text")
647
+ st.session_state.patient_response = response
648
+ with interactive_container:
649
+ html(f"""
650
+
651
+ <style>
652
+ @import url('https://fonts.googleapis.com/css2?family=Pixelify+Sans&display=swap');
653
+ </style>
654
+
655
+ <style>
656
+ html {{
657
+ font-family: 'Pixelify Sans', monospace, serif;
658
+ }}
659
+ </style>
660
+ <div>
661
+ <img src="data:image/png;base64,{data_url}" />
662
+ <span id="user_input">You: {st.session_state.get('user_inputs') or ''}</span>
663
+ </div>
664
+
665
+ <div>
666
+ <img src="data:image/gif;base64,{patient_url}" /><br/>
667
+ <span id="bot_response">{'Patient: '+st.session_state.get('patient_response') if st.session_state.get('patient_response') else '...'}</span>
668
+ </div>
669
+ """, height=500)
670
+
671
+ elif st.session_state.scenario_tab_index == ScenarioTabIndex.GRADER_LLM:
672
+ st.session_state.grader_output = "" if not st.session_state.get("grader_output") else st.session_state.grader_output
673
+ def get_grades():
674
+ txt = f"""
675
+ <summary>
676
+ {st.session_state.diagnosis}
677
+ </summary>
678
+ <differential-1>
679
+ {st.session_state.differential_1}
680
+ </differential-1>
681
+ <differential-2>
682
+ {st.session_state.differential_2}
683
+ </differential-2>
684
+ <differential-3>
685
+ {st.session_state.differential_3}
686
+ </differential-3>
687
+ """
688
+ response = st.session_state.chain2.invoke(txt)
689
+ st.session_state.grader_output = response
690
+ st.session_state.has_llm_output = bool(st.session_state.get("grader_output"))
691
+ ## TODO: False for now, need check llm output!
692
+ with st.expander("Your Diagnosis and Differentials", expanded=not st.session_state.has_llm_output):
693
+ st.session_state.diagnosis = st.text_area("Input your case summary and **main** diagnosis:", placeholder="This is a young gentleman with significant family history of stroke, and medical history of poorly-controlled hypertension. He presents with acute onset of bitemporal headache associated with dysarthria and meningism symptoms. Important negatives include the absence of focal neurological deficits, ataxia, and recent trauma.")
694
+ st.divider()
695
+ st.session_state.differential_1 = st.text_input("Differential 1")
696
+ st.session_state.differential_2 = st.text_input("Differential 2")
697
+ st.session_state.differential_3 = st.text_input("Differential 3")
698
+ with st.columns(6)[5]:
699
+ send_for_grading = st.button("Get grades!", on_click=get_grades)
700
+ with st.expander("Your rubrics", expanded=st.session_state.has_llm_output):
701
+ if st.session_state.grader_output:
702
+ st.write(st.session_state.grader_output.get("text").get("text"))
703
+
704
+ # back_btn = st.button("back to LLM?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.PATIENT_LLM])
705
+ back_btn = st.button("New Scenario?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
706
+
707
+ with dashboard_tab:
708
+ import dotenv
709
+ import firebase_admin, json
710
+ from firebase_admin import credentials, storage, firestore
711
+ import plotly.express as px
712
+ import plotly.graph_objects as go
713
+ import pandas as pd
714
+
715
+ os.environ["FIREBASE_CREDENTIAL"] = dotenv.get_key(dotenv.find_dotenv(), "FIREBASE_CREDENTIAL")
716
+ cred = credentials.Certificate(json.loads(os.environ.get("FIREBASE_CREDENTIAL")))
717
+
718
+ # Initialize Firebase (if not already initialized)
719
+ if not firebase_admin._apps:
720
+ firebase_admin.initialize_app(cred, {'storageBucket': 'healthhack-store.appspot.com'})
721
+
722
+ #firebase_admin.initialize_app(cred,{'storageBucket': 'healthhack-store.appspot.com'}) # connecting to firebase
723
+ db_client = firestore.client()
724
+
725
+ docs = db_client.collection("clinical_scores").stream()
726
+
727
+ # Create a list of dictionaries from the documents
728
+ data = []
729
+ for doc in docs:
730
+ doc_dict = doc.to_dict()
731
+ doc_dict['document_id'] = doc.id # In case you need the document ID later
732
+ data.append(doc_dict)
733
+
734
+ # Create a DataFrame
735
+ df = pd.DataFrame(data)
736
+
737
+ username = st.session_state.get("username")
738
+ st.title("Dashboard")
739
+
740
+ # Convert date from string to datetime if it's not already in datetime format
741
+ df['date'] = pd.to_datetime(df['date'], errors='coerce')
742
+
743
+ # Streamlit page configuration
744
+ #st.set_page_config(page_title="Interactive Data Dashboard", layout="wide")
745
+
746
+ # Use df_selection for filtering data based on authenticated user
747
+ if username != 'admin':
748
+ df_selection = df[df['name'] == username]
749
+ else:
750
+ df_selection = df # Admin sees all data
751
+
752
+ # Chart Title: Student Performance Dashboard
753
+ st.title(":bar_chart: Student Performance Dashboard")
754
+ st.markdown("##")
755
+
756
+ # Chart 1: Total attempts
757
+ if df_selection.empty:
758
+ st.error("No data available to display.")
759
+ else:
760
+ # Total attempts by name (filtered)
761
+ total_attempts_by_name = df_selection.groupby("name")['date'].count().reset_index()
762
+ total_attempts_by_name.columns = ['name', 'total_attempts']
763
+
764
+ # For a single point or multiple points, use a scatter plot
765
+ fig_total_attempts = px.scatter(
766
+ total_attempts_by_name,
767
+ x="name",
768
+ y="total_attempts",
769
+ title="<b>Total Attempts</b>",
770
+ size='total_attempts', # Adjust the size of points
771
+ color_discrete_sequence=["#0083B8"] * len(total_attempts_by_name),
772
+ template="plotly_white",
773
+ text='total_attempts' # Display total_attempts as text labels
774
+ )
775
+
776
+ # Add text annotation for each point
777
+ for line in range(0, total_attempts_by_name.shape[0]):
778
+ fig_total_attempts.add_annotation(
779
+ text=str(total_attempts_by_name['total_attempts'].iloc[line]),
780
+ x=total_attempts_by_name['name'].iloc[line],
781
+ y=total_attempts_by_name['total_attempts'].iloc[line],
782
+ showarrow=True,
783
+ font=dict(family="Courier New, monospace", size=18, color="#ffffff"),
784
+ align="center",
785
+ arrowhead=2,
786
+ arrowsize=1,
787
+ arrowwidth=2,
788
+ arrowcolor="#636363",
789
+ ax=20,
790
+ ay=-30,
791
+ bordercolor="#c7c7c7",
792
+ borderwidth=2,
793
+ borderpad=4,
794
+ bgcolor="#ff7f0e",
795
+ opacity=0.8
796
+ )
797
+
798
+ # Update traces for styling
799
+ fig_total_attempts.update_traces(marker=dict(size=12), selector=dict(mode='markers+text'))
800
+
801
+ # Display the scatter plot in Streamlit
802
+ st.plotly_chart(fig_total_attempts, use_container_width=True)
803
+
804
+ # Chart 2 (students only): Personal scores over time
805
+ if username != 'admin':
806
+ # Sort the DataFrame by 'date' in chronological order
807
+ df_selection = df_selection.sort_values(by='date')
808
+ #fig = px.bar(df_selection, x='date', y='global_score', title='Your scores!')
809
+
810
+ if len(df_selection) > 1:
811
+ # # If more than one point, use a bar chart
812
+ # fig = px.bar(df_selection, x='date', y='global_score', title='Global Score Over Time')
813
+ # # fig.update_yaxes(
814
+ # # tickmode='array',
815
+ # # tickvals=[1, 2, 3, 4, 5], # Reverse the order of tickvals
816
+ # # ticktext=['A', 'B','C','D','E'] # Reverse the order of ticktext
817
+ # # )
818
+ # Mapping dictionary
819
+ grade_to_score = {'A': 100, 'B': 80, 'C': 60, 'D': 40, 'E': 20}
820
+
821
+ # Apply mapping to convert letter grades to numerical scores
822
+ df_selection['numeric_score'] = df_selection['global_score'].map(grade_to_score)
823
+
824
+ # Sort the DataFrame by 'date' in chronological order
825
+ df_selection = df_selection.sort_values(by='date')
826
+
827
+ # Check if there's more than one point in the DataFrame
828
+ if len(df_selection) > 1:
829
+ # Create a bar chart using Plotly Express
830
+ fig = px.bar(df_selection, x='date', y='numeric_score', title='Your scores over time')
831
+ else:
832
+ # Create a bar chart with just one point
833
+ fig = px.bar(df_selection, x='date', y='numeric_score', title='Global Score')
834
+
835
+ # Manually set the y-axis ticks and labels
836
+ fig.update_yaxes(
837
+ tickmode='array',
838
+ tickvals=list(grade_to_score.values()), # Positions for the ticks
839
+ ticktext=list(grade_to_score.keys()), # Text labels for the ticks
840
+ range=[0, 120] # Extend the range a bit beyond 100 to accommodate 'A'
841
+ )
842
+
843
+ # # Use st.plotly_chart to display the chart in Streamlit
844
+ # st.plotly_chart(fig, use_container_width=True)
845
+
846
+ else:
847
+ # For a single point, use a scatter plot
848
+ fig = px.scatter(df_selection, x='date', y='global_score', title='Global Score',
849
+ text='global_score', size_max=60)
850
+ # Add text annotation
851
+ for line in range(0,df_selection.shape[0]):
852
+ fig.add_annotation(text=df_selection['global_score'].iloc[line],
853
+ x=df_selection['date'].iloc[line], y=df_selection['global_score'].iloc[line],
854
+ showarrow=True, font=dict(family="Courier New, monospace", size=18, color="#ffffff"),
855
+ align="center", arrowhead=2, arrowsize=1, arrowwidth=2, arrowcolor="#636363",
856
+ ax=20, ay=-30, bordercolor="#c7c7c7", borderwidth=2, borderpad=4, bgcolor="#ff7f0e",
857
+ opacity=0.8)
858
+ fig.update_traces(marker=dict(size=12), selector=dict(mode='markers+text'))
859
+
860
+ # Display the chart in Streamlit
861
+ st.plotly_chart(fig, use_container_width=True)
862
+
863
+ # Show students their scores over time
864
+ st.dataframe(df_selection[['date', 'global_score', 'name']])
865
+
866
+
867
+ # Chart 3 (admin only): Global score chart
868
+ # Define the order of categories explicitly
869
+ order_of_categories = ['A', 'B', 'C', 'D', 'E']
870
+
871
+ # Convert global_score to a categorical type with the specified order
872
+ df_selection['global_score'] = pd.Categorical(df_selection['global_score'], categories=order_of_categories, ordered=True)
873
+
874
+ # Plot the histogram
875
+ fig_score_distribution = px.histogram(
876
+ df_selection,
877
+ x="global_score",
878
+ title="<b>Global Score Distribution</b>",
879
+ color_discrete_sequence=["#33CFA5"],
880
+ category_orders={"global_score": ["A", "B", "C", "D", "E"]}
881
+ )
882
+ if username == 'admin':
883
+ st.plotly_chart(fig_score_distribution, use_container_width=True)
884
+
885
+
886
+ # Chart 4 (admin only): Students with <5 attempts (filtered)
887
+ if username == 'admin':
888
+ students_with_less_than_5_attempts = total_attempts_by_name[total_attempts_by_name['total_attempts'] < 5]
889
+ fig_less_than_5_attempts = px.bar(
890
+ students_with_less_than_5_attempts,
891
+ x="name",
892
+ y="total_attempts",
893
+ title="<b>Students with <5 Attempts</b>",
894
+ color_discrete_sequence=["#D62728"] * len(students_with_less_than_5_attempts),
895
+ template="plotly_white",
896
+ )
897
+
898
+ if username == 'admin':
899
+ st.plotly_chart(fig_less_than_5_attempts, use_container_width=True)
900
+
901
+
902
+ # Selection of a student for detailed view (<5 attempts) - based on filtered data
903
+ if username == 'admin':
904
+ selected_student_less_than_5 = st.selectbox("Select a student with less than 5 attempts to view details:", students_with_less_than_5_attempts['name'])
905
+ if selected_student_less_than_5:
906
+ st.write(df_selection[df_selection['name'] == selected_student_less_than_5])
907
+
908
+ # Chart 5 (admin only): Students with at least one global score of 'C', 'D', 'E' (filtered)
909
+ if username == 'admin':
910
+ students_with_cde = df_selection[df_selection['global_score'].isin(['C', 'D', 'E'])].groupby("name")['date'].count().reset_index()
911
+ students_with_cde.columns = ['name', 'total_attempts']
912
+ fig_students_with_cde = px.bar(
913
+ students_with_cde,
914
+ x="name",
915
+ y="total_attempts",
916
+ title="<b>Students with at least one global score of 'C', 'D', 'E'</b>",
917
+ color_discrete_sequence=["#FF7F0E"] * len(students_with_cde),
918
+ template="plotly_white",
919
+ )
920
+ st.plotly_chart(fig_students_with_cde, use_container_width=True)
921
+
922
+ # Selection of a student for detailed view (score of 'C', 'D', 'E') - based on filtered data
923
+ if username == 'admin':
924
+ selected_student_cde = st.selectbox("Select a student with at least one score of 'C', 'D', 'E' to view details:", students_with_cde['name'])
925
+ if selected_student_cde:
926
+ st.write(df_selection[df_selection['name'] == selected_student_cde])
927
+
928
+ # Chart 7 (all): Radar Chart
929
+
930
+ # Mapping grades to numeric values
931
+ grade_to_numeric = {'A': 90, 'B': 70, 'C': 50, 'D': 30, 'E': 10}
932
+ df.replace(grade_to_numeric, inplace=True)
933
+
934
+ # Calculate average numeric scores for each category
935
+ average_scores = df.groupby('name')[['hx_PC_score', 'hx_AS_score', 'hx_others_score', 'differentials_score']].mean().reset_index()
936
+
937
+ if username == 'admin':
938
+ st.title('Average Scores Radar Chart')
939
+ else:
940
+ st.title('Performance in each segment as compared to your friends!')
941
+
942
+ # Categories for the radar chart
943
+ categories = ['Presenting complaint', 'Associated symptoms', '(Others)', 'Differentials']
944
+
945
+ st.markdown("""
946
+ ###
947
+ Double click on the names in the legend to include/exclude them from the plot.
948
+ """)
949
+
950
+
951
+ # Custom colors for better contrast
952
+ colors = ['gold', 'cyan', 'magenta', 'green']
953
+
954
+ # Plotly Radar Chart
955
+ fig = go.Figure()
956
+
957
+ for index, row in average_scores.iterrows():
958
+ fig.add_trace(go.Scatterpolar(
959
+ r=[row['hx_PC_score'], row['hx_AS_score'], row['hx_others_score'], row['differentials_score']],
960
+ theta=categories,
961
+ fill='toself',
962
+ name=row['name'],
963
+ line=dict(color=colors[index % len(colors)])
964
+ ))
965
+
966
+ fig.update_layout(
967
+ polar=dict(
968
+ radialaxis=dict(
969
+ visible=True,
970
+ range=[0, 100], # Numeric range
971
+ tickvals=[10, 30, 50, 70, 90], # Positions for the grade labels
972
+ ticktext=['E', 'D', 'C', 'B', 'A'] # Grade labels
973
+ )),
974
+ showlegend=True,
975
+ height=600, # Set the height of the figure
976
+ width=600 # Set the width of the figure
977
+ )
978
+
979
+ # Display the figure in Streamlit
980
+ st.plotly_chart(fig, use_container_width=True)
981
+
public/char.png ADDED
public/chars/Female_talk.gif ADDED
public/chars/Female_walk .gif ADDED
public/chars/Male_talk.gif ADDED
public/chars/Male_wait.gif ADDED
requirements.txt CHANGED
@@ -13,4 +13,5 @@ faiss-cpu
13
  streamlit
14
  firebase-admin
15
  plotly
16
- torch==2.1.2
 
 
13
  streamlit
14
  firebase-admin
15
  plotly
16
+ torch==2.1.2
17
+ streamlit_authenticator
templates/grader.txt CHANGED
@@ -30,25 +30,25 @@ Example output JSON:
30
  {{{{
31
  "history_presenting_complain": {{{{
32
  "grade": "A",
33
- "remarks": "Your remarks here""
34
  }}}}
35
  }}}},
36
  {{{{
37
  "history_associated_symptoms": {{{{
38
  "grade": "B",
39
- "remarks": "Your remarks here""
40
  }}}}
41
  }}}},
42
  {{{{
43
  "history_others": {{{{
44
  "grade": "C",
45
- "remarks": "Your remarks here""
46
  }}}}
47
  }}}},
48
  {{{{
49
  "diagnosis_and_differentials": {{{{
50
  "grade": "D",
51
- "remarks": "Your remarks here""
52
  }}}}
53
  }}}},
54
  {{{{
 
30
  {{{{
31
  "history_presenting_complain": {{{{
32
  "grade": "A",
33
+ "remarks": "Your remarks here"
34
  }}}}
35
  }}}},
36
  {{{{
37
  "history_associated_symptoms": {{{{
38
  "grade": "B",
39
+ "remarks": "Your remarks here"
40
  }}}}
41
  }}}},
42
  {{{{
43
  "history_others": {{{{
44
  "grade": "C",
45
+ "remarks": "Your remarks here"
46
  }}}}
47
  }}}},
48
  {{{{
49
  "diagnosis_and_differentials": {{{{
50
  "grade": "D",
51
+ "remarks": "Your remarks here"
52
  }}}}
53
  }}}},
54
  {{{{