ExplaiNER / src /subpages /raw_data.py
Alexander Seifert
add stuff for vis2
c8d36ae
raw
history blame
1.77 kB
"""See the data as seen by your model."""
import pandas as pd
import streamlit as st
from src.subpages.page import Context, Page
from src.utils import aggrid_interactive_table
@st.cache
def convert_df(df):
return df.to_csv().encode("utf-8")
class RawDataPage(Page):
name = "Raw data"
icon = "qr-code"
def render(self, context: Context):
st.title(self.name)
with st.expander("💡", expanded=True):
st.write("See the data as seen by your model.")
st.subheader("Dataset")
st.code(
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
)
st.write("**Data after processing and inference**")
processed_df = (
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
)
cols = (
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
)
if "token_type_ids" not in processed_df.columns:
cols.remove("token_type_ids")
processed_df = processed_df[cols]
aggrid_interactive_table(processed_df)
processed_df_csv = convert_df(processed_df)
st.download_button(
"Download csv",
processed_df_csv,
"processed_data.csv",
"text/csv",
)
st.write("**Raw data (exploded by tokens)**")
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore
aggrid_interactive_table(raw_data_df)
raw_data_df_csv = convert_df(raw_data_df)
st.download_button(
"Download csv",
raw_data_df_csv,
"raw_data.csv",
"text/csv",
)