Spaces:
Runtime error
Runtime error
Fixing variable name clashes
Browse files- app_final.py +33 -44
app_final.py
CHANGED
@@ -42,6 +42,7 @@ import plotly.express as px
|
|
42 |
import plotly.graph_objects as go
|
43 |
import pandas as pd
|
44 |
|
|
|
45 |
|
46 |
if not os.path.isdir("./.streamlit"):
|
47 |
os.mkdir("./.streamlit")
|
@@ -826,19 +827,19 @@ else:
|
|
826 |
st.title("Medical Scenario Generator (for Admins)")
|
827 |
|
828 |
## Hardcode scenarios for now,
|
829 |
-
|
830 |
aortic dissection
|
831 |
anemia
|
832 |
cystitis
|
833 |
pneumonia
|
834 |
""".split("\n")
|
835 |
|
836 |
-
if "
|
837 |
-
st.session_state.
|
838 |
|
839 |
-
if "
|
840 |
-
st.session_state.
|
841 |
-
# st.session_state.
|
842 |
|
843 |
if "search_freetext" not in st.session_state:
|
844 |
st.session_state.search_freetext = " "
|
@@ -847,7 +848,7 @@ else:
|
|
847 |
#index_selectbox = st_tags(
|
848 |
# label='What medical condition would you like to generate a scenario for?',
|
849 |
# text='Input here ...',
|
850 |
-
# suggestions=
|
851 |
# value = ' ',
|
852 |
# maxtags = 1,
|
853 |
# key='0')
|
@@ -867,54 +868,42 @@ else:
|
|
867 |
# st.session_state.selected_index = indexes.index(search_selectbox)
|
868 |
# st.session_state.search_selectbox = search_selectbox
|
869 |
|
870 |
-
if "
|
871 |
-
st.session_state["
|
872 |
-
|
873 |
-
if "active_chat" not in st.session_state:
|
874 |
-
st.session_state.active_chat = 1
|
875 |
|
876 |
model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
|
877 |
model_kwargs = {"device": "cpu"}
|
878 |
# model_kwargs = {"device": "cuda"}
|
879 |
encode_kwargs = {"normalize_embeddings": True}
|
880 |
|
881 |
-
if "
|
882 |
-
st.session_state.
|
883 |
model_name=model_name,
|
884 |
model_kwargs = model_kwargs,
|
885 |
encode_kwargs = encode_kwargs)
|
886 |
-
|
887 |
-
if "
|
888 |
-
st.session_state.
|
889 |
-
llm = st.session_state.llm
|
890 |
#if "llm" not in st.session_state:
|
891 |
# st.session_state.llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
|
892 |
#llm = st.session_state.llm
|
893 |
#if "llm" not in st.session_state:
|
894 |
# st.session_state.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
895 |
-
|
896 |
|
897 |
## ------------------------------------------------------------------------------------------------
|
898 |
## Generator part
|
899 |
-
|
900 |
|
901 |
-
if "
|
902 |
-
#st.session_state.
|
903 |
-
st.session_state.
|
904 |
-
|
905 |
-
store = st.session_state.store
|
906 |
|
907 |
def topk(searchKW):
|
908 |
search_r = st.session_state.store.similarity_search(searchKW, k=5)
|
909 |
return [x.page_content for x in search_r]
|
910 |
|
911 |
-
#def onSelectButton(s_index):
|
912 |
-
# topkindexes = topk(s_index) #return top 5 list of similiar diseases
|
913 |
-
# return selected_options
|
914 |
-
|
915 |
-
#selectButton = st.button(on_click = onSelectButton(st.session_state.index_selectbox))
|
916 |
-
#selectButton = st.button("Search")
|
917 |
-
|
918 |
if 'searchbtn_clicked' not in st.session_state:
|
919 |
st.session_state['searchbtn_clicked'] = False
|
920 |
|
@@ -940,13 +929,13 @@ else:
|
|
940 |
col2.write('')
|
941 |
else:
|
942 |
st.markdown('---')
|
943 |
-
st.write("No results found. Perhaps try another condition? Some examples that work: "+', '.join(
|
944 |
|
945 |
if search_freetext != " ":
|
946 |
options = topk(search_freetext)
|
947 |
searchInner(options)
|
948 |
else:
|
949 |
-
options = topk(
|
950 |
searchInner(options)
|
951 |
|
952 |
st.write(st.session_state['selected_option'])
|
@@ -972,10 +961,10 @@ else:
|
|
972 |
if 'genbtn_clicked' not in st.session_state:
|
973 |
st.session_state['genbtn_clicked'] = False
|
974 |
|
975 |
-
if "
|
976 |
with open('templates/kgen.txt', 'r') as file:
|
977 |
-
|
978 |
-
st.session_state.
|
979 |
|
980 |
### ------------------------------------------------------------------------------------------------
|
981 |
### DEBUGGING CODE
|
@@ -985,9 +974,9 @@ else:
|
|
985 |
### ------------------------------------------------------------------------------------------------
|
986 |
|
987 |
|
988 |
-
|
989 |
input_variables = ["infostorekg"],
|
990 |
-
template = st.session_state.
|
991 |
)
|
992 |
|
993 |
if 'formautofill' not in st.session_state:
|
@@ -1019,16 +1008,16 @@ else:
|
|
1019 |
infoPrompt = kgMatch(st.session_state.selected_option)
|
1020 |
st.session_state.infostorekg = str(infoPrompt)
|
1021 |
|
1022 |
-
if ("
|
1023 |
or
|
1024 |
-
st.session_state.
|
1025 |
#st.session_state.chain = (
|
1026 |
#{
|
1027 |
# "infostorekg": passState,
|
1028 |
# } |
|
1029 |
-
#LLMChain(llm=
|
1030 |
-
st.session_state.
|
1031 |
-
chain = st.session_state.
|
1032 |
|
1033 |
st.session_state['formautofill'] = chain.invoke({"infostorekg": st.session_state.infostorekg}).get("text")
|
1034 |
else:
|
|
|
42 |
import plotly.graph_objects as go
|
43 |
import pandas as pd
|
44 |
|
45 |
+
import networkx as nx
|
46 |
|
47 |
if not os.path.isdir("./.streamlit"):
|
48 |
os.mkdir("./.streamlit")
|
|
|
827 |
st.title("Medical Scenario Generator (for Admins)")
|
828 |
|
829 |
## Hardcode scenarios for now,
|
830 |
+
indexes_gen = """
|
831 |
aortic dissection
|
832 |
anemia
|
833 |
cystitis
|
834 |
pneumonia
|
835 |
""".split("\n")
|
836 |
|
837 |
+
if "selected_index_gen" not in st.session_state:
|
838 |
+
st.session_state.selected_index_gen = 0
|
839 |
|
840 |
+
if "search_selectbox_gen" not in st.session_state:
|
841 |
+
st.session_state.search_selectbox_gen = " "
|
842 |
+
# st.session_state.index_selectbox_gen = "Headache"
|
843 |
|
844 |
if "search_freetext" not in st.session_state:
|
845 |
st.session_state.search_freetext = " "
|
|
|
848 |
#index_selectbox = st_tags(
|
849 |
# label='What medical condition would you like to generate a scenario for?',
|
850 |
# text='Input here ...',
|
851 |
+
# suggestions=indexes_gen,
|
852 |
# value = ' ',
|
853 |
# maxtags = 1,
|
854 |
# key='0')
|
|
|
868 |
# st.session_state.selected_index = indexes.index(search_selectbox)
|
869 |
# st.session_state.search_selectbox = search_selectbox
|
870 |
|
871 |
+
if "openai_model_gen" not in st.session_state:
|
872 |
+
st.session_state["openai_model_gen"] = "gpt-3.5-turbo"
|
|
|
|
|
|
|
873 |
|
874 |
model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
|
875 |
model_kwargs = {"device": "cpu"}
|
876 |
# model_kwargs = {"device": "cuda"}
|
877 |
encode_kwargs = {"normalize_embeddings": True}
|
878 |
|
879 |
+
if "embeddings_gen" not in st.session_state:
|
880 |
+
st.session_state.embeddings_gen = HuggingFaceEmbeddings(
|
881 |
model_name=model_name,
|
882 |
model_kwargs = model_kwargs,
|
883 |
encode_kwargs = encode_kwargs)
|
884 |
+
embeddings_gen = st.session_state.embeddings_gen
|
885 |
+
if "llm_gen" not in st.session_state:
|
886 |
+
st.session_state.llm_gen = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
|
|
|
887 |
#if "llm" not in st.session_state:
|
888 |
# st.session_state.llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
|
889 |
#llm = st.session_state.llm
|
890 |
#if "llm" not in st.session_state:
|
891 |
# st.session_state.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
|
892 |
+
llm_gen = st.session_state.llm_gen
|
893 |
|
894 |
## ------------------------------------------------------------------------------------------------
|
895 |
## Generator part
|
896 |
+
index_name_gen = f"indexes/faiss_index_large_v2"
|
897 |
|
898 |
+
if "store_gen" not in st.session_state:
|
899 |
+
#st.session_state.store_gen = FAISS.load_local(index_name_gen, embeddings_gen)
|
900 |
+
st.session_state.store_gen = db.get_store(index_name_gen, embeddings=embeddings_gen)
|
901 |
+
store_gen = st.session_state.store_gen
|
|
|
902 |
|
903 |
def topk(searchKW):
|
904 |
search_r = st.session_state.store.similarity_search(searchKW, k=5)
|
905 |
return [x.page_content for x in search_r]
|
906 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
907 |
if 'searchbtn_clicked' not in st.session_state:
|
908 |
st.session_state['searchbtn_clicked'] = False
|
909 |
|
|
|
929 |
col2.write('')
|
930 |
else:
|
931 |
st.markdown('---')
|
932 |
+
st.write("No results found. Perhaps try another condition? Some examples that work: "+', '.join(indexes_gen))
|
933 |
|
934 |
if search_freetext != " ":
|
935 |
options = topk(search_freetext)
|
936 |
searchInner(options)
|
937 |
else:
|
938 |
+
options = topk(indexes_gen[st.session_state.selected_index])
|
939 |
searchInner(options)
|
940 |
|
941 |
st.write(st.session_state['selected_option'])
|
|
|
961 |
if 'genbtn_clicked' not in st.session_state:
|
962 |
st.session_state['genbtn_clicked'] = False
|
963 |
|
964 |
+
if "TEMPLATE_gen" not in st.session_state:
|
965 |
with open('templates/kgen.txt', 'r') as file:
|
966 |
+
TEMPLATE_gen = file.read()
|
967 |
+
st.session_state.TEMPLATE_gen = TEMPLATE_gen
|
968 |
|
969 |
### ------------------------------------------------------------------------------------------------
|
970 |
### DEBUGGING CODE
|
|
|
974 |
### ------------------------------------------------------------------------------------------------
|
975 |
|
976 |
|
977 |
+
prompt_gen = PromptTemplate(
|
978 |
input_variables = ["infostorekg"],
|
979 |
+
template = st.session_state.TEMPLATE_gen
|
980 |
)
|
981 |
|
982 |
if 'formautofill' not in st.session_state:
|
|
|
1008 |
infoPrompt = kgMatch(st.session_state.selected_option)
|
1009 |
st.session_state.infostorekg = str(infoPrompt)
|
1010 |
|
1011 |
+
if ("chain_gen" not in st.session_state
|
1012 |
or
|
1013 |
+
st.session_state.TEMPLATE_gen != TEMPLATE):
|
1014 |
#st.session_state.chain = (
|
1015 |
#{
|
1016 |
# "infostorekg": passState,
|
1017 |
# } |
|
1018 |
+
#LLMChain(llm=llm_gen, prompt=prompt, verbose=False)
|
1019 |
+
st.session_state.chain_gen = LLMChain(llm=llm_gen, prompt=prompt_gen, verbose = False)
|
1020 |
+
chain = st.session_state.chain_gen
|
1021 |
|
1022 |
st.session_state['formautofill'] = chain.invoke({"infostorekg": st.session_state.infostorekg}).get("text")
|
1023 |
else:
|