ivnban27-ctl commited on
Commit
70a482f
·
1 Parent(s): 348d7de

cpc and bad practices features

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [client]
2
+ showSidebarNavigation = false
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 💬
4
  colorFrom: red
5
  colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.26.0
8
- app_file: convosim.py
9
  pinned: false
10
  ---
11
 
 
4
  colorFrom: red
5
  colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.38.0
8
+ app_file: main.py
9
  pinned: false
10
  ---
11
 
app_config.py CHANGED
@@ -19,8 +19,11 @@ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
19
  ENDPOINT_NAMES = {
20
  # "CTL_llama2": "texter_simulator",
21
  "CTL_llama3": "texter_simulator_llm",
 
22
  # 'CTL_llama2': "llama2_convo_sim",
23
- "CTL_mistral": "convo_sim_mistral"
 
 
24
  }
25
 
26
  def source2label(source):
@@ -36,6 +39,8 @@ DB_CONVOS = 'conversations'
36
  DB_COMPLETIONS = 'comparison_completions'
37
  DB_BATTLES = 'battles'
38
  DB_ERRORS = 'completion_errors'
 
 
39
 
40
  MAX_MSG_COUNT = 60
41
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
 
19
  ENDPOINT_NAMES = {
20
  # "CTL_llama2": "texter_simulator",
21
  "CTL_llama3": "texter_simulator_llm",
22
+ # "CTL_llama3": "databricks-meta-llama-3-1-70b-instruct",
23
  # 'CTL_llama2': "llama2_convo_sim",
24
+ # "CTL_mistral": "convo_sim_mistral",
25
+ "CPC": "phase_classifier",
26
+ "BadPractices": "training_adherence_bp"
27
  }
28
 
29
  def source2label(source):
 
39
  DB_COMPLETIONS = 'comparison_completions'
40
  DB_BATTLES = 'battles'
41
  DB_ERRORS = 'completion_errors'
42
+ DB_CPC = "cpc_comparison"
43
+ DB_BP = "bad_practices_comparison"
44
 
45
  MAX_MSG_COUNT = 60
46
  WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+
