Spaces:
Runtime error
Runtime error
File size: 4,900 Bytes
c8d36ae 597bf7d 7a75a86 597bf7d e18be25 597bf7d a2351d6 597bf7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""Show every example sorted by loss (descending) for close inspection."""
import pandas as pd
import streamlit as st
from src.subpages.page import Context, Page
from src.utils import (
colorize_classes,
get_bg_color,
get_fg_color,
htmlify_labeled_example,
)
class LossySamplesPage(Page):
name = "Samples by Loss"
icon = "sort-numeric-down-alt"
def get_widget_defaults(self):
return {
"skip_correct": True,
"samples_by_loss_show_df": True,
}
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write("Show every example sorted by loss (descending) for close inspection.")
st.write(
"The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
)
st.write(
"The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
)
st.write(
"The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
)
st.subheader("💥 Samples ⬇loss")
skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
st.write(
"""<style>
thead {
display: none;
}
td {
white-space: nowrap;
padding: 0 5px !important;
}
</style>""",
unsafe_allow_html=True,
)
top_indices = (
context.df.sort_values(by="total_loss", ascending=False)
.query("total_loss > 0.5")
.index
)
cnt = 0
for idx in top_indices:
sample = context.df_tokens_merged.loc[idx]
if isinstance(sample, pd.Series):
continue
if skip_correct and sum(sample.labels != sample.preds) == 0:
continue
if show_df:
def colorize_col(col):
if col.name == "labels" or col.name == "preds":
bgs = []
fgs = []
ops = []
for v in col.values:
bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
fgs.append(get_fg_color(bgs[-1]))
ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
return [
f"background-color: {bg}; color: {fg}; opacity: {op};"
for bg, fg, op in zip(bgs, fgs, ops)
]
return [""] * len(col)
df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
losses_slice = pd.IndexSlice["losses", :]
# x = df.T.astype(str)
# st.dataframe(x)
# st.dataframe(x.loc[losses_slice])
styler = (
df.T.style.apply(colorize_col, axis=1)
.bar(subset=losses_slice, axis=1)
.format(precision=3)
)
# styler.data = styler.data.astype(str)
st.write(styler.to_html(), unsafe_allow_html=True)
st.write("")
# st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore
# st.write(
# colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
# )
col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
cnt += 1
counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
col1.write(f"{counter}{loss}", unsafe_allow_html=True)
col1.write("")
col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
# st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
|