Spaces:
Paused
Paused
import re | |
import uuid | |
import pandas as pd | |
import streamlit as st | |
import re | |
import matplotlib.pyplot as plt | |
import subprocess | |
import sys | |
import io | |
import inspect | |
from utils.default_values import get_system_prompt, get_guidelines_dict | |
from utils.epfl_meditron_utils import get_llm_response, gptq_model_options | |
from utils.openai_utils import get_available_engines, get_search_query_type_options | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
from sklearn.metrics import classification_report | |
POC_VERSION = "0.1.1" | |
st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') | |
def display_streamlit_sidebar(): | |
st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) | |
st.sidebar.write('**Parameters**') | |
form = st.sidebar.form("config_form", clear_on_submit=True) | |
model_name_or_path = form.selectbox("Select model", gptq_model_options(), index=st.session_state["model_index"]) | |
model_name_or_path_other = form.text_input('Or input any GPTQ model', value=st.session_state["model_name_or_path_other"]) | |
temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"]) | |
do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"]) | |
top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"]) | |
top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"]) | |
max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=4096, step=1, value=st.session_state["max_new_tokens"]) | |
repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"]) | |
submitted = form.form_submit_button("Start session") | |
if submitted: | |
print('Parameters updated...') | |
st.session_state['session_started'] = True | |
st.session_state["session_events"] = [] | |
if len(model_name_or_path_other) > 0: | |
st.session_state["model_name"] = model_name_or_path_other | |
st.session_state["model_name_or_path_other"] = model_name_or_path_other | |
else: | |
st.session_state["model_name"] = model_name_or_path | |
st.session_state["model_index"] = gptq_model_options().index(model_name_or_path) | |
st.session_state["model_name_or_path"] = model_name_or_path | |
st.session_state["temperature"] = temperature | |
st.session_state["do_sample"] = do_sample | |
st.session_state["top_p"] = top_p | |
st.session_state["top_k"] = top_k | |
st.session_state["max_new_tokens"] = max_new_tokens | |
st.session_state["repetition_penalty"] = repetition_penalty | |
st.rerun() | |
def init_session_state(): | |
print('init_session_state()') | |
st.session_state['session_started'] = False | |
st.session_state["session_events"] = [] | |
st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ" | |
st.session_state["model_name_or_path_other"] = "" | |
st.session_state["model_index"] = 0 | |
st.session_state["temperature"] = 0.01 | |
st.session_state["do_sample"] = True | |
st.session_state["top_p"] = 0.95 | |
st.session_state["top_k"] = 40 | |
st.session_state["max_new_tokens"] = 4096 | |
st.session_state["repetition_penalty"] = 1.1 | |
st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience" | |
st.session_state["prompt"] = "" | |
st.session_state["llm_messages"] = [] | |
def display_session_overview(): | |
st.subheader('History of LLM queries') | |
st.write(st.session_state["llm_messages"]) | |
st.subheader("Session costs overview") | |
df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) | |
st.write(df_session_overview) | |
if "prompt_tokens" in df_session_overview: | |
prompt_tokens = df_session_overview["prompt_tokens"].sum() | |
st.write("Prompt tokens: " + str(prompt_tokens)) | |
prompt_cost = df_session_overview["prompt_cost_chf"].sum() | |
st.write("Prompt CHF: " + str(prompt_cost)) | |
completion_tokens = df_session_overview["completion_tokens"].sum() | |
st.write("Completion tokens: " + str(completion_tokens)) | |
completion_cost = df_session_overview["completion_cost_chf"].sum() | |
st.write("Completion CHF: " + str(completion_cost)) | |
completion_cost = df_session_overview["total_cost_chf"].sum() | |
st.write("Total costs CHF: " + str(completion_cost)) | |
total_time = df_session_overview["response_time"].sum() | |
st.write("Total compute time (ms): " + str(total_time)) | |
def get_prompt_format(model_name): | |
formatted_text = "" | |
if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ": | |
formatted_text = '''[INST] <<SYS>> | |
{system_message} | |
<</SYS>> | |
{prompt}[/INST] | |
''' | |
if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ": | |
formatted_text = '''<|im_start|>system | |
{system_message}<|im_end|> | |
<|im_start|>user | |
{prompt}<|im_end|> | |
<|im_start|>assistant | |
''' | |
return inspect.cleandoc(formatted_text) | |
def format_prompt(template, system_message, prompt): | |
if template == "": | |
return f"{system_message} {prompt}" | |
return template.format(system_message=system_message, prompt=prompt) | |
def display_llm_output(): | |
st.header("LLM") | |
form = st.form('llm') | |
prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"]) | |
prompt_format = form.text_area('Prompt format', value=prompt_format_str, height=170) | |
system_prompt = form.text_area('System message', value=st.session_state["system_prompt"], height=170) | |
prompt = form.text_area('Prompt', value=st.session_state["prompt"], height=170) | |
submitted = form.form_submit_button('Submit') | |
if submitted: | |
st.session_state["system_prompt"] = system_prompt | |
st.session_state["prompt"] = prompt | |
formatted_prompt = format_prompt(prompt_format, system_prompt, prompt) | |
print(f"Formatted prompt: {format_prompt}") | |
llm_response = get_llm_response( | |
st.session_state["model_name"], | |
st.session_state["temperature"], | |
st.session_state["do_sample"], | |
st.session_state["top_p"], | |
st.session_state["top_k"], | |
st.session_state["max_new_tokens"], | |
st.session_state["repetition_penalty"], | |
formatted_prompt) | |
st.write(llm_response) | |
def main(): | |
print('Running Local LLM PoC Streamlit app...') | |
session_inactive_info = st.empty() | |
if "session_started" not in st.session_state or not st.session_state["session_started"]: | |
init_session_state() | |
display_streamlit_sidebar() | |
else: | |
display_streamlit_sidebar() | |
session_inactive_info.empty() | |
display_llm_output() | |
display_session_overview() | |
if __name__ == '__main__': | |
main() | |