Spaces:
Paused
Paused
| import re | |
| import uuid | |
| import pandas as pd | |
| import streamlit as st | |
| import re | |
| import matplotlib.pyplot as plt | |
| import subprocess | |
| import sys | |
| import io | |
| from utils.default_values import get_system_prompt, get_guidelines_dict | |
| from utils.epfl_meditron_utils import get_llm_response | |
| from utils.openai_utils import get_available_engines, get_search_query_type_options | |
| from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
| from sklearn.metrics import classification_report | |
| DATA_FOLDER = "data/" | |
| POC_VERSION = "0.1.0" | |
| MAX_QUESTIONS = 10 | |
| AVAILABLE_LANGUAGES = ["DE", "EN", "FR"] | |
| st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') | |
| # Azure apparently truncates message if longer than 200, see | |
| MAX_SYSTEM_MESSAGE_TOKENS = 200 | |
| def format_question(q): | |
| res = q | |
| # Remove numerical prefixes, if any, e.g. '1. [...]' | |
| if re.match(r'^[0-9].\s', q): | |
| res = res[3:] | |
| # Replace doc reference by doc name | |
| if len(st.session_state["citations"]) > 0: | |
| for source_ref in re.findall(r'\[doc[0-9]+\]', res): | |
| citation_number = int(re.findall(r'[0-9]+', source_ref)[0]) | |
| citation_index = citation_number - 1 if citation_number > 0 else 0 | |
| citation = st.session_state["citations"][citation_index] | |
| source_title = citation["title"] | |
| res = res.replace(source_ref, '[' + source_title + ']') | |
| return res.strip() | |
| def get_text_from_row(text): | |
| res = str(text) | |
| if res == "nan": | |
| return "" | |
| return res | |
| def get_questions_from_df(df, lang, test_scenario_name): | |
| questions = [] | |
| for i, row in df.iterrows(): | |
| questions.append({ | |
| "question": row[lang + ": Fragen"], | |
| "answer": get_text_from_row(row[test_scenario_name]), | |
| "question_id": uuid.uuid4() | |
| }) | |
| return questions | |
| def get_questions(df, lead_symptom, lang, test_scenario_name): | |
| print(str(st.session_state["lead_symptom"]) + " -> " + lead_symptom) | |
| print(str(st.session_state["scenario_name"]) + " -> " + test_scenario_name) | |
| if st.session_state["lead_symptom"] != lead_symptom or st.session_state["scenario_name"] != test_scenario_name: | |
| st.session_state["lead_symptom"] = lead_symptom | |
| st.session_state["scenario_name"] = test_scenario_name | |
| symptom_col_name = st.session_state["language"] + ": Symptome" | |
| df_questions = df[(df[symptom_col_name] == lead_symptom)] | |
| st.session_state["questions"] = get_questions_from_df(df_questions, lang, test_scenario_name) | |
| return st.session_state["questions"] | |
| def display_streamlit_sidebar(): | |
| st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) | |
| st.sidebar.write('**Parameters**') | |
| form = st.sidebar.form("config_form", clear_on_submit=True) | |
| model_repo_id = form.text_input(label="Repo", value=st.session_state["model_repo_id"]) | |
| model_filename = form.text_input(label="File name", value=st.session_state["model_filename"]) | |
| model_type = form.text_input(label="Model type", value=st.session_state["model_type"]) | |
| gpu_layers = form.slider('GPU Layers', min_value=0, | |
| max_value=100, value=st.session_state['gpu_layers'], step=1) | |
| system_prompt = "" | |
| #form.text_area(label='System prompt', | |
| # value=st.session_state["system_prompt"]) | |
| temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0, | |
| max_value=1.0, value=st.session_state['temperature'], step=0.1) | |
| top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0, | |
| max_value=1.0, value=st.session_state['top_p'], step=0.1) | |
| form.write('Best practice is to only modify temperature or top_p, not both') | |
| submitted = form.form_submit_button("Start session") | |
| if submitted: | |
| print('Parameters updated...') | |
| restart_session() | |
| st.session_state['session_started'] = True | |
| st.session_state["model_repo_id"] = model_repo_id | |
| st.session_state["model_filename"] = model_filename | |
| st.session_state["model_type"] = model_type | |
| st.session_state['gpu_layers'] = gpu_layers | |
| st.session_state["questions"] = [] | |
| st.session_state["lead_symptom"] = None | |
| st.session_state["scenario_name"] = None | |
| st.session_state["system_prompt"] = system_prompt | |
| st.session_state['session_started'] = True | |
| st.session_state["session_started"] = True | |
| st.session_state["temperature"] = temperature | |
| st.session_state["top_p"] = top_p | |
| st.rerun() | |
| def to_str(text): | |
| res = str(text) | |
| if res == "nan": | |
| return " " | |
| return " " + res | |
| def set_df_prompts(path, sheet_name): | |
| df_prompts = pd.read_excel(path, sheet_name, header=None) | |
| for i in range(3, df_prompts.shape[0]): | |
| df_prompts.iloc[2] += df_prompts.iloc[i].apply(to_str) | |
| df_prompts = df_prompts.T | |
| df_prompts = df_prompts[[0, 1, 2]] | |
| df_prompts[0] = df_prompts[0].astype(str) | |
| df_prompts[1] = df_prompts[1].astype(str) | |
| df_prompts[2] = df_prompts[2].astype(str) | |
| df_prompts.columns = ["Questionnaire", "Used Guideline", "Prompt"] | |
| df_prompts = df_prompts[1:] | |
| st.session_state["df_prompts"] = df_prompts | |
| def handle_nbq_click(c): | |
| question_without_source = re.sub(r'\[.*\]', '', c) | |
| question_without_source = question_without_source.strip() | |
| st.session_state['doctor_question'] = question_without_source | |
| def get_doctor_question_value(): | |
| if 'doctor_question' in st.session_state: | |
| return st.session_state['doctor_question'] | |
| return '' | |
| def update_chat_history(dr_question, patient_reply): | |
| print("update_chat_history" + str(dr_question) + " - " + str(patient_reply) + '...\n') | |
| if dr_question is not None: | |
| dr_msg = { | |
| "role": "Doctor", | |
| "content": dr_question | |
| } | |
| st.session_state["chat_history_array"].append(dr_msg) | |
| if patient_reply is not None: | |
| patient_msg = { | |
| "role": "Patient", | |
| "content": patient_reply | |
| } | |
| st.session_state["chat_history_array"].append(patient_msg) | |
| return st.session_state["chat_history_array"] | |
| def get_chat_history_string(chat_history): | |
| res = '' | |
| for i in chat_history: | |
| if i["role"] == "Doctor": | |
| res += '**Doctor**: ' + str(i["content"].strip()) + " \n " | |
| elif i["role"] == "Patient": | |
| res += '**Patient**: ' + str(i["content"].strip()) + " \n\n " | |
| else: | |
| raise Exception('Unknown role: ' + str(i["role"])) | |
| return res | |
| def restart_session(): | |
| print("Resetting params...") | |
| st.session_state["emg_class_enabled"] = False | |
| st.session_state["enable_llm_summary"] = False | |
| st.session_state["num_variants"] = 3 | |
| st.session_state["lang_index"] = 0 | |
| st.session_state["llm_message"] = "" | |
| st.session_state["llm_messages"] = [] | |
| st.session_state["triage_prompt_variants"] = ['''You are a telemedicine triage agent that decides between the following: | |
| Emergency: Patient health is at risk if he doesn't speak to a Doctor urgently | |
| Telecare: Patient can likely be treated remotely | |
| General Practitioner: Patient should visit a GP for an ad-real consultation''', | |
| '''You are a Doctor assistant that decides if a medical case can likely be treated remotely by a Doctor or not. | |
| The remote Doctor can write prescriptions and request the patient to provide a picture. | |
| Provide the triage recommendation first, and then explain your reasoning respecting the format given below: | |
| Treat remotely: <your reasoning> | |
| Treat ad-real: <your reasoning>''', | |
| '''You are a medical triage agent working for the telemedicine Company Medgate based in Switzerland. | |
| You decide if a case can be treated remotely or not, knowing that the remote Doctor can write prescriptions and request pictures. | |
| Provide the triage recommendation first, and then explain your reasoning respecting the format given below: | |
| Treat remotely: <your reasoning> | |
| Treat ad-real: <your reasoning>'''] | |
| st.session_state['nbqs'] = [] | |
| st.session_state['citations'] = {} | |
| st.session_state['past_messages'] = [] | |
| st.session_state["last_request"] = None | |
| st.session_state["last_proposal"] = None | |
| st.session_state['doctor_question'] = '' | |
| st.session_state['patient_reply'] = '' | |
| st.session_state['chat_history_array'] = [] | |
| st.session_state['chat_history'] = '' | |
| st.session_state['feed_summary'] = '' | |
| st.session_state['summary'] = '' | |
| st.session_state["selected_guidelines"] = ["General"] | |
| st.session_state["guidelines_dict"] = get_guidelines_dict() | |
| st.session_state["triage_recommendation"] = '' | |
| st.session_state["session_events"] = [] | |
| def init_session_state(): | |
| print('init_session_state()') | |
| st.session_state['session_started'] = False | |
| st.session_state['guidelines_ignored'] = False | |
| st.session_state['model_index'] = 1 | |
| st.session_state["model_repo_id"] = "TheBloke/meditron-7B-GGUF" | |
| st.session_state["model_filename"] = "meditron-7b.Q5_K_S.gguf" | |
| st.session_state["model_type"] = "llama" | |
| st.session_state['gpu_layers'] = 1 | |
| default_gender_index = 0 | |
| st.session_state['gender'] = get_genders()[default_gender_index] | |
| st.session_state['gender_index'] = default_gender_index | |
| st.session_state['age'] = 30 | |
| st.session_state['patient_medical_info'] = '' | |
| default_search_query = 0 | |
| st.session_state['search_query_type'] = get_search_query_type_options()[default_search_query] | |
| st.session_state['search_query_type_index'] = default_search_query | |
| st.session_state['engine'] = get_available_engines()[0] | |
| st.session_state['temperature'] = 0.0 | |
| st.session_state['top_p'] = 1.0 | |
| st.session_state['feed_chat_transcript'] = '' | |
| st.session_state["llm_model"] = True | |
| st.session_state["hugging_face_models"] = True | |
| st.session_state["local_models"] = True | |
| restart_session() | |
| st.session_state['system_prompt'] = get_system_prompt() | |
| st.session_state['system_prompt_after_on_change'] = get_system_prompt() | |
| st.session_state["summary"] = '' | |
| def get_genders(): | |
| return ['Male', 'Female'] | |
| def display_session_overview(): | |
| st.subheader('History of LLM queries') | |
| st.write(st.session_state["llm_messages"]) | |
| st.subheader("Session costs overview") | |
| df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) | |
| st.write(df_session_overview) | |
| if "prompt_tokens" in df_session_overview: | |
| prompt_tokens = df_session_overview["prompt_tokens"].sum() | |
| st.write("Prompt tokens: " + str(prompt_tokens)) | |
| prompt_cost = df_session_overview["prompt_cost_chf"].sum() | |
| st.write("Prompt CHF: " + str(prompt_cost)) | |
| completion_tokens = df_session_overview["completion_tokens"].sum() | |
| st.write("Completion tokens: " + str(completion_tokens)) | |
| completion_cost = df_session_overview["completion_cost_chf"].sum() | |
| st.write("Completion CHF: " + str(completion_cost)) | |
| completion_cost = df_session_overview["total_cost_chf"].sum() | |
| st.write("Total costs CHF: " + str(completion_cost)) | |
| total_time = df_session_overview["response_time"].sum() | |
| st.write("Total compute time (ms): " + str(total_time)) | |
| def remove_question(question_id): | |
| st.session_state["questions"] = [value for value in st.session_state["questions"] if | |
| str(value["question_id"]) != str(question_id)] | |
| st.rerun() | |
| def get_prompt_from_lead_symptom(df_config, df_prompt, lead_symptom, lang, fallback=True): | |
| de_lead_symptom = lead_symptom | |
| if lang != "DE": | |
| df_lead_symptom = df_config[df_config[lang + ": Symptome"] == lead_symptom] | |
| de_lead_symptom = df_lead_symptom["DE: Symptome"].iloc[0] | |
| print("DE lead symptom: " + de_lead_symptom) | |
| for i, row in df_prompt.iterrows(): | |
| if de_lead_symptom in row["Questionnaire"]: | |
| return row["Prompt"] | |
| warning_text = "No guidelines found for lead symptom " + lead_symptom + " (" + de_lead_symptom + ")" | |
| if fallback: | |
| st.toast(warning_text + ", using generic prompt", icon='🚨') | |
| return st.session_state["system_prompt"] | |
| st.toast(warning_text, icon='🚨') | |
| return "" | |
| def get_scenarios(df): | |
| return [v for v in df.columns.values if v.startswith('TLC') or v.startswith('GP')] | |
| def get_gender_age_from_test_scenario(test_scenario): | |
| try: | |
| result = re.search(r"([FM])(\d+)", test_scenario) | |
| res_age = int(result.group(2)) | |
| gender = result.group(1) | |
| res_gender = None | |
| if gender == "M": | |
| res_gender = "Male" | |
| elif gender == "F": | |
| res_gender = "Female" | |
| else: | |
| raise Exception('Unexpected gender') | |
| return res_gender, res_age | |
| except: | |
| st.error("Unable to extract name, gender; using 30M as default") | |
| return "Male", 30 | |
| def get_freetext_to_reco(reco_freetext_cased, emg_class_enabled=False): | |
| reco_freetext = "" | |
| if reco_freetext_cased: | |
| reco_freetext = reco_freetext_cased.lower() | |
| if reco_freetext.startswith('treat remotely') or reco_freetext.startswith('telecare'): | |
| return 'TELECARE' | |
| if reco_freetext.startswith('treat ad-real') or reco_freetext.startswith('gp') \ | |
| or reco_freetext.startswith('general practitioner'): | |
| return 'GP' | |
| if reco_freetext.startswith('emergency') or reco_freetext.startswith('emg') \ | |
| or reco_freetext.startswith('urgent'): | |
| if emg_class_enabled: | |
| return 'EMERGENCY' | |
| return 'GP' | |
| if "gp" in reco_freetext or 'general practitioner' in reco_freetext \ | |
| or "nicht über tele" in reco_freetext or 'durch einen arzt erford' in reco_freetext \ | |
| or "persönliche untersuchung erfordert" in reco_freetext: | |
| return 'GP' | |
| if ("telecare" in reco_freetext or 'telemed' in reco_freetext or | |
| 'can be treated remotely' in reco_freetext): | |
| return 'TELECARE' | |
| if ('emergency' in reco_freetext or 'urgent' in reco_freetext or | |
| 'not be treated remotely' in reco_freetext or "nicht tele" in reco_freetext): | |
| return 'GP' | |
| warning_msg = 'Cannot extract reco from LLM text: ' + reco_freetext | |
| st.toast(warning_msg) | |
| print(warning_msg) | |
| return 'TRIAGE_IMPOSSIBLE' | |
| def get_structured_reco(row, index, emg_class_enabled): | |
| freetext_reco_col_name = "llm_reco_freetext_" + str(index) | |
| freetext_reco = row[freetext_reco_col_name].lower() | |
| return get_freetext_to_reco(freetext_reco, emg_class_enabled) | |
| def add_expected_dispo(row, emg_class_enabled): | |
| disposition = row["disposition"] | |
| if disposition == "GP" or disposition == "TELECARE": | |
| return disposition | |
| if disposition == "EMERGENCY": | |
| if emg_class_enabled: | |
| return "EMERGENCY" | |
| return "GP" | |
| raise Exception("Missing disposition for row " + str(row.name) + " with summary " + row["case_summary"]) | |
| def get_test_scenarios(df): | |
| res = [] | |
| for col in df.columns.values: | |
| if str(col).startswith('GP') or str(col).startswith('TLC'): | |
| res.append(col) | |
| return res | |
| def get_transcript(df, test_scenario, lang): | |
| transcript = "" | |
| for i, row in df.iterrows(): | |
| transcript += "\nDoctor: " + row[lang + ": Fragen"] | |
| transcript += ", Patient: " + str(row[test_scenario]) | |
| return transcript | |
| def get_expected_from_scenario(test_scenario): | |
| reco = test_scenario.split('_')[0] | |
| if reco == "GP": | |
| return "GP" | |
| elif reco == "TLC": | |
| return "TELECARE" | |
| else: | |
| raise Exception('Unexpected reco: ' + reco) | |
| def plot_report(title, expected, predicted, display_labels): | |
| st.markdown('#### ' + title) | |
| conf_matrix = confusion_matrix(expected, predicted, labels=display_labels) | |
| conf_matrix_plot = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=display_labels) | |
| conf_matrix_plot.plot() | |
| st.pyplot(plt.gcf()) | |
| report = classification_report(expected, predicted, output_dict=True) | |
| df_report = pd.DataFrame(report).transpose() | |
| st.write(df_report) | |
| df_rp = df_report | |
| df_rp = df_rp.drop('support', axis=1) | |
| df_rp = df_rp.drop(['accuracy', 'macro avg', 'weighted avg']) | |
| try: | |
| ax = df_rp.plot(kind="bar", legend=True) | |
| for container in ax.containers: | |
| ax.bar_label(container, fontsize=7) | |
| plt.xticks(rotation=45) | |
| plt.legend(loc=(1.04, 0)) | |
| st.pyplot(plt.gcf()) | |
| except Exception as e: | |
| # Out of bounds | |
| pass | |
| def get_complete_prompt(generic_prompt, guidelines_prompt): | |
| complete_prompt = "" | |
| if generic_prompt: | |
| complete_prompt += generic_prompt | |
| if generic_prompt and guidelines_prompt: | |
| complete_prompt += ".\n\n" | |
| if guidelines_prompt: | |
| complete_prompt += guidelines_prompt | |
| return complete_prompt | |
| def run_command(args): | |
| """Run command, transfer stdout/stderr back into Streamlit and manage error""" | |
| cmd = ' '.join(args) | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| print(result) | |
| def get_diarized_f_path(audio_f_name): | |
| # TODO p2: Quick hack, cleaner with os or regexes | |
| base_name = audio_f_name.split('.')[0] | |
| return DATA_FOLDER + base_name + ".txt" | |
| def display_llm_output(): | |
| st.header("LLM") | |
| form = st.form('llm') | |
| llm_message = form.text_area('Message', value=st.session_state["llm_message"]) | |
| api_submitted = form.form_submit_button('Submit') | |
| if api_submitted: | |
| llm_response = get_llm_response( | |
| st.session_state["model_repo_id"], | |
| st.session_state["model_filename"], | |
| st.session_state["model_type"], | |
| st.session_state["gpu_layers"], | |
| llm_message) | |
| st.write(llm_response) | |
| st.write('Done displaying LLM response') | |
| def main(): | |
| print('Running Local LLM PoC Streamlit app...') | |
| session_inactive_info = st.empty() | |
| if "session_started" not in st.session_state or not st.session_state["session_started"]: | |
| init_session_state() | |
| display_streamlit_sidebar() | |
| else: | |
| display_streamlit_sidebar() | |
| session_inactive_info.empty() | |
| display_llm_output() | |
| display_session_overview() | |
| if __name__ == '__main__': | |
| main() | |