File size: 3,096 Bytes
2d4811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""The App module is the main entry point for the application.

    Run `streamlit run app.py` to start the app.
"""

import pandas as pd
import streamlit as st
from streamlit_option_menu import option_menu

from src.load import load_context
from src.subpages import (
    DebugPage,
    FindDuplicatesPage,
    HomePage,
    LossesPage,
    LossySamplesPage,
    MetricsPage,
    MisclassifiedPage,
    Page,
    ProbingPage,
    RandomSamplesPage,
    RawDataPage,
)
from src.subpages.attention import AttentionPage
from src.subpages.hidden_states import HiddenStatesPage
from src.subpages.inspect import InspectPage
from src.utils import classmap

sts = st.sidebar
st.set_page_config(
    layout="wide",
    page_title="Error Analysis",
    page_icon="🏷️",
)


def _show_menu(pages: list[Page]) -> int:
    with st.sidebar:
        page_names = [p.name for p in pages]
        page_icons = [p.icon for p in pages]
        selected_menu_item = st.session_state.active_page = option_menu(
            menu_title="ExplaiNER",
            options=page_names,
            icons=page_icons,
            menu_icon="layout-wtf",
            default_index=0,
        )
        return page_names.index(selected_menu_item)
    assert False


def _initialize_session_state(pages: list[Page]):
    if "active_page" not in st.session_state:
        for page in pages:
            st.session_state.update(**page._get_widget_defaults())
    st.session_state.update(st.session_state)


def _write_color_legend(context):
    def style(x):
        return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]

    labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
    colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]

    color_legend_df = pd.DataFrame(
        [classmap[l] for l in labels], columns=["label"], index=labels
    ).T
    st.sidebar.write(
        color_legend_df.T.style.apply(style, axis=0).set_properties(
            **{"color": "white", "text-align": "center"}
        )
    )


def main():
    """The main entry point for the application."""
    pages: list[Page] = [
        HomePage(),
        AttentionPage(),
        HiddenStatesPage(),
        ProbingPage(),
        MetricsPage(),
        LossySamplesPage(),
        LossesPage(),
        MisclassifiedPage(),
        RandomSamplesPage(),
        FindDuplicatesPage(),
        InspectPage(),
        RawDataPage(),
        DebugPage(),
    ]

    _initialize_session_state(pages)

    selected_page_idx = _show_menu(pages)
    selected_page = pages[selected_page_idx]

    if isinstance(selected_page, HomePage):
        selected_page.render()
        return

    if "model_name" not in st.session_state:
        # this can happen if someone loads another page directly (without going through home)
        st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
        return

    context = load_context(**st.session_state)
    _write_color_legend(context)
    selected_page.render(context)


if __name__ == "__main__":
    main()