Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from collections import defaultdict | |
from functools import reduce | |
from itertools import chain | |
import ast | |
LABEL_ENTAILS = "label_is_entails" | |
LABEL_CONTRA = "label_is_contradict" | |
LABELS = [ | |
LABEL_ENTAILS, | |
LABEL_CONTRA, | |
] | |
N_QUESTIONS = 201 | |
VERSIONS = ( | |
"BARD", | |
"PALM-2", | |
"Mistral", | |
"Bing", | |
"GPT-3.5", | |
"GPT-3.5+ICL", | |
"GPT-3.5+RAG", | |
"GPT-3.5+ICL+RAG", | |
"GPT-4", | |
"GPT-4+ICL", | |
"GPT-4+RAG", | |
"GPT-4+ICL+RAG", | |
"MedAlpaca" | |
) | |
def load_data(version_name: str) -> (pd.DataFrame, pd.DataFrame): | |
path_dir = "./data/" + version_name | |
df = pd.read_csv(path_dir + "/dataset_statistics.csv", encoding="utf-8") | |
df.rename( | |
columns={ | |
"precision_['Must_have', 'Nice_to_have']": "precision", | |
"recall_Must_have": "comprehensiveness", | |
}, | |
inplace=True, | |
) | |
reasoning = pd.read_csv(path_dir + "/reasoning.csv", encoding="utf-8") | |
return df, reasoning | |
def load_params(version_name: str) -> dict: | |
path_dir = "./data/" + version_name | |
try: | |
params = pd.read_json(path_dir + "/params.json", orient="records") | |
return params | |
except FileNotFoundError: | |
return {} | |
def load_prompt(version_name: str) -> str: | |
path_dir = "./data/" + version_name | |
prompt = list( | |
chain.from_iterable( | |
pd.read_csv(path_dir + "/prompt.txt", sep="|", header=None).values | |
) | |
) | |
prompt = "\n \n".join(prompt) | |
return prompt | |
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame: | |
df["result"].replace(to_replace="<br>", value=" ", regex=True, inplace=True) | |
for col in ["misses", "contra"]: | |
df[col].replace(to_replace="<br>", value=" * \n", regex=True, inplace=True) | |
df = df.round(3) | |
return df | |
def on_more_click(show_more, idx): | |
show_more[idx] = True | |
def on_less_click(show_more, idx): | |
show_more[idx] = False | |
def main(): | |
st.set_page_config( | |
layout="wide", page_title="AskMed Evaluation", initial_sidebar_state="expanded" | |
) | |
st.write(""" # K-QA Evaluation""") | |
current_versions = st.sidebar.multiselect("select_version", VERSIONS) | |
if len(current_versions) == 1: | |
current_version = current_versions[0] | |
df, reasoning = load_data(version_name=current_version) | |
df = preprocess_data(df) | |
params = load_params(version_name=current_version) | |
count_question = params["count_question"].iloc[0] | |
try: | |
prompt = load_prompt(version_name=current_version) | |
except FileNotFoundError: | |
prompt = "" | |
col1, col2, col3 = st.columns(3) | |
comp = 100 * df.comprehensiveness.sum() / N_QUESTIONS | |
col1.metric("Comprehensiveness", f"{comp:.2f}") | |
hall = 100 * sum(reasoning.label_is_contradict) / N_QUESTIONS | |
col2.metric("Hallucination rate", f"{hall:.3f}") | |
col3.metric( | |
"Ratio-answered", | |
f"{len(df)/count_question:.2f} ({len(df)}/{count_question})", | |
) | |
# display prompt | |
st.caption("**prompt:**") | |
st.code(f"{prompt}") | |
if f"show_more_{current_version}" not in st.session_state: | |
st.session_state[f"show_more_{current_version}"] = dict.fromkeys( | |
np.arange(len(df)), False | |
) | |
show_more = st.session_state[f"show_more_{current_version}"] | |
# order of rows | |
order_by = st.radio( | |
"**How would you like to order the rows?**", | |
("Question", "comprehensiveness", "precision"), | |
horizontal=True, | |
) | |
st.markdown("----") | |
df = df.sort_values(by=[order_by]) | |
fields = [ | |
"Question", | |
"result", | |
"comprehensiveness", | |
"precision", | |
"misses", | |
"contra", | |
"", | |
] | |
cols_width = [3, 7, 1, 1, 2.3, 2.3, 1.1] | |
cols_header = st.columns(cols_width) | |
# header | |
for col, field in zip(cols_header, fields): | |
col.write("**" + field + "**") | |
# # rows | |
for index, row in df.iterrows(): | |
cols = st.columns(cols_width) | |
for ind, field in enumerate(fields[:-1]): | |
cols[ind].caption(row[field]) | |
placeholder = cols[-1].empty() | |
if show_more[index]: | |
placeholder.button( | |
"less", | |
key=str(index) + "_", | |
on_click=on_less_click, | |
args=[show_more, index], | |
) | |
question = row["Question"] | |
st.write(reasoning.loc[reasoning.question == question, :].iloc[:, 3:]) | |
else: | |
placeholder.button( | |
"more", | |
key=index, | |
on_click=on_more_click, | |
args=[show_more, index], | |
) | |
elif len(current_versions) > 1: | |
res_dict, metrics = defaultdict(dict), defaultdict(dict) | |
for current_version in current_versions: | |
df, reasoning = load_data(version_name=current_version) | |
res_dict["df"][current_version] = preprocess_data(df) | |
res_dict["reasoning"][current_version] = reasoning.loc[ | |
:, ["question", "statement", "answer"] + LABELS | |
] | |
metrics["N answers"][current_version] = len(df) | |
if N_QUESTIONS > len(df): | |
n_rows_to_add = N_QUESTIONS - len(df) | |
rows = pd.DataFrame( | |
[{"comprehensiveness": 0, "precision": 1}] * n_rows_to_add | |
) | |
df_w_empty_rows = pd.concat( | |
[res_dict["df"][current_version], rows], ignore_index=True | |
) | |
else: | |
df_w_empty_rows = res_dict["df"][current_version].copy() | |
metrics["comprehensiveness"][ | |
current_version | |
] = df_w_empty_rows.comprehensiveness.mean().round(3) | |
metrics["hallucination_rate"][current_version] = ( | |
100 * reasoning.label_is_contradict.sum() / N_QUESTIONS | |
) | |
metrics["words"][current_version] = round( | |
reasoning.answer.apply(lambda x: len(x.split(" "))).mean(), 2 | |
) | |
data_frames = [res_dict["reasoning"][ver] for ver in current_versions] | |
for i, ver in enumerate(current_versions): | |
data_frames[i].rename( | |
columns={ | |
"answer": f"answer+{ver}", | |
LABEL_ENTAILS: f"{LABEL_ENTAILS}+{ver}", | |
LABEL_CONTRA: f"{LABEL_CONTRA}+{ver}", | |
}, | |
inplace=True, | |
) | |
df_reasoning = reduce( | |
lambda left, right: pd.merge( | |
left, right, on=["question", "statement"], how="inner" | |
), | |
data_frames, | |
) | |
st.write("#### Metrics") | |
st.write(pd.DataFrame(metrics)) | |
# Order samples | |
st.write("**" + "Display samples in an ascending order" + "**") | |
c1, c2, c3 = st.columns(3) | |
version = c1.selectbox("version", current_versions) | |
metric = c2.selectbox("metric", ("precision", "comprehensiveness")) | |
n_samples = c3.selectbox("n_samples", (5, 10, 20)) | |
questions = ( | |
res_dict["df"][version] | |
.sort_values(by=metric, inplace=False)["Question"] | |
.to_list()[:n_samples] | |
) | |
for question in questions: | |
q_a_dict = {"question": question} | |
scores = {"question": "."} | |
for i, version_name in enumerate(current_versions): | |
cond = res_dict["df"][version_name].Question == question | |
if cond.sum() == 0: | |
q_a_dict.update({f"answer_{version_name}": "NOT ANSWERED"}) | |
scores[f"{version_name}"] = f"NONE" | |
else: | |
q_a_dict[f"answer_{version_name}"] = ( | |
res_dict["df"][version_name].loc[cond, "result"].values[0] | |
) | |
precision = round( | |
res_dict["df"][version_name].loc[cond, "precision"].values[0], 3 | |
) | |
recall = round( | |
res_dict["df"][version_name] | |
.loc[cond, "comprehensiveness"] | |
.values[0], | |
3, | |
) | |
# Style | |
misses = ast.literal_eval( | |
res_dict["df"][version_name].loc[cond, "misses"].values[0] | |
) | |
if len(misses) > 0: | |
misses[0] = f"- {misses[0]}" | |
misses = "\n -".join(misses) | |
contra = ast.literal_eval( | |
res_dict["df"][version_name].loc[cond, "contra"].values[0] | |
) | |
if len(contra) > 0: | |
contra[0] = f"- {contra[0]}" | |
contra = "\n -".join(contra) | |
scores[f"{version_name}"] = ( | |
f"**Pr**: {precision}, " | |
f"**Re**: {recall}\n \n" | |
f"**misses**: \n {misses} \n \n " | |
f"**contra**: \n {contra}" | |
) | |
cols = st.columns( | |
[ | |
int(len(q_a_dict.keys()) * 0.5) | |
if ind == 0 | |
else len(q_a_dict.keys()) | |
for ind, i in enumerate(q_a_dict.keys()) | |
] | |
) | |
for col, field in zip(cols, q_a_dict.keys()): | |
col.write("**" + field + "**") | |
for ind, field, score in zip(range(len(q_a_dict.keys())), q_a_dict, scores): | |
cols[ind].caption(q_a_dict[field]) | |
cols[ind].caption(scores[score]) | |
st.markdown("""---""") | |
if __name__ == "__main__": | |
main() | |