import streamlit as st
import numpy as np
import pandas as pd
from collections import defaultdict
from functools import reduce
from itertools import chain
import ast
LABEL_ENTAILS = "label_is_entails"
LABEL_CONTRA = "label_is_contradict"
LABELS = [
LABEL_ENTAILS,
LABEL_CONTRA,
]
N_QUESTIONS = 201
VERSIONS = (
"BARD",
"PALM-2",
"Mistral",
"Bing",
"GPT-3.5",
"GPT-3.5+ICL",
"GPT-3.5+RAG",
"GPT-3.5+ICL+RAG",
"GPT-4",
"GPT-4+ICL",
"GPT-4+RAG",
"GPT-4+ICL+RAG",
"MedAlpaca"
)
@st.cache_data
def load_data(version_name: str) -> (pd.DataFrame, pd.DataFrame):
path_dir = "./data/" + version_name
df = pd.read_csv(path_dir + "/dataset_statistics.csv", encoding="utf-8")
df.rename(
columns={
"precision_['Must_have', 'Nice_to_have']": "precision",
"recall_Must_have": "comprehensiveness",
},
inplace=True,
)
reasoning = pd.read_csv(path_dir + "/reasoning.csv", encoding="utf-8")
return df, reasoning
@st.cache_data
def load_params(version_name: str) -> dict:
path_dir = "./data/" + version_name
try:
params = pd.read_json(path_dir + "/params.json", orient="records")
return params
except FileNotFoundError:
return {}
@st.cache_data
def load_prompt(version_name: str) -> str:
path_dir = "./data/" + version_name
prompt = list(
chain.from_iterable(
pd.read_csv(path_dir + "/prompt.txt", sep="|", header=None).values
)
)
prompt = "\n \n".join(prompt)
return prompt
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
df["result"].replace(to_replace="
", value=" ", regex=True, inplace=True)
for col in ["misses", "contra"]:
df[col].replace(to_replace="
", value=" * \n", regex=True, inplace=True)
df = df.round(3)
return df
def on_more_click(show_more, idx):
show_more[idx] = True
def on_less_click(show_more, idx):
show_more[idx] = False
def main():
st.set_page_config(
layout="wide", page_title="AskMed Evaluation", initial_sidebar_state="expanded"
)
st.write(""" # K-QA Evaluation""")
current_versions = st.sidebar.multiselect("select_version", VERSIONS)
if len(current_versions) == 1:
current_version = current_versions[0]
df, reasoning = load_data(version_name=current_version)
df = preprocess_data(df)
params = load_params(version_name=current_version)
count_question = params["count_question"].iloc[0]
try:
prompt = load_prompt(version_name=current_version)
except FileNotFoundError:
prompt = ""
col1, col2, col3 = st.columns(3)
comp = 100 * df.comprehensiveness.sum() / N_QUESTIONS
col1.metric("Comprehensiveness", f"{comp:.2f}")
hall = 100 * sum(reasoning.label_is_contradict) / N_QUESTIONS
col2.metric("Hallucination rate", f"{hall:.3f}")
col3.metric(
"Ratio-answered",
f"{len(df)/count_question:.2f} ({len(df)}/{count_question})",
)
# display prompt
st.caption("**prompt:**")
st.code(f"{prompt}")
if f"show_more_{current_version}" not in st.session_state:
st.session_state[f"show_more_{current_version}"] = dict.fromkeys(
np.arange(len(df)), False
)
show_more = st.session_state[f"show_more_{current_version}"]
# order of rows
order_by = st.radio(
"**How would you like to order the rows?**",
("Question", "comprehensiveness", "precision"),
horizontal=True,
)
st.markdown("----")
df = df.sort_values(by=[order_by])
fields = [
"Question",
"result",
"comprehensiveness",
"precision",
"misses",
"contra",
"",
]
cols_width = [3, 7, 1, 1, 2.3, 2.3, 1.1]
cols_header = st.columns(cols_width)
# header
for col, field in zip(cols_header, fields):
col.write("**" + field + "**")
# # rows
for index, row in df.iterrows():
cols = st.columns(cols_width)
for ind, field in enumerate(fields[:-1]):
cols[ind].caption(row[field])
placeholder = cols[-1].empty()
if show_more[index]:
placeholder.button(
"less",
key=str(index) + "_",
on_click=on_less_click,
args=[show_more, index],
)
question = row["Question"]
st.write(reasoning.loc[reasoning.question == question, :].iloc[:, 3:])
else:
placeholder.button(
"more",
key=index,
on_click=on_more_click,
args=[show_more, index],
)
elif len(current_versions) > 1:
res_dict, metrics = defaultdict(dict), defaultdict(dict)
for current_version in current_versions:
df, reasoning = load_data(version_name=current_version)
res_dict["df"][current_version] = preprocess_data(df)
res_dict["reasoning"][current_version] = reasoning.loc[
:, ["question", "statement", "answer"] + LABELS
]
metrics["N answers"][current_version] = len(df)
if N_QUESTIONS > len(df):
n_rows_to_add = N_QUESTIONS - len(df)
rows = pd.DataFrame(
[{"comprehensiveness": 0, "precision": 1}] * n_rows_to_add
)
df_w_empty_rows = pd.concat(
[res_dict["df"][current_version], rows], ignore_index=True
)
else:
df_w_empty_rows = res_dict["df"][current_version].copy()
metrics["comprehensiveness"][
current_version
] = df_w_empty_rows.comprehensiveness.mean().round(3)
metrics["hallucination_rate"][current_version] = (
100 * reasoning.label_is_contradict.sum() / N_QUESTIONS
)
metrics["words"][current_version] = round(
reasoning.answer.apply(lambda x: len(x.split(" "))).mean(), 2
)
data_frames = [res_dict["reasoning"][ver] for ver in current_versions]
for i, ver in enumerate(current_versions):
data_frames[i].rename(
columns={
"answer": f"answer+{ver}",
LABEL_ENTAILS: f"{LABEL_ENTAILS}+{ver}",
LABEL_CONTRA: f"{LABEL_CONTRA}+{ver}",
},
inplace=True,
)
df_reasoning = reduce(
lambda left, right: pd.merge(
left, right, on=["question", "statement"], how="inner"
),
data_frames,
)
st.write("#### Metrics")
st.write(pd.DataFrame(metrics))
# Order samples
st.write("**" + "Display samples in an ascending order" + "**")
c1, c2, c3 = st.columns(3)
version = c1.selectbox("version", current_versions)
metric = c2.selectbox("metric", ("precision", "comprehensiveness"))
n_samples = c3.selectbox("n_samples", (5, 10, 20))
questions = (
res_dict["df"][version]
.sort_values(by=metric, inplace=False)["Question"]
.to_list()[:n_samples]
)
for question in questions:
q_a_dict = {"question": question}
scores = {"question": "."}
for i, version_name in enumerate(current_versions):
cond = res_dict["df"][version_name].Question == question
if cond.sum() == 0:
q_a_dict.update({f"answer_{version_name}": "NOT ANSWERED"})
scores[f"{version_name}"] = f"NONE"
else:
q_a_dict[f"answer_{version_name}"] = (
res_dict["df"][version_name].loc[cond, "result"].values[0]
)
precision = round(
res_dict["df"][version_name].loc[cond, "precision"].values[0], 3
)
recall = round(
res_dict["df"][version_name]
.loc[cond, "comprehensiveness"]
.values[0],
3,
)
# Style
misses = ast.literal_eval(
res_dict["df"][version_name].loc[cond, "misses"].values[0]
)
if len(misses) > 0:
misses[0] = f"- {misses[0]}"
misses = "\n -".join(misses)
contra = ast.literal_eval(
res_dict["df"][version_name].loc[cond, "contra"].values[0]
)
if len(contra) > 0:
contra[0] = f"- {contra[0]}"
contra = "\n -".join(contra)
scores[f"{version_name}"] = (
f"**Pr**: {precision}, "
f"**Re**: {recall}\n \n"
f"**misses**: \n {misses} \n \n "
f"**contra**: \n {contra}"
)
cols = st.columns(
[
int(len(q_a_dict.keys()) * 0.5)
if ind == 0
else len(q_a_dict.keys())
for ind, i in enumerate(q_a_dict.keys())
]
)
for col, field in zip(cols, q_a_dict.keys()):
col.write("**" + field + "**")
for ind, field, score in zip(range(len(q_a_dict.keys())), q_a_dict, scores):
cols[ind].caption(q_a_dict[field])
cols[ind].caption(scores[score])
st.markdown("""---""")
if __name__ == "__main__":
main()