import streamlit as st import numpy as np import pandas as pd from collections import defaultdict from functools import reduce from itertools import chain import ast LABEL_ENTAILS = "label_is_entails" LABEL_CONTRA = "label_is_contradict" LABELS = [ LABEL_ENTAILS, LABEL_CONTRA, ] N_QUESTIONS = 201 VERSIONS = ( "BARD", "PALM-2", "Mistral", "Bing", "GPT-3.5", "GPT-3.5+ICL", "GPT-3.5+RAG", "GPT-3.5+ICL+RAG", "GPT-4", "GPT-4+ICL", "GPT-4+RAG", "GPT-4+ICL+RAG", "MedAlpaca" ) @st.cache_data def load_data(version_name: str) -> (pd.DataFrame, pd.DataFrame): path_dir = "./data/" + version_name df = pd.read_csv(path_dir + "/dataset_statistics.csv", encoding="utf-8") df.rename( columns={ "precision_['Must_have', 'Nice_to_have']": "precision", "recall_Must_have": "comprehensiveness", }, inplace=True, ) reasoning = pd.read_csv(path_dir + "/reasoning.csv", encoding="utf-8") return df, reasoning @st.cache_data def load_params(version_name: str) -> dict: path_dir = "./data/" + version_name try: params = pd.read_json(path_dir + "/params.json", orient="records") return params except FileNotFoundError: return {} @st.cache_data def load_prompt(version_name: str) -> str: path_dir = "./data/" + version_name prompt = list( chain.from_iterable( pd.read_csv(path_dir + "/prompt.txt", sep="|", header=None).values ) ) prompt = "\n \n".join(prompt) return prompt def preprocess_data(df: pd.DataFrame) -> pd.DataFrame: df["result"].replace(to_replace="
", value=" ", regex=True, inplace=True) for col in ["misses", "contra"]: df[col].replace(to_replace="
", value=" * \n", regex=True, inplace=True) df = df.round(3) return df def on_more_click(show_more, idx): show_more[idx] = True def on_less_click(show_more, idx): show_more[idx] = False def main(): st.set_page_config( layout="wide", page_title="AskMed Evaluation", initial_sidebar_state="expanded" ) st.write(""" # K-QA Evaluation""") current_versions = st.sidebar.multiselect("select_version", VERSIONS) if len(current_versions) == 1: current_version = current_versions[0] df, reasoning = load_data(version_name=current_version) df = preprocess_data(df) params = load_params(version_name=current_version) count_question = params["count_question"].iloc[0] try: prompt = load_prompt(version_name=current_version) except FileNotFoundError: prompt = "" col1, col2, col3 = st.columns(3) comp = 100 * df.comprehensiveness.sum() / N_QUESTIONS col1.metric("Comprehensiveness", f"{comp:.2f}") hall = 100 * sum(reasoning.label_is_contradict) / N_QUESTIONS col2.metric("Hallucination rate", f"{hall:.3f}") col3.metric( "Ratio-answered", f"{len(df)/count_question:.2f} ({len(df)}/{count_question})", ) # display prompt st.caption("**prompt:**") st.code(f"{prompt}") if f"show_more_{current_version}" not in st.session_state: st.session_state[f"show_more_{current_version}"] = dict.fromkeys( np.arange(len(df)), False ) show_more = st.session_state[f"show_more_{current_version}"] # order of rows order_by = st.radio( "**How would you like to order the rows?**", ("Question", "comprehensiveness", "precision"), horizontal=True, ) st.markdown("----") df = df.sort_values(by=[order_by]) fields = [ "Question", "result", "comprehensiveness", "precision", "misses", "contra", "", ] cols_width = [3, 7, 1, 1, 2.3, 2.3, 1.1] cols_header = st.columns(cols_width) # header for col, field in zip(cols_header, fields): col.write("**" + field + "**") # # rows for index, row in df.iterrows(): cols = st.columns(cols_width) for ind, field in enumerate(fields[:-1]): cols[ind].caption(row[field]) placeholder = cols[-1].empty() if show_more[index]: placeholder.button( "less", key=str(index) + "_", on_click=on_less_click, args=[show_more, index], ) question = row["Question"] st.write(reasoning.loc[reasoning.question == question, :].iloc[:, 3:]) else: placeholder.button( "more", key=index, on_click=on_more_click, args=[show_more, index], ) elif len(current_versions) > 1: res_dict, metrics = defaultdict(dict), defaultdict(dict) for current_version in current_versions: df, reasoning = load_data(version_name=current_version) res_dict["df"][current_version] = preprocess_data(df) res_dict["reasoning"][current_version] = reasoning.loc[ :, ["question", "statement", "answer"] + LABELS ] metrics["N answers"][current_version] = len(df) if N_QUESTIONS > len(df): n_rows_to_add = N_QUESTIONS - len(df) rows = pd.DataFrame( [{"comprehensiveness": 0, "precision": 1}] * n_rows_to_add ) df_w_empty_rows = pd.concat( [res_dict["df"][current_version], rows], ignore_index=True ) else: df_w_empty_rows = res_dict["df"][current_version].copy() metrics["comprehensiveness"][ current_version ] = df_w_empty_rows.comprehensiveness.mean().round(3) metrics["hallucination_rate"][current_version] = ( 100 * reasoning.label_is_contradict.sum() / N_QUESTIONS ) metrics["words"][current_version] = round( reasoning.answer.apply(lambda x: len(x.split(" "))).mean(), 2 ) data_frames = [res_dict["reasoning"][ver] for ver in current_versions] for i, ver in enumerate(current_versions): data_frames[i].rename( columns={ "answer": f"answer+{ver}", LABEL_ENTAILS: f"{LABEL_ENTAILS}+{ver}", LABEL_CONTRA: f"{LABEL_CONTRA}+{ver}", }, inplace=True, ) df_reasoning = reduce( lambda left, right: pd.merge( left, right, on=["question", "statement"], how="inner" ), data_frames, ) st.write("#### Metrics") st.write(pd.DataFrame(metrics)) # Order samples st.write("**" + "Display samples in an ascending order" + "**") c1, c2, c3 = st.columns(3) version = c1.selectbox("version", current_versions) metric = c2.selectbox("metric", ("precision", "comprehensiveness")) n_samples = c3.selectbox("n_samples", (5, 10, 20)) questions = ( res_dict["df"][version] .sort_values(by=metric, inplace=False)["Question"] .to_list()[:n_samples] ) for question in questions: q_a_dict = {"question": question} scores = {"question": "."} for i, version_name in enumerate(current_versions): cond = res_dict["df"][version_name].Question == question if cond.sum() == 0: q_a_dict.update({f"answer_{version_name}": "NOT ANSWERED"}) scores[f"{version_name}"] = f"NONE" else: q_a_dict[f"answer_{version_name}"] = ( res_dict["df"][version_name].loc[cond, "result"].values[0] ) precision = round( res_dict["df"][version_name].loc[cond, "precision"].values[0], 3 ) recall = round( res_dict["df"][version_name] .loc[cond, "comprehensiveness"] .values[0], 3, ) # Style misses = ast.literal_eval( res_dict["df"][version_name].loc[cond, "misses"].values[0] ) if len(misses) > 0: misses[0] = f"- {misses[0]}" misses = "\n -".join(misses) contra = ast.literal_eval( res_dict["df"][version_name].loc[cond, "contra"].values[0] ) if len(contra) > 0: contra[0] = f"- {contra[0]}" contra = "\n -".join(contra) scores[f"{version_name}"] = ( f"**Pr**: {precision}, " f"**Re**: {recall}\n \n" f"**misses**: \n {misses} \n \n " f"**contra**: \n {contra}" ) cols = st.columns( [ int(len(q_a_dict.keys()) * 0.5) if ind == 0 else len(q_a_dict.keys()) for ind, i in enumerate(q_a_dict.keys()) ] ) for col, field in zip(cols, q_a_dict.keys()): col.write("**" + field + "**") for ind, field, score in zip(range(len(q_a_dict.keys())), q_a_dict, scores): cols[ind].caption(q_a_dict[field]) cols[ind].caption(scores[score]) st.markdown("""---""") if __name__ == "__main__": main()