plebias commited on
Commit
4c607ac
1 Parent(s): 9ffeffb

Fixing variable name clashes

Browse files
Files changed (1) hide show
  1. app_final.py +33 -44
app_final.py CHANGED
@@ -42,6 +42,7 @@ import plotly.express as px
42
  import plotly.graph_objects as go
43
  import pandas as pd
44
 
 
45
 
46
  if not os.path.isdir("./.streamlit"):
47
  os.mkdir("./.streamlit")
@@ -826,19 +827,19 @@ else:
826
  st.title("Medical Scenario Generator (for Admins)")
827
 
828
  ## Hardcode scenarios for now,
829
- indexes = """
830
  aortic dissection
831
  anemia
832
  cystitis
833
  pneumonia
834
  """.split("\n")
835
 
836
- if "selected_index" not in st.session_state:
837
- st.session_state.selected_index = 0
838
 
839
- if "search_selectbox" not in st.session_state:
840
- st.session_state.search_selectbox = " "
841
- # st.session_state.index_selectbox = "Headache"
842
 
843
  if "search_freetext" not in st.session_state:
844
  st.session_state.search_freetext = " "
@@ -847,7 +848,7 @@ else:
847
  #index_selectbox = st_tags(
848
  # label='What medical condition would you like to generate a scenario for?',
849
  # text='Input here ...',
850
- # suggestions=indexes,
851
  # value = ' ',
852
  # maxtags = 1,
853
  # key='0')
@@ -867,54 +868,42 @@ else:
867
  # st.session_state.selected_index = indexes.index(search_selectbox)
868
  # st.session_state.search_selectbox = search_selectbox
869
 
870
- if "openai_model" not in st.session_state:
871
- st.session_state["openai_model"] = "gpt-3.5-turbo"
872
-
873
- if "active_chat" not in st.session_state:
874
- st.session_state.active_chat = 1
875
 
876
  model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
877
  model_kwargs = {"device": "cpu"}
878
  # model_kwargs = {"device": "cuda"}
879
  encode_kwargs = {"normalize_embeddings": True}
880
 
881
- if "embeddings" not in st.session_state:
882
- st.session_state.embeddings = HuggingFaceEmbeddings(
883
  model_name=model_name,
884
  model_kwargs = model_kwargs,
885
  encode_kwargs = encode_kwargs)
886
- embeddings = st.session_state.embeddings
887
- if "llm" not in st.session_state:
888
- st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
889
- llm = st.session_state.llm
890
  #if "llm" not in st.session_state:
891
  # st.session_state.llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
892
  #llm = st.session_state.llm
893
  #if "llm" not in st.session_state:
894
  # st.session_state.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
895
- llm = st.session_state.llm
896
 
897
  ## ------------------------------------------------------------------------------------------------
898
  ## Generator part
899
- index_name = f"indexes/faiss_index_large_v2"
900
 
901
- if "store" not in st.session_state:
902
- #st.session_state.store = FAISS.load_local(index_name, embeddings)
903
- st.session_state.store = db.get_store(index_name, embeddings=embeddings)
904
- #st.session_state.store.similarity_search('hello')
905
- store = st.session_state.store
906
 
907
  def topk(searchKW):
908
  search_r = st.session_state.store.similarity_search(searchKW, k=5)
909
  return [x.page_content for x in search_r]
910
 
911
- #def onSelectButton(s_index):
912
- # topkindexes = topk(s_index) #return top 5 list of similiar diseases
913
- # return selected_options
914
-
915
- #selectButton = st.button(on_click = onSelectButton(st.session_state.index_selectbox))
916
- #selectButton = st.button("Search")
917
-
918
  if 'searchbtn_clicked' not in st.session_state:
919
  st.session_state['searchbtn_clicked'] = False
920
 
@@ -940,13 +929,13 @@ else:
940
  col2.write('')
941
  else:
942
  st.markdown('---')
943
- st.write("No results found. Perhaps try another condition? Some examples that work: "+', '.join(indexes))
944
 
945
  if search_freetext != " ":
946
  options = topk(search_freetext)
947
  searchInner(options)
948
  else:
949
- options = topk(indexes[st.session_state.selected_index])
950
  searchInner(options)
951
 
952
  st.write(st.session_state['selected_option'])
@@ -972,10 +961,10 @@ else:
972
  if 'genbtn_clicked' not in st.session_state:
973
  st.session_state['genbtn_clicked'] = False
974
 
975
- if "TEMPLATE" not in st.session_state:
976
  with open('templates/kgen.txt', 'r') as file:
977
- TEMPLATE = file.read()
978
- st.session_state.TEMPLATE = TEMPLATE
979
 
980
  ### ------------------------------------------------------------------------------------------------
981
  ### DEBUGGING CODE
@@ -985,9 +974,9 @@ else:
985
  ### ------------------------------------------------------------------------------------------------
986
 
987
 
988
- prompt = PromptTemplate(
989
  input_variables = ["infostorekg"],
990
- template = st.session_state.TEMPLATE
991
  )
992
 
993
  if 'formautofill' not in st.session_state:
@@ -1019,16 +1008,16 @@ else:
1019
  infoPrompt = kgMatch(st.session_state.selected_option)
1020
  st.session_state.infostorekg = str(infoPrompt)
1021
 
1022
- if ("chain" not in st.session_state
1023
  or
1024
- st.session_state.TEMPLATE != TEMPLATE):
1025
  #st.session_state.chain = (
1026
  #{
1027
  # "infostorekg": passState,
1028
  # } |
1029
- #LLMChain(llm=llm, prompt=prompt, verbose=False)
1030
- st.session_state.chain = LLMChain(llm=llm, prompt=prompt, verbose = False)
1031
- chain = st.session_state.chain
1032
 
