Spaces:
Runtime error
Runtime error
Yew Chong
commited on
Commit
•
bc4dcba
1
Parent(s):
ac9e946
final combined app
Browse files- .gitignore +0 -3
- .streamlit/config.toml +2 -1
- README.md +7 -14
- app_final.py +981 -0
- public/char.png +0 -0
- public/chars/Female_talk.gif +0 -0
- public/chars/Female_walk .gif +0 -0
- public/chars/Male_talk.gif +0 -0
- public/chars/Male_wait.gif +0 -0
- requirements.txt +2 -1
- templates/grader.txt +4 -4
.gitignore
CHANGED
@@ -24,8 +24,5 @@ test*.py
|
|
24 |
test*.html
|
25 |
test*.ipynb
|
26 |
|
27 |
-
## Images
|
28 |
-
*.png
|
29 |
-
|
30 |
# streamlit
|
31 |
.streamlit/secrets.toml
|
|
|
24 |
test*.html
|
25 |
test*.ipynb
|
26 |
|
|
|
|
|
|
|
27 |
# streamlit
|
28 |
.streamlit/secrets.toml
|
.streamlit/config.toml
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
[theme]
|
2 |
base = "dark"
|
3 |
-
primaryColor="#6633F6"
|
|
|
|
1 |
[theme]
|
2 |
base = "dark"
|
3 |
+
primaryColor="#6633F6"
|
4 |
+
backgroundColor="#0E1117"
|
README.md
CHANGED
@@ -5,28 +5,21 @@ colorFrom: red
|
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.30.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
##
|
13 |
-
Download the relevant LLM python notebooks (e.g. `LLM for Patient.ipynb`)
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
# How to run locally
|
21 |
|
22 |
-
|
23 |
|
24 |
-
|
25 |
|
26 |
-
Add your own .env file based on the env.example (huggingface, openai, firebase tokens required)
|
27 |
-
|
28 |
-
???
|
29 |
-
|
30 |
-
profit
|
31 |
|
32 |
---------------------------------
|
|
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.30.0
|
8 |
+
app_file: app_final.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
## How to run locally
|
|
|
13 |
|
14 |
+
1. git clone
|
15 |
|
16 |
+
2. `python -m pip install -r requirements.txt`
|
17 |
|
18 |
+
3. Add your own .env file based on the env.example (huggingface, openai, firebase tokens required)
|
|
|
19 |
|
20 |
+
4. `streamlit run app.py`
|
21 |
|
22 |
+
5. Open `localhost:8501`
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
---------------------------------
|
app_final.py
ADDED
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
import streamlit as st
|
3 |
+
import streamlit.components.v1 as components
|
4 |
+
import datetime, time
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import math
|
7 |
+
import base64
|
8 |
+
|
9 |
+
## Firestore ??
|
10 |
+
import os
|
11 |
+
# import sys
|
12 |
+
# import inspect
|
13 |
+
# currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
14 |
+
# parentdir = os.path.dirname(currentdir)
|
15 |
+
# sys.path.append(parentdir)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# ## ----------------------------------------------------------------
|
21 |
+
# ## LLM Part
|
22 |
+
import openai
|
23 |
+
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
|
24 |
+
import tiktoken
|
25 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
26 |
+
from langchain.prompts.prompt import PromptTemplate
|
27 |
+
from operator import itemgetter
|
28 |
+
from langchain.schema import StrOutputParser
|
29 |
+
from langchain_core.output_parsers import StrOutputParser
|
30 |
+
from langchain_core.runnables import RunnablePassthrough
|
31 |
+
|
32 |
+
import langchain_community.embeddings.huggingface
|
33 |
+
from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings
|
34 |
+
from langchain_community.vectorstores import FAISS
|
35 |
+
|
36 |
+
from langchain.chains import LLMChain
|
37 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory
|
38 |
+
|
39 |
+
import os, dotenv
|
40 |
+
from dotenv import load_dotenv
|
41 |
+
load_dotenv()
|
42 |
+
|
43 |
+
if not os.path.isdir("./.streamlit"):
|
44 |
+
os.mkdir("./.streamlit")
|
45 |
+
print('made streamlit folder')
|
46 |
+
if not os.path.isfile("./.streamlit/secrets.toml"):
|
47 |
+
with open("./.streamlit/secrets.toml", "w") as f:
|
48 |
+
f.write(os.environ.get("STREAMLIT_SECRETS"))
|
49 |
+
print('made new file')
|
50 |
+
|
51 |
+
|
52 |
+
import db_firestore as db
|
53 |
+
|
54 |
+
## Load from streamlit!!
|
55 |
+
os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN") or st.secrets["HF_TOKEN"]
|
56 |
+
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or st.secrets["OPENAI_API_KEY"]
|
57 |
+
os.environ["FIREBASE_CREDENTIAL"] = os.environ.get("FIREBASE_CREDENTIAL") or st.secrets["FIREBASE_CREDENTIAL"]
|
58 |
+
|
59 |
+
|
60 |
+
if "openai_model" not in st.session_state:
|
61 |
+
st.session_state["openai_model"] = "gpt-3.5-turbo-1106"
|
62 |
+
|
63 |
+
## Hardcode indexes for now
|
64 |
+
## TODO: Move indexes to firebase
|
65 |
+
indexes = """Bleeding
|
66 |
+
ChestPain
|
67 |
+
Dysphagia
|
68 |
+
Headache
|
69 |
+
ShortnessOfBreath
|
70 |
+
Vomiting
|
71 |
+
Weakness
|
72 |
+
Weakness2""".split("\n")
|
73 |
+
|
74 |
+
# if "selected_index" not in st.session_state:
|
75 |
+
# st.session_state.selected_index = 3
|
76 |
+
|
77 |
+
# if "index_selectbox" not in st.session_state:
|
78 |
+
# st.session_state.index_selectbox = "Headache"
|
79 |
+
|
80 |
+
# index_selectbox = st.selectbox("Select index",indexes, index=int(st.session_state.selected_index))
|
81 |
+
|
82 |
+
# if index_selectbox != indexes[st.session_state.selected_index]:
|
83 |
+
# st.session_state.selected_index = indexes.index(index_selectbox)
|
84 |
+
# st.session_state.index_selectbox = index_selectbox
|
85 |
+
# del st.session_state["store"]
|
86 |
+
# del st.session_state["store2"]
|
87 |
+
# del st.session_state["retriever"]
|
88 |
+
# del st.session_state["retriever2"]
|
89 |
+
# del st.session_state["chain"]
|
90 |
+
# del st.session_state["chain2"]
|
91 |
+
|
92 |
+
|
93 |
+
model_name = "bge-large-en-v1.5"
|
94 |
+
model_kwargs = {"device": "cpu"}
|
95 |
+
encode_kwargs = {"normalize_embeddings": True}
|
96 |
+
if "embeddings" not in st.session_state:
|
97 |
+
st.session_state.embeddings = HuggingFaceBgeEmbeddings(
|
98 |
+
# model_name=model_name,
|
99 |
+
model_kwargs = model_kwargs,
|
100 |
+
encode_kwargs = encode_kwargs)
|
101 |
+
embeddings = st.session_state.embeddings
|
102 |
+
|
103 |
+
if "llm" not in st.session_state:
|
104 |
+
st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
|
105 |
+
llm = st.session_state.llm
|
106 |
+
if "llm_i" not in st.session_state:
|
107 |
+
st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
|
108 |
+
llm_i = st.session_state.llm_i
|
109 |
+
if "llm_gpt4" not in st.session_state:
|
110 |
+
st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
111 |
+
llm_gpt4 = st.session_state.llm_gpt4
|
112 |
+
|
113 |
+
# ## ------------------------------------------------------------------------------------------------
|
114 |
+
# ## Patient part
|
115 |
+
|
116 |
+
# index_name = f"indexes/{st.session_state.index_selectbox}/QA"
|
117 |
+
|
118 |
+
# if "store" not in st.session_state:
|
119 |
+
# st.session_state.store = db.get_store(index_name, embeddings=embeddings)
|
120 |
+
# store = st.session_state.store
|
121 |
+
|
122 |
+
if "TEMPLATE" not in st.session_state:
|
123 |
+
with open('templates/patient.txt', 'r') as file:
|
124 |
+
TEMPLATE = file.read()
|
125 |
+
st.session_state.TEMPLATE = TEMPLATE
|
126 |
+
TEMPLATE = st.session_state.TEMPLATE
|
127 |
+
# with st.expander("Patient Prompt"):
|
128 |
+
# TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
|
129 |
+
|
130 |
+
prompt = PromptTemplate(
|
131 |
+
input_variables = ["question", "context"],
|
132 |
+
template = st.session_state.TEMPLATE
|
133 |
+
)
|
134 |
+
|
135 |
+
# if "retriever" not in st.session_state:
|
136 |
+
# st.session_state.retriever = store.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
137 |
+
# retriever = st.session_state.retriever
|
138 |
+
|
139 |
+
def format_docs(docs):
|
140 |
+
return "\n--------------------\n".join(doc.page_content for doc in docs)
|
141 |
+
|
142 |
+
|
143 |
+
# if "memory" not in st.session_state:
|
144 |
+
# st.session_state.memory = ConversationBufferWindowMemory(
|
145 |
+
# llm=llm, memory_key="chat_history", input_key="question",
|
146 |
+
# k=5, human_prefix="student", ai_prefix="patient",)
|
147 |
+
# memory = st.session_state.memory
|
148 |
+
|
149 |
+
|
150 |
+
# if ("chain" not in st.session_state
|
151 |
+
# or
|
152 |
+
# st.session_state.TEMPLATE != TEMPLATE):
|
153 |
+
# st.session_state.chain = (
|
154 |
+
# {
|
155 |
+
# "context": retriever | format_docs,
|
156 |
+
# "question": RunnablePassthrough()
|
157 |
+
# } |
|
158 |
+
# LLMChain(llm=llm, prompt=prompt, memory=memory, verbose=False)
|
159 |
+
# )
|
160 |
+
# chain = st.session_state.chain
|
161 |
+
|
162 |
+
sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"}
|
163 |
+
|
164 |
+
# ## ------------------------------------------------------------------------------------------------
|
165 |
+
# ## ------------------------------------------------------------------------------------------------
|
166 |
+
# ## Grader part
|
167 |
+
# index_name = f"indexes/{st.session_state.index_selectbox}/Rubric"
|
168 |
+
|
169 |
+
# if "store2" not in st.session_state:
|
170 |
+
# st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
|
171 |
+
# store2 = st.session_state.store2
|
172 |
+
|
173 |
+
if "TEMPLATE2" not in st.session_state:
|
174 |
+
with open('templates/grader.txt', 'r') as file:
|
175 |
+
TEMPLATE2 = file.read()
|
176 |
+
st.session_state.TEMPLATE2 = TEMPLATE2
|
177 |
+
TEMPLATE2 = st.session_state.TEMPLATE2
|
178 |
+
# with st.expander("Grader Prompt"):
|
179 |
+
# TEMPLATE2 = st.text_area("Grader Prompt", value=st.session_state.TEMPLATE2)
|
180 |
+
|
181 |
+
prompt2 = PromptTemplate(
|
182 |
+
input_variables = ["question", "context", "history"],
|
183 |
+
template = st.session_state.TEMPLATE2
|
184 |
+
)
|
185 |
+
|
186 |
+
def get_patient_chat_history(_):
|
187 |
+
return st.session_state.get("patient_chat_history")
|
188 |
+
|
189 |
+
# if "retriever2" not in st.session_state:
|
190 |
+
# st.session_state.retriever2 = store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
191 |
+
# retriever2 = st.session_state.retriever2
|
192 |
+
|
193 |
+
# def format_docs(docs):
|
194 |
+
# return "\n--------------------\n".join(doc.page_content for doc in docs)
|
195 |
+
|
196 |
+
|
197 |
+
# fake_history = '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in memory.chat_memory.messages])
|
198 |
+
# fake_history = '\n'.join([(sp_mapper.get(i['role'], i['role']) + ": "+ i['content']) for i in st.session_state.messages_1])
|
199 |
+
# st.write(fake_history)
|
200 |
+
|
201 |
+
# def y(_):
|
202 |
+
# return fake_history
|
203 |
+
|
204 |
+
# if ("chain2" not in st.session_state
|
205 |
+
# or
|
206 |
+
# st.session_state.TEMPLATE2 != TEMPLATE2):
|
207 |
+
# st.session_state.chain2 = (
|
208 |
+
# {
|
209 |
+
# "context": retriever2 | format_docs,
|
210 |
+
# "history": y,
|
211 |
+
# "question": RunnablePassthrough(),
|
212 |
+
# } |
|
213 |
+
|
214 |
+
# # LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
|
215 |
+
# LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #|
|
216 |
+
# | {
|
217 |
+
# "json": itemgetter("text"),
|
218 |
+
# "text": (
|
219 |
+
# LLMChain(
|
220 |
+
# llm=llm,
|
221 |
+
# prompt=PromptTemplate(
|
222 |
+
# input_variables=["text"],
|
223 |
+
# template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
|
224 |
+
# verbose=False)
|
225 |
+
# )
|
226 |
+
# }
|
227 |
+
# )
|
228 |
+
# chain2 = st.session_state.chain2
|
229 |
+
|
230 |
+
# ## ------------------------------------------------------------------------------------------------
|
231 |
+
# ## ------------------------------------------------------------------------------------------------
|
232 |
+
# ## Streamlit now
|
233 |
+
|
234 |
+
# # from dotenv import load_dotenv
|
235 |
+
# # import os
|
236 |
+
# # load_dotenv()
|
237 |
+
# # key = os.environ.get("OPENAI_API_KEY")
|
238 |
+
# # client = OpenAI(api_key=key)
|
239 |
+
|
240 |
+
|
241 |
+
# if st.button("Clear History and Memory", type="primary"):
|
242 |
+
# st.session_state.messages_1 = []
|
243 |
+
# st.session_state.messages_2 = []
|
244 |
+
# st.session_state.memory = ConversationBufferWindowMemory(llm=llm, memory_key="chat_history", input_key="question" )
|
245 |
+
# memory = st.session_state.memory
|
246 |
+
|
247 |
+
# ## Testing HTML
|
248 |
+
# # html_string = """
|
249 |
+
# # <canvas></canvas>
|
250 |
+
|
251 |
+
|
252 |
+
# # <script>
|
253 |
+
# # canvas = document.querySelector('canvas');
|
254 |
+
# # canvas.width = 1024;
|
255 |
+
# # canvas.height = 576;
|
256 |
+
# # console.log(canvas);
|
257 |
+
|
258 |
+
# # const c = canvas.getContext('2d');
|
259 |
+
# # c.fillStyle = "green";
|
260 |
+
# # c.fillRect(0,0,canvas.width,canvas.height);
|
261 |
+
|
262 |
+
# # const img = new Image();
|
263 |
+
# # img.src = "./tksfordumtrive.png";
|
264 |
+
# # c.drawImage(img, 10, 10);
|
265 |
+
# # </script>
|
266 |
+
|
267 |
+
# # <style>
|
268 |
+
# # body {
|
269 |
+
# # margin: 0;
|
270 |
+
# # }
|
271 |
+
# # </style>
|
272 |
+
# # """
|
273 |
+
# # components.html(html_string,
|
274 |
+
# # width=1280,
|
275 |
+
# # height=640)
|
276 |
+
|
277 |
+
|
278 |
+
# st.write("Timer has been removed, switch with this button")
|
279 |
+
|
280 |
+
# if st.button(f"Switch to {'PATIENT' if st.session_state.active_chat==2 else 'GRADER'}"+".... Buggy button, please double click"):
|
281 |
+
# st.session_state.active_chat = 3 - st.session_state.active_chat
|
282 |
+
|
283 |
+
# # st.write("Currently in " + ('PATIENT' if st.session_state.active_chat==2 else 'GRADER'))
|
284 |
+
|
285 |
+
# # Create two columns for the two chat interfaces
|
286 |
+
# col1, col2 = st.columns(2)
|
287 |
+
|
288 |
+
# # First chat interface
|
289 |
+
# with col1:
|
290 |
+
# st.subheader("Student LLM")
|
291 |
+
# for message in st.session_state.messages_1:
|
292 |
+
# with st.chat_message(message["role"]):
|
293 |
+
# st.markdown(message["content"])
|
294 |
+
|
295 |
+
# # Second chat interface
|
296 |
+
# with col2:
|
297 |
+
# # st.write("pls dun spam this, its tons of tokens cos chat history")
|
298 |
+
# st.subheader("Grader LLM")
|
299 |
+
# st.write("grader takes a while to load... please be patient")
|
300 |
+
# for message in st.session_state.messages_2:
|
301 |
+
# with st.chat_message(message["role"]):
|
302 |
+
# st.markdown(message["content"])
|
303 |
+
|
304 |
+
# # Timer and Input
|
305 |
+
# # time_left = None
|
306 |
+
# # if st.session_state.start_time:
|
307 |
+
# # time_elapsed = datetime.datetime.now() - st.session_state.start_time
|
308 |
+
# # time_left = datetime.timedelta(minutes=10) - time_elapsed
|
309 |
+
# # st.write(f"Time left: {time_left}")
|
310 |
+
|
311 |
+
# # if time_left is None or time_left > datetime.timedelta(0):
|
312 |
+
# # # Chat 1 is active
|
313 |
+
# # prompt = st.text_input("Enter your message for Chat 1:")
|
314 |
+
# # active_chat = 1
|
315 |
+
# # messages = st.session_state.messages_1
|
316 |
+
# # elif time_left and time_left <= datetime.timedelta(0):
|
317 |
+
# # # Chat 2 is active
|
318 |
+
# # prompt = st.text_input("Enter your message for Chat 2:")
|
319 |
+
# # active_chat = 2
|
320 |
+
# # messages = st.session_state.messages_2
|
321 |
+
|
322 |
+
# if st.session_state.active_chat==1:
|
323 |
+
# text_prompt = st.text_input("Enter your message for PATIENT")
|
324 |
+
# messages = st.session_state.messages_1
|
325 |
+
# else:
|
326 |
+
# text_prompt = st.text_input("Enter your message for GRADER")
|
327 |
+
# messages = st.session_state.messages_2
|
328 |
+
|
329 |
+
|
330 |
+
# from langchain.callbacks.manager import tracing_v2_enabled
|
331 |
+
# from uuid import uuid4
|
332 |
+
# import os
|
333 |
+
|
334 |
+
# if text_prompt:
|
335 |
+
# messages.append({"role": "user", "content": text_prompt})
|
336 |
+
|
337 |
+
# with (col1 if st.session_state.active_chat == 1 else col2):
|
338 |
+
# with st.chat_message("user"):
|
339 |
+
# st.markdown(text_prompt)
|
340 |
+
|
341 |
+
# with (col1 if st.session_state.active_chat == 1 else col2):
|
342 |
+
# with st.chat_message("assistant"):
|
343 |
+
# message_placeholder = st.empty()
|
344 |
+
# if True: ## with tracing_v2_enabled(project_name = "streamlit"):
|
345 |
+
# if st.session_state.active_chat==1:
|
346 |
+
# full_response = chain.invoke(text_prompt).get("text")
|
347 |
+
# else:
|
348 |
+
# full_response = chain2.invoke(text_prompt).get("text").get("text")
|
349 |
+
# message_placeholder.markdown(full_response)
|
350 |
+
# messages.append({"role": "assistant", "content": full_response})
|
351 |
+
|
352 |
+
|
353 |
+
# st.write('fake history is:')
|
354 |
+
# st.write(y(""))
|
355 |
+
# st.write('done')
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
## ====================
|
361 |
+
|
362 |
+
if not st.session_state.get("scenario_list", None):
|
363 |
+
st.session_state.scenario_list = indexes
|
364 |
+
|
365 |
+
def init_patient_llm():
|
366 |
+
if "messages_1" not in st.session_state:
|
367 |
+
st.session_state.messages_1 = []
|
368 |
+
## messages 2?
|
369 |
+
|
370 |
+
index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/QA"
|
371 |
+
if "store" not in st.session_state:
|
372 |
+
st.session_state.store = db.get_store(index_name, embeddings=embeddings)
|
373 |
+
if "retriever" not in st.session_state:
|
374 |
+
st.session_state.retriever = st.session_state.store.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
375 |
+
if "memory" not in st.session_state:
|
376 |
+
st.session_state.memory = ConversationBufferWindowMemory(
|
377 |
+
llm=llm, memory_key="chat_history", input_key="question",
|
378 |
+
k=5, human_prefix="student", ai_prefix="patient",)
|
379 |
+
|
380 |
+
if ("chain" not in st.session_state
|
381 |
+
or
|
382 |
+
st.session_state.TEMPLATE != TEMPLATE):
|
383 |
+
st.session_state.chain = (
|
384 |
+
{
|
385 |
+
"context": st.session_state.retriever | format_docs,
|
386 |
+
"question": RunnablePassthrough()
|
387 |
+
} |
|
388 |
+
LLMChain(llm=llm, prompt=prompt, memory=st.session_state.memory, verbose=False)
|
389 |
+
)
|
390 |
+
|
391 |
+
def init_grader_llm():
|
392 |
+
## Grader
|
393 |
+
index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/Rubric"
|
394 |
+
|
395 |
+
## Reset time
|
396 |
+
st.session_state.start_time = False
|
397 |
+
|
398 |
+
if "store2" not in st.session_state:
|
399 |
+
st.session_state.store2 = db.get_store(index_name, embeddings=embeddings)
|
400 |
+
if "retriever2" not in st.session_state:
|
401 |
+
st.session_state.retriever2 = st.session_state.store2.as_retriever(search_type="similarity", search_kwargs={"k":2})
|
402 |
+
|
403 |
+
## Re-init history
|
404 |
+
st.session_state["patient_chat_history"] = "History\n" + '\n'.join([(sp_mapper.get(i.type, i.type) + ": "+ i.content) for i in st.session_state.memory.chat_memory.messages])
|
405 |
+
|
406 |
+
if ("chain2" not in st.session_state
|
407 |
+
or
|
408 |
+
st.session_state.TEMPLATE2 != TEMPLATE2):
|
409 |
+
st.session_state.chain2 = (
|
410 |
+
{
|
411 |
+
"context": st.session_state.retriever2 | format_docs,
|
412 |
+
"history": (get_patient_chat_history),
|
413 |
+
"question": RunnablePassthrough(),
|
414 |
+
} |
|
415 |
+
|
416 |
+
# LLMChain(llm=llm_i, prompt=prompt2, verbose=False ) #|
|
417 |
+
LLMChain(llm=llm_gpt4, prompt=prompt2, verbose=False ) #|
|
418 |
+
| {
|
419 |
+
"json": itemgetter("text"),
|
420 |
+
"text": (
|
421 |
+
LLMChain(
|
422 |
+
llm=llm,
|
423 |
+
prompt=PromptTemplate(
|
424 |
+
input_variables=["text"],
|
425 |
+
template="Interpret the following JSON of the student's grades, and do a write-up for each section.\n\n```json\n{text}\n```"),
|
426 |
+
verbose=False)
|
427 |
+
)
|
428 |
+
}
|
429 |
+
)
|
430 |
+
|
431 |
+
|
432 |
+
login_info = {
|
433 |
+
"bob":"builder",
|
434 |
+
"student1": "password",
|
435 |
+
"admin":"admin"
|
436 |
+
}
|
437 |
+
|
438 |
+
def set_username(x):
|
439 |
+
st.session_state.username = x
|
440 |
+
|
441 |
+
def validate_username(username, password):
|
442 |
+
if login_info.get(username) == password:
|
443 |
+
set_username(username)
|
444 |
+
else:
|
445 |
+
st.warning("Wrong username or password")
|
446 |
+
return None
|
447 |
+
|
448 |
+
if not st.session_state.get("username"):
|
449 |
+
## ask to login
|
450 |
+
st.title("Login")
|
451 |
+
username = st.text_input("Username:")
|
452 |
+
password = st.text_input("Password:", type="password")
|
453 |
+
login_button = st.button("Login", on_click=validate_username, args=[username, password])
|
454 |
+
|
455 |
+
else:
|
456 |
+
if True: ## Says hello and logout
|
457 |
+
col_1, col_2 = st.columns([1,3])
|
458 |
+
col_2.title(f"Hello there, {st.session_state.username}")
|
459 |
+
# Display logout button
|
460 |
+
if col_1.button('Logout'):
|
461 |
+
# Remove username from session state
|
462 |
+
del st.session_state.username
|
463 |
+
# Rerun the app to go back to the login view
|
464 |
+
st.rerun()
|
465 |
+
|
466 |
+
scenario_tab, dashboard_tab = st.tabs(["Training", "Dashboard"])
|
467 |
+
# st.header("head")
|
468 |
+
# st.markdown("## markdown")
|
469 |
+
# st.caption("caption")
|
470 |
+
# st.divider()
|
471 |
+
# import pandas as pd
|
472 |
+
# import numpy as np
|
473 |
+
# map_data = pd.DataFrame(
|
474 |
+
# np.random.randn(1000, 2) / [50, 50] + [37.76, -122.4],
|
475 |
+
# columns=['lat', 'lon'])
|
476 |
+
|
477 |
+
# st.map(map_data)
|
478 |
+
|
479 |
+
class ScenarioTabIndex:
|
480 |
+
SELECT_SCENARIO = 0
|
481 |
+
PATIENT_LLM = 1
|
482 |
+
GRADER_LLM = 2
|
483 |
+
|
484 |
+
def set_scenario_tab_index(x):
|
485 |
+
st.session_state.scenario_tab_index=x
|
486 |
+
return None
|
487 |
+
|
488 |
+
def select_scenario_and_change_tab(_):
|
489 |
+
set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM)
|
490 |
+
|
491 |
+
def go_to_patient_llm():
|
492 |
+
selected_scenario = st.session_state.get('selected_scenario')
|
493 |
+
if selected_scenario is None or selected_scenario < 0:
|
494 |
+
st.warning("Please select a scenario!")
|
495 |
+
else:
|
496 |
+
## TODO: Clear state for time, LLM, Index, etc
|
497 |
+
states = ["store", "store2","retriever","retriever2","chain","chain2"]
|
498 |
+
for state_to_del in states:
|
499 |
+
if state_to_del in st.session_state:
|
500 |
+
del st.session_state[state_to_del]
|
501 |
+
init_patient_llm()
|
502 |
+
set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM)
|
503 |
+
if not st.session_state.get("scenario_tab_index"):
|
504 |
+
set_scenario_tab_index(ScenarioTabIndex.SELECT_SCENARIO)
|
505 |
+
|
506 |
+
with scenario_tab:
|
507 |
+
## Check in select scenario
|
508 |
+
if st.session_state.scenario_tab_index == ScenarioTabIndex.SELECT_SCENARIO:
|
509 |
+
def change_scenario(scenario_index):
|
510 |
+
st.session_state.selected_scenario = scenario_index
|
511 |
+
if st.session_state.get("selected_scenario", None) is None:
|
512 |
+
st.session_state.selected_scenario = -1
|
513 |
+
|
514 |
+
total_cols = 3
|
515 |
+
rows = list()
|
516 |
+
# for _ in range(0, number_of_indexes, total_cols):
|
517 |
+
# rows.extend(st.columns(total_cols))
|
518 |
+
|
519 |
+
st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}")
|
520 |
+
for i, scenario in enumerate(st.session_state.scenario_list):
|
521 |
+
if i % total_cols == 0:
|
522 |
+
rows.extend(st.columns(total_cols))
|
523 |
+
curr_col = rows[(-total_cols + i % total_cols)]
|
524 |
+
tile = curr_col.container(height=120)
|
525 |
+
## TODO: Implement highlight box if index is selected
|
526 |
+
# if st.session_state.selected_scenario == i:
|
527 |
+
# tile.markdown("<style>background: pink !important;</style>", unsafe_allow_html=True)
|
528 |
+
tile.write(":balloon:")
|
529 |
+
tile.button(label=scenario, on_click=change_scenario, args=[i])
|
530 |
+
|
531 |
+
select_scenario_btn = st.button("Select Scenario", on_click=go_to_patient_llm, args=[])
|
532 |
+
|
533 |
+
elif st.session_state.scenario_tab_index == ScenarioTabIndex.PATIENT_LLM:
|
534 |
+
st.header("Patient info")
|
535 |
+
st.write("Pull the info here!!!")
|
536 |
+
col1, col2, col3 = st.columns([1,3,1])
|
537 |
+
with col1:
|
538 |
+
back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
539 |
+
with col3:
|
540 |
+
start_timer_button = st.button("START")
|
541 |
+
|
542 |
+
with col2:
|
543 |
+
TIME_LIMIT = 60*10 ## to change to 10 minutes
|
544 |
+
time.sleep(1)
|
545 |
+
if start_timer_button:
|
546 |
+
st.session_state.start_time = datetime.datetime.now()
|
547 |
+
# st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time')
|
548 |
+
st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time
|
549 |
+
|
550 |
+
from streamlit.components.v1 import html
|
551 |
+
|
552 |
+
|
553 |
+
html(f"""
|
554 |
+
<style>
|
555 |
+
@import url('https://fonts.googleapis.com/css2?family=Pixelify+Sans&display=swap');
|
556 |
+
@import url('https://fonts.googleapis.com/css2?family=VT323&display=swap');
|
557 |
+
@import url('https://fonts.googleapis.com/css2?family=Monofett&display=swap');
|
558 |
+
</style>
|
559 |
+
|
560 |
+
<style>
|
561 |
+
html {{
|
562 |
+
font-family: 'Pixelify Sans', monospace, serif;
|
563 |
+
font-family: 'VT323', monospace, sans-serif;
|
564 |
+
font-family: 'Monofett', monospace, sans-serif;
|
565 |
+
font-family: 'Times New Roman', sans-serif;
|
566 |
+
background-color: #0E1117 !important;
|
567 |
+
color: RGB(250,250,250);
|
568 |
+
// border-radius: 25%;
|
569 |
+
// border: 1px solid #0E1117;
|
570 |
+
}}
|
571 |
+
html, body {{
|
572 |
+
// background-color: transparent !important;
|
573 |
+
// margin: 10px;
|
574 |
+
// border: 1px solid pink;
|
575 |
+
text-align: center;
|
576 |
+
}}
|
577 |
+
body {{
|
578 |
+
background-color: #0E1117;
|
579 |
+
// margin: 10px;
|
580 |
+
// border: 1px solid pink;
|
581 |
+
}}
|
582 |
+
|
583 |
+
body #ttime {{
|
584 |
+
font-weight: bold;
|
585 |
+
font-family: 'VT323', monospace, sans-serif;
|
586 |
+
// font-family: 'Pixelify Sans', monospace, serif;
|
587 |
+
}}
|
588 |
+
</style>
|
589 |
+
|
590 |
+
<div>
|
591 |
+
<h1>Time left</h1>
|
592 |
+
<h1 id="ttime"> </h1>
|
593 |
+
</div>
|
594 |
+
|
595 |
+
|
596 |
+
<script>
|
597 |
+
|
598 |
+
var x = setInterval(function() {{
|
599 |
+
var start_time_str = "{st.session_state.start_time}";
|
600 |
+
var start_date = new Date(start_time_str);
|
601 |
+
var curr_date = new Date();
|
602 |
+
var time_difference = curr_date - start_date;
|
603 |
+
var time_diff_secs = Math.floor(time_difference / 1000);
|
604 |
+
var time_left = {TIME_LIMIT} - time_diff_secs;
|
605 |
+
var mins = Math.floor(time_left / 60);
|
606 |
+
var secs = time_left % 60;
|
607 |
+
var fmins = mins.toString().padStart(2, '0');
|
608 |
+
var fsecs = secs.toString().padStart(2, '0');
|
609 |
+
console.log("run");
|
610 |
+
|
611 |
+
if (start_time_str == "False") {{
|
612 |
+
document.getElementById("ttime").innerHTML = 'Press "Start" to start!';
|
613 |
+
clearInterval(x);
|
614 |
+
}}
|
615 |
+
else if (time_left <= 0) {{
|
616 |
+
document.getElementById("ttime").innerHTML = "Time's Up!!!";
|
617 |
+
clearInterval(x);
|
618 |
+
}}
|
619 |
+
else {{
|
620 |
+
document.getElementById("ttime").innerHTML = `${{fmins}}:${{fsecs}}`;
|
621 |
+
}}
|
622 |
+
}}, 999)
|
623 |
+
|
624 |
+
</script>
|
625 |
+
""",
|
626 |
+
)
|
627 |
+
|
628 |
+
with open("./public/char.png", "rb") as f:
|
629 |
+
contents = f.read()
|
630 |
+
data_url = base64.b64encode(contents).decode("utf-8")
|
631 |
+
|
632 |
+
with open("./public/chars/Male_talk.gif", "rb") as f:
|
633 |
+
contents = f.read()
|
634 |
+
patient_url = base64.b64encode(contents).decode("utf-8")
|
635 |
+
interactive_container = st.container()
|
636 |
+
user_input_col ,r = st.columns([4,1])
|
637 |
+
def to_grader_llm():
|
638 |
+
init_grader_llm()
|
639 |
+
set_scenario_tab_index(ScenarioTabIndex.GRADER_LLM)
|
640 |
+
|
641 |
+
with r:
|
642 |
+
to_grader_btn = st.button("To Grader", on_click=to_grader_llm)
|
643 |
+
with user_input_col:
|
644 |
+
user_inputs = st.text_input("", placeholder="Chat with the patient here!", key="user_inputs")
|
645 |
+
if user_inputs:
|
646 |
+
response = st.session_state.chain.invoke(user_inputs).get("text")
|
647 |
+
st.session_state.patient_response = response
|
648 |
+
with interactive_container:
|
649 |
+
html(f"""
|
650 |
+
|
651 |
+
<style>
|
652 |
+
@import url('https://fonts.googleapis.com/css2?family=Pixelify+Sans&display=swap');
|
653 |
+
</style>
|
654 |
+
|
655 |
+
<style>
|
656 |
+
html {{
|
657 |
+
font-family: 'Pixelify Sans', monospace, serif;
|
658 |
+
}}
|
659 |
+
</style>
|
660 |
+
<div>
|
661 |
+
<img src="data:image/png;base64,{data_url}" />
|
662 |
+
<span id="user_input">You: {st.session_state.get('user_inputs') or ''}</span>
|
663 |
+
</div>
|
664 |
+
|
665 |
+
<div>
|
666 |
+
<img src="data:image/gif;base64,{patient_url}" /><br/>
|
667 |
+
<span id="bot_response">{'Patient: '+st.session_state.get('patient_response') if st.session_state.get('patient_response') else '...'}</span>
|
668 |
+
</div>
|
669 |
+
""", height=500)
|
670 |
+
|
671 |
+
elif st.session_state.scenario_tab_index == ScenarioTabIndex.GRADER_LLM:
|
672 |
+
st.session_state.grader_output = "" if not st.session_state.get("grader_output") else st.session_state.grader_output
|
673 |
+
def get_grades():
|
674 |
+
txt = f"""
|
675 |
+
<summary>
|
676 |
+
{st.session_state.diagnosis}
|
677 |
+
</summary>
|
678 |
+
<differential-1>
|
679 |
+
{st.session_state.differential_1}
|
680 |
+
</differential-1>
|
681 |
+
<differential-2>
|
682 |
+
{st.session_state.differential_2}
|
683 |
+
</differential-2>
|
684 |
+
<differential-3>
|
685 |
+
{st.session_state.differential_3}
|
686 |
+
</differential-3>
|
687 |
+
"""
|
688 |
+
response = st.session_state.chain2.invoke(txt)
|
689 |
+
st.session_state.grader_output = response
|
690 |
+
st.session_state.has_llm_output = bool(st.session_state.get("grader_output"))
|
691 |
+
## TODO: False for now, need check llm output!
|
692 |
+
with st.expander("Your Diagnosis and Differentials", expanded=not st.session_state.has_llm_output):
|
693 |
+
st.session_state.diagnosis = st.text_area("Input your case summary and **main** diagnosis:", placeholder="This is a young gentleman with significant family history of stroke, and medical history of poorly-controlled hypertension. He presents with acute onset of bitemporal headache associated with dysarthria and meningism symptoms. Important negatives include the absence of focal neurological deficits, ataxia, and recent trauma.")
|
694 |
+
st.divider()
|
695 |
+
st.session_state.differential_1 = st.text_input("Differential 1")
|
696 |
+
st.session_state.differential_2 = st.text_input("Differential 2")
|
697 |
+
st.session_state.differential_3 = st.text_input("Differential 3")
|
698 |
+
with st.columns(6)[5]:
|
699 |
+
send_for_grading = st.button("Get grades!", on_click=get_grades)
|
700 |
+
with st.expander("Your rubrics", expanded=st.session_state.has_llm_output):
|
701 |
+
if st.session_state.grader_output:
|
702 |
+
st.write(st.session_state.grader_output.get("text").get("text"))
|
703 |
+
|
704 |
+
# back_btn = st.button("back to LLM?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.PATIENT_LLM])
|
705 |
+
back_btn = st.button("New Scenario?", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO])
|
706 |
+
|
707 |
+
with dashboard_tab:
|
708 |
+
import dotenv
|
709 |
+
import firebase_admin, json
|
710 |
+
from firebase_admin import credentials, storage, firestore
|
711 |
+
import plotly.express as px
|
712 |
+
import plotly.graph_objects as go
|
713 |
+
import pandas as pd
|
714 |
+
|
715 |
+
os.environ["FIREBASE_CREDENTIAL"] = dotenv.get_key(dotenv.find_dotenv(), "FIREBASE_CREDENTIAL")
|
716 |
+
cred = credentials.Certificate(json.loads(os.environ.get("FIREBASE_CREDENTIAL")))
|
717 |
+
|
718 |
+
# Initialize Firebase (if not already initialized)
|
719 |
+
if not firebase_admin._apps:
|
720 |
+
firebase_admin.initialize_app(cred, {'storageBucket': 'healthhack-store.appspot.com'})
|
721 |
+
|
722 |
+
#firebase_admin.initialize_app(cred,{'storageBucket': 'healthhack-store.appspot.com'}) # connecting to firebase
|
723 |
+
db_client = firestore.client()
|
724 |
+
|
725 |
+
docs = db_client.collection("clinical_scores").stream()
|
726 |
+
|
727 |
+
# Create a list of dictionaries from the documents
|
728 |
+
data = []
|
729 |
+
for doc in docs:
|
730 |
+
doc_dict = doc.to_dict()
|
731 |
+
doc_dict['document_id'] = doc.id # In case you need the document ID later
|
732 |
+
data.append(doc_dict)
|
733 |
+
|
734 |
+
# Create a DataFrame
|
735 |
+
df = pd.DataFrame(data)
|
736 |
+
|
737 |
+
username = st.session_state.get("username")
|
738 |
+
st.title("Dashboard")
|
739 |
+
|
740 |
+
# Convert date from string to datetime if it's not already in datetime format
|
741 |
+
df['date'] = pd.to_datetime(df['date'], errors='coerce')
|
742 |
+
|
743 |
+
# Streamlit page configuration
|
744 |
+
#st.set_page_config(page_title="Interactive Data Dashboard", layout="wide")
|
745 |
+
|
746 |
+
# Use df_selection for filtering data based on authenticated user
|
747 |
+
if username != 'admin':
|
748 |
+
df_selection = df[df['name'] == username]
|
749 |
+
else:
|
750 |
+
df_selection = df # Admin sees all data
|
751 |
+
|
752 |
+
# Chart Title: Student Performance Dashboard
|
753 |
+
st.title(":bar_chart: Student Performance Dashboard")
|
754 |
+
st.markdown("##")
|
755 |
+
|
756 |
+
# Chart 1: Total attempts
|
757 |
+
if df_selection.empty:
|
758 |
+
st.error("No data available to display.")
|
759 |
+
else:
|
760 |
+
# Total attempts by name (filtered)
|
761 |
+
total_attempts_by_name = df_selection.groupby("name")['date'].count().reset_index()
|
762 |
+
total_attempts_by_name.columns = ['name', 'total_attempts']
|
763 |
+
|
764 |
+
# For a single point or multiple points, use a scatter plot
|
765 |
+
fig_total_attempts = px.scatter(
|
766 |
+
total_attempts_by_name,
|
767 |
+
x="name",
|
768 |
+
y="total_attempts",
|
769 |
+
title="<b>Total Attempts</b>",
|
770 |
+
size='total_attempts', # Adjust the size of points
|
771 |
+
color_discrete_sequence=["#0083B8"] * len(total_attempts_by_name),
|
772 |
+
template="plotly_white",
|
773 |
+
text='total_attempts' # Display total_attempts as text labels
|
774 |
+
)
|
775 |
+
|
776 |
+
# Add text annotation for each point
|
777 |
+
for line in range(0, total_attempts_by_name.shape[0]):
|
778 |
+
fig_total_attempts.add_annotation(
|
779 |
+
text=str(total_attempts_by_name['total_attempts'].iloc[line]),
|
780 |
+
x=total_attempts_by_name['name'].iloc[line],
|
781 |
+
y=total_attempts_by_name['total_attempts'].iloc[line],
|
782 |
+
showarrow=True,
|
783 |
+
font=dict(family="Courier New, monospace", size=18, color="#ffffff"),
|
784 |
+
align="center",
|
785 |
+
arrowhead=2,
|
786 |
+
arrowsize=1,
|
787 |
+
arrowwidth=2,
|
788 |
+
arrowcolor="#636363",
|
789 |
+
ax=20,
|
790 |
+
ay=-30,
|
791 |
+
bordercolor="#c7c7c7",
|
792 |
+
borderwidth=2,
|
793 |
+
borderpad=4,
|
794 |
+
bgcolor="#ff7f0e",
|
795 |
+
opacity=0.8
|
796 |
+
)
|
797 |
+
|
798 |
+
# Update traces for styling
|
799 |
+
fig_total_attempts.update_traces(marker=dict(size=12), selector=dict(mode='markers+text'))
|
800 |
+
|
801 |
+
# Display the scatter plot in Streamlit
|
802 |
+
st.plotly_chart(fig_total_attempts, use_container_width=True)
|
803 |
+
|
804 |
+
# Chart 2 (students only): Personal scores over time
|
805 |
+
if username != 'admin':
|
806 |
+
# Sort the DataFrame by 'date' in chronological order
|
807 |
+
df_selection = df_selection.sort_values(by='date')
|
808 |
+
#fig = px.bar(df_selection, x='date', y='global_score', title='Your scores!')
|
809 |
+
|
810 |
+
if len(df_selection) > 1:
|
811 |
+
# # If more than one point, use a bar chart
|
812 |
+
# fig = px.bar(df_selection, x='date', y='global_score', title='Global Score Over Time')
|
813 |
+
# # fig.update_yaxes(
|
814 |
+
# # tickmode='array',
|
815 |
+
# # tickvals=[1, 2, 3, 4, 5], # Reverse the order of tickvals
|
816 |
+
# # ticktext=['A', 'B','C','D','E'] # Reverse the order of ticktext
|
817 |
+
# # )
|
818 |
+
# Mapping dictionary
|
819 |
+
grade_to_score = {'A': 100, 'B': 80, 'C': 60, 'D': 40, 'E': 20}
|
820 |
+
|
821 |
+
# Apply mapping to convert letter grades to numerical scores
|
822 |
+
df_selection['numeric_score'] = df_selection['global_score'].map(grade_to_score)
|
823 |
+
|
824 |
+
# Sort the DataFrame by 'date' in chronological order
|
825 |
+
df_selection = df_selection.sort_values(by='date')
|
826 |
+
|
827 |
+
# Check if there's more than one point in the DataFrame
|
828 |
+
if len(df_selection) > 1:
|
829 |
+
# Create a bar chart using Plotly Express
|
830 |
+
fig = px.bar(df_selection, x='date', y='numeric_score', title='Your scores over time')
|
831 |
+
else:
|
832 |
+
# Create a bar chart with just one point
|
833 |
+
fig = px.bar(df_selection, x='date', y='numeric_score', title='Global Score')
|
834 |
+
|
835 |
+
# Manually set the y-axis ticks and labels
|
836 |
+
fig.update_yaxes(
|
837 |
+
tickmode='array',
|
838 |
+
tickvals=list(grade_to_score.values()), # Positions for the ticks
|
839 |
+
ticktext=list(grade_to_score.keys()), # Text labels for the ticks
|
840 |
+
range=[0, 120] # Extend the range a bit beyond 100 to accommodate 'A'
|
841 |
+
)
|
842 |
+
|
843 |
+
# # Use st.plotly_chart to display the chart in Streamlit
|
844 |
+
# st.plotly_chart(fig, use_container_width=True)
|
845 |
+
|
846 |
+
else:
|
847 |
+
# For a single point, use a scatter plot
|
848 |
+
fig = px.scatter(df_selection, x='date', y='global_score', title='Global Score',
|
849 |
+
text='global_score', size_max=60)
|
850 |
+
# Add text annotation
|
851 |
+
for line in range(0,df_selection.shape[0]):
|
852 |
+
fig.add_annotation(text=df_selection['global_score'].iloc[line],
|
853 |
+
x=df_selection['date'].iloc[line], y=df_selection['global_score'].iloc[line],
|
854 |
+
showarrow=True, font=dict(family="Courier New, monospace", size=18, color="#ffffff"),
|
855 |
+
align="center", arrowhead=2, arrowsize=1, arrowwidth=2, arrowcolor="#636363",
|
856 |
+
ax=20, ay=-30, bordercolor="#c7c7c7", borderwidth=2, borderpad=4, bgcolor="#ff7f0e",
|
857 |
+
opacity=0.8)
|
858 |
+
fig.update_traces(marker=dict(size=12), selector=dict(mode='markers+text'))
|
859 |
+
|
860 |
+
# Display the chart in Streamlit
|
861 |
+
st.plotly_chart(fig, use_container_width=True)
|
862 |
+
|
863 |
+
# Show students their scores over time
|
864 |
+
st.dataframe(df_selection[['date', 'global_score', 'name']])
|
865 |
+
|
866 |
+
|
867 |
+
# Chart 3 (admin only): Global score chart
|
868 |
+
# Define the order of categories explicitly
|
869 |
+
order_of_categories = ['A', 'B', 'C', 'D', 'E']
|
870 |
+
|
871 |
+
# Convert global_score to a categorical type with the specified order
|
872 |
+
df_selection['global_score'] = pd.Categorical(df_selection['global_score'], categories=order_of_categories, ordered=True)
|
873 |
+
|
874 |
+
# Plot the histogram
|
875 |
+
fig_score_distribution = px.histogram(
|
876 |
+
df_selection,
|
877 |
+
x="global_score",
|
878 |
+
title="<b>Global Score Distribution</b>",
|
879 |
+
color_discrete_sequence=["#33CFA5"],
|
880 |
+
category_orders={"global_score": ["A", "B", "C", "D", "E"]}
|
881 |
+
)
|
882 |
+
if username == 'admin':
|
883 |
+
st.plotly_chart(fig_score_distribution, use_container_width=True)
|
884 |
+
|
885 |
+
|
886 |
+
# Chart 4 (admin only): Students with <5 attempts (filtered)
|
887 |
+
if username == 'admin':
|
888 |
+
students_with_less_than_5_attempts = total_attempts_by_name[total_attempts_by_name['total_attempts'] < 5]
|
889 |
+
fig_less_than_5_attempts = px.bar(
|
890 |
+
students_with_less_than_5_attempts,
|
891 |
+
x="name",
|
892 |
+
y="total_attempts",
|
893 |
+
title="<b>Students with <5 Attempts</b>",
|
894 |
+
color_discrete_sequence=["#D62728"] * len(students_with_less_than_5_attempts),
|
895 |
+
template="plotly_white",
|
896 |
+
)
|
897 |
+
|
898 |
+
if username == 'admin':
|
899 |
+
st.plotly_chart(fig_less_than_5_attempts, use_container_width=True)
|
900 |
+
|
901 |
+
|
902 |
+
# Selection of a student for detailed view (<5 attempts) - based on filtered data
|
903 |
+
if username == 'admin':
|
904 |
+
selected_student_less_than_5 = st.selectbox("Select a student with less than 5 attempts to view details:", students_with_less_than_5_attempts['name'])
|
905 |
+
if selected_student_less_than_5:
|
906 |
+
st.write(df_selection[df_selection['name'] == selected_student_less_than_5])
|
907 |
+
|
908 |
+
# Chart 5 (admin only): Students with at least one global score of 'C', 'D', 'E' (filtered)
|
909 |
+
if username == 'admin':
|
910 |
+
students_with_cde = df_selection[df_selection['global_score'].isin(['C', 'D', 'E'])].groupby("name")['date'].count().reset_index()
|
911 |
+
students_with_cde.columns = ['name', 'total_attempts']
|
912 |
+
fig_students_with_cde = px.bar(
|
913 |
+
students_with_cde,
|
914 |
+
x="name",
|
915 |
+
y="total_attempts",
|
916 |
+
title="<b>Students with at least one global score of 'C', 'D', 'E'</b>",
|
917 |
+
color_discrete_sequence=["#FF7F0E"] * len(students_with_cde),
|
918 |
+
template="plotly_white",
|
919 |
+
)
|
920 |
+
st.plotly_chart(fig_students_with_cde, use_container_width=True)
|
921 |
+
|
922 |
+
# Selection of a student for detailed view (score of 'C', 'D', 'E') - based on filtered data
|
923 |
+
if username == 'admin':
|
924 |
+
selected_student_cde = st.selectbox("Select a student with at least one score of 'C', 'D', 'E' to view details:", students_with_cde['name'])
|
925 |
+
if selected_student_cde:
|
926 |
+
st.write(df_selection[df_selection['name'] == selected_student_cde])
|
927 |
+
|
928 |
+
# Chart 7 (all): Radar Chart
|
929 |
+
|
930 |
+
# Mapping grades to numeric values
|
931 |
+
grade_to_numeric = {'A': 90, 'B': 70, 'C': 50, 'D': 30, 'E': 10}
|
932 |
+
df.replace(grade_to_numeric, inplace=True)
|
933 |
+
|
934 |
+
# Calculate average numeric scores for each category
|
935 |
+
average_scores = df.groupby('name')[['hx_PC_score', 'hx_AS_score', 'hx_others_score', 'differentials_score']].mean().reset_index()
|
936 |
+
|
937 |
+
if username == 'admin':
|
938 |
+
st.title('Average Scores Radar Chart')
|
939 |
+
else:
|
940 |
+
st.title('Performance in each segment as compared to your friends!')
|
941 |
+
|
942 |
+
# Categories for the radar chart
|
943 |
+
categories = ['Presenting complaint', 'Associated symptoms', '(Others)', 'Differentials']
|
944 |
+
|
945 |
+
st.markdown("""
|
946 |
+
###
|
947 |
+
Double click on the names in the legend to include/exclude them from the plot.
|
948 |
+
""")
|
949 |
+
|
950 |
+
|
951 |
+
# Custom colors for better contrast
|
952 |
+
colors = ['gold', 'cyan', 'magenta', 'green']
|
953 |
+
|
954 |
+
# Plotly Radar Chart
|
955 |
+
fig = go.Figure()
|
956 |
+
|
957 |
+
for index, row in average_scores.iterrows():
|
958 |
+
fig.add_trace(go.Scatterpolar(
|
959 |
+
r=[row['hx_PC_score'], row['hx_AS_score'], row['hx_others_score'], row['differentials_score']],
|
960 |
+
theta=categories,
|
961 |
+
fill='toself',
|
962 |
+
name=row['name'],
|
963 |
+
line=dict(color=colors[index % len(colors)])
|
964 |
+
))
|
965 |
+
|
966 |
+
fig.update_layout(
|
967 |
+
polar=dict(
|
968 |
+
radialaxis=dict(
|
969 |
+
visible=True,
|
970 |
+
range=[0, 100], # Numeric range
|
971 |
+
tickvals=[10, 30, 50, 70, 90], # Positions for the grade labels
|
972 |
+
ticktext=['E', 'D', 'C', 'B', 'A'] # Grade labels
|
973 |
+
)),
|
974 |
+
showlegend=True,
|
975 |
+
height=600, # Set the height of the figure
|
976 |
+
width=600 # Set the width of the figure
|
977 |
+
)
|
978 |
+
|
979 |
+
# Display the figure in Streamlit
|
980 |
+
st.plotly_chart(fig, use_container_width=True)
|
981 |
+
|
public/char.png
ADDED
public/chars/Female_talk.gif
ADDED
public/chars/Female_walk .gif
ADDED
public/chars/Male_talk.gif
ADDED
public/chars/Male_wait.gif
ADDED
requirements.txt
CHANGED
@@ -13,4 +13,5 @@ faiss-cpu
|
|
13 |
streamlit
|
14 |
firebase-admin
|
15 |
plotly
|
16 |
-
torch==2.1.2
|
|
|
|
13 |
streamlit
|
14 |
firebase-admin
|
15 |
plotly
|
16 |
+
torch==2.1.2
|
17 |
+
streamlit_authenticator
|
templates/grader.txt
CHANGED
@@ -30,25 +30,25 @@ Example output JSON:
|
|
30 |
{{{{
|
31 |
"history_presenting_complain": {{{{
|
32 |
"grade": "A",
|
33 |
-
"remarks": "Your remarks here"
|
34 |
}}}}
|
35 |
}}}},
|
36 |
{{{{
|
37 |
"history_associated_symptoms": {{{{
|
38 |
"grade": "B",
|
39 |
-
"remarks": "Your remarks here"
|
40 |
}}}}
|
41 |
}}}},
|
42 |
{{{{
|
43 |
"history_others": {{{{
|
44 |
"grade": "C",
|
45 |
-
"remarks": "Your remarks here"
|
46 |
}}}}
|
47 |
}}}},
|
48 |
{{{{
|
49 |
"diagnosis_and_differentials": {{{{
|
50 |
"grade": "D",
|
51 |
-
"remarks": "Your remarks here"
|
52 |
}}}}
|
53 |
}}}},
|
54 |
{{{{
|
|
|
30 |
{{{{
|
31 |
"history_presenting_complain": {{{{
|
32 |
"grade": "A",
|
33 |
+
"remarks": "Your remarks here"
|
34 |
}}}}
|
35 |
}}}},
|
36 |
{{{{
|
37 |
"history_associated_symptoms": {{{{
|
38 |
"grade": "B",
|
39 |
+
"remarks": "Your remarks here"
|
40 |
}}}}
|
41 |
}}}},
|
42 |
{{{{
|
43 |
"history_others": {{{{
|
44 |
"grade": "C",
|
45 |
+
"remarks": "Your remarks here"
|
46 |
}}}}
|
47 |
}}}},
|
48 |
{{{{
|
49 |
"diagnosis_and_differentials": {{{{
|
50 |
"grade": "D",
|
51 |
+
"remarks": "Your remarks here"
|
52 |
}}}}
|
53 |
}}}},
|
54 |
{{{{
|