ExplaiNER / src /subpages /misclassified.py
Alexander Seifert
add stuff for vis2
c8d36ae
raw history blame
No virus
2.94 kB
"""This page contains all misclassified examples and allows filtering by specific error types."""
from collections import defaultdict
import pandas as pd
import streamlit as st
from sklearn.metrics import confusion_matrix
from src.subpages.page import Context, Page
from src.utils import htmlify_labeled_example
class MisclassifiedPage(Page):
name = "Misclassified"
icon = "x-octagon"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write(
"This page contains all misclassified examples and allows filtering by specific error types."
)
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
cm = confusion_matrix(
misclassified_samples.labels,
misclassified_samples.preds,
labels=context.labels,
)
# st.pyplot(
# plot_confusion_matrix(
# y_preds=misclassified_samples["preds"],
# y_true=misclassified_samples["labels"],
# labels=labels,
# normalize=None,
# zero_diagonal=True,
# ),
# )
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
import numpy as np
np.fill_diagonal(df.values, "")
st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
# import matplotlib.pyplot as plt
# st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
# selection = aggrid_interactive_table(df)
# st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
confusions = defaultdict(int)
for i, row in enumerate(cm):
for j, _ in enumerate(row):
if i == j or cm[i][j] == 0:
continue
confusions[(context.labels[i], context.labels[j])] += cm[i][j]
def format_func(item):
return (
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
)
conf = st.radio(
"Filter by Class Confusion",
options=list(zip(confusions.keys(), confusions.values())),
format_func=format_func,
)
# st.write(
# f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
# )
filtered_indices = misclassified_samples.query(
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
).index
for i, idx in enumerate(filtered_indices):
sample = context.df_tokens_merged.loc[idx]
st.write(
htmlify_labeled_example(sample),
unsafe_allow_html=True,
)
st.write("---")