File size: 6,598 Bytes
c3af76c
5ead791
 
c3af76c
5e91161
5ead791
 
c3af76c
21d3461
 
 
 
 
 
 
 
 
8e05eba
 
f2c0884
8e05eba
f2c0884
 
 
 
 
 
 
 
 
 
 
 
 
21d3461
 
 
c3af76c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aef92c
c3af76c
 
0aef92c
 
 
 
 
 
 
 
c3af76c
0aef92c
 
21d3461
 
 
 
 
0aef92c
 
 
 
 
 
 
 
 
c3af76c
0aef92c
c3af76c
 
 
 
5ead791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6eb5e3
 
 
 
 
 
 
 
8e05eba
5ead791
 
 
 
 
 
8e05eba
 
 
 
5ead791
8e05eba
 
 
 
21d3461
 
 
8e05eba
 
21d3461
8e05eba
 
 
 
 
21d3461
8e05eba
 
 
 
21d3461
 
 
 
8e05eba
 
 
 
21d3461
 
 
 
 
 
 
 
8e05eba
 
c3af76c
8e05eba
 
 
 
 
c3af76c
8e05eba
 
 
 
21d3461
 
 
 
8e05eba
 
f6eb5e3
8e05eba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st


def _viz_rank(results):
    tau = results["tau"]
    concepts = results["concepts"]

    tau_mu = tau.mean(axis=0)

    sorted_idx = np.argsort(tau_mu)
    sorted_tau = tau_mu[sorted_idx]
    sorted_concepts = [concepts[idx] for idx in sorted_idx]

    sorted_width = 1 - sorted_tau
    sorted_width /= sorted_width.max()
    sorted_width *= 80

    rank_el = ""
    for concept_idx, concept in enumerate(sorted_concepts):
        circle_style = (
            "background: #418FDE;border-radius: 50%;width:"
            f" {sorted_width[concept_idx]}px;padding-bottom:"
            f" {sorted_width[concept_idx]}px;"
        )
        rank_el += (
            "<div id='conceptContainer'><p"
            f" id='concept'><strong>{concept}<strong></p><div id='circleContainer'><div"
            f" style='{circle_style}'></div></div></div>"
        )
    st.markdown(rank_el, unsafe_allow_html=True)


def _viz_test(results):
    rejected = results["rejected"]
    tau = results["tau"]
    concepts = results["concepts"]
    significance_level = results["significance_level"]

    rejected_mu = rejected.mean(axis=0)
    tau_mu = tau.mean(axis=0)

    sorted_idx = np.argsort(tau_mu)[::-1]
    sorted_tau = tau_mu[sorted_idx]
    sorted_rejected = rejected_mu[sorted_idx]
    sorted_concepts = [concepts[idx] for idx in sorted_idx]

    rank_df = []
    for concept, tau, rejected in zip(sorted_concepts, sorted_tau, sorted_rejected):
        rank_df.append({"concept": concept, "tau": tau, "rejected": rejected})
    rank_df = pd.DataFrame(rank_df)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=rank_df["rejected"],
            y=rank_df["concept"],
            marker=dict(size=8),
            line=dict(color="#1f78b4", dash="dash"),
            name="Rejection rate",
        )
    )
    fig.add_trace(
        go.Bar(
            x=rank_df["tau"],
            y=rank_df["concept"],
            orientation="h",
            marker=dict(color="#a6cee3"),
            name="Rejection time",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[significance_level, significance_level],
            y=[sorted_concepts[0], sorted_concepts[0]],
            mode="lines",
            line=dict(color="black", dash="dash"),
            name="significance level",
        )
    )
    fig.add_vline(significance_level, line_dash="dash", line_color="black")

    fig.update_layout(
        yaxis_title="Rank of importance",
        xaxis_title="",
        margin=dict(l=20, r=20, t=20, b=20),
    )
    if rank_df["tau"].min() <= 0.3:
        fig.update_layout(
            legend=dict(
                x=0.3,
                y=1.0,
                bordercolor="black",
                borderwidth=1,
            ),
        )

    _, centercol, _ = st.columns([1, 3, 1])
    with centercol:
        st.plotly_chart(fig, use_container_width=True)


def _viz_wealth(results):
    wealth = results["wealth"]
    concepts = results["concepts"]
    significance_level = results["significance_level"]

    wealth_mu = wealth.mean(axis=0)

    wealth_df = []
    for concept_idx, concept in enumerate(concepts):
        for t in range(wealth.shape[1]):
            wealth_df.append(
                {"time": t, "concept": concept, "wealth": wealth_mu[t, concept_idx]}
            )
    wealth_df = pd.DataFrame(wealth_df)

    fig = px.line(wealth_df, x="time", y="wealth", color="concept")
    fig.add_hline(
        y=1 / significance_level,
        line_dash="dash",
        line_color="black",
        annotation_text="Rejection threshold (1 / α)",
        annotation_position="bottom right",
    )
    fig.update_yaxes(range=[0, 1.5 * 1 / significance_level])
    fig.update_layout(margin=dict(l=20, r=20, t=20, b=20))
    st.plotly_chart(fig, use_container_width=True)


def viz_results():
    results = st.session_state.results

    st.header("Results")
    rank_tab, test_tab, wealth_tab = st.tabs(
        ["Rank of importance", "Testing results", "Wealth process"]
    )

    with rank_tab:
        st.subheader("Rank of Importance")
        st.write(
            """
                This tab visually shows the rank of importance of the specified concepts 
                for the prediction of the model on the input image. Larger font sizes indicate 
                higher importance. See the other two tabs for more details.
            """
        )

        if results is not None:
            _viz_rank(results)
            st.divider()
        else:
            st.info("Waiting for results", icon="ℹ️")

    with test_tab:
        st.subheader("Testing Results")
        st.write(
            """
                Importance is measured by performing sequential tests of statistical independence. 
                This tab shows the results of these tests and how the rank of importance is computed.
                Concepts are sorted by increasing rejection time, where a shorter rejection time indicates 
                higher importance.
            """
        )
        with st.expander("Details"):
            st.markdown(
                """
                    Results are averaged over multiple random draws of conditioning subsets of
                    concepts. The number of tests can be controlled under `Advanced settings`.

                    - **Rejection rate**: The average number of times the test is rejected for a concept.
                    - **Rejection time**: The (normalized) average number of steps before the test is 
                    rejected for a concept.
                    - **Significance level**: The level at which the test is rejected for a concept.
                """
            )

        if results is not None:
            _viz_test(results)
            st.divider()
        else:
            st.info("Waiting for results", icon="ℹ️")

    with wealth_tab:
        st.subheader("Wealth Process of Testing Procedures")
        st.markdown(
            """
                Sequential tests instantiate a wealth process for each concept. Once the 
                wealth reaches a value of 1/α, the test is rejected with Type I error control at
                level α. This tab shows the average wealth process of the testing procedures for
                each concept.
            """
        )

        if results is not None:
            _viz_wealth(results)
            st.divider()
        else:
            st.info("Waiting for results", icon="ℹ️")