Spaces:
Sleeping
Sleeping
Commit
·
2e79a3c
1
Parent(s):
cf6ebf9
change in model aliveness calculations
Browse files- app_config.py +20 -4
- models/databricks/texter_sim_llm.py +1 -1
- models/ta_models/bp_utils.py +1 -1
- models/ta_models/cpc_utils.py +1 -1
- pages/model_loader.py +9 -6
- utils/app_utils.py +23 -16
app_config.py
CHANGED
@@ -18,12 +18,28 @@ SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT4o',
|
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
-
"CTL_llama3":
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# 'CTL_llama2': "llama2_convo_sim",
|
24 |
# "CTL_mistral": "convo_sim_mistral",
|
25 |
-
"CPC":
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
}
|
28 |
|
29 |
def source2label(source):
|
|
|
18 |
|
19 |
ENDPOINT_NAMES = {
|
20 |
# "CTL_llama2": "texter_simulator",
|
21 |
+
"CTL_llama3": {
|
22 |
+
"name": "texter_simulator_llm",
|
23 |
+
"model_type": "llm"
|
24 |
+
},
|
25 |
+
# "CTL_llama3": {
|
26 |
+
# "name": "databricks-meta-llama-3-1-70b-instruct",
|
27 |
+
# "model_type": "llm"
|
28 |
+
# },
|
29 |
# 'CTL_llama2': "llama2_convo_sim",
|
30 |
# "CTL_mistral": "convo_sim_mistral",
|
31 |
+
"CPC": {
|
32 |
+
"name": "phase_classifier",
|
33 |
+
"model_type": "classificator"
|
34 |
+
},
|
35 |
+
"BadPractices": {
|
36 |
+
"name": "training_adherence_bp",
|
37 |
+
"model_type": "classificator"
|
38 |
+
},
|
39 |
+
"training_adherence": {
|
40 |
+
"name": "training_adherence",
|
41 |
+
"model_type": "llm"
|
42 |
+
},
|
43 |
}
|
44 |
|
45 |
def source2label(source):
|
models/databricks/texter_sim_llm.py
CHANGED
@@ -16,7 +16,7 @@ texter:"""
|
|
16 |
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
-
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")
|
20 |
PROMPT = PromptTemplate(
|
21 |
input_variables=['history', 'input'],
|
22 |
template=_DATABRICKS_TEMPLATE_
|
|
|
16 |
|
17 |
def get_databricks_chain(source, issue, language, memory, temperature=0.8, texter_name="Kit"):
|
18 |
|
19 |
+
endpoint_name = ENDPOINT_NAMES.get(source, "texter_simulator")['name']
|
20 |
PROMPT = PromptTemplate(
|
21 |
input_variables=['history', 'input'],
|
22 |
template=_DATABRICKS_TEMPLATE_
|
models/ta_models/bp_utils.py
CHANGED
@@ -10,7 +10,7 @@ from app_config import ENDPOINT_NAMES
|
|
10 |
logger = get_logger(__name__)
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
-
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"])
|
14 |
HEADERS = {
|
15 |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
"Content-Type": "application/json",
|
|
|
10 |
logger = get_logger(__name__)
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
BP_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["BadPractices"]['name'])
|
14 |
HEADERS = {
|
15 |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
"Content-Type": "application/json",
|
models/ta_models/cpc_utils.py
CHANGED
@@ -10,7 +10,7 @@ from app_config import ENDPOINT_NAMES
|
|
10 |
logger = get_logger(__name__)
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
-
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"])
|
14 |
HEADERS = {
|
15 |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
"Content-Type": "application/json",
|
|
|
10 |
logger = get_logger(__name__)
|
11 |
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, truncation_side="left")
|
13 |
+
CPC_URL = os.environ["DATABRICKS_URL"].format(endpoint_name=ENDPOINT_NAMES["CPC"]['name'])
|
14 |
HEADERS = {
|
15 |
"Authorization": f"Bearer {os.environ['DATABRICKS_TOKEN']}",
|
16 |
"Content-Type": "application/json",
|
pages/model_loader.py
CHANGED
@@ -12,17 +12,20 @@ models_alive = False
|
|
12 |
start = time.time()
|
13 |
|
14 |
MODELS2LOAD = {
|
15 |
-
"CPC": {"model_name": "Phase Classifier", "loaded":
|
16 |
-
"CTL_llama3": {"model_name": "Texter Simulator", "loaded":
|
17 |
-
"BadPractices": {"model_name": "Advice Identificator", "loaded":
|
|
|
18 |
}
|
19 |
|
20 |
-
def write_model_status(writer, model_name, loaded, fail=
|
21 |
if loaded:
|
22 |
writer.write(f"✅ - {model_name} Loaded")
|
23 |
else:
|
24 |
-
if fail:
|
25 |
-
writer.write(f"❌ - {model_name} Failed to Load")
|
|
|
|
|
26 |
else:
|
27 |
writer.write(f"🔄 - {model_name} Loading")
|
28 |
|
|
|
12 |
start = time.time()
|
13 |
|
14 |
MODELS2LOAD = {
|
15 |
+
"CPC": {"model_name": "Phase Classifier", "loaded":None,},
|
16 |
+
"CTL_llama3": {"model_name": "Texter Simulator", "loaded":None,},
|
17 |
+
"BadPractices": {"model_name": "Advice Identificator", "loaded":None},
|
18 |
+
"training_adherence": {"model_name": "Training Adherence", "loaded":None},
|
19 |
}
|
20 |
|
21 |
+
def write_model_status(writer, model_name, loaded, fail=None):
|
22 |
if loaded:
|
23 |
writer.write(f"✅ - {model_name} Loaded")
|
24 |
else:
|
25 |
+
if fail in ["400", "500"]:
|
26 |
+
writer.write(f"❌ - {model_name} Failed to Load, Contact ifbarrerarincon@crisistextline.org")
|
27 |
+
elif fail == "404":
|
28 |
+
writer.write(f"❌ - {model_name} Still loading, please try in a couple of minutes")
|
29 |
else:
|
30 |
writer.write(f"🔄 - {model_name} Loading")
|
31 |
|
utils/app_utils.py
CHANGED
@@ -64,33 +64,40 @@ def create_memory_add_initial_message(memories, issue, language, changed_source=
|
|
64 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
65 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
66 |
|
67 |
-
def is_model_alive(
|
68 |
-
if model_type
|
69 |
-
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=
|
70 |
headers = HEADERS
|
71 |
-
|
72 |
-
# Send request to Serving
|
73 |
body_request = {
|
74 |
"inputs": [""]
|
75 |
}
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
except:
|
79 |
-
return
|
80 |
-
|
81 |
endpoint_url="https://api.openai.com/v1/models"
|
82 |
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
83 |
try:
|
84 |
_ = requests.get(url=endpoint_url, headers=headers, timeout=1)
|
85 |
-
return
|
86 |
except:
|
87 |
-
return
|
88 |
-
else:
|
89 |
-
raise Exception(f"Model Type {model_type} not supported")
|
90 |
|
91 |
def are_models_alive():
|
92 |
models_alive = []
|
93 |
-
for
|
94 |
-
models_alive.append(is_model_alive(
|
95 |
openai = is_model_alive("openai", model_type="openai")
|
96 |
-
return all(models_alive + [openai])
|
|
|
64 |
if len(st.session_state[memory].buffer_as_messages) < 1:
|
65 |
add_initial_message(issue, language, st.session_state[memory], texter_name=texter_name, counselor_name=counselor_name)
|
66 |
|
67 |
+
def is_model_alive(name, timeout=2, model_type="classificator"):
|
68 |
+
if model_type!="openai":
|
69 |
+
endpoint_url=os.environ['DATABRICKS_URL'].format(endpoint_name=name)
|
70 |
headers = HEADERS
|
71 |
+
if model_type == "classificator":
|
|
|
72 |
body_request = {
|
73 |
"inputs": [""]
|
74 |
}
|
75 |
+
elif model_type == "llm":
|
76 |
+
body_request = {
|
77 |
+
"prompt": "",
|
78 |
+
"temperature": 0,
|
79 |
+
"max_tokens": 1,
|
80 |
+
}
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise Exception(f"Model Type {model_type} not supported")
|
84 |
+
try:
|
85 |
+
response = requests.post(url=endpoint_url, headers=HEADERS, json=body_request, timeout=timeout)
|
86 |
+
return str(response.status_code)
|
87 |
except:
|
88 |
+
return "404"
|
89 |
+
else:
|
90 |
endpoint_url="https://api.openai.com/v1/models"
|
91 |
headers = {"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",}
|
92 |
try:
|
93 |
_ = requests.get(url=endpoint_url, headers=headers, timeout=1)
|
94 |
+
return "200"
|
95 |
except:
|
96 |
+
return "404"
|
|
|
|
|
97 |
|
98 |
def are_models_alive():
|
99 |
models_alive = []
|
100 |
+
for config in ENDPOINT_NAMES.values():
|
101 |
+
models_alive.append(is_model_alive(**config))
|
102 |
openai = is_model_alive("openai", model_type="openai")
|
103 |
+
return all([x=="200" for x in models_alive + [openai]])
|