Spaces:
Runtime error
Runtime error
"""Show count, mean and median loss per token and label.""" | |
import streamlit as st | |
from src.subpages.page import Context, Page | |
from src.utils import AgGrid, aggrid_interactive_table | |
def get_loss_by_token(df_tokens): | |
return ( | |
df_tokens.groupby("tokens")[["losses"]] | |
.agg(["count", "mean", "median", "sum"]) | |
.droplevel(level=0, axis=1) # Get rid of multi-level columns | |
.sort_values(by="sum", ascending=False) | |
.reset_index() | |
) | |
def get_loss_by_label(df_tokens): | |
return ( | |
df_tokens.groupby("labels")[["losses"]] | |
.agg(["count", "mean", "median", "sum"]) | |
.droplevel(level=0, axis=1) | |
.sort_values(by="mean", ascending=False) | |
.reset_index() | |
) | |
class LossesPage(Page): | |
name = "Loss by Token/Label" | |
icon = "sort-alpha-down" | |
def render(self, context: Context): | |
st.title(self.name) | |
with st.expander("💡", expanded=True): | |
st.write("Show count, mean and median loss per token and label.") | |
st.write( | |
"Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." | |
) | |
col1, _, col2 = st.columns([8, 1, 6]) | |
with col1: | |
st.subheader("💬 Loss by Token") | |
st.session_state["_merge_tokens"] = st.checkbox( | |
"Merge tokens", value=True, key="merge_tokens" | |
) | |
loss_by_token = ( | |
get_loss_by_token(context.df_tokens_merged) | |
if st.session_state["merge_tokens"] | |
else get_loss_by_token(context.df_tokens_cleaned) | |
) | |
aggrid_interactive_table(loss_by_token.round(3)) | |
# st.subheader("🏷️ Loss by Label") | |
# loss_by_label = get_loss_by_label(df_tokens_cleaned) | |
# st.dataframe(loss_by_label) | |
st.write( | |
"_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" | |
) | |
with col2: | |
st.subheader("🏷️ Loss by Label") | |
loss_by_label = get_loss_by_label(context.df_tokens_cleaned) | |
AgGrid(loss_by_label.round(3), height=200) | |