wadood's picture
added tabs
c50c9d2
raw
history blame
No virus
5.25 kB
import pandas as pd
import streamlit as st
# from annotated_text import annotated_text
from annotated_text.util import get_annotated_html
from streamlit_annotation_tools import text_labeler
from constants import PREDICTION_ADDITION_INSTRUCTION
from evaluation_metrics import EVALUATION_METRICS
from predefined_example import EXAMPLES
from span_dataclass_converters import (
get_highlight_spans_from_ner_spans,
get_ner_spans_from_annotations,
)
@st.cache_resource
def get_examples_attributes(selected_example):
"Return example attributes so that they are not refreshed on every interaction"
return (
selected_example.text,
selected_example.gt_labels,
selected_example.gt_spans,
selected_example.predictions,
selected_example.tags,
)
if __name__ == "__main__":
st.set_page_config(layout="wide")
st.title("πŸ“ˆ NER Metrics Comparison βš–οΈ")
st.write(
"Evaluation for the NER task requires a ground truth and a prediction that will be evaluated. The ground truth is shown below, add predictions in the next section to compare the evaluation metrics."
)
explanation_tab, comparision_tab = st.tabs(["πŸ“™ Explanation", "βš–οΈ Comparision"])
with explanation_tab:
st.write("This is the place holder for explanation of all the metrics")
with comparision_tab:
# with st.container():
st.subheader("Ground Truth & Predictions") # , divider='rainbow')
selected_example = st.selectbox(
"Select an example text from the drop down below",
[example for example in EXAMPLES],
format_func=lambda ex: ex.text,
)
text, gt_labels, gt_spans, predictions, tags = get_examples_attributes(
selected_example
)
# annotated_text(
# get_highlight_spans_from_ner_spans(
# get_ner_spans_from_annotations(gt_labels), text
# )
# )
annotated_predictions = [
get_annotated_html(get_highlight_spans_from_ner_spans(ner_span, text))
for ner_span in predictions
]
predictions_df = pd.DataFrame(
{
# "ID": [f"Prediction_{index}" for index in range(len(predictions))],
"Prediction": annotated_predictions,
"ner_spans": predictions,
},
index=["Ground Truth"]
+ [f"Prediction_{index}" for index in range(len(predictions) - 1)],
)
# st.subheader("Predictions") # , divider='rainbow')
with st.expander("Click to Add Predictions"):
st.subheader("Adding predictions")
st.markdown(PREDICTION_ADDITION_INSTRUCTION)
st.write(
"Note: Only the spans of the selected label name is shown at a given instance.",
)
labels = text_labeler(text, gt_labels)
st.json(labels, expanded=False)
# if st.button("Add Prediction"):
# labels = text_labeler(text)
if st.button("Add!"):
spans = get_ner_spans_from_annotations(labels)
spans = sorted(spans, key=lambda span: span["start"])
predictions.append(spans)
annotated_predictions.append(
get_annotated_html(get_highlight_spans_from_ner_spans(spans, text))
)
predictions_df = pd.DataFrame(
{
# "ID": [f"Prediction_{index}" for index in range(len(predictions))],
"Prediction": annotated_predictions,
"ner_spans": predictions,
},
index=["Ground Truth"]
+ [f"Prediction_{index}" for index in range(len(predictions) - 1)],
)
print("added")
highlighted_predictions_df = predictions_df[["Prediction"]]
st.write(
highlighted_predictions_df.to_html(escape=False), unsafe_allow_html=True
)
st.divider()
### EVALUATION METRICS COMPARISION ###
st.subheader("Evaluation Metrics Comparision") # , divider='rainbow')
st.markdown(
"The different evaluation metrics we have for the NER task are\n"
f"{''.join(['- '+evaluation_metric.name+'\n' for evaluation_metric in EVALUATION_METRICS])}"
)
with st.expander("View Predictions Details"):
st.write(predictions_df.to_html(escape=False), unsafe_allow_html=True)
if st.button("Get Metrics!"):
for evaluation_metric in EVALUATION_METRICS:
predictions_df[evaluation_metric.name] = predictions_df.ner_spans.apply(
lambda ner_spans: evaluation_metric.get_evaluation_metric(
# metric_type=evaluation_metric_type,
gt_ner_span=gt_spans,
pred_ner_span=ner_spans,
text=text,
tags=tags,
)
)
metrics_df = predictions_df.drop(["ner_spans"], axis=1)
st.write(metrics_df.to_html(escape=False), unsafe_allow_html=True)