plebias commited on
Commit
9ffeffb
1 Parent(s): 81302bc

Pushed Generate tab

Browse files
Files changed (1) hide show
  1. app_final.py +363 -2
app_final.py CHANGED
@@ -209,7 +209,7 @@ else:
209
  # Rerun the app to go back to the login view
210
  st.rerun()
211
 
212
- scenario_tab, dashboard_tab = st.tabs(["Training", "Dashboard"])
213
 
214
  class ScenarioTabIndex:
215
  SELECT_SCENARIO = 0
@@ -251,7 +251,7 @@ else:
251
  # rows.extend(st.columns(total_cols))
252
 
253
  st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}")
254
- st.button("Generate a new scenario")
255
  for i, scenario in enumerate(st.session_state.scenario_list):
256
  if i % total_cols == 0:
257
  rows.extend(st.columns(total_cols))
@@ -822,3 +822,364 @@ else:
822
  # Display the figure in Streamlit
823
  st.plotly_chart(fig, use_container_width=True)
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Rerun the app to go back to the login view
210
  st.rerun()
211
 
212
+ scenario_tab, dashboard_tab, generate_tab = st.tabs(["Training", "Dashboard", "Generate Scenario"])
213
 
214
  class ScenarioTabIndex:
215
  SELECT_SCENARIO = 0
 
251
  # rows.extend(st.columns(total_cols))
252
 
253
  st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}")
254
+ #st.button("Generate a new scenario")
255
  for i, scenario in enumerate(st.session_state.scenario_list):
256
  if i % total_cols == 0:
257
  rows.extend(st.columns(total_cols))
 
822
  # Display the figure in Streamlit
823
  st.plotly_chart(fig, use_container_width=True)
824
 
