import re import uuid import pandas as pd import streamlit as st import re import matplotlib.pyplot as plt import subprocess import sys import io import inspect from utils.default_values import get_system_prompt, get_guidelines_dict from utils.epfl_meditron_utils import get_llm_response, gptq_model_options from utils.openai_utils import get_available_engines, get_search_query_type_options from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from sklearn.metrics import classification_report POC_VERSION = "0.1.1" st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') def display_streamlit_sidebar(): st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) st.sidebar.write('**Parameters**') form = st.sidebar.form("config_form", clear_on_submit=True) model_name_or_path = form.selectbox("Select model", gptq_model_options(), index=st.session_state["model_index"]) model_name_or_path_other = form.text_input('Or input any GPTQ model', value=st.session_state["model_name_or_path_other"]) temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"]) do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"]) top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"]) top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"]) max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=4096, step=1, value=st.session_state["max_new_tokens"]) repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"]) submitted = form.form_submit_button("Start session") if submitted: print('Parameters updated...') st.session_state['session_started'] = True st.session_state["session_events"] = [] if len(model_name_or_path_other) > 0: st.session_state["model_name"] = model_name_or_path_other st.session_state["model_name_or_path_other"] = model_name_or_path_other else: st.session_state["model_name"] = model_name_or_path st.session_state["model_index"] = gptq_model_options().index(model_name_or_path) st.session_state["model_name_or_path"] = model_name_or_path st.session_state["temperature"] = temperature st.session_state["do_sample"] = do_sample st.session_state["top_p"] = top_p st.session_state["top_k"] = top_k st.session_state["max_new_tokens"] = max_new_tokens st.session_state["repetition_penalty"] = repetition_penalty st.rerun() def init_session_state(): print('init_session_state()') st.session_state['session_started'] = False st.session_state["session_events"] = [] st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ" st.session_state["model_name_or_path_other"] = "" st.session_state["model_index"] = 0 st.session_state["temperature"] = 0.01 st.session_state["do_sample"] = True st.session_state["top_p"] = 0.95 st.session_state["top_k"] = 40 st.session_state["max_new_tokens"] = 4096 st.session_state["repetition_penalty"] = 1.1 st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience" st.session_state["prompt"] = "" st.session_state["llm_messages"] = [] def display_session_overview(): st.subheader('History of LLM queries') st.write(st.session_state["llm_messages"]) st.subheader("Session costs overview") df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) st.write(df_session_overview) if "prompt_tokens" in df_session_overview: prompt_tokens = df_session_overview["prompt_tokens"].sum() st.write("Prompt tokens: " + str(prompt_tokens)) prompt_cost = df_session_overview["prompt_cost_chf"].sum() st.write("Prompt CHF: " + str(prompt_cost)) completion_tokens = df_session_overview["completion_tokens"].sum() st.write("Completion tokens: " + str(completion_tokens)) completion_cost = df_session_overview["completion_cost_chf"].sum() st.write("Completion CHF: " + str(completion_cost)) completion_cost = df_session_overview["total_cost_chf"].sum() st.write("Total costs CHF: " + str(completion_cost)) total_time = df_session_overview["response_time"].sum() st.write("Total compute time (ms): " + str(total_time)) def get_prompt_format(model_name): formatted_text = "" if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ": formatted_text = '''[INST] <> {system_message} <> {prompt}[/INST] ''' if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ": formatted_text = '''<|im_start|>system {system_message}<|im_end|> <|im_start|>user {prompt}<|im_end|> <|im_start|>assistant ''' return inspect.cleandoc(formatted_text) def format_prompt(template, system_message, prompt): if template == "": return f"{system_message} {prompt}" return template.format(system_message=system_message, prompt=prompt) def display_llm_output(): st.header("LLM") form = st.form('llm') prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"]) prompt_format = form.text_area('Prompt format', value=prompt_format_str, height=170) system_prompt = form.text_area('System message', value=st.session_state["system_prompt"], height=170) prompt = form.text_area('Prompt', value=st.session_state["prompt"], height=170) submitted = form.form_submit_button('Submit') if submitted: st.session_state["system_prompt"] = system_prompt st.session_state["prompt"] = prompt formatted_prompt = format_prompt(prompt_format, system_prompt, prompt) print(f"Formatted prompt: {format_prompt}") llm_response = get_llm_response( st.session_state["model_name"], st.session_state["temperature"], st.session_state["do_sample"], st.session_state["top_p"], st.session_state["top_k"], st.session_state["max_new_tokens"], st.session_state["repetition_penalty"], formatted_prompt) st.write(llm_response) def main(): print('Running Local LLM PoC Streamlit app...') session_inactive_info = st.empty() if "session_started" not in st.session_state or not st.session_state["session_started"]: init_session_state() display_streamlit_sidebar() else: display_streamlit_sidebar() session_inactive_info.empty() display_llm_output() display_session_overview() if __name__ == '__main__': main()