"""Inspect your whole dataset, either unfiltered or by id.""" import streamlit as st from src.subpages.page import Context, Page from src.utils import aggrid_interactive_table, colorize_classes class InspectPage(Page): name = "Inspect" icon = "search" def render(self, context: Context): st.title(self.name) with st.expander("💡", expanded=True): st.write("Inspect your whole dataset, either unfiltered or by id.") df = context.df_tokens cols = ( "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() ) if "token_type_ids" not in df.columns: cols.remove("token_type_ids") df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols] if st.checkbox("Filter by id", value=True): ids = list(sorted(map(int, df.ids.unique()))) next_id = st.session_state.get("next_id", 0) example_id = st.selectbox("Select an example", ids, index=next_id) df = df[df.ids == str(example_id)][1:-1] # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore st.dataframe(colorize_classes(df.round(3).astype(str))) if st.button("Next example"): st.session_state.next_id = (ids.index(example_id) + 1) % len(ids) if st.button("Previous example"): st.session_state.next_id = (ids.index(example_id) - 1) % len(ids) else: aggrid_interactive_table(df.round(3))