ExplaiNER / subpages /misclassified.py
Alexander Seifert
initial commit
597bf7d
raw
history blame
No virus
2.84 kB
from collections import defaultdict
import pandas as pd
import streamlit as st
from sklearn.metrics import confusion_matrix
from subpages.page import Context, Page
from utils import htmlify_labeled_example
class MisclassifiedPage(Page):
name = "Misclassified"
icon = "x-octagon"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write(
"This page contains all misclassified examples and allows filtering by specific error types."
)
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
cm = confusion_matrix(
misclassified_samples.labels,
misclassified_samples.preds,
labels=context.labels,
)
# st.pyplot(
# plot_confusion_matrix(
# y_preds=misclassified_samples["preds"],
# y_true=misclassified_samples["labels"],
# labels=labels,
# normalize=None,
# zero_diagonal=True,
# ),
# )
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
import numpy as np
np.fill_diagonal(df.values, "")
st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
# import matplotlib.pyplot as plt
# st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
# selection = aggrid_interactive_table(df)
# st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
confusions = defaultdict(int)
for i, row in enumerate(cm):
for j, _ in enumerate(row):
if i == j or cm[i][j] == 0:
continue
confusions[(context.labels[i], context.labels[j])] += cm[i][j]
def format_func(item):
return (
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
)
conf = st.radio(
"Filter by Class Confusion",
options=list(zip(confusions.keys(), confusions.values())),
format_func=format_func,
)
# st.write(
# f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
# )
filtered_indices = misclassified_samples.query(
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
).index
for i, idx in enumerate(filtered_indices):
sample = context.df_tokens_merged.loc[idx]
st.write(
htmlify_labeled_example(sample),
unsafe_allow_html=True,
)
st.write("---")