File size: 11,069 Bytes
90f4ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# streamlit
import streamlit as st
from streamlit_vega_lite import altair_component
import base64


# data
import pandas as pd

# utils
from numpy import round
from interactive_model_cards import utils as ut


def perf_interact(type="model perf",min_size=0):
    """ Instructions for interacting with the view"""

    if type == "model perf":
        st.markdown(
            f"""
            <span>
                <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/warning-black.png", "rb").read()).decode()}"> All subpopulations with <strong>fewer than {min_size}</strong> sentences are reporting potentially unreliable results. These are <strong style="color:red">identified with a red border</strong> around the bar.
            </span>
            """,
            unsafe_allow_html=True
        )
        st.markdown("") #just to space them out
        st.markdown(
            f"""
            <span>
                <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/click.png", "rb").read()).decode()}"> Click on the bars to see example sentences.
            </span>
            """,
            unsafe_allow_html=True
        )

        st.markdown("") #just to space them out
    else:
        st.write("This visualization shows a representation of the data according to how similar two sentences are *relative to the data the model was trained on*.")

        st.markdown(
            f"""
            <span>
                <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/click.png", "rb").read()).decode()}"> <strong>Here are ways to interact with this view</strong>:
            </span>
            """,
            unsafe_allow_html=True
        )

        st.write("* You can `zoom in and out` of the visualization")
        st.write("* You can `hover` over a data point to see the sentence and sentiment")
        st.write("*  You can `click on the legend` to emphasize subpopulations in the data according to positive of negative sentiment.")

        