825
+ with generate_tab:
826
+ st.title("Medical Scenario Generator (for Admins)")
827
+
828
+ ## Hardcode scenarios for now,
829
+ indexes = """
830
+ aortic dissection
831
+ anemia
832
+ cystitis
833
+ pneumonia
834
+ """.split("\n")
835
+
836
+ if "selected_index" not in st.session_state:
837
+ st.session_state.selected_index = 0
838
+
839
+ if "search_selectbox" not in st.session_state:
840
+ st.session_state.search_selectbox = " "
841
+ # st.session_state.index_selectbox = "Headache"
842
+
843
+ if "search_freetext" not in st.session_state:
844
+ st.session_state.search_freetext = " "
845
+ # st.session_state.index_selectbox = "Headache"
846
+
847
+ #index_selectbox = st_tags(
848
+ # label='What medical condition would you like to generate a scenario for?',
849
+ # text='Input here ...',
850
+ # suggestions=indexes,
851
+ # value = ' ',
852
+ # maxtags = 1,
853
+ # key='0')
854
+
855
+ st.write('What medical condition would you like to generate a scenario for?')
856
+ search_freetext = st.text_input("Type your own", value = " ")
857
+ if search_freetext != st.session_state.search_freetext:
858
+ st.session_state.search_freetext = search_freetext
859
+
860
+ #hard0, free0 = st.columns(2)
861
+ #search_selectbox = hard0.selectbox(
862
+ # 'Choose one OR Type on the right',
863
+ # indexes, index=0)
864
+ #search_freetext = free0.text_input("Type your own")
865
+ #
866
+ #if search_selectbox != indexes[st.session_state.selected_index]:
867
+ # st.session_state.selected_index = indexes.index(search_selectbox)
868
+ # st.session_state.search_selectbox = search_selectbox
869
+
870
+ if "openai_model" not in st.session_state:
871
+ st.session_state["openai_model"] = "gpt-3.5-turbo"
872
+
873
+ if "active_chat" not in st.session_state:
874
+ st.session_state.active_chat = 1
875
+
876
+ model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
877
+ model_kwargs = {"device": "cpu"}
878
+ # model_kwargs = {"device": "cuda"}
879
+ encode_kwargs = {"normalize_embeddings": True}
880
+
881
+ if "embeddings" not in st.session_state:
882
+ st.session_state.embeddings = HuggingFaceEmbeddings(
883
+ model_name=model_name,
884
+ model_kwargs = model_kwargs,
885
+ encode_kwargs = encode_kwargs)
886
+ embeddings = st.session_state.embeddings
887
+ if "llm" not in st.session_state:
888
+ st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
889
+ llm = st.session_state.llm
890
+ #if "llm" not in st.session_state:
891
+ # st.session_state.llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
892
+ #llm = st.session_state.llm
893
+ #if "llm" not in st.session_state:
894
+ # st.session_state.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
895
+ llm = st.session_state.llm
896
+
897
+ ## ------------------------------------------------------------------------------------------------
898
+ ## Generator part
899
+ index_name = f"indexes/faiss_index_large_v2"
900
+
901
+ if "store" not in st.session_state:
902
+ #st.session_state.store = FAISS.load_local(index_name, embeddings)
903
+ st.session_state.store = db.get_store(index_name, embeddings=embeddings)
904
+ #st.session_state.store.similarity_search('hello')
905
+ store = st.session_state.store
906
+
907
+ def topk(searchKW):
908
+ search_r = st.session_state.store.similarity_search(searchKW, k=5)
909
+ return [x.page_content for x in search_r]
910
+
911
+ #def onSelectButton(s_index):
912
+ # topkindexes = topk(s_index) #return top 5 list of similiar diseases
913
+ # return selected_options
914
+
915
+ #selectButton = st.button(on_click = onSelectButton(st.session_state.index_selectbox))
916
+ #selectButton = st.button("Search")
917
+
918
+ if 'searchbtn_clicked' not in st.session_state:
919
+ st.session_state['searchbtn_clicked'] = False
920
+
921
+ if 'selected_option' not in st.session_state:
922
+ st.session_state['selected_option'] = ""
923
+
924
+ def search_callback():
925
+ st.session_state['searchbtn_clicked'] = True
926
+
927
+
928
+ if st.button('search', on_click=search_callback) or st.session_state['searchbtn_clicked'] or st.session_state.search_freetext != ' ':
929
+ def searchInner(searchOptions):
930
+ if len(searchOptions)>0:
931
+ st.markdown('---')
932
+ col1, col2 = st.columns(2)
933
+ selected_options = col1.multiselect(
934
+ 'Choose the most relevant condition:',
935
+ searchOptions, max_selections = 1)
936
+ if len(selected_options)>0:
937
+ col2.write(selected_options[0])
938
+ st.session_state['selected_option'] = selected_options[0]
939
+ else:
940
+ col2.write('')
941
+ else:
942
+ st.markdown('---')
943
+ st.write("No results found. Perhaps try another condition? Some examples that work: "+', '.join(indexes))
944
+
945
+ if search_freetext != " ":
946
+ options = topk(search_freetext)
947
+ searchInner(options)
948
+ else:
949
+ options = topk(indexes[st.session_state.selected_index])
950
+ searchInner(options)
951
+
952
+ st.write(st.session_state['selected_option'])
953
+
954
+ ## ------------------------------------------------------------------------------------------------
955
+ ## LLM part
956
+
957
+ kg_name = f"kgstore"
958
+
959
+ if 'infostorekg' not in st.session_state:
960
+ st.session_state.infostorekg = ""
961
+
962
+ if "dfdisease" not in st.session_state:
963
+ st.session_state.dfdisease = db.get_csv(kg_name, isDiseases = True)
964
+ if "dffull" not in st.session_state:
965
+ st.session_state.dffull = db.get_csv(kg_name, isDiseases = False)
966
+ if "datanet" not in st.session_state:
967
+ st.session_state.datanet = nx.from_pandas_edgelist(st.session_state.dffull , 'x_id', 'y_id', ['relation'])
968
+ datanet = st.session_state.datanet
969
+ kgD = st.session_state.dfdisease[['group_id_bert','group_name_bert', 'mondo_definition', 'umls_description','orphanet_definition']].astype(str).values.tolist()
970
+ kgD2 = [' '.join([x[1]+'.']+list(set([y for y in x[2:] if y != 'nan']))) for x in kgD]
971
+
972
+ if 'genbtn_clicked' not in st.session_state:
973
+ st.session_state['genbtn_clicked'] = False
974
+
975
+ if "TEMPLATE" not in st.session_state:
976
+ with open('templates/kgen.txt', 'r') as file:
977
+ TEMPLATE = file.read()
978
+ st.session_state.TEMPLATE = TEMPLATE
979
+
980
+ ### ------------------------------------------------------------------------------------------------
981
+ ### DEBUGGING CODE
982
+ #with st.expander("Patient Prompt"):
983
+ # TEMPLATE = st.text_area("Patient Prompt", value=st.session_state.TEMPLATE)
984
+ # st.session_state.TEMPLATE= TEMPLATE
985
+ ### ------------------------------------------------------------------------------------------------
986
+
987
+
988
+ prompt = PromptTemplate(
989
+ input_variables = ["infostorekg"],
990
+ template = st.session_state.TEMPLATE
991
+ )
992
+
993
+ if 'formautofill' not in st.session_state:
994
+ st.session_state['formautofill'] = ""
995
+
996
+ def gen_callback():
997
+ st.session_state['genbtn_clicked'] = True
998
+
999
+ def kgMatch(nodeName):
1000
+ newidx = kgD[kgD2.index(nodeName)][0]
1001
+ df_disease = st.session_state.dfdisease
1002
+ df_full = st.session_state.dffull
1003
+ desG = nx.single_source_dijkstra(datanet, newidx, cutoff = 1)
1004
+ diseaseName = df_disease[df_disease.group_id_bert == newidx]['group_name_bert'].unique().tolist()[0]
1005
+
1006
+ phenotypeFilter = df_full[(df_full['x_id'] == newidx)| (df_full['y_id'] == newidx)]
1007
+ phenotypeList = [x for x in list(set(phenotypeFilter.y_name.unique().tolist()+ phenotypeFilter.x_name.unique().tolist())) if diseaseName not in x ]
1008
+
1009
+ return (diseaseName, phenotypeList)
1010
+
1011
+ def passState(dummy):
1012
+ if "infostorekg" in st.session_state:
1013
+ return str(st.session_state.infostorekg)
1014
+ else:
1015
+ return dummy
1016
+
1017
+ if st.button('Generate scenario', on_click=gen_callback) or st.session_state['genbtn_clicked']:
1018
+ if len(st.session_state.selected_option)>0:
1019
+ infoPrompt = kgMatch(st.session_state.selected_option)
1020
+ st.session_state.infostorekg = str(infoPrompt)
1021
+
1022
+ if ("chain" not in st.session_state
1023
+ or
1024
+ st.session_state.TEMPLATE != TEMPLATE):
1025
+ #st.session_state.chain = (
1026
+ #{
1027
+ # "infostorekg": passState,
1028
+ # } |
1029
+ #LLMChain(llm=llm, prompt=prompt, verbose=False)
1030
+ st.session_state.chain = LLMChain(llm=llm, prompt=prompt, verbose = False)
1031
+ chain = st.session_state.chain
1032
+
1033
+ st.session_state['formautofill'] = chain.invoke({"infostorekg": st.session_state.infostorekg}).get("text")
1034
+ else:
1035
+ st.warning('Please search and select a condition first!')
1036
+
1037
+ ## ------------------------------------------------------------------------------------------------
1038
+ ## Forms part
1039
+
1040
+ conDict = {
1041
+ }
1042
+ rubDict = {'complaints': """Grade A: Elicits all of the above points in detail
1043
+ Grade B: Explores both presenting complaints (fill in) and (others) in almost full detail and rules
1044
+ out red flags
1045
+ Grade C: Explores both presenting complaints (fill in) incompletely and looks out for
1046
+ red flags
1047
+ Grade D: Explores both presenting complaints incompletely (fill in) but does not rule
1048
+ out any red flags/ explores one complaint and rules out at least one red flag
1049
+ Grade E: Only explores one of the two presenting complaints (fill in)""",
1050
+ 'syms': """Grade A: Explores at least (5) differentials in detail including (fill in) and elicits all * (6)
1051
+ points
1052
+ Grade B: Explores most (4) of the above systems including (fill in) and elicits all (6) *
1053
+ points
1054
+ Grade C: Explores most (4) of the above systems and elicits most (4-6) * points
1055
+ Grade D: Explores more than half (3) of the above systems and elicits most (4-6) * points
1056
+ Grade E: Explores only 1-2 of the above systems or asks less than half (1-3) * points""",
1057
+ 'others': """Grade A: Elicits all (4) of the * points and past medical Hx of (fill in)
1058
+ Grade B: Elicits all (4) of the * points and past medical Hx of (fill in),
1059
+ but did not go into important details
1060
+ Grade C: Elicits most (2-3) of the * points and past medical Hx of (fill in) in adequate detail
1061
+ Grade D: Elicits most (2-3) of the * points and past medical Hx of (fill in)
1062
+ but not in detail
1063
+ Grade E: Elicits 0-1 of the * points or did not take past medical Hx of (fill in)(not taking a (specific history: fill in ) history will give the candidate this score for the domain)""",
1064
+ 'findings': """Grade A: Presents all (4) of the * points, has (fill in) as top differentials with justification,
1065
+ and at least one other differentials with adequate justification
1066
+ Grade B: Presents most (2-3) of the * points, has (fill in) as top differentials but inadequate
1067
+ justification
1068
+ Grade C: Presents most (2-3) of the * points, has either (fill in) as top differential with at least
1069
+ one other differential
1070
+ Grade D: Presents most (2-3) of the *points OR only able to have 1 diagnosis without differential diagnosis
1071
+ Grade E: Presents few (0-1) of * points OR unable to have any diagnosis or differentials"""
1072
+ }
1073
+
1074
+
1075
+ ### ------------------------------------------------------------------------------------------------
1076
+ ### DEBUGGING CODE
1077
+ #with st.expander("GPTOUTPUT"):
1078
+ # out = st.text_area(" ", value=st.session_state['formautofill'])
1079
+ ### ------------------------------------------------------------------------------------------------
1080
+
1081
+ def splitReply():
1082
+ gendata = json.loads(st.session_state['formautofill'], strict = False)
1083
+ conditionsGen = []
1084
+ def curseDict(possibleDict, defDict):
1085
+ if type(defDict[possibleDict]) == str:
1086
+ return '\n' + possibleDict + ': '+ defDict[possibleDict]
1087
+ elif type(defDict[possibleDict]) == list:
1088
+ if all(isinstance(item, str) for item in defDict[possibleDict]):
1089
+ return '\n' + possibleDict + ': '+ '\n '.join(defDict[possibleDict])
1090
+ else:
1091
+ returnList = [str(x) for x in defDict[possibleDict]]
1092
+ return '\n' + possibleDict + ': '+ '\n '.join(returnList)
1093
+ elif type(defDict[possibleDict]) == dict:
1094
+ out = possibleDict
1095
+ for m in defDict[possibleDict]:
1096
+ out += curseDict(m, defDict[possibleDict])
1097
+ return out
1098
+ else:
1099
+ return possibleDict+'\n'+ str(defDict[possibleDict])
1100
+
1101
+ for x in gendata:
1102
+ if 'patient' in x.lower():
1103
+ conditionsGen.append(x)
1104
+ for y in gendata[x]:
1105
+ conditionsGen[-1] += curseDict(y, gendata[x])
1106
+ conDict['patients'] = conditionsGen[-1]
1107
+ elif 'complain' in x.lower() or 'present' in x.lower():
1108
+ conditionsGen.append(x)
1109
+ for y in gendata[x]:
1110
+ conditionsGen[-1] += curseDict(y, gendata[x])
1111
+ conDict['complaints'] = conditionsGen[-1]
1112
+
1113
+ elif 'symptom' in x.lower() or 'associate' in x.lower():
1114
+ conditionsGen.append(x)
1115
+ for y in gendata[x]:
1116
+ conditionsGen[-1] += curseDict(y, gendata[x])
1117
+ conDict['syms'] = conditionsGen[-1]
1118
+
1119
+ elif 'other' in x.lower():
1120
+ conditionsGen.append(x)
1121
+ for y in gendata[x]:
1122
+ conditionsGen[-1] += curseDict(y, gendata[x])
1123
+ conDict['others'] = conditionsGen[-1]
1124
+
1125
+ if 'diagnosis' in x.lower() or 'differential' in x.lower():
1126
+ conditionsGen.append(x)
1127
+ for y in gendata[x]:
1128
+ conditionsGen[-1] += curseDict(y, gendata[x])
1129
+ conDict['findings'] = conditionsGen[-1]
1130
+
1131
+ if len(st.session_state['formautofill'])>0:
1132
+ with st.form("filled_form"):
1133
+ st.write("Generated Autofill")
1134
+
1135
+ splitReply()
1136
+ with st.expander("Patient Scenario: Provided to students at the start of the exam"):
1137
+ patient_val_filled = st.text_area(" ", conDict['patients'], height=400, key="patientscenario")
1138
+
1139
+ st.write("Rubrics: Details students are expected to ask about and rubrics details for grading")
1140
+ with st.expander("History Taking: Presenting Complaints"):
1141
+ patient_val_filled = st.text_area(" ", conDict['complaints'], height=400, key="complaints1")
1142
+ complaints_val_filled = st.text_area("Rubrics: Complaints", rubDict['complaints'], height=400, key="complaints2")
1143
+ with st.expander("History Taking: Associated Symptoms"):
1144
+ syms_val_filled = st.text_area(" ", conDict['syms'], height=400, key="syms")
1145
+ syms_rubrics_filled = st.text_area("Rubrics: Symptoms", rubDict['syms'], height=400, key="syms2")
1146
+ with st.expander("History Taking: Others"):
1147
+ others_val_filled = st.text_area(" ", conDict['others'], height=400, key="others")
1148
+ others_rubrics_filled = st.text_area("Rubrics: Others", rubDict['others'], height=400, key="others2")
1149
+ with st.expander("Presentation of Findings, Diagnosis, and Differentials"):
1150
+ findings_val_filled = st.text_area(" ", conDict['findings'], height=400, key="findings")
1151
+ findings_rubrics_filled = st.text_area("Rubrics: Findings and Diagnosis",rubDict['findings'], height=400, key="findings2")
1152
+
1153
+ # Every form must have a submit button.
1154
+ submitted = st.form_submit_button("Submit")
1155
+ if submitted:
1156
+ #conDict.send(to firebase, with key) # retrieve from key
1157
+ st.write("check out your new scenario here!")
1158
+ loadScenario = st.button("Go to patient simulator (currently not implemented)")
1159
+ else:
1160
+ with st.form("empty_form"):
1161
+ st.write("Blank Form")
1162
+ with st.expander("Patient Scenario: Provided to students at the start of the exam"):
1163
+ patient_val_filled = st.text_area(" ", height=400, key="patientscenario_empty")
1164
+
1165
+ st.write("Rubrics: Details students are expected to ask about and rubrics details for grading")
1166
+ with st.expander("History Taking: Presenting Complaints"):
1167
+ col1_com, col2_com= st.columns(2)
1168
+ patient_val_filled = col1_com.text_area(" ", height=400, key="complaints_empty")
1169
+ complaints_val_filled = col2_com.text_area("Rubrics: Complaints", rubDict['complaints'], height=400, key="complaints2_empty")
1170
+ with st.expander("History Taking: Associated Symptoms"):
1171
+ syms_val_filled = st.text_area(" ", height=400, key="syms_empty")
1172
+ syms_rubrics_filled = st.text_area("Rubrics: Symptoms", rubDict['syms'], height=400, key="syms2_empty")
1173
+ with st.expander("History Taking: Others"):
1174
+ others_val_filled = st.text_area(" ", height=400, key="others_empty")
1175
+ others_rubrics_filled = st.text_area("Rubrics: Others", rubDict['others'], height=400, key="others2_empty")
1176
+ with st.expander("Presentation of Findings, Diagnosis, and Differentials"):
1177
+ findings_val_filled = st.text_area(" ", height=400, key="findings_empty")
1178
+ findings_rubrics_filled = st.text_area("Rubrics: Findings and Diagnosis",rubDict['findings'], height=400, key="findings2_empty")
1179
+
1180
+ # Every form must have a submit button.
1181
+ submitted_empty = st.form_submit_button("Submit")
1182
+ if submitted_empty:
1183
+ #conDict.send(to firebase, with key) # retrieve from key
1184
+ st.write("check out your new scenario here!")
1185
+ loadScenario = st.button("Go to patient simulator (currently not implemented)")