Spaces:
Runtime error
Runtime error
"""See the data as seen by your model.""" | |
import pandas as pd | |
import streamlit as st | |
from src.subpages.page import Context, Page | |
from src.utils import aggrid_interactive_table | |
def convert_df(df): | |
return df.to_csv().encode("utf-8") | |
class RawDataPage(Page): | |
name = "Raw data" | |
icon = "qr-code" | |
def render(self, context: Context): | |
st.title(self.name) | |
with st.expander("💡", expanded=True): | |
st.write("See the data as seen by your model.") | |
st.subheader("Dataset") | |
st.code( | |
f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" | |
) | |
st.write("**Data after processing and inference**") | |
processed_df = ( | |
context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) | |
) | |
cols = ( | |
"ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() | |
) | |
if "token_type_ids" not in processed_df.columns: | |
cols.remove("token_type_ids") | |
processed_df = processed_df[cols] | |
aggrid_interactive_table(processed_df) | |
processed_df_csv = convert_df(processed_df) | |
st.download_button( | |
"Download csv", | |
processed_df_csv, | |
"processed_data.csv", | |
"text/csv", | |
) | |
st.write("**Raw data (exploded by tokens)**") | |
raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore | |
aggrid_interactive_table(raw_data_df) | |
raw_data_df_csv = convert_df(raw_data_df) | |
st.download_button( | |
"Download csv", | |
raw_data_df_csv, | |
"raw_data.csv", | |
"text/csv", | |
) | |