Files changed (2) hide show
  1. .gitignore +1 -0
  2. pages/convosim.py +27 -16
.gitignore CHANGED
@@ -26,6 +26,7 @@ share/python-wheels/
26
  .installed.cfg
27
  *.egg
28
  MANIFEST
 
29
 
30
  # PyInstaller
31
  # Usually these files are written by a python script from a template
 
26
  .installed.cfg
27
  *.egg
28
  MANIFEST
29
+ *.csv
30
 
31
  # PyInstaller
32
  # Usually these files are written by a python script from a template
pages/convosim.py CHANGED
@@ -13,6 +13,7 @@ from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
 
14
  logger = get_logger(__name__)
15
  temperature = 0.8
 
16
  # username = "barb-chase" #"ivnban-ctl"
17
  st.set_page_config(page_title="Conversation Simulator")
18
 
@@ -20,7 +21,10 @@ if "sent_messages" not in st.session_state:
20
  st.session_state['sent_messages'] = 0
21
  if not are_models_alive():
22
  st.switch_page("pages/model_loader.py")
23
-
 
 
 
24
  if "total_messages" not in st.session_state:
25
  st.session_state['total_messages'] = 0
26
  if "issue" not in st.session_state:
@@ -49,30 +53,36 @@ if "scored" not in st.session_state:
49
  memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
50
 
51
  with st.sidebar:
52
- username = st.text_input("Username", value='', max_chars=30)
53
- if 'counselor_name' not in st.session_state:
54
- st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
 
 
55
  # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
56
  issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
57
- on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
58
  )
59
  supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
60
  language = st.selectbox("Select a Language", supported_languages, index=0,
61
  format_func=lambda x: "English" if x=="en" else "Spanish",
62
- on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
63
  )
64
 
65
  source = st.selectbox("Select a source Model A", SOURCES, index=0,
66
  format_func=source2label, key="source"
67
- )
68
-
69
- changed_source = any([
70
- st.session_state['previous_source'] != source,
71
- st.session_state['issue'] != issue,
72
- st.session_state['counselor_name'] != username,
73
- ])
74
- if changed_source:
75
- st.session_state["counselor_name"] = username
 
 
 
 
76
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
77
  logger.debug(f"texter name is {st.session_state['texter_name']}")
78
  st.session_state['previous_source'] = source
@@ -85,9 +95,10 @@ if changed_source:
85
  create_memory_add_initial_message(memories,
86
  issue,
87
  language,
88
- changed_source=changed_source,
89
  counselor_name=st.session_state["counselor_name"],
90
  texter_name=st.session_state["texter_name"])
 
91
  st.session_state['previous_source'] = source
92
  memoryA = st.session_state[list(memories.keys())[0]]
93
  # issue only without "." marker for model compatibility
 
13
 
14
  logger = get_logger(__name__)
15
  temperature = 0.8
16
+ reset_convo = False
17
  # username = "barb-chase" #"ivnban-ctl"
18
  st.set_page_config(page_title="Conversation Simulator")
19
 
 
21
  st.session_state['sent_messages'] = 0
22
  if not are_models_alive():
23
  st.switch_page("pages/model_loader.py")
24
+ # if "reset_convo" not in st.session_state:
25
+ # st.session_state['reset_convo'] = False
26
+ if 'counselor_name' not in st.session_state:
27
+ st.session_state["counselor_name"] = ""
28
  if "total_messages" not in st.session_state:
29
  st.session_state['total_messages'] = 0
30
  if "issue" not in st.session_state:
 
53
  memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
54
 
55
  with st.sidebar:
56
+ if st.button("Reset Conversation", type="primary"):
57
+ reset_convo = True
58
+ logger.debug("Clear conversation manually")
59
+ username = st.text_input("Username", value=st.session_state["counselor_name"], max_chars=30)
60
+ st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
61
  # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
62
  issue = st.selectbox("Select a Scenario", ISSUES, index=ISSUES.index(st.session_state['issue']), format_func=issue2label,
63
+ # kwargs={"memories":memories, "username":username, "language":"English"}
64
  )
65
  supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
66
  language = st.selectbox("Select a Language", supported_languages, index=0,
67
  format_func=lambda x: "English" if x=="en" else "Spanish",
68
+ # kwargs={"memories":memories, "username":username, "language":"English"}
69
  )
70
 
71
  source = st.selectbox("Select a source Model A", SOURCES, index=0,
72
  format_func=source2label, key="source"
73
+ )
74
+ # changed_source = any([
75
+ # st.session_state['previous_source'] != source,
76
+ # st.session_state['issue'] != issue,
77
+ # st.session_state['counselor_name'] != username,
78
+ # ])
79
+ logger.info("-"*10)
80
+ logger.info(f"Reset convo is {reset_convo}")
81
+ logger.info("-"*10)
82
+
83
+ if reset_convo:
84
+ clear_memory(memories, username, language)
85
+ st.session_state['counselor_name'] = username
86
  st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
87
  logger.debug(f"texter name is {st.session_state['texter_name']}")
88
  st.session_state['previous_source'] = source
 
95
  create_memory_add_initial_message(memories,
96
  issue,
97
  language,
98
+ changed_source=reset_convo,
99
  counselor_name=st.session_state["counselor_name"],
100
  texter_name=st.session_state["texter_name"])
101
+
102
  st.session_state['previous_source'] = source
103
  memoryA = st.session_state[list(memories.keys())[0]]
104
  # issue only without "." marker for model compatibility