File size: 1,582 Bytes
c8d36ae
597bf7d
 
7a75a86
 
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Inspect your whole dataset, either unfiltered or by id."""
import streamlit as st

from src.subpages.page import Context, Page
from src.utils import aggrid_interactive_table, colorize_classes


class InspectPage(Page):
    name = "Inspect"
    icon = "search"

    def render(self, context: Context):
        st.title(self.name)
        with st.expander("💡", expanded=True):
            st.write("Inspect your whole dataset, either unfiltered or by id.")

        df = context.df_tokens
        cols = (
            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
        )
        if "token_type_ids" not in df.columns:
            cols.remove("token_type_ids")
        df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]

        if st.checkbox("Filter by id", value=True):
            ids = list(sorted(map(int, df.ids.unique())))
            next_id = st.session_state.get("next_id", 0)

            example_id = st.selectbox("Select an example", ids, index=next_id)
            df = df[df.ids == str(example_id)][1:-1]
            # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses"))  # type: ignore
            st.dataframe(colorize_classes(df.round(3).astype(str)))

            if st.button("Next example"):
                st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
            if st.button("Previous example"):
                st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
        else:
            aggrid_interactive_table(df.round(3))