K-QA / app.py
Itaykhealth's picture
Add medAlpaca and change to ICL
ea9ac36
import streamlit as st
import numpy as np
import pandas as pd
from collections import defaultdict
from functools import reduce
from itertools import chain
import ast
LABEL_ENTAILS = "label_is_entails"
LABEL_CONTRA = "label_is_contradict"
LABELS = [
LABEL_ENTAILS,
LABEL_CONTRA,
]
N_QUESTIONS = 201
VERSIONS = (
"BARD",
"PALM-2",
"Mistral",
"Bing",
"GPT-3.5",
"GPT-3.5+ICL",
"GPT-3.5+RAG",
"GPT-3.5+ICL+RAG",
"GPT-4",
"GPT-4+ICL",
"GPT-4+RAG",
"GPT-4+ICL+RAG",
"MedAlpaca"
)
@st.cache_data
def load_data(version_name: str) -> (pd.DataFrame, pd.DataFrame):
path_dir = "./data/" + version_name
df = pd.read_csv(path_dir + "/dataset_statistics.csv", encoding="utf-8")
df.rename(
columns={
"precision_['Must_have', 'Nice_to_have']": "precision",
"recall_Must_have": "comprehensiveness",
},
inplace=True,
)
reasoning = pd.read_csv(path_dir + "/reasoning.csv", encoding="utf-8")
return df, reasoning
@st.cache_data
def load_params(version_name: str) -> dict:
path_dir = "./data/" + version_name
try:
params = pd.read_json(path_dir + "/params.json", orient="records")
return params
except FileNotFoundError:
return {}
@st.cache_data
def load_prompt(version_name: str) -> str:
path_dir = "./data/" + version_name
prompt = list(
chain.from_iterable(
pd.read_csv(path_dir + "/prompt.txt", sep="|", header=None).values
)
)
prompt = "\n \n".join(prompt)
return prompt
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
df["result"].replace(to_replace="<br>", value=" ", regex=True, inplace=True)
for col in ["misses", "contra"]:
df[col].replace(to_replace="<br>", value=" * \n", regex=True, inplace=True)
df = df.round(3)
return df
def on_more_click(show_more, idx):
show_more[idx] = True
def on_less_click(show_more, idx):
show_more[idx] = False
def main():
st.set_page_config(
layout="wide", page_title="AskMed Evaluation", initial_sidebar_state="expanded"
)
st.write(""" # K-QA Evaluation""")
current_versions = st.sidebar.multiselect("select_version", VERSIONS)
if len(current_versions) == 1:
current_version = current_versions[0]
df, reasoning = load_data(version_name=current_version)
df = preprocess_data(df)
params = load_params(version_name=current_version)
count_question = params["count_question"].iloc[0]
try:
prompt = load_prompt(version_name=current_version)
except FileNotFoundError:
prompt = ""
col1, col2, col3 = st.columns(3)
comp = 100 * df.comprehensiveness.sum() / N_QUESTIONS
col1.metric("Comprehensiveness", f"{comp:.2f}")
hall = 100 * sum(reasoning.label_is_contradict) / N_QUESTIONS
col2.metric("Hallucination rate", f"{hall:.3f}")
col3.metric(
"Ratio-answered",
f"{len(df)/count_question:.2f} ({len(df)}/{count_question})",
)
# display prompt
st.caption("**prompt:**")
st.code(f"{prompt}")
if f"show_more_{current_version}" not in st.session_state:
st.session_state[f"show_more_{current_version}"] = dict.fromkeys(
np.arange(len(df)), False
)
show_more = st.session_state[f"show_more_{current_version}"]
# order of rows
order_by = st.radio(
"**How would you like to order the rows?**",
("Question", "comprehensiveness", "precision"),
horizontal=True,
)
st.markdown("----")
df = df.sort_values(by=[order_by])
fields = [
"Question",
"result",
"comprehensiveness",
"precision",
"misses",
"contra",
"",
]
cols_width = [3, 7, 1, 1, 2.3, 2.3, 1.1]
cols_header = st.columns(cols_width)
# header
for col, field in zip(cols_header, fields):
col.write("**" + field + "**")
# # rows
for index, row in df.iterrows():
cols = st.columns(cols_width)
for ind, field in enumerate(fields[:-1]):
cols[ind].caption(row[field])
placeholder = cols[-1].empty()
if show_more[index]:
placeholder.button(
"less",
key=str(index) + "_",
on_click=on_less_click,
args=[show_more, index],
)
question = row["Question"]
st.write(reasoning.loc[reasoning.question == question, :].iloc[:, 3:])
else:
placeholder.button(
"more",
key=index,
on_click=on_more_click,
args=[show_more, index],
)
elif len(current_versions) > 1:
res_dict, metrics = defaultdict(dict), defaultdict(dict)
for current_version in current_versions:
df, reasoning = load_data(version_name=current_version)
res_dict["df"][current_version] = preprocess_data(df)
res_dict["reasoning"][current_version] = reasoning.loc[
:, ["question", "statement", "answer"] + LABELS
]
metrics["N answers"][current_version] = len(df)
if N_QUESTIONS > len(df):
n_rows_to_add = N_QUESTIONS - len(df)
rows = pd.DataFrame(
[{"comprehensiveness": 0, "precision": 1}] * n_rows_to_add
)
df_w_empty_rows = pd.concat(
[res_dict["df"][current_version], rows], ignore_index=True
)
else:
df_w_empty_rows = res_dict["df"][current_version].copy()
metrics["comprehensiveness"][
current_version
] = df_w_empty_rows.comprehensiveness.mean().round(3)
metrics["hallucination_rate"][current_version] = (
100 * reasoning.label_is_contradict.sum() / N_QUESTIONS
)
metrics["words"][current_version] = round(
reasoning.answer.apply(lambda x: len(x.split(" "))).mean(), 2
)
data_frames = [res_dict["reasoning"][ver] for ver in current_versions]
for i, ver in enumerate(current_versions):
data_frames[i].rename(
columns={
"answer": f"answer+{ver}",
LABEL_ENTAILS: f"{LABEL_ENTAILS}+{ver}",
LABEL_CONTRA: f"{LABEL_CONTRA}+{ver}",
},
inplace=True,
)
df_reasoning = reduce(
lambda left, right: pd.merge(
left, right, on=["question", "statement"], how="inner"
),
data_frames,
)
st.write("#### Metrics")
st.write(pd.DataFrame(metrics))
# Order samples
st.write("**" + "Display samples in an ascending order" + "**")
c1, c2, c3 = st.columns(3)
version = c1.selectbox("version", current_versions)
metric = c2.selectbox("metric", ("precision", "comprehensiveness"))
n_samples = c3.selectbox("n_samples", (5, 10, 20))
questions = (
res_dict["df"][version]
.sort_values(by=metric, inplace=False)["Question"]
.to_list()[:n_samples]
)
for question in questions:
q_a_dict = {"question": question}
scores = {"question": "."}
for i, version_name in enumerate(current_versions):
cond = res_dict["df"][version_name].Question == question
if cond.sum() == 0:
q_a_dict.update({f"answer_{version_name}": "NOT ANSWERED"})
scores[f"{version_name}"] = f"NONE"
else:
q_a_dict[f"answer_{version_name}"] = (
res_dict["df"][version_name].loc[cond, "result"].values[0]
)
precision = round(
res_dict["df"][version_name].loc[cond, "precision"].values[0], 3
)
recall = round(
res_dict["df"][version_name]
.loc[cond, "comprehensiveness"]
.values[0],
3,
)
# Style
misses = ast.literal_eval(
res_dict["df"][version_name].loc[cond, "misses"].values[0]
)
if len(misses) > 0:
misses[0] = f"- {misses[0]}"
misses = "\n -".join(misses)
contra = ast.literal_eval(
res_dict["df"][version_name].loc[cond, "contra"].values[0]
)
if len(contra) > 0:
contra[0] = f"- {contra[0]}"
contra = "\n -".join(contra)
scores[f"{version_name}"] = (
f"**Pr**: {precision}, "
f"**Re**: {recall}\n \n"
f"**misses**: \n {misses} \n \n "
f"**contra**: \n {contra}"
)
cols = st.columns(
[
int(len(q_a_dict.keys()) * 0.5)
if ind == 0
else len(q_a_dict.keys())
for ind, i in enumerate(q_a_dict.keys())
]
)
for col, field in zip(cols, q_a_dict.keys()):
col.write("**" + field + "**")
for ind, field, score in zip(range(len(q_a_dict.keys())), q_a_dict, scores):
cols[ind].caption(q_a_dict[field])
cols[ind].caption(scores[score])
st.markdown("""---""")
if __name__ == "__main__":
main()