Spaces:
Runtime error
Runtime error
"""This page contains all misclassified examples and allows filtering by specific error types.""" | |
from collections import defaultdict | |
import pandas as pd | |
import streamlit as st | |
from sklearn.metrics import confusion_matrix | |
from src.subpages.page import Context, Page | |
from src.utils import htmlify_labeled_example | |
class MisclassifiedPage(Page): | |
name = "Misclassified" | |
icon = "x-octagon" | |
def render(self, context: Context): | |
st.title(self.name) | |
with st.expander("💡", expanded=True): | |
st.write( | |
"This page contains all misclassified examples and allows filtering by specific error types." | |
) | |
misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() | |
misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] | |
cm = confusion_matrix( | |
misclassified_samples.labels, | |
misclassified_samples.preds, | |
labels=context.labels, | |
) | |
# st.pyplot( | |
# plot_confusion_matrix( | |
# y_preds=misclassified_samples["preds"], | |
# y_true=misclassified_samples["labels"], | |
# labels=labels, | |
# normalize=None, | |
# zero_diagonal=True, | |
# ), | |
# ) | |
df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) | |
import numpy as np | |
np.fill_diagonal(df.values, "") | |
st.dataframe(df.applymap(lambda x: x if x != "0" else "")) | |
# import matplotlib.pyplot as plt | |
# st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html()) | |
# selection = aggrid_interactive_table(df) | |
# st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) | |
confusions = defaultdict(int) | |
for i, row in enumerate(cm): | |
for j, _ in enumerate(row): | |
if i == j or cm[i][j] == 0: | |
continue | |
confusions[(context.labels[i], context.labels[j])] += cm[i][j] | |
def format_func(item): | |
return ( | |
f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" | |
) | |
conf = st.radio( | |
"Filter by Class Confusion", | |
options=list(zip(confusions.keys(), confusions.values())), | |
format_func=format_func, | |
) | |
# st.write( | |
# f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`" | |
# ) | |
filtered_indices = misclassified_samples.query( | |
f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" | |
).index | |
for i, idx in enumerate(filtered_indices): | |
sample = context.df_tokens_merged.loc[idx] | |
st.write( | |
htmlify_labeled_example(sample), | |
unsafe_allow_html=True, | |
) | |
st.write("---") | |