File size: 7,105 Bytes
597bf7d
 
 
 
 
 
7a75a86
 
 
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9cb6e
597bf7d
d5ecc0d
 
 
61fec8d
 
 
597bf7d
d5ecc0d
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9cb6e
597bf7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2918df9
597bf7d
 
fb9cb6e
597bf7d
fb9cb6e
597bf7d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
import random
from typing import Optional

import streamlit as st

from src.data import get_data
from src.subpages.page import Context, Page
from src.utils import PROJ, classmap, color_map_color

_SENTENCE_ENCODER_MODEL = (
    "sentence-transformers/all-MiniLM-L6-v2",
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
)[0]
_MODEL_NAME = (
    "elastic/distilbert-base-uncased-finetuned-conll03-english",
    "gagan3012/bert-tiny-finetuned-ner",
    "socialmediaie/bertweet-base_wnut17_ner",
    "sberbank-ai/bert-base-NER-reptile-5-datasets",
    "aseifert/comma-xlm-roberta-base",
    "dslim/bert-base-NER",
    "aseifert/distilbert-base-german-cased-comma-derstandard",
)[0]
_DATASET_NAME = (
    "conll2003",
    "wnut_17",
    "aseifert/comma",
)[0]
_CONFIG_NAME = (
    "conll2003",
    "wnut_17",
    "seifertverlag",
)[0]


class HomePage(Page):
    name = "Home / Setup"
    icon = "house"

    def get_widget_defaults(self):
        return {
            "encoder_model_name": _SENTENCE_ENCODER_MODEL,
            "model_name": _MODEL_NAME,
            "ds_name": _DATASET_NAME,
            "ds_split_name": "validation",
            "ds_config_name": _CONFIG_NAME,
            "split_sample_size": 512,
        }

    def render(self, context: Optional[Context] = None):
        st.title("ExplaiNER")

        with st.expander("💡", expanded=True):
            st.write(
                "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**."
            )
            st.write(
                "**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed."
            )
            st.write(
                "_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._"
            )

        col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15])

        with col1:
            random_form_key = f"settings-{random.randint(0, 100000)}"
            # FIXME: for some reason I'm getting the following error if I don't randomize the key:
            """
                2022-05-05 20:37:16.507 Traceback (most recent call last):
            File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script
                exec(code, module.__dict__)
            File "/Users/zoro/code/error-analysis/main.py", line 162, in <module>
                main()
            File "/Users/zoro/code/error-analysis/main.py", line 102, in main
                show_setup()
            File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup
                st.form_submit_button("Load Model & Data")
            File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button
                return self._form_submit_button(
            File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button
                return self.dg._button(
            File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button
                check_session_state_rules(default_value=None, key=key, writes_allowed=False)
            File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules
                raise StreamlitAPIException(
            streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state.
            """
            with st.form(key=random_form_key):
                st.subheader("Model & Data Selection")
                st.text_input(
                    label="NER Model:",
                    key="model_name",
                    help="Path or name of the model to use",
                )
                st.text_input(
                    label="Encoder Model:",
                    key="encoder_model_name",
                    help="Path or name of the encoder to use for duplicate detection",
                )
                ds_name = st.text_input(
                    label="Dataset:",
                    key="ds_name",
                    help="Path or name of the dataset to use",
                )
                ds_config_name = st.text_input(
                    label="Config (optional):",
                    key="ds_config_name",
                )
                ds_split_name = st.selectbox(
                    label="Split:",
                    options=["train", "validation", "test"],
                    key="ds_split_name",
                )
                split_sample_size = st.number_input(
                    "Sample size:",
                    step=16,
                    key="split_sample_size",
                    help="Sample size for the split, speeds up processing inside streamlit",
                )
                # breakpoint()
                # st.form_submit_button("Submit")
                st.form_submit_button("Load Model & Data")

        split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
        labels = list(
            set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"])
        )

        with col2a:
            st.subheader("Classes")
            st.write("**Color**")
            colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)}
            for label in labels:
                if f"color_{label}" not in st.session_state:
                    st.session_state[f"color_{label}"] = colors[label]
                st.color_picker(label, key=f"color_{label}")
        with col2b:
            st.subheader("—")
            st.write("**Icon**")
            emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys())
            for label in labels:
                if f"icon_{label}" not in st.session_state:
                    st.session_state[f"icon_{label}"] = classmap[label]
                st.selectbox(label, key=f"icon_{label}", options=emojis)
                classmap[label] = st.session_state[f"icon_{label}"]

        # if st.button("Reset to defaults"):
        #     st.session_state.update(**get_home_page_defaults())
        #     # time.sleep 2 secs
        #     import time
        #     time.sleep(1)

        #     # st.legacy_caching.clear_cache()
        #     st.experimental_rerun()