ExplaiNER / src /subpages /metrics.py
Alexander Seifert
make get_widget_defaults private
f259527
raw history blame
No virus
3.97 kB
"""
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
"""
import re
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
from seqeval.metrics import classification_report
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from src.subpages.page import Context, Page
def _get_evaluation(df):
y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
y_pred = df.apply(
lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
axis=1,
)
report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore
return report.replace(
"precision recall f1-score support",
"=" * 12 + " precision recall f1-score support",
)
def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
if zero_diagonal:
np.fill_diagonal(cm, 0)
# st.write(plt.rcParams["font.size"])
# plt.rcParams.update({'font.size': 10.0})
fig, ax = plt.subplots(figsize=(10, 10))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
fmt = "d" if normalize is None else ".3f"
disp.plot(
cmap="Blues",
include_values=True,
xticks_rotation="vertical",
values_format=fmt,
ax=ax,
colorbar=False,
)
return fig
class MetricsPage(Page):
name = "Metrics"
icon = "graph-up-arrow"
def _get_widget_defaults(self):
return {
"normalize": True,
"zero_diagonal": False,
}
def render(self, context: Context):
st.title(self.name)
with st.expander("πŸ’‘", expanded=True):
st.write(
"The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
)
st.write(
"With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
)
eval_results = _get_evaluation(context.df)
if len(eval_results.splitlines()) < 8:
col1, _, col2 = st.columns([8, 1, 1])
else:
col1 = col2 = st
col1.subheader("🎯 Evaluation Results")
col1.code(eval_results)
results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
df = pd.DataFrame(data, columns="class support f1".split())
fig = px.scatter(
df,
x="support",
y="f1",
range_y=(0, 1.05),
color="class",
)
# fig.update_layout(title_text="asdf", title_yanchor="bottom")
col1.plotly_chart(fig)
col2.subheader("πŸ”  Confusion Matrix")
normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
col2.pyplot(
plot_confusion_matrix(
y_true=context.df_tokens_cleaned["labels"],
y_preds=context.df_tokens_cleaned["preds"],
labels=context.labels,
normalize=normalize,
zero_diagonal=zero_diagonal,
),
)