4
+ from utils.app_utils import are_models_alive
5
+
6
+ logger = get_logger(__name__)
7
+
8
+ st.set_page_config(page_title="Conversation Simulator")
9
+
10
+ with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
11
+ if not are_models_alive():
12
+ st.switch_page("pages/model_loader.py")
13
+ else:
14
+ st.switch_page("pages/convosim.py")
models/databricks/texter_sim_llm.py CHANGED
@@ -17,14 +17,12 @@ texter:"""
17
  def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
18
 
19
  endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
20
-
21
  PROMPT = PromptTemplate(
22
  input_variables=['history', 'input'],
23
  template=_DATABRICKS_TEMPLATE_
24
  )
25
 
26
  llm = CustomDatabricksLLM(
27
- # endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
28
  endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
29
  bearer_token=os.environ["DATABRICKS_TOKEN"],
30
  texter_name=texter_name,
@@ -43,4 +41,17 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
43
  )
44
 
45
  logging.debug(f"loaded Databricks model")
46
- return llm_chain, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
18
 
19
  endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
 
20
  PROMPT = PromptTemplate(
21
  input_variables=['history', 'input'],
22
  template=_DATABRICKS_TEMPLATE_
23
  )
24
 
25
  llm = CustomDatabricksLLM(
 
26
  endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
27
  bearer_token=os.environ["DATABRICKS_TOKEN"],
28
  texter_name=texter_name,
 
41
  )
42
 
43
  logging.debug(f"loaded Databricks model")
44
+ return llm_chain, None
45
+
46
+ def cpc_is_alive():
47
+ body_request = {
48
+ "inputs": [""]
49
+ }
50
+ try:
51
+ # Send request to Serving
52
+ response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request, timeout=2)
53
+ if response.status_code == 200:
54
+ return True
55
+ else: return False
56
+ except:
57
+ return False
models/ta_models/bp_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import requests
4
+ import os
5
+ from .config import model_name_or_path, BP_THRESHOLD
6
+ from transformers import AutoTokenizer
7
+ from utils.mongo_utils import new_bp_comparison
8
+ from app_config import ENDPOINT_NAMES
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
13
+ BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"])
14
+ HEADERS = {
15
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
16
+ "Content-Type": "application/json",
17
+ }
18
+
19
+ def bp_predict_message(context, input):
20
+ # context = memory.load_memory_variables({})[memory.memory_key]
21
+ encoding = tokenizer(
22
+ context,
23
+ input,
24
+ truncation="only_first",
25
+ )['input_ids']
26
+ body_request = {
27
+ "inputs": [tokenizer.decode(encoding)],
28
+ "params": {
29
+ "top_k": None
30
+ }
31
+ }
32
+
33
+ try:
34
+ # Send request to Serving
35
+ response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
36
+ if response.status_code == 200:
37
+ response = response.json()['predictions'][0]
38
+ logger.debug(f"Raw BP prediction is {response}")
39
+ return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
40
+ except:
41
+ pass
42
+
43
+ def bp_push2db(manual_confirmation=None):
44
+ if manual_confirmation is None:
45
+ if st.session_state.sel_bp == "Advice":
46
+ manual_confirmation = {"is_advice":True, "is_personal_info":False}
47
+ elif st.session_state.sel_bp == "Personal Info":
48
+ manual_confirmation = {"is_advice":False, "is_personal_info":True}
49
+ elif st.session_state.sel_bp == "Advice & Personal Info":
50
+ manual_confirmation = {"is_advice":True, "is_personal_info":True}
51
+ else:
52
+ manual_confirmation = {"is_advice":False, "is_personal_info":False}
53
+ new_bp_comparison(**{
54
+ "client": st.session_state['db_client'],
55
+ "convo_id": st.session_state['convo_id'],
56
+ "model": st.session_state['source'],
57
+ "context": st.session_state["context"],
58
+ "last_message": st.session_state["last_message"],
59
+ "ytrue": manual_confirmation,
60
+ "ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
61
+ })
models/ta_models/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name_or_path = "FacebookAI/xlm-roberta-large"
2
+
3
+ CPC_LABEL2STR = {
4
+ "0_ActiveEngagement": "Active Engagement",
5
+ "1_Explore": "Explore",
6
+ "2_IRA": "Immidiate Risk Assessment",
7
+ "3_SafetyAssessment": "Safety Assessment",
8
+ "4_SP&NS": "Safety Planning & Next Steps",
9
+ "5_EmergencyIntervention": "Emergency Intervention",
10
+ "6_WrappingUp": "Wrapping Up",
11
+ "7_Other": "Other",
12
+ }
13
+
14
+ CPC_LBL_OPTS = list(CPC_LABEL2STR.keys())
15
+
16
+ def cpc_label2str(phase):
17
+ return CPC_LABEL2STR[phase]
18
+
19
+ def phase2int(phase):
20
+ return int(phase.split("_")[0])
21
+
22
+ BP_THRESHOLD = 0.7
23
+ BP_LAB2STR = {
24
+ "is_advice": "Advice",
25
+ "is_personal_info": "Personal Info Sharing",
26
+ }
models/ta_models/cpc_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import requests
4
+ import os
5
+ from .config import model_name_or_path
6
+ from transformers import AutoTokenizer
7
+ from utils.mongo_utils import new_cpc_comparison
8
+ from app_config import ENDPOINT_NAMES
9
+
10
+ logger = get_logger(__name__)
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
13
+ CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"])
14
+ HEADERS = {
15
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
16
+ "Content-Type": "application/json",
17
+ }
18
+
19
+ def cpc_predict_message(context, input):
20
+ # context = memory.load_memory_variables({})[memory.memory_key]
21
+ encoding = tokenizer(
22
+ context,
23
+ input,
24
+ truncation="only_first",
25
+ )['input_ids']
26
+ body_request = {
27
+ "inputs": [tokenizer.decode(encoding)]
28
+ }
29
+
30
+ try:
31
+ # Send request to Serving
32
+ response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
33
+ if response.status_code == 200:
34
+ return response.json()['predictions'][0]["0"]["label"]
35
+ except:
36
+ pass
37
+
38
+ def cpc_push2db(is_same):
39
+ text_is_same = "SAME" if is_same else "WRONG"
40
+ logger.debug(f"pushing new {text_is_same} CPC")
41
+ new_cpc_comparison(**{
42
+ "client": st.session_state['db_client'],
43
+ "convo_id": st.session_state['convo_id'],
44
+ "model": st.session_state['source'],
45
+ "context": st.session_state["context"],
46
+ "last_message": st.session_state["last_message"],
47
+ "ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
48
+ "ypred": st.session_state["last_phase"],
49
+ })
pages/convosim.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ from langchain.schema.messages import HumanMessage
5
+ from utils.mongo_utils import get_db_client
6
+ from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF, are_models_alive
7
+ from utils.memory_utils import clear_memory, push_convo2db
8
+ from utils.chain_utils import get_chain, custom_chain_predict
9
+ from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
10
+ from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
11
+ from models.ta_models.cpc_utils import cpc_push2db
12
+ from models.ta_models.bp_utils import bp_predict_message, bp_push2db
13
+
14
+ logger = get_logger(__name__)
15
+ temperature = 0.8
16
+ # username = "barb-chase" #"ivnban-ctl"
17
+ st.set_page_config(page_title="Conversation Simulator")
18
+
19
+ if "sent_messages" not in st.session_state:
20
+ st.session_state['sent_messages'] = 0
21
+ if not are_models_alive():
22
+ st.switch_page("pages/model_loader.py")
23
+
24
+ if "total_messages" not in st.session_state:
25
+ st.session_state['total_messages'] = 0
26
+ if "issue" not in st.session_state:
27
+ st.session_state['issue'] = ISSUES[0]
28
+ if 'previous_source' not in st.session_state:
29
+ st.session_state['previous_source'] = SOURCES[0]
30
+ if 'db_client' not in st.session_state:
31
+ st.session_state["db_client"] = get_db_client()
32
+ if 'texter_name' not in st.session_state:
33
+ st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
34
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
35
+ if "last_phase" not in st.session_state:
36
+ st.session_state["last_phase"] = CPC_LBL_OPTS[0]
37
+ # st.session_state["sel_phase"] = CPC_LBL_OPTS[0]
38
+ if "changed_cpc" not in st.session_state:
39
+ st.session_state["changed_cpc"] = False
40
+ if "changed_bp" not in st.session_state:
41
+ st.session_state["changed_bp"] = False
42
+
43
+ # st.session_state["sel_phase"] = st.session_state["last_phase"]
44
+
45
+ memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
46
+
47
+ with st.sidebar:
48
+ username = st.text_input("Username", value='Dani', max_chars=30)
49
+ if 'counselor_name' not in st.session_state:
50
+ st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
51
+ # temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
52
+ issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
53
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
54
+ )
55
+ supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
56
+ language = st.selectbox("Select a Language", supported_languages, index=0,
57
+ format_func=lambda x: "English" if x=="en" else "Spanish",
58
+ on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
59
+ )
60
+
61
+ source = st.selectbox("Select a source Model A", SOURCES, index=0,
62
+ format_func=source2label, key="source"
63
+ )
64
+
65
+ changed_source = any([
66
+ st.session_state['previous_source'] != source,
67
+ st.session_state['issue'] != issue,
68
+ st.session_state['counselor_name'] != username,
69
+ ])
70
+ if changed_source:
71
+ st.session_state["counselor_name"] = username
72
+ st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
73
+ logger.debug(f"texter name is {st.session_state['texter_name']}")
74
+ st.session_state['previous_source'] = source
75
+ st.session_state['issue'] = issue
76
+ st.session_state['sent_messages'] = 0
77
+ st.session_state['total_messages'] = 0
78
+ create_memory_add_initial_message(memories,
79
+ issue,
80
+ language,
81
+ changed_source=changed_source,
82
+ counselor_name=st.session_state["counselor_name"],
83
+ texter_name=st.session_state["texter_name"])
84
+ st.session_state['previous_source'] = source
85
+ memoryA = st.session_state[list(memories.keys())[0]]
86
+ # issue only without "." marker for model compatibility
87
+ llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
88
+
89
+ st.title("💬 Simulator")
90
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
91
+ for msg in memoryA.buffer_as_messages:
92
+ role = "user" if type(msg) == HumanMessage else "assistant"
93
+ st.chat_message(role).write(msg.content)
94
+
95
+ def sent_request_llm(llm_chain, prompt):
96
+ st.session_state['sent_messages'] += 1
97
+ st.chat_message("user").write(prompt)
98
+ responses = custom_chain_predict(llm_chain, prompt, stopper)
99
+ for response in responses:
100
+ st.chat_message("assistant").write(response)
101
+
102
+ # @st.dialog("Bad Practice Detected")
103
+ # def confirm_bp(bp_prediction, prompt):
104
+ # bps = [BP_LAB2STR[x['label']] for x in bp_prediction if x['score']]
105
+ # st.markdown(f"The last message was considered :red[{' and '.join(bps)}]")
106
+ # "Are you sure you want to send this message?"
107
+ # newprompt = st.text_input("Change message to:")
108
+ # "If you do not want to change leave textbox empty"
109
+ # for bp in BP_LAB2STR.keys():
110
+ # _ = st.checkbox(f"Original Message was {BP_LAB2STR[bp]}", key=f"chkbx_{bp}", value=BP_LAB2STR[bp] in bps)
111
+
112
+ # if st.button("Confirm"):
113
+ # if newprompt is not None and newprompt != "":
114
+ # prompt = newprompt
115
+ # bp_push2db(
116
+ # {bp:st.session_state[f"chkbx_{bp}"] for bp in BP_LAB2STR.keys()}
117
+ # )
118
+ # sent_request_llm(llm_chain, prompt)
119
+ # st.rerun()
120
+
121
+ if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
122
+ if 'convo_id' not in st.session_state:
123
+ push_convo2db(memories, username, language)
124
+ st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
125
+ st.session_state['last_message'] = prompt
126
+ if (not st.session_state.changed_cpc) and st.session_state["sent_messages"] > 0:
127
+ cpc_push2db(True)
128
+ else: st.session_state.changed_cpc = False
129
+ if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
130
+ bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
131
+ else: st.session_state.changed_bp = False
132
+
133
+ context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
134
+ st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
135
+ if any([x['score'] for x in st.session_state['bp_prediction']]):
136
+ for bp in st.session_state['bp_prediction']:
137
+ if bp["score"]:
138
+ st.error(f"Detected {BP_LAB2STR[bp['label']]} in the last message!")
139
+ st.session_state.changed_bp = True
140
+ else:
141
+ sent_request_llm(llm_chain, prompt)
142
+
143
+ with st.sidebar:
144
+ st.divider()
145
+ st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
146
+ st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
147
+ # st.markdown()
148
+ def on_change_cpc():
149
+ cpc_push2db(False)
150
+ st.session_state.changed_cpc = True
151
+ def on_change_bp():
152
+ bp_push2db()
153
+ st.session_state.changed_bp = True
154
+
155
+ if st.session_state["sent_messages"] > 0:
156
+ _ = st.selectbox(f"""Last Human Message was considered :blue[**{
157
+ cpc_label2str(st.session_state['last_phase'])
158
+ }**]. If not please select from the following options""",
159
+
160
+ CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
161
+ key="sel_phase",
162
+ )
163
+
164
+ BPs = [BP_LAB2STR[x['label']] for x in st.session_state['bp_prediction'] if x['score']]
165
+ selecttitle = f"""Last Human Message was considered :blue[**{
166
+ " and ".join(BPs)
167
+ }**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
168
+ _ = st.selectbox(selecttitle + " If not please select from the following options""",
169
+
170
+ ["None", "Advice", "Personal Info", "Advice & Personal Info"], index=None, on_change=on_change_bp,
171
+ key="sel_bp"
172
+ )
173
+
174
+ st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
175
+ if st.session_state['total_messages'] >= MAX_MSG_COUNT:
176
+ st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
177
+ elif st.session_state['total_messages'] >= WARN_MSG_COUT:
178
+ st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
179
+
180
+ if not are_models_alive():
181
+ st.switch_page("pages/model_loader.py")
pages/model_loader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ from streamlit.logger import get_logger
4
+ from utils.app_utils import is_model_alive
5
+ from app_config import ENDPOINT_NAMES
6
+
7
+ logger = get_logger(__name__)
8
+
9
+ st.set_page_config(page_title="Conversation Simulator")
10
+
11
+ models_alive = False
12
+ start = time.time()
13
+
14
+ MODELS2LOAD = {
15
+ "CPC": {"model_name": "Phase Classifier", "loaded":False,},
16
+ "CTL_llama3": {"model_name": "Texter Simulator", "loaded":False,},
17
+ "BadPractices": {"model_name": "Advice Identificator", "loaded":False},
18
+ }
19
+
20
+ def write_model_status(writer, model_name, loaded, fail=False):
21
+ if loaded:
22
+ writer.write(f"✅ - {model_name} Loaded")
23
+ else:
24
+ if fail:
25
+ writer.write(f"❌ - {model_name} Failed to Load")
26
+ else:
27
+ writer.write(f"🔄 - {model_name} Loading")
28
+
29
+ with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
30
+
31
+ for k in MODELS2LOAD.keys():
32
+ MODELS2LOAD[k]["writer"] = st.empty()
33
+
34
+ while not models_alive:
35
+ time.sleep(2)
36
+ for name, config in MODELS2LOAD.items():
37
+ config["loaded"] = is_model_alive(ENDPOINT_NAMES[name])
38
+
39
+ models_alive = all([x['loaded'] for x in MODELS2LOAD.values()])
40
+
41
+ for _, config in MODELS2LOAD.items():
42
+ write_model_status(**config)
43
+
44
+ if int(time.time()-start) > 30:
45
+ status.update(
46
+ label="Models took too long to load. Please Refresh Page in a couple of minutes", state="error", expanded=True
47
+ )
48
+ for _, config in MODELS2LOAD.items():
49
+ write_model_status(**config, fail=True)
50
+ break
51
+
52
+ if models_alive:
53
+ st.switch_page("pages/convosim.py")
requirements.txt CHANGED
@@ -4,4 +4,5 @@ mlflow==2.9.0
4
  langchain==0.3.0
5
  langchain-openai==0.2.0
6
  langchain-community==0.3.0
7
- streamlit==1.38.0
 
 
4
  langchain==0.3.0
5
  langchain-openai==0.2.0
6
  langchain-community==0.3.0
7
+ streamlit==1.38.0
8
+ transformers==4.43.0
utils/app_utils.py CHANGED
@@ -1,19 +1,22 @@
1
  import pandas as pd
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
- import langchain
 
5
 
6
-
7
- from app_config import ENVIRON
8
  from utils.memory_utils import change_memories
9
  from models.model_seeds import seeds
10
 
11
- langchain.verbose = ENVIRON =="dev"
12
  logger = get_logger(__name__)
13
 
14
  # TODO: Include more variable and representative names
15
  DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
16
  DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
 
 
 
 
17
 
18
  def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
19
  if names_df is None:
@@ -61,4 +64,33 @@ def create_memory_add_initial_message(memories, issue, language, changed_source=
61
  if len(st.session_state[memory].buffer_as_messages) < 1:
62
  add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
63
 
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import streamlit as st
3
  from streamlit.logger import get_logger
4
+ import os
5
+ import requests
6
 
7
+ from app_config import ENDPOINT_NAMES
 
8
  from utils.memory_utils import change_memories
9
  from models.model_seeds import seeds
10
 
 
11
  logger = get_logger(__name__)
12
 
13
  # TODO: Include more variable and representative names
14
  DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
15
  DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
16
+ HEADERS = {
17
+ "Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
18
+ "Content-Type": "application/json",
19
+ }
20
 
21
  def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
22
  if names_df is None:
 
64
  if len(st.session_state[memory].buffer_as_messages) < 1:
65
  add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
66
 
67
+ def is_model_alive(endpoint_name, timeout=2, model_type="databricks"):
68
+ if model_type=="databricks":
69
+ endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name)
70
+ headers = HEADERS
71
+ try:
72
+ # Send request to Serving
73
+ body_request = {
74
+ "inputs": [""]
75
+ }
76
+ _ = requests.post(url=endpoint_url, headers=HEADERS, json=body_request, timeout=timeout)
77
+ return True
78
+ except:
79
+ return False
80
+ elif model_type=="openai":
81
+ endpoint_url="https://api.openai.com/v1/models"
82
+ headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
83
+ try:
84
+ _ = requests.get(url=endpoint_url, headers=headers, timeout=1)
85
+ return True
86
+ except:
87
+ return False
88
+ else:
89
+ raise Exception(f"Model Type {model_type} not supported")
90
+
91
+ def are_models_alive():
92
+ models_alive = []
93
+ for ename in ENDPOINT_NAMES.values():
94
+ models_alive.append(is_model_alive(ename))
95
+ openai = is_model_alive("openai", model_type="openai")
96
+ return all(models_alive + [openai])
utils/chain_utils.py CHANGED
@@ -1,9 +1,11 @@
 
1
  from streamlit.logger import get_logger
2
- from models.model_seeds import seeds
3
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
4
  from models.openai.role_models import get_role_chain, get_template_role_models
5
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
6
  from models.databricks.texter_sim_llm import get_databricks_chain
 
7
 
8
  logger = get_logger(__name__)
9
 
@@ -32,7 +34,12 @@ def custom_chain_predict(llm_chain, input, stop):
32
  llm_chain._validate_inputs(inputs)
33
  outputs = llm_chain._call(inputs)
34
  llm_chain._validate_outputs(outputs)
35
- llm_chain.memory.chat_memory.add_user_message(inputs['input'])
 
 
 
 
 
36
  for out in outputs[llm_chain.output_key]:
37
  llm_chain.memory.chat_memory.add_ai_message(out)
38
  return outputs[llm_chain.output_key]
 
1
+ import streamlit as st
2
  from streamlit.logger import get_logger
3
+ from langchain_core.messages import HumanMessage
4
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
5
  from models.openai.role_models import get_role_chain, get_template_role_models
6
  from models.databricks.scenario_sim_biz import get_databricks_biz_chain
7
  from models.databricks.texter_sim_llm import get_databricks_chain
8
+ from models.ta_models.cpc_utils import cpc_predict_message
9
 
10
  logger = get_logger(__name__)
11
 
 
34
  llm_chain._validate_inputs(inputs)
35
  outputs = llm_chain._call(inputs)
36
  llm_chain._validate_outputs(outputs)
37
+ phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message'])
38
+ st.session_state['last_phase'] = phase
39
+ logger.debug(phase)
40
+ llm_chain.memory.chat_memory.add_user_message(
41
+ HumanMessage(inputs['input'], response_metadata={"phase":phase})
42
+ )
43
  for out in outputs[llm_chain.output_key]:
44
  llm_chain.memory.chat_memory.add_ai_message(out)
45
  return outputs[llm_chain.output_key]
utils/memory_utils.py CHANGED
@@ -30,7 +30,6 @@ def change_memories(memories, language, changed_source=False):
30
 
31
  if ("convo_id" in st.session_state) and changed_source:
32
  del st.session_state['convo_id']
33
-
34
 
35
  def clear_memory(memories, username, language):
36
  for memory, _ in memories.items():
 
30
 
31
  if ("convo_id" in st.session_state) and changed_source:
32
  del st.session_state['convo_id']
 
33
 
34
  def clear_memory(memories, username, language):
35
  for memory, _ in memories.items():
utils/mongo_utils.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
- from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
@@ -99,6 +99,42 @@ def new_completion_error(client, comparison_id, username, model):
99
  error_id = errors.insert_one(error).inserted_id
100
  logger.info(f"DBUTILS: new error id is {error_id}")
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def get_non_assesed_comparison(client, username):
103
  from bson.son import SON
104
  pipeline = [
 
4
  from streamlit.logger import get_logger
5
  from pymongo.mongo_client import MongoClient
6
  from pymongo.server_api import ServerApi
7
+ from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP
8
 
9
  DB_URL = os.environ['MONGO_URL']
10
  DB_USR = os.environ['MONGO_USR']
 
99
  error_id = errors.insert_one(error).inserted_id
100
  logger.info(f"DBUTILS: new error id is {error_id}")
101
 
102
+ def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
103
+ # context = memory.load_memory_variables({})[memory.memory_key]
104
+ comp = {
105
+ "error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
106
+ "conversation_id": convo_id,
107
+ "model": model,
108
+ "context": context,
109
+ "last_message": last_message,
110
+ "predicted_phase": ypred,
111
+ "manual_phase": ytrue,
112
+ }
113
+
114
+ db = client[DB_SCHEMA]
115
+ cpc_comps = db[DB_CPC]
116
+ comarison_id = cpc_comps.insert_one(comp).inserted_id
117
+ # logger.info(f"DBUTILS: new error id is {error_id}")
118
+
119
+ def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
120
+ # context = memory.load_memory_variables({})[memory.memory_key]
121
+ comp = {
122
+ "error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
123
+ "conversation_id": convo_id,
124
+ "model": model,
125
+ "context": context,
126
+ "last_message": last_message,
127
+ "is_advice": ypred["is_advice"],
128
+ "manual_is_advice": ytrue["is_advice"],
129
+ "is_pi": ypred["is_personal_info"],
130
+ "manual_is_pi": ytrue["is_personal_info"],
131
+ }
132
+
133
+ db = client[DB_SCHEMA]
134
+ bp_comps = db[DB_BP]
135
+ comarison_id = bp_comps.insert_one(comp).inserted_id
136
+ logger.info(f"DBUTILS: new BP id is {comarison_id}")
137
+
138
  def get_non_assesed_comparison(client, username):
139
  from bson.son import SON
140
  pipeline = [