def quant_panel(sst_db, embedding, col,data_view):
    """ Quantitative Panel Layout"""

    all_metrics = {}
    with col:
        if data_view == "Model Performance Metrics":
            st.warning("**Model Performance Metrics**")

            st.markdown("* Evaluation metrics include [accuracy](https://simple.wikipedia.org/wiki/Accuracy_and_precision), [precision](https://en.wikipedia.org/wiki/Precision_and_recall), and [recall](https://en.wikipedia.org/wiki/Precision_and_recall).")
            st.markdown(" * Performance is shown for the training and testing set, as well as special groups within this dataset that have been automatically associated with US protected groups")
        

            min_size = st.number_input("Flag (with a red border) subpopulations with fewer than the follow sentences:", value=100, min_value=30, max_value=10000)
            
            perf_interact(type="model perf",min_size=min_size)

            #st.write(f'* All subsamples with `fewer than {min_size} sentences` are reporting potentially unreliable results and are <span style="color:red; fontface:bold">flagged with red border</span>. Take extra care when interpretting this data.', unsafe_allow_html=True)
            #st.markdown("* Click on the bars to see examples of sentences")

            for key in st.session_state["quant_ex"]:
                tmp = st.session_state["quant_ex"][key]

                if tmp is not None:
                    for iKey in tmp.keys():
                        all_metrics[iKey] = {}
                        all_metrics[iKey]["metrics"] = tmp[iKey]
                        all_metrics[iKey]["source"] = key

                        if key == "Overall Performance":
                            #get the size of the dataset
                            idx = ut.get_sliceid(list(sst_db.slices)).index(iKey)
                            slice_data = list(sst_db.slices)[idx]

                            # write slice data to UI
                            df = ut.slice_to_df(slice_data)
                            all_metrics[iKey]["size"] = df.shape[0]

                            # due to the way slices are added
                            # this hack is required
                            if "RGDataset" in iKey:
                                all_metrics[iKey]["source"] = "Custom Slice"
                            elif "protected" in iKey:
                                all_metrics[iKey]["source"] = "US Protected Class"
                        else:
                            all_metrics[iKey]["size"] = st.session_state["user_data"].shape[0]

            # st.write(all_metrics)
            chart = ut.visualize_metrics(all_metrics, max_width=100, linked_vis=True,min_size=min_size)
            event_dict = altair_component(altair_chart=chart)

            # st.altair_chart(chart)

            # if something was clicked on, find out what it was
            if "name" in event_dict.keys():
                # identify what it was selected on
                st.session_state["selected_slice"] = {
                    "name": event_dict["name"][0],
                    "source": event_dict["source"][0],
                }

            if st.session_state["selected_slice"] is not None:
                get_selected = st.session_state["selected_slice"]["name"]

                #subsampling data from training data
                if st.session_state["selected_slice"]["source"] in [
                    "Overall Performance",
                    "Custom Slice",
                    "US Protected Class"
                ]:
                    selected = st.session_state["selected_slice"]["name"]
                    # get selected slice data
                    #st.write(ut.get_sliceid(list(sst_db.slices)))
                    idx = ut.get_sliceid(list(sst_db.slices)).index(selected)
                    slice_data = list(sst_db.slices)[idx]

                    # write slice data to UI
                    df = ut.slice_to_df(slice_data)
                

                #subsetting the data
                    st.warning("**Data Details**")
                    with st.expander("Customize Data Sample"):
                        with st.form("Sample Form"):
                            st.number_input(
                                "Number of Samples",
                                value=min(df.shape[0],10),
                                min_value=1,
                                max_value=df.shape[0],
                                key="sampleNum",
                            )
                            st.selectbox(
                                "Sample Type",
                                [
                                    "Random Sample",
                                    "Highest Probabilities",
                                    "Lowest Probabilities",
                                    "Mid Probabilities",
                                ],
                                index=0,
                                key="sampleType",
                            )
                            st.form_submit_button("Generate Sample")
                    
                    #drawing the sampled data
                    
                    #summarize slice information
                    displayName = str(selected).split("->")
                    
                    if len(displayName) > 1:
                        displayName = displayName[1].split("@")[0].strip()
                    else:
                        displayName= displayName[0]

                    st.markdown(
                        f"* The slice `{displayName}` has a total size of `{df.shape[0]} sentences`"
                    )
                    #summarize data sample size and sampling method
                    st.markdown(
                        f"* Shown is a subsample of all the data to `{st.session_state['sampleNum']}` sampled by `{st.session_state['sampleType']}`"
                    )

                    # add terms in user has selectd a custom slice
                    if st.session_state["selected_slice"]["source"]=="Custom Slice":
                        terms_str = ', '.join(st.session_state["slice_terms"][selected])
                        st.markdown(f"* This slice contains sentences containing one or more of following has the following terms:`{terms_str}`")

                    elif st.session_state["selected_slice"]["source"]=="US Protected Class":
                        terms = st.session_state["protected_class"][displayName]
                        terms_str = ", ".join(terms)
                        st.markdown(f"* Sentences pertaining this US Protected Classes contain the following-terms: `{terms_str}`")
                        st.markdown(
                            f"""
                            <span>
                                <img src="data:image/png;base64,{base64.b64encode(open("./assets/img/warning-black.png", "rb").read()).decode()}"> Detecting US Protected classess by key word search is not perfect. Some sentences below may not be pertintent to a protected class, for example the word 'black' can refer individuals but not always.
                            </span>
                            """,
                            unsafe_allow_html=True
                        )
                        
                    st.table(
                        ut.subsample_df(
                            df,
                            st.session_state["sampleNum"],
                            st.session_state["sampleType"],
                        )
                    )

                elif st.session_state["selected_slice"]["source"] in ["User Custom Sentence"]:
                    #st.markdown(f"These are {st.session_state["user_data"]} custom sentences you have defined")
                    st.markdown("**Data Details**")
                    df = st.session_state["user_data"]
                    st.markdown(f"These are your `{df.shape[0]}` custom sentences")
                    st.write(df)
        else:
            st.warning("**Subpopulation Comparison**")
            perf_interact(type="comparison")

            with st.expander("how to read this chart:"):
                st.markdown("* each **point** is a single sentence")
                st.markdown("* the **position** of each dot is determined mathematically based upon an analysis of the words in a sentence. The **closer** two points on the visualization the **more similar** the sentences are. The **further apart ** two points on the visualization the **more different** the sentences are")
                st.markdown(" * the **shape** of each point reflects whether it a positive (diamond) or negative sentiment (circle)")
                st.markdown("* the **color** of each point is the ")

            #down sample embedding for altair limitations
            tmp = embedding
            tmp = ut.down_samp(embedding)
            st.altair_chart(ut.data_comparison(tmp))


__all__ = ["quant_panel"]