Spaces:
Sleeping
Sleeping
Commit
·
70a482f
1
Parent(s):
348d7de
cpc and bad practices features
Browse files- .streamlit/config.toml +2 -0
- README.md +2 -2
- app_config.py +6 -1
- main.py +14 -0
- models/databricks/texter_sim_llm.py +14 -3
- models/ta_models/bp_utils.py +61 -0
- models/ta_models/config.py +26 -0
- models/ta_models/cpc_utils.py +49 -0
- pages/convosim.py +181 -0
- pages/model_loader.py +53 -0
- requirements.txt +2 -1
- utils/app_utils.py +37 -5
- utils/chain_utils.py +9 -2
- utils/memory_utils.py +0 -1
- utils/mongo_utils.py +37 -1
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[client]
|
2 |
+
showSidebarNavigation = false
|
README.md
CHANGED
@@ -4,8 +4,8 @@ emoji: 💬
|
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.38.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
app_config.py
CHANGED
@@ -19,8 +19,11 @@ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
|
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
"CTL_llama3": "texter_simulator_llm",
|
|
|
22 |
# 'CTL_llama2': "llama2_convo_sim",
|
23 |
-
"CTL_mistral": "convo_sim_mistral"
|
|
|
|
|
24 |
}
|
25 |
|
26 |
def source2label(source):
|
@@ -36,6 +39,8 @@ DB_CONVOS = 'conversations'
|
|
36 |
DB_COMPLETIONS = 'comparison_completions'
|
37 |
DB_BATTLES = 'battles'
|
38 |
DB_ERRORS = 'completion_errors'
|
|
|
|
|
39 |
|
40 |
MAX_MSG_COUNT = 60
|
41 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
|
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
"CTL_llama3": "texter_simulator_llm",
|
22 |
+
# "CTL_llama3": "databricks-meta-llama-3-1-70b-instruct",
|
23 |
# 'CTL_llama2': "llama2_convo_sim",
|
24 |
+
# "CTL_mistral": "convo_sim_mistral",
|
25 |
+
"CPC": "phase_classifier",
|
26 |
+
"BadPractices": "training_adherence_bp"
|
27 |
}
|
28 |
|
29 |
def source2label(source):
|
|
|
39 |
DB_COMPLETIONS = 'comparison_completions'
|
40 |
DB_BATTLES = 'battles'
|
41 |
DB_ERRORS = 'completion_errors'
|
42 |
+
DB_CPC = "cpc_comparison"
|
43 |
+
DB_BP = "bad_practices_comparison"
|
44 |
|
45 |
MAX_MSG_COUNT = 60
|
46 |
WARN_MSG_COUT = int(MAX_MSG_COUNT*0.8)
|
main.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
|
4 |
+
from utils.app_utils import are_models_alive
|
5 |
+
|
6 |
+
logger = get_logger(__name__)
|
7 |
+
|
8 |
+
st.set_page_config(page_title="Conversation Simulator")
|
9 |
+
|
10 |
+
with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
|
11 |
+
if not are_models_alive():
|
12 |
+
st.switch_page("pages/model_loader.py")
|
13 |
+
else:
|
14 |
+
st.switch_page("pages/convosim.py")
|
models/databricks/texter_sim_llm.py
CHANGED
@@ -17,14 +17,12 @@ texter:"""
|
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
|
20 |
-
|
21 |
PROMPT = PromptTemplate(
|
22 |
input_variables=['history', 'input'],
|
23 |
template=_DATABRICKS_TEMPLATE_
|
24 |
)
|
25 |
|
26 |
llm = CustomDatabricksLLM(
|
27 |
-
# endpoint_url="https://dbc-6dca8e8f-4084.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-1-70b-instruct/invocations",
|
28 |
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
29 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
30 |
texter_name=texter_name,
|
@@ -43,4 +41,17 @@ def get_databricks_chain(source, issue, language, memory, temperature=0.8, texte
|
|
43 |
)
|
44 |
|
45 |
logging.debug(f"loaded Databricks model")
|
46 |
-
return llm_chain, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
|
|
|
20 |
PROMPT = PromptTemplate(
|
21 |
input_variables=['history', 'input'],
|
22 |
template=_DATABRICKS_TEMPLATE_
|
23 |
)
|
24 |
|
25 |
llm = CustomDatabricksLLM(
|
|
|
26 |
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name),
|
27 |
bearer_token=os.environ["DATABRICKS_TOKEN"],
|
28 |
texter_name=texter_name,
|
|
|
41 |
)
|
42 |
|
43 |
logging.debug(f"loaded Databricks model")
|
44 |
+
return llm_chain, None
|
45 |
+
|
46 |
+
def cpc_is_alive():
|
47 |
+
body_request = {
|
48 |
+
"inputs": [""]
|
49 |
+
}
|
50 |
+
try:
|
51 |
+
# Send request to Serving
|
52 |
+
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request, timeout=2)
|
53 |
+
if response.status_code == 200:
|
54 |
+
return True
|
55 |
+
else: return False
|
56 |
+
except:
|
57 |
+
return False
|
models/ta_models/bp_utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
from .config import model_name_or_path, BP_THRESHOLD
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.mongo_utils import new_bp_comparison
|
8 |
+
from app_config import ENDPOINT_NAMES
|
9 |
+
|
10 |
+
logger = get_logger(__name__)
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"])
|
14 |
+
HEADERS = {
|
15 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
+
"Content-Type": "application/json",
|
17 |
+
}
|
18 |
+
|
19 |
+
def bp_predict_message(context, input):
|
20 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
21 |
+
encoding = tokenizer(
|
22 |
+
context,
|
23 |
+
input,
|
24 |
+
truncation="only_first",
|
25 |
+
)['input_ids']
|
26 |
+
body_request = {
|
27 |
+
"inputs": [tokenizer.decode(encoding)],
|
28 |
+
"params": {
|
29 |
+
"top_k": None
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Send request to Serving
|
35 |
+
response = requests.post(url=BP_URL, headers=HEADERS, json=body_request)
|
36 |
+
if response.status_code == 200:
|
37 |
+
response = response.json()['predictions'][0]
|
38 |
+
logger.debug(f"Raw BP prediction is {response}")
|
39 |
+
return [{k:v > BP_THRESHOLD if k=="score" else v for k,v in dict_.items()} for _, dict_ in response.items() ]
|
40 |
+
except:
|
41 |
+
pass
|
42 |
+
|
43 |
+
def bp_push2db(manual_confirmation=None):
|
44 |
+
if manual_confirmation is None:
|
45 |
+
if st.session_state.sel_bp == "Advice":
|
46 |
+
manual_confirmation = {"is_advice":True, "is_personal_info":False}
|
47 |
+
elif st.session_state.sel_bp == "Personal Info":
|
48 |
+
manual_confirmation = {"is_advice":False, "is_personal_info":True}
|
49 |
+
elif st.session_state.sel_bp == "Advice & Personal Info":
|
50 |
+
manual_confirmation = {"is_advice":True, "is_personal_info":True}
|
51 |
+
else:
|
52 |
+
manual_confirmation = {"is_advice":False, "is_personal_info":False}
|
53 |
+
new_bp_comparison(**{
|
54 |
+
"client": st.session_state['db_client'],
|
55 |
+
"convo_id": st.session_state['convo_id'],
|
56 |
+
"model": st.session_state['source'],
|
57 |
+
"context": st.session_state["context"],
|
58 |
+
"last_message": st.session_state["last_message"],
|
59 |
+
"ytrue": manual_confirmation,
|
60 |
+
"ypred": {x['label']:x['score'] for x in st.session_state['bp_prediction']},
|
61 |
+
})
|
models/ta_models/config.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name_or_path = "FacebookAI/xlm-roberta-large"
|
2 |
+
|
3 |
+
CPC_LABEL2STR = {
|
4 |
+
"0_ActiveEngagement": "Active Engagement",
|
5 |
+
"1_Explore": "Explore",
|
6 |
+
"2_IRA": "Immidiate Risk Assessment",
|
7 |
+
"3_SafetyAssessment": "Safety Assessment",
|
8 |
+
"4_SP&NS": "Safety Planning & Next Steps",
|
9 |
+
"5_EmergencyIntervention": "Emergency Intervention",
|
10 |
+
"6_WrappingUp": "Wrapping Up",
|
11 |
+
"7_Other": "Other",
|
12 |
+
}
|
13 |
+
|
14 |
+
CPC_LBL_OPTS = list(CPC_LABEL2STR.keys())
|
15 |
+
|
16 |
+
def cpc_label2str(phase):
|
17 |
+
return CPC_LABEL2STR[phase]
|
18 |
+
|
19 |
+
def phase2int(phase):
|
20 |
+
return int(phase.split("_")[0])
|
21 |
+
|
22 |
+
BP_THRESHOLD = 0.7
|
23 |
+
BP_LAB2STR = {
|
24 |
+
"is_advice": "Advice",
|
25 |
+
"is_personal_info": "Personal Info Sharing",
|
26 |
+
}
|
models/ta_models/cpc_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit.logger import get_logger
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
from .config import model_name_or_path
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from utils.mongo_utils import new_cpc_comparison
|
8 |
+
from app_config import ENDPOINT_NAMES
|
9 |
+
|
10 |
+
logger = get_logger(__name__)
|
11 |
+
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"])
|
14 |
+
HEADERS = {
|
15 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
+
"Content-Type": "application/json",
|
17 |
+
}
|
18 |
+
|
19 |
+
def cpc_predict_message(context, input):
|
20 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
21 |
+
encoding = tokenizer(
|
22 |
+
context,
|
23 |
+
input,
|
24 |
+
truncation="only_first",
|
25 |
+
)['input_ids']
|
26 |
+
body_request = {
|
27 |
+
"inputs": [tokenizer.decode(encoding)]
|
28 |
+
}
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Send request to Serving
|
32 |
+
response = requests.post(url=CPC_URL, headers=HEADERS, json=body_request)
|
33 |
+
if response.status_code == 200:
|
34 |
+
return response.json()['predictions'][0]["0"]["label"]
|
35 |
+
except:
|
36 |
+
pass
|
37 |
+
|
38 |
+
def cpc_push2db(is_same):
|
39 |
+
text_is_same = "SAME" if is_same else "WRONG"
|
40 |
+
logger.debug(f"pushing new {text_is_same} CPC")
|
41 |
+
new_cpc_comparison(**{
|
42 |
+
"client": st.session_state['db_client'],
|
43 |
+
"convo_id": st.session_state['convo_id'],
|
44 |
+
"model": st.session_state['source'],
|
45 |
+
"context": st.session_state["context"],
|
46 |
+
"last_message": st.session_state["last_message"],
|
47 |
+
"ytrue": st.session_state["last_phase"] if is_same else st.session_state["sel_phase"],
|
48 |
+
"ypred": st.session_state["last_phase"],
|
49 |
+
})
|
pages/convosim.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
from langchain.schema.messages import HumanMessage
|
5 |
+
from utils.mongo_utils import get_db_client
|
6 |
+
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF, are_models_alive
|
7 |
+
from utils.memory_utils import clear_memory, push_convo2db
|
8 |
+
from utils.chain_utils import get_chain, custom_chain_predict
|
9 |
+
from app_config import ISSUES, SOURCES, source2label, issue2label, MAX_MSG_COUNT, WARN_MSG_COUT
|
10 |
+
from models.ta_models.config import CPC_LBL_OPTS, cpc_label2str, BP_LAB2STR
|
11 |
+
from models.ta_models.cpc_utils import cpc_push2db
|
12 |
+
from models.ta_models.bp_utils import bp_predict_message, bp_push2db
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
temperature = 0.8
|
16 |
+
# username = "barb-chase" #"ivnban-ctl"
|
17 |
+
st.set_page_config(page_title="Conversation Simulator")
|
18 |
+
|
19 |
+
if "sent_messages" not in st.session_state:
|
20 |
+
st.session_state['sent_messages'] = 0
|
21 |
+
if not are_models_alive():
|
22 |
+
st.switch_page("pages/model_loader.py")
|
23 |
+
|
24 |
+
if "total_messages" not in st.session_state:
|
25 |
+
st.session_state['total_messages'] = 0
|
26 |
+
if "issue" not in st.session_state:
|
27 |
+
st.session_state['issue'] = ISSUES[0]
|
28 |
+
if 'previous_source' not in st.session_state:
|
29 |
+
st.session_state['previous_source'] = SOURCES[0]
|
30 |
+
if 'db_client' not in st.session_state:
|
31 |
+
st.session_state["db_client"] = get_db_client()
|
32 |
+
if 'texter_name' not in st.session_state:
|
33 |
+
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
34 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
35 |
+
if "last_phase" not in st.session_state:
|
36 |
+
st.session_state["last_phase"] = CPC_LBL_OPTS[0]
|
37 |
+
# st.session_state["sel_phase"] = CPC_LBL_OPTS[0]
|
38 |
+
if "changed_cpc" not in st.session_state:
|
39 |
+
st.session_state["changed_cpc"] = False
|
40 |
+
if "changed_bp" not in st.session_state:
|
41 |
+
st.session_state["changed_bp"] = False
|
42 |
+
|
43 |
+
# st.session_state["sel_phase"] = st.session_state["last_phase"]
|
44 |
+
|
45 |
+
memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}
|
46 |
+
|
47 |
+
with st.sidebar:
|
48 |
+
username = st.text_input("Username", value='Dani', max_chars=30)
|
49 |
+
if 'counselor_name' not in st.session_state:
|
50 |
+
st.session_state["counselor_name"] = username #get_random_name(names_df=DEFAULT_NAMES_DF)
|
51 |
+
# temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
|
52 |
+
issue = st.selectbox("Select a Scenario", ISSUES, index=0, format_func=issue2label,
|
53 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
54 |
+
)
|
55 |
+
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
|
56 |
+
language = st.selectbox("Select a Language", supported_languages, index=0,
|
57 |
+
format_func=lambda x: "English" if x=="en" else "Spanish",
|
58 |
+
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
|
59 |
+
)
|
60 |
+
|
61 |
+
source = st.selectbox("Select a source Model A", SOURCES, index=0,
|
62 |
+
format_func=source2label, key="source"
|
63 |
+
)
|
64 |
+
|
65 |
+
changed_source = any([
|
66 |
+
st.session_state['previous_source'] != source,
|
67 |
+
st.session_state['issue'] != issue,
|
68 |
+
st.session_state['counselor_name'] != username,
|
69 |
+
])
|
70 |
+
if changed_source:
|
71 |
+
st.session_state["counselor_name"] = username
|
72 |
+
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
|
73 |
+
logger.debug(f"texter name is {st.session_state['texter_name']}")
|
74 |
+
st.session_state['previous_source'] = source
|
75 |
+
st.session_state['issue'] = issue
|
76 |
+
st.session_state['sent_messages'] = 0
|
77 |
+
st.session_state['total_messages'] = 0
|
78 |
+
create_memory_add_initial_message(memories,
|
79 |
+
issue,
|
80 |
+
language,
|
81 |
+
changed_source=changed_source,
|
82 |
+
counselor_name=st.session_state["counselor_name"],
|
83 |
+
texter_name=st.session_state["texter_name"])
|
84 |
+
st.session_state['previous_source'] = source
|
85 |
+
memoryA = st.session_state[list(memories.keys())[0]]
|
86 |
+
# issue only without "." marker for model compatibility
|
87 |
+
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
|
88 |
+
|
89 |
+
st.title("💬 Simulator")
|
90 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
91 |
+
for msg in memoryA.buffer_as_messages:
|
92 |
+
role = "user" if type(msg) == HumanMessage else "assistant"
|
93 |
+
st.chat_message(role).write(msg.content)
|
94 |
+
|
95 |
+
def sent_request_llm(llm_chain, prompt):
|
96 |
+
st.session_state['sent_messages'] += 1
|
97 |
+
st.chat_message("user").write(prompt)
|
98 |
+
responses = custom_chain_predict(llm_chain, prompt, stopper)
|
99 |
+
for response in responses:
|
100 |
+
st.chat_message("assistant").write(response)
|
101 |
+
|
102 |
+
# @st.dialog("Bad Practice Detected")
|
103 |
+
# def confirm_bp(bp_prediction, prompt):
|
104 |
+
# bps = [BP_LAB2STR[x['label']] for x in bp_prediction if x['score']]
|
105 |
+
# st.markdown(f"The last message was considered :red[{' and '.join(bps)}]")
|
106 |
+
# "Are you sure you want to send this message?"
|
107 |
+
# newprompt = st.text_input("Change message to:")
|
108 |
+
# "If you do not want to change leave textbox empty"
|
109 |
+
# for bp in BP_LAB2STR.keys():
|
110 |
+
# _ = st.checkbox(f"Original Message was {BP_LAB2STR[bp]}", key=f"chkbx_{bp}", value=BP_LAB2STR[bp] in bps)
|
111 |
+
|
112 |
+
# if st.button("Confirm"):
|
113 |
+
# if newprompt is not None and newprompt != "":
|
114 |
+
# prompt = newprompt
|
115 |
+
# bp_push2db(
|
116 |
+
# {bp:st.session_state[f"chkbx_{bp}"] for bp in BP_LAB2STR.keys()}
|
117 |
+
# )
|
118 |
+
# sent_request_llm(llm_chain, prompt)
|
119 |
+
# st.rerun()
|
120 |
+
|
121 |
+
if prompt := st.chat_input(disabled=st.session_state['total_messages'] > MAX_MSG_COUNT - 4): #account for next interaction
|
122 |
+
if 'convo_id' not in st.session_state:
|
123 |
+
push_convo2db(memories, username, language)
|
124 |
+
st.session_state['context'] = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
125 |
+
st.session_state['last_message'] = prompt
|
126 |
+
if (not st.session_state.changed_cpc) and st.session_state["sent_messages"] > 0:
|
127 |
+
cpc_push2db(True)
|
128 |
+
else: st.session_state.changed_cpc = False
|
129 |
+
if (not st.session_state.changed_bp) and st.session_state["sent_messages"] > 0:
|
130 |
+
bp_push2db({x['label']:x['score'] for x in st.session_state['bp_prediction']})
|
131 |
+
else: st.session_state.changed_bp = False
|
132 |
+
|
133 |
+
context = llm_chain.memory.load_memory_variables({})[llm_chain.memory.memory_key]
|
134 |
+
st.session_state['bp_prediction'] = bp_predict_message(context, prompt)
|
135 |
+
if any([x['score'] for x in st.session_state['bp_prediction']]):
|
136 |
+
for bp in st.session_state['bp_prediction']:
|
137 |
+
if bp["score"]:
|
138 |
+
st.error(f"Detected {BP_LAB2STR[bp['label']]} in the last message!")
|
139 |
+
st.session_state.changed_bp = True
|
140 |
+
else:
|
141 |
+
sent_request_llm(llm_chain, prompt)
|
142 |
+
|
143 |
+
with st.sidebar:
|
144 |
+
st.divider()
|
145 |
+
st.markdown(f"### Total Sent Messages: :red[**{st.session_state['sent_messages']}**]")
|
146 |
+
st.markdown(f"### Total Messages: :red[**{st.session_state['total_messages']}**]")
|
147 |
+
# st.markdown()
|
148 |
+
def on_change_cpc():
|
149 |
+
cpc_push2db(False)
|
150 |
+
st.session_state.changed_cpc = True
|
151 |
+
def on_change_bp():
|
152 |
+
bp_push2db()
|
153 |
+
st.session_state.changed_bp = True
|
154 |
+
|
155 |
+
if st.session_state["sent_messages"] > 0:
|
156 |
+
_ = st.selectbox(f"""Last Human Message was considered :blue[**{
|
157 |
+
cpc_label2str(st.session_state['last_phase'])
|
158 |
+
}**]. If not please select from the following options""",
|
159 |
+
|
160 |
+
CPC_LBL_OPTS, index=None,format_func=cpc_label2str, on_change=on_change_cpc,
|
161 |
+
key="sel_phase",
|
162 |
+
)
|
163 |
+
|
164 |
+
BPs = [BP_LAB2STR[x['label']] for x in st.session_state['bp_prediction'] if x['score']]
|
165 |
+
selecttitle = f"""Last Human Message was considered :blue[**{
|
166 |
+
" and ".join(BPs)
|
167 |
+
}**].""" if len(BPs) > 0 else "Last Human Message was NOT considered Bad Practice."
|
168 |
+
_ = st.selectbox(selecttitle + " If not please select from the following options""",
|
169 |
+
|
170 |
+
["None", "Advice", "Personal Info", "Advice & Personal Info"], index=None, on_change=on_change_bp,
|
171 |
+
key="sel_bp"
|
172 |
+
)
|
173 |
+
|
174 |
+
st.session_state['total_messages'] = len(memoryA.chat_memory.messages)
|
175 |
+
if st.session_state['total_messages'] >= MAX_MSG_COUNT:
|
176 |
+
st.toast(f"Total of {MAX_MSG_COUNT} Messages reached. Conversation Ended", icon=":material/verified:")
|
177 |
+
elif st.session_state['total_messages'] >= WARN_MSG_COUT:
|
178 |
+
st.toast(f"The conversation will end at {MAX_MSG_COUNT} Total Messages ", icon=":material/warning:")
|
179 |
+
|
180 |
+
if not are_models_alive():
|
181 |
+
st.switch_page("pages/model_loader.py")
|
pages/model_loader.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit.logger import get_logger
|
4 |
+
from utils.app_utils import is_model_alive
|
5 |
+
from app_config import ENDPOINT_NAMES
|
6 |
+
|
7 |
+
logger = get_logger(__name__)
|
8 |
+
|
9 |
+
st.set_page_config(page_title="Conversation Simulator")
|
10 |
+
|
11 |
+
models_alive = False
|
12 |
+
start = time.time()
|
13 |
+
|
14 |
+
MODELS2LOAD = {
|
15 |
+
"CPC": {"model_name": "Phase Classifier", "loaded":False,},
|
16 |
+
"CTL_llama3": {"model_name": "Texter Simulator", "loaded":False,},
|
17 |
+
"BadPractices": {"model_name": "Advice Identificator", "loaded":False},
|
18 |
+
}
|
19 |
+
|
20 |
+
def write_model_status(writer, model_name, loaded, fail=False):
|
21 |
+
if loaded:
|
22 |
+
writer.write(f"✅ - {model_name} Loaded")
|
23 |
+
else:
|
24 |
+
if fail:
|
25 |
+
writer.write(f"❌ - {model_name} Failed to Load")
|
26 |
+
else:
|
27 |
+
writer.write(f"🔄 - {model_name} Loading")
|
28 |
+
|
29 |
+
with st.status("Loading Models Please Wait...(this may take up to 5 min)", expanded=True) as status:
|
30 |
+
|
31 |
+
for k in MODELS2LOAD.keys():
|
32 |
+
MODELS2LOAD[k]["writer"] = st.empty()
|
33 |
+
|
34 |
+
while not models_alive:
|
35 |
+
time.sleep(2)
|
36 |
+
for name, config in MODELS2LOAD.items():
|
37 |
+
config["loaded"] = is_model_alive(ENDPOINT_NAMES[name])
|
38 |
+
|
39 |
+
models_alive = all([x['loaded'] for x in MODELS2LOAD.values()])
|
40 |
+
|
41 |
+
for _, config in MODELS2LOAD.items():
|
42 |
+
write_model_status(**config)
|
43 |
+
|
44 |
+
if int(time.time()-start) > 30:
|
45 |
+
status.update(
|
46 |
+
label="Models took too long to load. Please Refresh Page in a couple of minutes", state="error", expanded=True
|
47 |
+
)
|
48 |
+
for _, config in MODELS2LOAD.items():
|
49 |
+
write_model_status(**config, fail=True)
|
50 |
+
break
|
51 |
+
|
52 |
+
if models_alive:
|
53 |
+
st.switch_page("pages/convosim.py")
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ mlflow==2.9.0
|
|
4 |
langchain==0.3.0
|
5 |
langchain-openai==0.2.0
|
6 |
langchain-community==0.3.0
|
7 |
-
streamlit==1.38.0
|
|
|
|
4 |
langchain==0.3.0
|
5 |
langchain-openai==0.2.0
|
6 |
langchain-community==0.3.0
|
7 |
+
streamlit==1.38.0
|
8 |
+
transformers==4.43.0
|
utils/app_utils.py
CHANGED
@@ -1,19 +1,22 @@
|
|
1 |
import pandas as pd
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
-
import
|
|
|
5 |
|
6 |
-
|
7 |
-
from app_config import ENVIRON
|
8 |
from utils.memory_utils import change_memories
|
9 |
from models.model_seeds import seeds
|
10 |
|
11 |
-
langchain.verbose = ENVIRON =="dev"
|
12 |
logger = get_logger(__name__)
|
13 |
|
14 |
# TODO: Include more variable and representative names
|
15 |
DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
|
16 |
DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
|
19 |
if names_df is None:
|
@@ -61,4 +64,33 @@ def create_memory_add_initial_message(memories, issue, language, changed_source=
|
|
61 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
62 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import streamlit as st
|
3 |
from streamlit.logger import get_logger
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
|
7 |
+
from app_config import ENDPOINT_NAMES
|
|
|
8 |
from utils.memory_utils import change_memories
|
9 |
from models.model_seeds import seeds
|
10 |
|
|
|
11 |
logger = get_logger(__name__)
|
12 |
|
13 |
# TODO: Include more variable and representative names
|
14 |
DEFAULT_NAMES = ["Olivia", "Kit", "Abby", "Tom", "Carolyne", "Jessiny"]
|
15 |
DEFAULT_NAMES_DF = pd.read_csv("./utils/names.csv")
|
16 |
+
HEADERS = {
|
17 |
+
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
}
|
20 |
|
21 |
def get_random_name(gender="Neutral", ethnical_group="Neutral", names_df=None):
|
22 |
if names_df is None:
|
|
|
64 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
65 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
66 |
|
67 |
+
def is_model_alive(endpoint_name, timeout=2, model_type="databricks"):
|
68 |
+
if model_type=="databricks":
|
69 |
+
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=endpoint_name)
|
70 |
+
headers = HEADERS
|
71 |
+
try:
|
72 |
+
# Send request to Serving
|
73 |
+
body_request = {
|
74 |
+
"inputs": [""]
|
75 |
+
}
|
76 |
+
_ = requests.post(url=endpoint_url, headers=HEADERS, json=body_request, timeout=timeout)
|
77 |
+
return True
|
78 |
+
except:
|
79 |
+
return False
|
80 |
+
elif model_type=="openai":
|
81 |
+
endpoint_url="https://api.openai.com/v1/models"
|
82 |
+
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
83 |
+
try:
|
84 |
+
_ = requests.get(url=endpoint_url, headers=headers, timeout=1)
|
85 |
+
return True
|
86 |
+
except:
|
87 |
+
return False
|
88 |
+
else:
|
89 |
+
raise Exception(f"Model Type {model_type} not supported")
|
90 |
+
|
91 |
+
def are_models_alive():
|
92 |
+
models_alive = []
|
93 |
+
for ename in ENDPOINT_NAMES.values():
|
94 |
+
models_alive.append(is_model_alive(ename))
|
95 |
+
openai = is_model_alive("openai", model_type="openai")
|
96 |
+
return all(models_alive + [openai])
|
utils/chain_utils.py
CHANGED
@@ -1,9 +1,11 @@
|
|
|
|
1 |
from streamlit.logger import get_logger
|
2 |
-
from
|
3 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
4 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
5 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
6 |
from models.databricks.texter_sim_llm import get_databricks_chain
|
|
|
7 |
|
8 |
logger = get_logger(__name__)
|
9 |
|
@@ -32,7 +34,12 @@ def custom_chain_predict(llm_chain, input, stop):
|
|
32 |
llm_chain._validate_inputs(inputs)
|
33 |
outputs = llm_chain._call(inputs)
|
34 |
llm_chain._validate_outputs(outputs)
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
for out in outputs[llm_chain.output_key]:
|
37 |
llm_chain.memory.chat_memory.add_ai_message(out)
|
38 |
return outputs[llm_chain.output_key]
|
|
|
1 |
+
import streamlit as st
|
2 |
from streamlit.logger import get_logger
|
3 |
+
from langchain_core.messages import HumanMessage
|
4 |
from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
|
5 |
from models.openai.role_models import get_role_chain, get_template_role_models
|
6 |
from models.databricks.scenario_sim_biz import get_databricks_biz_chain
|
7 |
from models.databricks.texter_sim_llm import get_databricks_chain
|
8 |
+
from models.ta_models.cpc_utils import cpc_predict_message
|
9 |
|
10 |
logger = get_logger(__name__)
|
11 |
|
|
|
34 |
llm_chain._validate_inputs(inputs)
|
35 |
outputs = llm_chain._call(inputs)
|
36 |
llm_chain._validate_outputs(outputs)
|
37 |
+
phase = cpc_predict_message(st.session_state['context'], st.session_state['last_message'])
|
38 |
+
st.session_state['last_phase'] = phase
|
39 |
+
logger.debug(phase)
|
40 |
+
llm_chain.memory.chat_memory.add_user_message(
|
41 |
+
HumanMessage(inputs['input'], response_metadata={"phase":phase})
|
42 |
+
)
|
43 |
for out in outputs[llm_chain.output_key]:
|
44 |
llm_chain.memory.chat_memory.add_ai_message(out)
|
45 |
return outputs[llm_chain.output_key]
|
utils/memory_utils.py
CHANGED
@@ -30,7 +30,6 @@ def change_memories(memories, language, changed_source=False):
|
|
30 |
|
31 |
if ("convo_id" in st.session_state) and changed_source:
|
32 |
del st.session_state['convo_id']
|
33 |
-
|
34 |
|
35 |
def clear_memory(memories, username, language):
|
36 |
for memory, _ in memories.items():
|
|
|
30 |
|
31 |
if ("convo_id" in st.session_state) and changed_source:
|
32 |
del st.session_state['convo_id']
|
|
|
33 |
|
34 |
def clear_memory(memories, username, language):
|
35 |
for memory, _ in memories.items():
|
utils/mongo_utils.py
CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
-
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
@@ -99,6 +99,42 @@ def new_completion_error(client, comparison_id, username, model):
|
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
logger.info(f"DBUTILS: new error id is {error_id}")
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def get_non_assesed_comparison(client, username):
|
103 |
from bson.son import SON
|
104 |
pipeline = [
|
|
|
4 |
from streamlit.logger import get_logger
|
5 |
from pymongo.mongo_client import MongoClient
|
6 |
from pymongo.server_api import ServerApi
|
7 |
+
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS, DB_CPC, DB_BP
|
8 |
|
9 |
DB_URL = os.environ['MONGO_URL']
|
10 |
DB_USR = os.environ['MONGO_USR']
|
|
|
99 |
error_id = errors.insert_one(error).inserted_id
|
100 |
logger.info(f"DBUTILS: new error id is {error_id}")
|
101 |
|
102 |
+
def new_cpc_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
103 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
104 |
+
comp = {
|
105 |
+
"error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
106 |
+
"conversation_id": convo_id,
|
107 |
+
"model": model,
|
108 |
+
"context": context,
|
109 |
+
"last_message": last_message,
|
110 |
+
"predicted_phase": ypred,
|
111 |
+
"manual_phase": ytrue,
|
112 |
+
}
|
113 |
+
|
114 |
+
db = client[DB_SCHEMA]
|
115 |
+
cpc_comps = db[DB_CPC]
|
116 |
+
comarison_id = cpc_comps.insert_one(comp).inserted_id
|
117 |
+
# logger.info(f"DBUTILS: new error id is {error_id}")
|
118 |
+
|
119 |
+
def new_bp_comparison(client, convo_id, model, context, last_message, ytrue, ypred):
|
120 |
+
# context = memory.load_memory_variables({})[memory.memory_key]
|
121 |
+
comp = {
|
122 |
+
"error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
|
123 |
+
"conversation_id": convo_id,
|
124 |
+
"model": model,
|
125 |
+
"context": context,
|
126 |
+
"last_message": last_message,
|
127 |
+
"is_advice": ypred["is_advice"],
|
128 |
+
"manual_is_advice": ytrue["is_advice"],
|
129 |
+
"is_pi": ypred["is_personal_info"],
|
130 |
+
"manual_is_pi": ytrue["is_personal_info"],
|
131 |
+
}
|
132 |
+
|
133 |
+
db = client[DB_SCHEMA]
|
134 |
+
bp_comps = db[DB_BP]
|
135 |
+
comarison_id = bp_comps.insert_one(comp).inserted_id
|
136 |
+
logger.info(f"DBUTILS: new BP id is {comarison_id}")
|
137 |
+
|
138 |
def get_non_assesed_comparison(client, username):
|
139 |
from bson.son import SON
|
140 |
pipeline = [
|