1033
  st.session_state['formautofill'] = chain.invoke({"infostorekg": st.session_state.infostorekg}).get("text")
1034
  else:
 
42
  import plotly.graph_objects as go
43
  import pandas as pd
44
 
45
+ import networkx as nx
46
 
47
  if not os.path.isdir("./.streamlit"):
48
  os.mkdir("./.streamlit")
 
827
  st.title("Medical Scenario Generator (for Admins)")
828
 
829
  ## Hardcode scenarios for now,
830
+ indexes_gen = """
831
  aortic dissection
832
  anemia
833
  cystitis
834
  pneumonia
835
  """.split("\n")
836
 
837
+ if "selected_index_gen" not in st.session_state:
838
+ st.session_state.selected_index_gen = 0
839
 
840
+ if "search_selectbox_gen" not in st.session_state:
841
+ st.session_state.search_selectbox_gen = " "
842
+ # st.session_state.index_selectbox_gen = "Headache"
843
 
844
  if "search_freetext" not in st.session_state:
845
  st.session_state.search_freetext = " "
 
848
  #index_selectbox = st_tags(
849
  # label='What medical condition would you like to generate a scenario for?',
850
  # text='Input here ...',
851
+ # suggestions=indexes_gen,
852
  # value = ' ',
853
  # maxtags = 1,
854
  # key='0')
 
868
  # st.session_state.selected_index = indexes.index(search_selectbox)
869
  # st.session_state.search_selectbox = search_selectbox
870
 
871
+ if "openai_model_gen" not in st.session_state:
872
+ st.session_state["openai_model_gen"] = "gpt-3.5-turbo"
 
 
 
873
 
874
  model_name = "pritamdeka/S-PubMedBert-MS-MARCO"
875
  model_kwargs = {"device": "cpu"}
876
  # model_kwargs = {"device": "cuda"}
877
  encode_kwargs = {"normalize_embeddings": True}
878
 
879
+ if "embeddings_gen" not in st.session_state:
880
+ st.session_state.embeddings_gen = HuggingFaceEmbeddings(
881
  model_name=model_name,
882
  model_kwargs = model_kwargs,
883
  encode_kwargs = encode_kwargs)
884
+ embeddings_gen = st.session_state.embeddings_gen
885
+ if "llm_gen" not in st.session_state:
886
+ st.session_state.llm_gen = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0)
 
887
  #if "llm" not in st.session_state:
888
  # st.session_state.llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)
889
  #llm = st.session_state.llm
890
  #if "llm" not in st.session_state:
891
  # st.session_state.llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
892
+ llm_gen = st.session_state.llm_gen
893
 
894
  ## ------------------------------------------------------------------------------------------------
895
  ## Generator part
896
+ index_name_gen = f"indexes/faiss_index_large_v2"
897
 
898
+ if "store_gen" not in st.session_state:
899
+ #st.session_state.store_gen = FAISS.load_local(index_name_gen, embeddings_gen)
900
+ st.session_state.store_gen = db.get_store(index_name_gen, embeddings=embeddings_gen)
901
+ store_gen = st.session_state.store_gen
 
902
 
903
  def topk(searchKW):
904
  search_r = st.session_state.store.similarity_search(searchKW, k=5)
905
  return [x.page_content for x in search_r]
906
 
 
 
 
 
 
 
 
907
  if 'searchbtn_clicked' not in st.session_state:
908
  st.session_state['searchbtn_clicked'] = False
909
 
 
929
  col2.write('')
930
  else:
931
  st.markdown('---')
932
+ st.write("No results found. Perhaps try another condition? Some examples that work: "+', '.join(indexes_gen))
933
 
934
  if search_freetext != " ":
935
  options = topk(search_freetext)
936
  searchInner(options)
937
  else:
938
+ options = topk(indexes_gen[st.session_state.selected_index])
939
  searchInner(options)
940
 
941
  st.write(st.session_state['selected_option'])
 
961
  if 'genbtn_clicked' not in st.session_state:
962
  st.session_state['genbtn_clicked'] = False
963
 
964
+ if "TEMPLATE_gen" not in st.session_state:
965
  with open('templates/kgen.txt', 'r') as file:
966
+ TEMPLATE_gen = file.read()
967
+ st.session_state.TEMPLATE_gen = TEMPLATE_gen
968
 
969
  ### ------------------------------------------------------------------------------------------------
970
  ### DEBUGGING CODE
 
974
  ### ------------------------------------------------------------------------------------------------
975
 
976
 
977
+ prompt_gen = PromptTemplate(
978
  input_variables = ["infostorekg"],
979
+ template = st.session_state.TEMPLATE_gen
980
  )
981
 
982
  if 'formautofill' not in st.session_state:
 
1008
  infoPrompt = kgMatch(st.session_state.selected_option)
1009
  st.session_state.infostorekg = str(infoPrompt)
1010
 
1011
+ if ("chain_gen" not in st.session_state
1012
  or
1013
+ st.session_state.TEMPLATE_gen != TEMPLATE):
1014
  #st.session_state.chain = (
1015
  #{
1016
  # "infostorekg": passState,
1017
  # } |
1018
+ #LLMChain(llm=llm_gen, prompt=prompt, verbose=False)
1019
+ st.session_state.chain_gen = LLMChain(llm=llm_gen, prompt=prompt_gen, verbose = False)
1020
+ chain = st.session_state.chain_gen
1021
 
1022
  st.session_state['formautofill'] = chain.invoke({"infostorekg": st.session_state.infostorekg}).get("text")
1023
  else: