File size: 3,167 Bytes
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import pandas as pd
import streamlit as st
from streamlit_option_menu import option_menu

from load import load_context
from subpages import (
    DebugPage,
    FindDuplicatesPage,
    HomePage,
    LossesPage,
    LossySamplesPage,
    MetricsPage,
    MisclassifiedPage,
    Page,
    ProbingPage,
    RandomSamplesPage,
    RawDataPage,
)
from subpages.attention import AttentionPage
from subpages.embeddings import EmbeddingsPage
from subpages.inspect import InspectPage

sts = st.sidebar
st.set_page_config(
    layout="wide",
    page_title="Error Analysis",
    page_icon="🏷️",
)


def _show_menu(pages: list[Page]) -> int:
    with st.sidebar:
        page_names = [p.name for p in pages]
        page_icons = [p.icon for p in pages]
        selected_menu_item = st.session_state.active_page = option_menu(
            menu_title="ExplaiNER",
            options=page_names,
            icons=page_icons,
            menu_icon="layout-wtf",
            default_index=0,
        )
        return page_names.index(selected_menu_item)
    assert False


def _initialize_session_state(pages: list[Page]):
    if "active_page" not in st.session_state:
        for page in pages:
            st.session_state.update(**page.get_widget_defaults())
    st.session_state.update(st.session_state)


def _write_color_legend(context):
    def style(x):
        return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]

    labelmap = {
        "O": "O",
        "person": "πŸ™Ž",
        "PER": "πŸ™Ž",
        "location": "🌎",
        "LOC": "🌎",
        "corporation": "🏀",
        "ORG": "🏀",
        "product": "πŸ“±",
        "creative": "🎷",
        "group": "🎷",
        "MISC": "🎷",
    }

    labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
    colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]

    color_legend_df = pd.DataFrame(
        [labelmap[l] for l in labels], columns=["label"], index=labels
    ).T
    st.sidebar.write(
        color_legend_df.T.style.apply(style, axis=0).set_properties(
            **{"color": "white", "text-align": "center"}
        )
    )


def main():
    pages: list[Page] = [
        HomePage(),
        AttentionPage(),
        EmbeddingsPage(),
        ProbingPage(),
        MetricsPage(),
        MisclassifiedPage(),
        LossesPage(),
        LossySamplesPage(),
        RandomSamplesPage(),
        FindDuplicatesPage(),
        InspectPage(),
        RawDataPage(),
        DebugPage(),
    ]

    _initialize_session_state(pages)

    selected_page_idx = _show_menu(pages)
    selected_page = pages[selected_page_idx]

    if isinstance(selected_page, HomePage):
        selected_page.render()
        return

    if "model_name" not in st.session_state:
        # this can happen if someone loads another page directly (without going through home)
        st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
        return

    context = load_context(**st.session_state)
    _write_color_legend(context)
    selected_page.render(context)


if __name__ == "__main__":
    main()