ExplaiNER / src /subpages /losses.py
Alexander Seifert
add stuff for vis2
c8d36ae
raw history blame
No virus
2.26 kB
"""Show count, mean and median loss per token and label."""
import streamlit as st
from src.subpages.page import Context, Page
from src.utils import AgGrid, aggrid_interactive_table
@st.cache
def get_loss_by_token(df_tokens):
return (
df_tokens.groupby("tokens")[["losses"]]
.agg(["count", "mean", "median", "sum"])
.droplevel(level=0, axis=1) # Get rid of multi-level columns
.sort_values(by="sum", ascending=False)
.reset_index()
)
@st.cache
def get_loss_by_label(df_tokens):
return (
df_tokens.groupby("labels")[["losses"]]
.agg(["count", "mean", "median", "sum"])
.droplevel(level=0, axis=1)
.sort_values(by="mean", ascending=False)
.reset_index()
)
class LossesPage(Page):
name = "Loss by Token/Label"
icon = "sort-alpha-down"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write("Show count, mean and median loss per token and label.")
st.write(
"Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
)
col1, _, col2 = st.columns([8, 1, 6])
with col1:
st.subheader("💬 Loss by Token")
st.session_state["_merge_tokens"] = st.checkbox(
"Merge tokens", value=True, key="merge_tokens"
)
loss_by_token = (
get_loss_by_token(context.df_tokens_merged)
if st.session_state["merge_tokens"]
else get_loss_by_token(context.df_tokens_cleaned)
)
aggrid_interactive_table(loss_by_token.round(3))
# st.subheader("🏷️ Loss by Label")
# loss_by_label = get_loss_by_label(df_tokens_cleaned)
# st.dataframe(loss_by_label)
st.write(
"_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
)
with col2:
st.subheader("🏷️ Loss by Label")
loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
AgGrid(loss_by_label.round(3), height=200)