Spaces:
Paused
Paused
from dotenv import load_dotenv | |
import os | |
from timeit import default_timer as timer | |
import time | |
import requests | |
import streamlit as st | |
import tiktoken | |
load_dotenv("environments/.env") | |
LLM_IDK_ANSWER = "CANT_PROVIDE_NBQS" | |
ENGINE_GPT_3_5 = "gpt3_5_test" | |
ENGINE_GPT_4 = "gpt-4-test" | |
DEBUG = True | |
HUNDRED_CENTS = 100 | |
FAKE_OPENAI_RESPONSE = False | |
def get_openai_response_msg(response): | |
if response is None: | |
raise Exception("Unexpected error querying OpenAI: response is None") | |
if "choices" not in response: | |
st.error("Missing choices from response:") | |
st.error(response) | |
return None | |
choices = list(response["choices"]) | |
choice = choices[0] | |
return choice["message"] | |
def build_query_msg_content(selected_guidelines, chat_array): | |
dr_patient_conv = "Give 1 new question for which we don't know the answer" | |
if len(chat_array) > 0: | |
transcript = '"' | |
for i in chat_array: | |
if i["role"] == "Doctor": | |
transcript += "Doctor: " + str(i["content"].strip()) + "\n" | |
else: | |
transcript += "Patient: " + str(i["content"].strip()) + "\n" | |
transcript += '"\n' | |
dr_patient_conv += ( | |
"The patient already answered the following questions: \n" + transcript | |
) | |
guidelines_txt = "" | |
if len(selected_guidelines) > 0: | |
guidelines_txt = ". Only ask questions strictly based on the following without hallucinating:\n" | |
for g in selected_guidelines: | |
guidelines_txt += st.session_state["guidelines_dict"][g.lower()] | |
return dr_patient_conv + guidelines_txt | |
def build_general_chat_system_prompt(system_prompt, pre_chat_summary): | |
patient_input_str = 'Patient input: ' + pre_chat_summary | |
task_str = '''Task: Based on the patient input, | |
propose the most suited question. Don't use the same question twice.''' | |
updated_prompt = system_prompt + "\n" + patient_input_str + "\n" + task_str | |
openai_system_message = {"role": "system", "content": updated_prompt} | |
return openai_system_message | |
def get_general_chat_user_msg(): | |
guidelines_msg = { | |
"role": "user", | |
"content": build_query_msg_content( | |
st.session_state["selected_guidelines"], | |
st.session_state["chat_history_array"] | |
), | |
} | |
return guidelines_msg | |
def get_chat_history_string(chat_history): | |
res = "" | |
for i in chat_history: | |
if i["role"] == "Doctor": | |
res += "**Doctor**: " + str(i["content"].strip()) + " \n " | |
else: | |
res += "**Patient**: " + str(i["content"].strip()) + " \n\n " | |
return res | |
def get_doctor_question( | |
engine, | |
temperature, | |
top_p, | |
system_prompt, | |
pre_chat_summary, | |
patient_reply | |
): | |
print("Requesting Doctor question...") | |
if len(st.session_state["past_messages"]) == 0: | |
print("Initializing system prompt...") | |
general_chat_system_message = build_general_chat_system_prompt(system_prompt, pre_chat_summary) | |
st.session_state["past_messages"].append(general_chat_system_message) | |
user_msg = get_general_chat_user_msg() | |
st.session_state["last_request"] = user_msg | |
openai_messages = st.session_state["past_messages"] + [user_msg] | |
response = send_openai_request( | |
engine, None, temperature, top_p, openai_messages, "get_doctor_question" | |
) | |
openai_proposal = get_openai_response_msg(response) | |
st.session_state["last_proposal"] = openai_proposal | |
return openai_proposal | |
def summarize_conversation(prompt_msg, content, engine, temperature, top_p): | |
print("Summarizing conversation...") | |
prompt_obj = { | |
"role": "system", | |
"content": prompt_msg | |
} | |
new_msg = {"role": "user", "content": content} | |
messages = [prompt_obj, new_msg] | |
st.session_state["last_request"] = messages | |
response = send_openai_request( | |
engine, None, temperature, top_p, messages, "summarize_session" | |
) | |
openai_proposal = get_openai_response_msg(response) | |
st.session_state["last_proposal"] = openai_proposal | |
return openai_proposal | |
def get_triage_recommendation(prompt_msg, content, engine, temperature, top_p): | |
print("Requesting triage recommendation...") | |
system_prompt = { | |
"role": "system", | |
"content": prompt_msg | |
} | |
msg = content | |
new_msg = {"role": "user", "content": msg} | |
messages = [system_prompt, new_msg] | |
response = send_openai_request( | |
engine, None, temperature, top_p, messages, "get_llm_triage_reco" | |
) | |
openai_proposal = get_openai_response_msg(response) | |
return openai_proposal | |
def summarize_feed_info( | |
engine, temperature, top_p, age, gender, patient_medical_info, contact_reason, health_situation | |
): | |
print("Summarizing feed info...") | |
msg = "Please summarize the following:" | |
msg += "Patient is " + gender + " " + str(age) + " old. " | |
if patient_medical_info: | |
msg += patient_medical_info + ". " | |
if contact_reason: | |
msg += "Contact reason: " + contact_reason + ". " | |
if health_situation: | |
msg += "Health situation: " + health_situation + ". " | |
system_message = {"role": "system", "content": "You summarize patient information"} | |
new_msg = {"role": "user", "content": msg} | |
messages = [system_message] + [new_msg] | |
response = send_openai_request( | |
engine, None, temperature, top_p, messages, "summarize_params_and_concern" | |
) | |
openai_proposal = get_openai_response_msg(response) | |
return openai_proposal["content"] | |
def get_available_engines(): | |
return [ENGINE_GPT_3_5, ENGINE_GPT_4] | |
# See API ref & Swagger: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference | |
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/use-your-data-quickstart?source=recommendations&tabs=bash&pivots=rest-api#retrieve-required-variables | |
# for instructions on where to find the different parameters in Azure portal | |
def send_openai_request_old( | |
engine, search_query_type, temperature, top_p, messages, event_name | |
): | |
print('send_openai_request: ' + str(event_name) + '\n\n') | |
if FAKE_OPENAI_RESPONSE: | |
print("Faking OpenAI response...") | |
session_event = { | |
"event_name": event_name, | |
"prompt_tokens": 10, | |
"prompt_cost_chf": 0.1, | |
"completion_tokens": 11, | |
"completion_cost_chf": 0.11, | |
"total_cost_chf": 0, | |
"response_time": 0, | |
} | |
st.session_state["session_events"] += [session_event] | |
return {'id': 'chatcmpl-86wTdbCLS1wxeEOKNCtWPu7vMgyoq', 'object': 'chat.completion', 'created': 1696665445, | |
'model': 'gpt-4', 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': { | |
'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, | |
'sexual': {'filtered': False, 'severity': 'safe'}, | |
'violence': {'filtered': False, 'severity': 'safe'}}}], | |
'choices': [{'index': 0, 'finish_reason': 'stop', 'message': {'role': 'assistant', | |
'content': 'How long have you been experiencing these headaches and how have they developed over time?'}, | |
'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, | |
'self_harm': {'filtered': False, 'severity': 'safe'}, | |
'sexual': {'filtered': False, 'severity': 'safe'}, | |
'violence': {'filtered': False, 'severity': 'safe'}}}], | |
'usage': {'completion_tokens': 16, 'prompt_tokens': 518, 'total_tokens': 534}} | |
request_start = timer() | |
print("Sending messages: ") | |
print(messages) | |
llm_deployment_name = "" | |
embedding_deployment_name = "" | |
search_index_name = "" | |
api_version = "2023-08-01-preview" | |
if engine == ENGINE_GPT_3_5: | |
api_base = "https://cog-gpt-35-sandbox.openai.azure.com/" | |
llm_deployment_name = "gpt3_5_test" | |
api_key = os.getenv("AZURE_OPENAI_GPT3_5_KEY") | |
embedding_deployment_name = "embedding-gpt3_5" | |
elif engine == ENGINE_GPT_4: | |
api_base = "https://cog-gpt-4-sandbox-uks.openai.azure.com/" | |
llm_deployment_name = "gpt-4-test" | |
api_key = os.getenv("AZURE_OPENAI_GPT4_KEY") | |
embedding_deployment_name = "embedding-gpt4" | |
else: | |
raise Exception("Engine not yet supported: " + engine) | |
url = ( | |
api_base | |
+ "openai/deployments/" | |
+ llm_deployment_name | |
+ "/chat/completions?api-version=" | |
+ api_version | |
) | |
headers = {"Content-Type": "application/json", "api-key": api_key} | |
payload = {"temperature": temperature, "top_p": top_p, "messages": messages} | |
if search_query_type is not None: | |
search_endpoint = "https://cog-robin-test-euw.search.windows.net" | |
embedding_endpoint = ( | |
api_base | |
+ "openai/deployments/" | |
+ embedding_deployment_name | |
+ "/embeddings?api-version=2023-05-15" | |
) | |
data_source = { | |
"type": "AzureCognitiveSearch", | |
"parameters": { | |
"endpoint": search_endpoint, | |
"key": os.getenv("AZURE_COG_SEARCH_KEY"), | |
"inScope": True, # Limit responses to grounded data | |
"queryType": search_query_type, | |
}, | |
} | |
if search_query_type == "simple" or search_query_type == "keyword": | |
if engine == ENGINE_GPT_4: | |
data_source["parameters"]["indexName"] = "guidelines-simple-gpt4-230907" | |
elif engine == ENGINE_GPT_3_5: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-simple-gpt35-230907" | |
if search_query_type == "semantic": | |
data_source["parameters"]["semanticConfiguration"] = "default" | |
if engine == ENGINE_GPT_4: | |
data_source["parameters"]["indexName"] = "guidelines-gpt4-230907" | |
elif engine == ENGINE_GPT_3_5: | |
data_source["parameters"]["indexName"] = "guidelines-gpt35-230907" | |
if ( | |
search_query_type == "vector" | |
or search_query_type == "vectorSimpleHybrid" | |
or search_query_type == "vectorSemanticHybrid" | |
): | |
data_source["parameters"]["embeddingEndpoint"] = embedding_endpoint | |
data_source["parameters"]["embeddingKey"] = api_key | |
if search_query_type == "vector": | |
if engine == ENGINE_GPT_4: | |
data_source["parameters"]["indexName"] = "guidelines-vector-gpt4-230907" | |
elif engine == ENGINE_GPT_3_5: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-vector-gpt35-230907" | |
if search_query_type == "vectorSimpleHybrid": | |
if engine == ENGINE_GPT_4: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-vector-hybrid-gpt4-230907" | |
elif engine == ENGINE_GPT_3_5: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-vector-hybrid-gpt35-230907" | |
if search_query_type == "vectorSemanticHybrid": | |
data_source["parameters"]["semanticConfiguration"] = "default" | |
if engine == ENGINE_GPT_4: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-vector-hybrid-sem-gpt4-230907" | |
elif engine == ENGINE_GPT_3_5: | |
data_source["parameters"][ | |
"indexName" | |
] = "guidelines-vector-hybrid-sem-gpt35-230907" | |
print("Data source:") | |
print(data_source) | |
# Here 'extensions' is needed if dataSource arg is provided in the payload | |
# See file upload limitations in https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits | |
url = ( | |
api_base | |
+ "openai/deployments/" | |
+ llm_deployment_name | |
+ "/extensions/chat/completions?api-version=" | |
+ api_version | |
) | |
payload["dataSources"] = [data_source] | |
print("Querying " + url + " ...") | |
response = requests.post(url, headers=headers, json=payload) | |
response_json = response.json() | |
print("\n\n\nResponse:") | |
print(str(response_json)) | |
print("\n\n") | |
request_end = timer() | |
try: | |
prompt_tokens = response_json["usage"]["prompt_tokens"] | |
prompt_cost = get_token_costs(prompt_tokens, engine, "prompt") | |
completion_tokens = response_json["usage"]["completion_tokens"] | |
completion_cost = get_token_costs(completion_tokens, engine, "completion") | |
session_event = { | |
"event_name": event_name, | |
"prompt_tokens": prompt_tokens, | |
"prompt_cost_chf": prompt_cost, | |
"completion_tokens": completion_tokens, | |
"completion_cost_chf": completion_cost, | |
"total_cost_chf": prompt_cost + completion_cost, | |
"response_time": request_end - request_start, | |
} | |
st.session_state["session_events"] += [session_event] | |
except: | |
print("Unable to update prompt and response tokens") | |
return response_json | |
# See API ref & Swagger: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference | |
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/use-your-data-quickstart?source=recommendations&tabs=bash&pivots=rest-api#retrieve-required-variables | |
# for instructions on where to find the different parameters in Azure portal | |
def send_openai_request( | |
engine, search_query_type, temperature, top_p, messages, event_name | |
): | |
request_start = timer() | |
if DEBUG: | |
print("Sending messages: ") | |
print(messages) | |
if FAKE_OPENAI_RESPONSE: | |
print("Faking OpenAI response...") | |
session_event = { | |
"event_name": "mocked_" + event_name, | |
"prompt_tokens": 0, | |
"prompt_cost_chf": 0, | |
"completion_tokens": 0, | |
"completion_cost_chf": 0, | |
"total_cost_chf": 0, | |
"response_time": 0, | |
} | |
st.session_state["session_events"] += [session_event] | |
return {'id': 'chatcmpl-86wTdbCLS1wxeEOKNCtWPu7vMgyoq', 'object': 'chat.completion', 'created': 1696665445, | |
'model': 'gpt-4', 'prompt_filter_results': [{'prompt_index': 0, 'content_filter_results': { | |
'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, | |
'sexual': {'filtered': False, 'severity': 'safe'}, | |
'violence': {'filtered': False, 'severity': 'safe'}}}], | |
'choices': [{'index': 0, 'finish_reason': 'stop', 'message': {'role': 'assistant', | |
'content': 'MOCKED LLM RESPONSE: GP: Patient cannot be treated remotely'}, | |
'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, | |
'self_harm': {'filtered': False, 'severity': 'safe'}, | |
'sexual': {'filtered': False, 'severity': 'safe'}, | |
'violence': {'filtered': False, 'severity': 'safe'}}}], | |
'usage': {'completion_tokens': 16, 'prompt_tokens': 518, 'total_tokens': 534}} | |
llm_deployment_name = "" | |
embedding_deployment_name = "" | |
search_index_name = "" | |
url = "" | |
api_version = "2023-08-01-preview" | |
if engine == ENGINE_GPT_3_5: | |
url = str(os.getenv("AZURE_OPENAI_GPT3_5_ENDPOINT")) | |
api_key = os.getenv("AZURE_OPENAI_GPT3_5_KEY") | |
embedding_deployment_name = "embedding-gpt3_5" | |
elif engine == ENGINE_GPT_4: | |
url = str(os.getenv("AZURE_OPENAI_GPT4_ENDPOINT")) | |
api_key = os.getenv("AZURE_OPENAI_GPT4_KEY") | |
embedding_deployment_name = "embedding-gpt4" | |
else: | |
raise Exception("Engine not yet supported: " + engine) | |
headers = {"Content-Type": "application/json", "api-key": api_key} | |
payload = {"temperature": temperature, "top_p": top_p, "messages": messages} | |
if DEBUG: | |
print("Querying " + url + " ...") | |
st.session_state["llm_messages"] += messages | |
response = requests.post(url, headers=headers, json=payload) | |
response_json = response.json() | |
print("Response:") | |
print(response_json) | |
while "error" in response_json: | |
if int(response_json["error"]["code"]) != 429: | |
raise Exception("OpenAI error: " + str(response_json)) | |
print('OpenAI rate limit reached, waiting 2s before retrying...') | |
time.sleep(2) | |
response = requests.post(url, headers=headers, json=payload) | |
response_json = response.json() | |
print(response_json) | |
request_end = timer() | |
try: | |
prompt_tokens = response_json["usage"]["prompt_tokens"] | |
prompt_cost = get_token_costs(prompt_tokens, engine, "prompt") | |
completion_tokens = response_json["usage"]["completion_tokens"] | |
completion_cost = get_token_costs(completion_tokens, engine, "completion") | |
session_event = { | |
"event_name": event_name, | |
"prompt_tokens": prompt_tokens, | |
"prompt_cost_chf": prompt_cost, | |
"completion_tokens": completion_tokens, | |
"completion_cost_chf": completion_cost, | |
"total_cost_chf": prompt_cost + completion_cost, | |
"response_time": request_end - request_start, | |
} | |
st.session_state["session_events"] += [session_event] | |
if DEBUG: | |
print(session_event) | |
except: | |
print("Unable to update prompt and response tokens") | |
return response_json | |
def send_patient_reply( | |
engine, search_query_type, temperature, selected_guidelines, top_p, chat_array | |
): | |
print("Submitting patient reply...") | |
msg_content = build_query_msg_content(selected_guidelines, chat_array) | |
new_message = {"role": "user", "content": msg_content} | |
st.session_state["last_request"] = new_message | |
messages = st.session_state["past_messages"] + [new_message] | |
response = send_openai_request( | |
engine, search_query_type, temperature, top_p, messages, "send_dr_patient_msg" | |
) | |
received_message = get_openai_response_msg(response) | |
st.session_state["last_proposal"] = received_message | |
return received_message | |
def get_num_tokens(text, engine): | |
model = "gpt-3.5-turbo" | |
if engine == ENGINE_GPT_3_5: | |
pass | |
elif engine == ENGINE_GPT_4: | |
model = "gpt-4" | |
else: | |
raise Exception("Unknown model: " + engine) | |
encoding = tiktoken.encoding_for_model(model) | |
num_tokens = len(encoding.encode(text)) | |
return num_tokens | |
# Source: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ | |
def get_token_costs(num_tokens, engine, query_type): | |
chf_by_1k_token = 0 | |
if engine == ENGINE_GPT_3_5: | |
if query_type == "prompt": | |
# usd_by_1k_token = 0.003 | |
chf_by_1k_token = 0.0028 | |
elif query_type == "completion": | |
# usd_by_1k_token = 0.004 | |
chf_by_1k_token = 0.0037 | |
else: | |
raise Exception("Unknown type: " + query_type) | |
elif engine == ENGINE_GPT_4: | |
if query_type == "prompt": | |
# usd_by_1k_token = 0.03 | |
chf_by_1k_token = 0.0028 | |
elif query_type == "completion": | |
# usd_by_1k_token = 0.06 | |
chf_by_1k_token = 0.055 | |
else: | |
raise Exception("Unknown type: " + query_type) | |
elif engine == "embedding": | |
chf_by_1k_token = 0.0001 | |
else: | |
raise Exception("Unknown model: " + engine) | |
return chf_by_1k_token * num_tokens / 1000 | |
# No API ref; allowed values obtained from OpenAI error messages | |
def get_search_query_type_options(): | |
return [ | |
None, | |
"simple", | |
"semantic", | |
"vector", | |
"vectorSimpleHybrid", | |
"vectorSemanticHybrid", | |
] | |
DATASET_AIDA_JIRA_TICKETS = "aida reviewed jira tickets (N=1'407)" | |
DATASET_GT_CASES = "gt-cases (N=2'434)" | |
DATASET_APP_CHATS = "app chats (N=300)" | |
def get_dataset_names(): | |
return [DATASET_APP_CHATS, DATASET_GT_CASES, DATASET_AIDA_JIRA_TICKETS] | |