Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import streamlit as st | |
def _viz_rank(results): | |
tau = results["tau"] | |
concepts = results["concepts"] | |
tau_mu = tau.mean(axis=0) | |
sorted_idx = np.argsort(tau_mu) | |
sorted_tau = tau_mu[sorted_idx] | |
sorted_concepts = [concepts[idx] for idx in sorted_idx] | |
sorted_width = 1 - sorted_tau | |
sorted_width /= sorted_width.max() | |
sorted_width *= 80 | |
rank_el = "" | |
for concept_idx, concept in enumerate(sorted_concepts): | |
circle_style = ( | |
"background: #418FDE;border-radius: 50%;width:" | |
f" {sorted_width[concept_idx]}px;padding-bottom:" | |
f" {sorted_width[concept_idx]}px;" | |
) | |
rank_el += ( | |
"<div id='conceptContainer'><p" | |
f" id='concept'><strong>{concept}<strong></p><div id='circleContainer'><div" | |
f" style='{circle_style}'></div></div></div>" | |
) | |
st.markdown(rank_el, unsafe_allow_html=True) | |
def _viz_test(results): | |
rejected = results["rejected"] | |
tau = results["tau"] | |
concepts = results["concepts"] | |
significance_level = results["significance_level"] | |
rejected_mu = rejected.mean(axis=0) | |
tau_mu = tau.mean(axis=0) | |
sorted_idx = np.argsort(tau_mu)[::-1] | |
sorted_tau = tau_mu[sorted_idx] | |
sorted_rejected = rejected_mu[sorted_idx] | |
sorted_concepts = [concepts[idx] for idx in sorted_idx] | |
rank_df = [] | |
for concept, tau, rejected in zip(sorted_concepts, sorted_tau, sorted_rejected): | |
rank_df.append({"concept": concept, "tau": tau, "rejected": rejected}) | |
rank_df = pd.DataFrame(rank_df) | |
fig = go.Figure() | |
fig.add_trace( | |
go.Scatter( | |
x=rank_df["rejected"], | |
y=rank_df["concept"], | |
marker=dict(size=8), | |
line=dict(color="#1f78b4", dash="dash"), | |
name="Rejection rate", | |
) | |
) | |
fig.add_trace( | |
go.Bar( | |
x=rank_df["tau"], | |
y=rank_df["concept"], | |
orientation="h", | |
marker=dict(color="#a6cee3"), | |
name="Rejection time", | |
) | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=[significance_level, significance_level], | |
y=[sorted_concepts[0], sorted_concepts[0]], | |
mode="lines", | |
line=dict(color="black", dash="dash"), | |
name="significance level", | |
) | |
) | |
fig.add_vline(significance_level, line_dash="dash", line_color="black") | |
fig.update_layout( | |
yaxis_title="Rank of importance", | |
xaxis_title="", | |
margin=dict(l=20, r=20, t=20, b=20), | |
) | |
if rank_df["tau"].min() <= 0.3: | |
fig.update_layout( | |
legend=dict( | |
x=0.3, | |
y=1.0, | |
bordercolor="black", | |
borderwidth=1, | |
), | |
) | |
_, centercol, _ = st.columns([1, 3, 1]) | |
with centercol: | |
st.plotly_chart(fig, use_container_width=True) | |
def _viz_wealth(results): | |
wealth = results["wealth"] | |
concepts = results["concepts"] | |
significance_level = results["significance_level"] | |
wealth_mu = wealth.mean(axis=0) | |
wealth_df = [] | |
for concept_idx, concept in enumerate(concepts): | |
for t in range(wealth.shape[1]): | |
wealth_df.append( | |
{"time": t, "concept": concept, "wealth": wealth_mu[t, concept_idx]} | |
) | |
wealth_df = pd.DataFrame(wealth_df) | |
fig = px.line(wealth_df, x="time", y="wealth", color="concept") | |
fig.add_hline( | |
y=1 / significance_level, | |
line_dash="dash", | |
line_color="black", | |
annotation_text="Rejection threshold (1 / α)", | |
annotation_position="bottom right", | |
) | |
fig.update_yaxes(range=[0, 1.5 * 1 / significance_level]) | |
fig.update_layout(margin=dict(l=20, r=20, t=20, b=20)) | |
st.plotly_chart(fig, use_container_width=True) | |
def viz_results(): | |
results = st.session_state.results | |
st.header("Results") | |
rank_tab, test_tab, wealth_tab = st.tabs( | |
["Rank of importance", "Testing results", "Wealth process"] | |
) | |
with rank_tab: | |
st.subheader("Rank of Importance") | |
st.write( | |
""" | |
This tab visually shows the rank of importance of the specified concepts | |
for the prediction of the model on the input image. Larger font sizes indicate | |
higher importance. See the other two tabs for more details. | |
""" | |
) | |
if results is not None: | |
_viz_rank(results) | |
st.divider() | |
else: | |
st.info("Waiting for results", icon="ℹ️") | |
with test_tab: | |
st.subheader("Testing Results") | |
st.write( | |
""" | |
Importance is measured by performing sequential tests of statistical independence. | |
This tab shows the results of these tests and how the rank of importance is computed. | |
Concepts are sorted by increasing rejection time, where a shorter rejection time indicates | |
higher importance. | |
""" | |
) | |
with st.expander("Details"): | |
st.markdown( | |
""" | |
Results are averaged over multiple random draws of conditioning subsets of | |
concepts. The number of tests can be controlled under `Advanced settings`. | |
- **Rejection rate**: The average number of times the test is rejected for a concept. | |
- **Rejection time**: The (normalized) average number of steps before the test is | |
rejected for a concept. | |
- **Significance level**: The level at which the test is rejected for a concept. | |
""" | |
) | |
if results is not None: | |
_viz_test(results) | |
st.divider() | |
else: | |
st.info("Waiting for results", icon="ℹ️") | |
with wealth_tab: | |
st.subheader("Wealth Process of Testing Procedures") | |
st.markdown( | |
""" | |
Sequential tests instantiate a wealth process for each concept. Once the | |
wealth reaches a value of 1/α, the test is rejected with Type I error control at | |
level α. This tab shows the average wealth process of the testing procedures for | |
each concept. | |
""" | |
) | |
if results is not None: | |
_viz_wealth(results) | |
st.divider() | |
else: | |
st.info("Waiting for results", icon="ℹ️") | |