wadood's picture
init working commit
44921ac
raw
history blame
No virus
4.87 kB
import pandas as pd
import streamlit as st
from annotated_text import annotated_text
from annotated_text.util import get_annotated_html
from streamlit_annotation_tools import text_labeler
from evaluation_metrics import EVALUATION_METRICS, get_evaluation_metric
from predefined_example import EXAMPLES
from span_dataclass_converters import (
get_highlight_spans_from_ner_spans,
get_ner_spans_from_annotations,
)
@st.cache_resource
def get_examples_attributes(selected_example):
"Return example attributes so that they are not refreshed on every interaction"
return (
selected_example.text,
selected_example.gt_labels,
selected_example.gt_spans,
selected_example.predictions,
)
if __name__ == "__main__":
st.set_page_config(layout="wide")
st.title("NER Evaluation Metrics Comparison")
st.write(
"Evaluation for the NER task requires a ground truth and a prediction that will be evaluated. The ground truth is shown below, add predictions in the next section to compare the evaluation metrics."
)
# with st.container():
st.subheader("Ground Truth") # , divider='rainbow')
selected_example = st.selectbox(
"Select an example text from the drop down below",
[example for example in EXAMPLES],
format_func=lambda ex: ex.text,
)
text, gt_labels, gt_spans, predictions = get_examples_attributes(selected_example)
annotated_text(
get_highlight_spans_from_ner_spans(
get_ner_spans_from_annotations(gt_labels), text
)
)
annotated_predictions = [
get_annotated_html(get_highlight_spans_from_ner_spans(ner_span, text))
for ner_span in predictions
]
predictions_df = pd.DataFrame(
{
# "ID": [f"Prediction_{index}" for index in range(len(predictions))],
"Prediction": annotated_predictions,
"ner_spans": predictions,
},
index=[f"Prediction_{index}" for index in range(len(predictions))],
)
st.subheader("Predictions") # , divider='rainbow')
with st.expander("Click to Add Predictions"):
st.subheader("Adding predictions")
st.markdown(
"""
Add predictions to the list of predictions on which the evaluation metric will be caculated.
- Select the entity type/label name and then highlight the span in the text below.
- To remove a span, double click on the higlighted text.
- Once you have your desired prediction, click on the 'Add' button.(The prediction created is shown in a json below)
"""
)
st.write(
"Note: Only the spans of the selected label name is shown at a given instance.",
)
labels = text_labeler(text, gt_labels)
st.json(labels, expanded=False)
# if st.button("Add Prediction"):
# labels = text_labeler(text)
if st.button("Add!"):
spans = get_ner_spans_from_annotations(labels)
spans = sorted(spans, key=lambda span: span["start"])
predictions.append(spans)
annotated_predictions.append(
get_annotated_html(get_highlight_spans_from_ner_spans(spans, text))
)
predictions_df = pd.DataFrame(
{
# "ID": [f"Prediction_{index}" for index in range(len(predictions))],
"Prediction": annotated_predictions,
"ner_spans": predictions,
},
index=[f"Prediction_{index}" for index in range(len(predictions))],
)
print("added")
highlighted_predictions_df = predictions_df[["Prediction"]]
st.write(highlighted_predictions_df.to_html(escape=False), unsafe_allow_html=True)
st.divider()
### EVALUATION METRICS COMPARISION ###
st.subheader("Evaluation Metrics Comparision") # , divider='rainbow')
st.markdown("""
The different evaluation metrics we have for the NER task are
- Span Based Evaluation with Partial Overlap
- Token Based Evaluation with Micro Avg
- Token Based Evaluation with Macro Avg
""")
with st.expander("View Predictions Details"):
st.write(predictions_df.to_html(escape=False), unsafe_allow_html=True)
if st.button("Get Metrics!"):
for evaluation_metric_type in EVALUATION_METRICS:
predictions_df[evaluation_metric_type] = predictions_df.ner_spans.apply(
lambda ner_spans: get_evaluation_metric(
metric_type=evaluation_metric_type,
gt_ner_span=gt_spans,
pred_ner_span=ner_spans,
text=text,
)
)
metrics_df = predictions_df.drop(["ner_spans"], axis=1)
st.write(metrics_df.to_html(escape=False), unsafe_allow_html=True)
print("compared")