File size: 1,767 Bytes
c8d36ae
597bf7d
 
 
7a75a86
 
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""See the data as seen by your model."""
import pandas as pd
import streamlit as st

from src.subpages.page import Context, Page
from src.utils import aggrid_interactive_table


@st.cache
def convert_df(df):
    return df.to_csv().encode("utf-8")


class RawDataPage(Page):
    name = "Raw data"
    icon = "qr-code"

    def render(self, context: Context):
        st.title(self.name)
        with st.expander("💡", expanded=True):
            st.write("See the data as seen by your model.")

        st.subheader("Dataset")
        st.code(
            f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
        )

        st.write("**Data after processing and inference**")

        processed_df = (
            context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
        )
        cols = (
            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
        )
        if "token_type_ids" not in processed_df.columns:
            cols.remove("token_type_ids")
        processed_df = processed_df[cols]
        aggrid_interactive_table(processed_df)
        processed_df_csv = convert_df(processed_df)
        st.download_button(
            "Download csv",
            processed_df_csv,
            "processed_data.csv",
            "text/csv",
        )

        st.write("**Raw data (exploded by tokens)**")
        raw_data_df = context.split.to_pandas().apply(pd.Series.explode)  # type: ignore
        aggrid_interactive_table(raw_data_df)
        raw_data_df_csv = convert_df(raw_data_df)
        st.download_button(
            "Download csv",
            raw_data_df_csv,
            "raw_data.csv",
            "text/csv",
        )