ExplaiNER / subpages /losses.py
Alexander Seifert
initial commit
597bf7d
raw history blame
No virus
2.01 kB
import streamlit as st
from subpages.page import Context, Page
from utils import AgGrid, aggrid_interactive_table
@st.cache
def get_loss_by_token(df_tokens):
return (
df_tokens.groupby("tokens")[["losses"]]
.agg(["count", "mean", "median", "sum"])
.droplevel(level=0, axis=1) # Get rid of multi-level columns
.sort_values(by="sum", ascending=False)
.reset_index()
)
@st.cache
def get_loss_by_label(df_tokens):
return (
df_tokens.groupby("labels")[["losses"]]
.agg(["count", "mean", "median", "sum"])
.droplevel(level=0, axis=1)
.sort_values(by="mean", ascending=False)
.reset_index()
)
class LossesPage(Page):
name = "Loss by Token/Label"
icon = "sort-alpha-down"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write("Show count, mean and median loss per token and label.")
col1, _, col2 = st.columns([8, 1, 6])
with col1:
st.subheader("💬 Loss by Token")
st.session_state["_merge_tokens"] = st.checkbox(
"Merge tokens", value=True, key="merge_tokens"
)
loss_by_token = (
get_loss_by_token(context.df_tokens_merged)
if st.session_state["merge_tokens"]
else get_loss_by_token(context.df_tokens_cleaned)
)
aggrid_interactive_table(loss_by_token.round(3))
# st.subheader("🏷️ Loss by Label")
# loss_by_label = get_loss_by_label(df_tokens_cleaned)
# st.dataframe(loss_by_label)
st.write(
"_Attention: This statistic disregards that tokens have contextual representations._"
)
with col2:
st.subheader("🏷️ Loss by Label")
loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
AgGrid(loss_by_label.round(3), height=200)