File size: 2,941 Bytes
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""This page contains all misclassified examples and allows filtering by specific error types."""
from collections import defaultdict

import pandas as pd
import streamlit as st
from sklearn.metrics import confusion_matrix

from src.subpages.page import Context, Page
from src.utils import htmlify_labeled_example


class MisclassifiedPage(Page):
    name = "Misclassified"
    icon = "x-octagon"

    def render(self, context: Context):
        st.title(self.name)
        with st.expander("💡", expanded=True):
            st.write(
                "This page contains all misclassified examples and allows filtering by specific error types."
            )

        misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
        misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
        cm = confusion_matrix(
            misclassified_samples.labels,
            misclassified_samples.preds,
            labels=context.labels,
        )

        # st.pyplot(
        #     plot_confusion_matrix(
        #         y_preds=misclassified_samples["preds"],
        #         y_true=misclassified_samples["labels"],
        #         labels=labels,
        #         normalize=None,
        #         zero_diagonal=True,
        #     ),
        # )
        df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
        import numpy as np

        np.fill_diagonal(df.values, "")
        st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
        # import matplotlib.pyplot as plt
        # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
        # selection = aggrid_interactive_table(df)

        # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)

        confusions = defaultdict(int)
        for i, row in enumerate(cm):
            for j, _ in enumerate(row):
                if i == j or cm[i][j] == 0:
                    continue
                confusions[(context.labels[i], context.labels[j])] += cm[i][j]

        def format_func(item):
            return (
                f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
            )

        conf = st.radio(
            "Filter by Class Confusion",
            options=list(zip(confusions.keys(), confusions.values())),
            format_func=format_func,
        )

        # st.write(
        #     f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
        # )

        filtered_indices = misclassified_samples.query(
            f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
        ).index
        for i, idx in enumerate(filtered_indices):
            sample = context.df_tokens_merged.loc[idx]
            st.write(
                htmlify_labeled_example(sample),
                unsafe_allow_html=True,
            )
            st.write("---")