Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Commit
•
21d3461
1
Parent(s):
0aef92c
Update
Browse files- .pre-commit-config.yaml +1 -0
- .streamlit/config.toml +2 -0
- app.py +4 -0
- app_lib/defaults.py +14 -0
- app_lib/main.py +15 -6
- app_lib/test.py +56 -23
- app_lib/user_input.py +52 -53
- app_lib/viz.py +78 -10
- assets/results/ace.npy +3 -0
- assets/results/english_springer_1.npy +3 -0
- assets/results/english_springer_2.npy +3 -0
- assets/results/french_horn.npy +3 -0
- assets/results/parachute.npy +3 -0
- header.md +2 -1
- ibydmt/test.py +0 -1
- precompute_results.py +46 -0
.pre-commit-config.yaml
CHANGED
@@ -8,3 +8,4 @@ repos:
|
|
8 |
rev: 22.6.0
|
9 |
hooks:
|
10 |
- id: black-jupyter
|
|
|
|
8 |
rev: 22.6.0
|
9 |
hooks:
|
10 |
- id: black-jupyter
|
11 |
+
args: ["--preview"]
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base = "light"
|
app.py
CHANGED
@@ -6,6 +6,10 @@ if "sidebar_state" not in st.session_state:
|
|
6 |
st.session_state.sidebar_state = "collapsed"
|
7 |
if "disabled" not in st.session_state:
|
8 |
st.session_state.disabled = False
|
|
|
|
|
|
|
|
|
9 |
if "results" not in st.session_state:
|
10 |
st.session_state.results = None
|
11 |
|
|
|
6 |
st.session_state.sidebar_state = "collapsed"
|
7 |
if "disabled" not in st.session_state:
|
8 |
st.session_state.disabled = False
|
9 |
+
if "image_name" not in st.session_state:
|
10 |
+
st.session_state.image_name = None
|
11 |
+
if "tested" not in st.session_state:
|
12 |
+
st.session_state.tested = False
|
13 |
if "results" not in st.session_state:
|
14 |
st.session_state.results = None
|
15 |
|
app_lib/defaults.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET_NAME = "imagenette"
|
2 |
+
MODEL_NAME = "open_clip:ViT-B-32"
|
3 |
+
|
4 |
+
SIGNIFICANCE_LEVEL_VALUE = 0.05
|
5 |
+
SIGNIFICANCE_LEVEL_STEP = 0.01
|
6 |
+
|
7 |
+
TAU_MAX_VALUE = 200
|
8 |
+
TAU_MAX_STEP = 50
|
9 |
+
|
10 |
+
R_VALUE = 10
|
11 |
+
R_STEP = 5
|
12 |
+
|
13 |
+
CARDINALITY_VALUE = lambda concepts: int(len(concepts) / 2)
|
14 |
+
CARDINALITY_STEP = 1
|
app_lib/main.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
|
4 |
-
from app_lib.test import get_testing_config, test
|
5 |
from app_lib.user_input import (
|
6 |
get_advanced_settings,
|
7 |
get_class_name,
|
@@ -19,10 +19,6 @@ def _disable():
|
|
19 |
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
20 |
columns = st.columns([0.40, 0.60])
|
21 |
|
22 |
-
with columns[1]:
|
23 |
-
st.header("Results")
|
24 |
-
viz_results()
|
25 |
-
|
26 |
with columns[0]:
|
27 |
st.header("Choose Image and Concepts")
|
28 |
|
@@ -30,6 +26,13 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
30 |
|
31 |
with image_col:
|
32 |
image_name, image = get_image()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
st.image(image, use_column_width=True)
|
34 |
|
35 |
change_image_button = st.button(
|
@@ -85,5 +88,11 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
85 |
cardinality,
|
86 |
dataset_name,
|
87 |
model_name,
|
88 |
-
device,
|
89 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
|
4 |
+
from app_lib.test import get_testing_config, load_precomputed_results, test
|
5 |
from app_lib.user_input import (
|
6 |
get_advanced_settings,
|
7 |
get_class_name,
|
|
|
19 |
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
20 |
columns = st.columns([0.40, 0.60])
|
21 |
|
|
|
|
|
|
|
|
|
22 |
with columns[0]:
|
23 |
st.header("Choose Image and Concepts")
|
24 |
|
|
|
26 |
|
27 |
with image_col:
|
28 |
image_name, image = get_image()
|
29 |
+
if image_name != st.session_state.image_name:
|
30 |
+
st.session_state.image_name = image_name
|
31 |
+
st.session_state.tested = False
|
32 |
+
|
33 |
+
if image_name is not None and not st.session_state.tested:
|
34 |
+
st.session_state.results = load_precomputed_results(image_name)
|
35 |
+
|
36 |
st.image(image, use_column_width=True)
|
37 |
|
38 |
change_image_button = st.button(
|
|
|
88 |
cardinality,
|
89 |
dataset_name,
|
90 |
model_name,
|
91 |
+
device=device,
|
92 |
)
|
93 |
+
|
94 |
+
st.session_state.tested = True
|
95 |
+
|
96 |
+
with columns[1]:
|
97 |
+
st.header("Results")
|
98 |
+
viz_results()
|
app_lib/test.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
2 |
|
3 |
import clip
|
@@ -145,6 +146,14 @@ def get_testing_config(**kwargs):
|
|
145 |
return testing_config
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def test(
|
149 |
testing_config,
|
150 |
image,
|
@@ -153,23 +162,44 @@ def test(
|
|
153 |
cardinality,
|
154 |
dataset_name,
|
155 |
model_name,
|
156 |
-
device,
|
|
|
157 |
):
|
158 |
-
|
|
|
|
|
|
|
159 |
model, preprocess, tokenizer = _load_model(model_name, device)
|
160 |
|
161 |
-
|
|
|
|
|
|
|
162 |
cbm = _encode_concepts(tokenizer, model, concepts, device)
|
163 |
|
164 |
-
|
|
|
|
|
|
|
165 |
h = _encode_image(model, preprocess, image, device)
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
embedding = _load_dataset(dataset_name, model_name)
|
175 |
semantics = embedding @ cbm.T
|
@@ -177,11 +207,6 @@ def test(
|
|
177 |
|
178 |
classifier = _encode_class_name(tokenizer, model, class_name, device)
|
179 |
|
180 |
-
progress_bar.progress(
|
181 |
-
1 / (len(concepts) + 1),
|
182 |
-
text=f"Testing concepts (can take up to a minute) [0 / {len(concepts)} completed]",
|
183 |
-
)
|
184 |
-
|
185 |
with ThreadPoolExecutor() as executor:
|
186 |
futures = [
|
187 |
executor.submit(
|
@@ -200,10 +225,14 @@ def test(
|
|
200 |
results = []
|
201 |
for idx, future in enumerate(as_completed(futures)):
|
202 |
results.append(future.result())
|
203 |
-
|
204 |
-
(
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
207 |
|
208 |
rejected = np.empty((testing_config.r, len(concepts)))
|
209 |
tau = np.empty((testing_config.r, len(concepts)))
|
@@ -218,7 +247,7 @@ def test(
|
|
218 |
|
219 |
tau /= testing_config.tau_max
|
220 |
|
221 |
-
|
222 |
"significance_level": testing_config.significance_level,
|
223 |
"concepts": concepts,
|
224 |
"rejected": rejected,
|
@@ -226,5 +255,9 @@ def test(
|
|
226 |
"wealth": wealth,
|
227 |
}
|
228 |
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
3 |
|
4 |
import clip
|
|
|
146 |
return testing_config
|
147 |
|
148 |
|
149 |
+
def load_precomputed_results(image_name):
|
150 |
+
results = np.load(
|
151 |
+
os.path.join("assets", "results", f"{image_name.split('.')[0]}.npy"),
|
152 |
+
allow_pickle=True,
|
153 |
+
).item()
|
154 |
+
return results
|
155 |
+
|
156 |
+
|
157 |
def test(
|
158 |
testing_config,
|
159 |
image,
|
|
|
162 |
cardinality,
|
163 |
dataset_name,
|
164 |
model_name,
|
165 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
166 |
+
with_streamlit=True,
|
167 |
):
|
168 |
+
if with_streamlit:
|
169 |
+
with st.spinner("Loading model"):
|
170 |
+
model, preprocess, tokenizer = _load_model(model_name, device)
|
171 |
+
else:
|
172 |
model, preprocess, tokenizer = _load_model(model_name, device)
|
173 |
|
174 |
+
if with_streamlit:
|
175 |
+
with st.spinner("Encoding concepts"):
|
176 |
+
cbm = _encode_concepts(tokenizer, model, concepts, device)
|
177 |
+
else:
|
178 |
cbm = _encode_concepts(tokenizer, model, concepts, device)
|
179 |
|
180 |
+
if with_streamlit:
|
181 |
+
with st.spinner("Encoding image"):
|
182 |
+
h = _encode_image(model, preprocess, image, device)
|
183 |
+
else:
|
184 |
h = _encode_image(model, preprocess, image, device)
|
185 |
+
z = h @ cbm.T
|
186 |
+
z = z.squeeze()
|
187 |
+
|
188 |
+
if with_streamlit:
|
189 |
+
progress_bar = st.progress(
|
190 |
+
0,
|
191 |
+
text=(
|
192 |
+
"Testing concepts (can take up to a minute) [0 /"
|
193 |
+
f" {len(concepts)} completed]"
|
194 |
+
),
|
195 |
+
)
|
196 |
+
progress_bar.progress(
|
197 |
+
1 / (len(concepts) + 1),
|
198 |
+
text=(
|
199 |
+
"Testing concepts (can take up to a minute) [0 /"
|
200 |
+
f" {len(concepts)} completed]"
|
201 |
+
),
|
202 |
+
)
|
203 |
|
204 |
embedding = _load_dataset(dataset_name, model_name)
|
205 |
semantics = embedding @ cbm.T
|
|
|
207 |
|
208 |
classifier = _encode_class_name(tokenizer, model, class_name, device)
|
209 |
|
|
|
|
|
|
|
|
|
|
|
210 |
with ThreadPoolExecutor() as executor:
|
211 |
futures = [
|
212 |
executor.submit(
|
|
|
225 |
results = []
|
226 |
for idx, future in enumerate(as_completed(futures)):
|
227 |
results.append(future.result())
|
228 |
+
if with_streamlit:
|
229 |
+
progress_bar.progress(
|
230 |
+
(idx + 2) / (len(concepts) + 1),
|
231 |
+
text=(
|
232 |
+
f"Testing concepts (can take up to a minute) [{idx + 1} /"
|
233 |
+
f" {len(concepts)} completed]"
|
234 |
+
),
|
235 |
+
)
|
236 |
|
237 |
rejected = np.empty((testing_config.r, len(concepts)))
|
238 |
tau = np.empty((testing_config.r, len(concepts)))
|
|
|
247 |
|
248 |
tau /= testing_config.tau_max
|
249 |
|
250 |
+
results = {
|
251 |
"significance_level": testing_config.significance_level,
|
252 |
"concepts": concepts,
|
253 |
"rejected": rejected,
|
|
|
255 |
"wealth": wealth,
|
256 |
}
|
257 |
|
258 |
+
if with_streamlit:
|
259 |
+
st.session_state.results = results
|
260 |
+
st.session_state.disabled = False
|
261 |
+
st.experimental_rerun()
|
262 |
+
else:
|
263 |
+
return results
|
app_lib/user_input.py
CHANGED
@@ -5,6 +5,7 @@ import streamlit as st
|
|
5 |
from PIL import Image
|
6 |
from streamlit_image_select import image_select
|
7 |
|
|
|
8 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
9 |
|
10 |
IMAGE_DIR = os.path.join("assets", "images")
|
@@ -30,102 +31,97 @@ def _validate_concepts(concepts):
|
|
30 |
|
31 |
|
32 |
def _get_significance_level():
|
33 |
-
|
|
|
34 |
return st.slider(
|
35 |
"Significance level",
|
36 |
-
help=" "
|
37 |
-
|
38 |
-
"The level of significance of the tests.",
|
39 |
-
f"Defaults to {DEFAULT:.2F}.",
|
40 |
-
]
|
41 |
-
),
|
42 |
-
min_value=STEP,
|
43 |
max_value=1.0,
|
44 |
-
value=
|
45 |
-
step=
|
46 |
disabled=st.session_state.disabled,
|
47 |
)
|
48 |
|
49 |
|
50 |
def _get_tau_max():
|
51 |
-
|
|
|
52 |
return int(
|
53 |
st.slider(
|
54 |
"Length of test",
|
55 |
-
help=" "
|
56 |
-
|
57 |
-
"The maximum number of steps for each test.",
|
58 |
-
f"Defaults to {DEFAULT}.",
|
59 |
-
]
|
60 |
-
),
|
61 |
-
min_value=STEP,
|
62 |
max_value=1000,
|
63 |
-
step=
|
64 |
-
value=
|
65 |
disabled=st.session_state.disabled,
|
66 |
)
|
67 |
)
|
68 |
|
69 |
|
70 |
def _get_number_of_tests():
|
71 |
-
|
|
|
72 |
return int(
|
73 |
st.slider(
|
74 |
"Number of tests per concept",
|
75 |
-
help=
|
76 |
-
|
77 |
-
|
78 |
-
f"Defaults to {DEFAULT}.",
|
79 |
-
]
|
80 |
),
|
81 |
-
min_value=
|
82 |
max_value=100,
|
83 |
-
step=
|
84 |
-
value=
|
85 |
disabled=st.session_state.disabled,
|
86 |
)
|
87 |
)
|
88 |
|
89 |
|
90 |
def _get_cardinality(concepts, concepts_ready):
|
91 |
-
|
|
|
92 |
return st.slider(
|
93 |
"Size of conditioning set",
|
94 |
-
help=
|
95 |
-
|
96 |
-
|
97 |
-
"Defaults to half of the number of concepts.",
|
98 |
-
]
|
99 |
),
|
100 |
min_value=1,
|
101 |
max_value=max(2, len(concepts) - 1),
|
102 |
-
value=
|
103 |
-
step=
|
104 |
disabled=st.session_state.disabled or not concepts_ready,
|
105 |
)
|
106 |
|
107 |
|
108 |
def _get_dataset_name():
|
109 |
-
|
|
|
110 |
return st.selectbox(
|
111 |
"Dataset",
|
112 |
-
options=
|
113 |
-
index=
|
114 |
-
help=
|
115 |
-
|
116 |
-
|
117 |
-
"Defaults to Imagenette.",
|
118 |
-
]
|
119 |
),
|
120 |
disabled=st.session_state.disabled,
|
121 |
)
|
122 |
|
123 |
|
124 |
def get_model_name():
|
|
|
|
|
125 |
return st.selectbox(
|
126 |
"Model to test",
|
127 |
-
options=
|
128 |
-
|
|
|
|
|
|
|
|
|
129 |
disabled=st.session_state.disabled,
|
130 |
)
|
131 |
|
@@ -148,13 +144,13 @@ def get_image():
|
|
148 |
|
149 |
|
150 |
def get_class_name(image_name=None):
|
151 |
-
|
152 |
IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
|
153 |
)
|
154 |
class_name = st.text_input(
|
155 |
-
"Class to
|
156 |
help="Name of the class to build the zero-shot CLIP classifier with.",
|
157 |
-
value=
|
158 |
disabled=st.session_state.disabled,
|
159 |
placeholder="Type class name here",
|
160 |
)
|
@@ -164,16 +160,19 @@ def get_class_name(image_name=None):
|
|
164 |
|
165 |
|
166 |
def get_concepts(image_name=None):
|
167 |
-
|
168 |
"\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
|
169 |
if image_name
|
170 |
else ""
|
171 |
)
|
172 |
concepts = st.text_area(
|
173 |
"Concepts to test",
|
174 |
-
help=
|
|
|
|
|
|
|
175 |
height=160,
|
176 |
-
value=
|
177 |
disabled=st.session_state.disabled,
|
178 |
placeholder="Type one concept\nper line",
|
179 |
)
|
|
|
5 |
from PIL import Image
|
6 |
from streamlit_image_select import image_select
|
7 |
|
8 |
+
import app_lib.defaults as defaults
|
9 |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
|
10 |
|
11 |
IMAGE_DIR = os.path.join("assets", "images")
|
|
|
31 |
|
32 |
|
33 |
def _get_significance_level():
|
34 |
+
default = defaults.SIGNIFICANCE_LEVEL_VALUE
|
35 |
+
step = defaults.SIGNIFICANCE_LEVEL_STEP
|
36 |
return st.slider(
|
37 |
"Significance level",
|
38 |
+
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
|
39 |
+
min_value=step,
|
|
|
|
|
|
|
|
|
|
|
40 |
max_value=1.0,
|
41 |
+
value=default,
|
42 |
+
step=step,
|
43 |
disabled=st.session_state.disabled,
|
44 |
)
|
45 |
|
46 |
|
47 |
def _get_tau_max():
|
48 |
+
default = defaults.TAU_MAX_VALUE
|
49 |
+
step = defaults.TAU_MAX_STEP
|
50 |
return int(
|
51 |
st.slider(
|
52 |
"Length of test",
|
53 |
+
help=f"The maximum number of steps for each test. Defaults to {default}.",
|
54 |
+
min_value=step,
|
|
|
|
|
|
|
|
|
|
|
55 |
max_value=1000,
|
56 |
+
step=step,
|
57 |
+
value=default,
|
58 |
disabled=st.session_state.disabled,
|
59 |
)
|
60 |
)
|
61 |
|
62 |
|
63 |
def _get_number_of_tests():
|
64 |
+
default = defaults.R_VALUE
|
65 |
+
step = defaults.R_STEP
|
66 |
return int(
|
67 |
st.slider(
|
68 |
"Number of tests per concept",
|
69 |
+
help=(
|
70 |
+
"The number of tests to average for each concept. "
|
71 |
+
f"Defaults to {default}."
|
|
|
|
|
72 |
),
|
73 |
+
min_value=step,
|
74 |
max_value=100,
|
75 |
+
step=step,
|
76 |
+
value=default,
|
77 |
disabled=st.session_state.disabled,
|
78 |
)
|
79 |
)
|
80 |
|
81 |
|
82 |
def _get_cardinality(concepts, concepts_ready):
|
83 |
+
default = defaults.CARDINALITY_VALUE
|
84 |
+
step = defaults.CARDINALITY_STEP
|
85 |
return st.slider(
|
86 |
"Size of conditioning set",
|
87 |
+
help=(
|
88 |
+
"The number of concepts to condition model predictions on. "
|
89 |
+
"Defaults to half of the number of concepts.",
|
|
|
|
|
90 |
),
|
91 |
min_value=1,
|
92 |
max_value=max(2, len(concepts) - 1),
|
93 |
+
value=default(concepts),
|
94 |
+
step=step,
|
95 |
disabled=st.session_state.disabled or not concepts_ready,
|
96 |
)
|
97 |
|
98 |
|
99 |
def _get_dataset_name():
|
100 |
+
options = SUPPORTED_DATASETS
|
101 |
+
default_idx = options.index(defaults.DATASET_NAME)
|
102 |
return st.selectbox(
|
103 |
"Dataset",
|
104 |
+
options=options,
|
105 |
+
index=default_idx,
|
106 |
+
help=(
|
107 |
+
"Name of the dataset to use to train sampler."
|
108 |
+
f"Defaults to {SUPPORTED_DATASETS[default_idx]}."
|
|
|
|
|
109 |
),
|
110 |
disabled=st.session_state.disabled,
|
111 |
)
|
112 |
|
113 |
|
114 |
def get_model_name():
|
115 |
+
options = list(SUPPORTED_MODELS.keys())
|
116 |
+
default_idx = options.index(defaults.MODEL_NAME)
|
117 |
return st.selectbox(
|
118 |
"Model to test",
|
119 |
+
options=options,
|
120 |
+
index=default_idx,
|
121 |
+
help=(
|
122 |
+
"Name of the vision-language model to test the predictions of."
|
123 |
+
f"Defaults to {options[default_idx]}"
|
124 |
+
),
|
125 |
disabled=st.session_state.disabled,
|
126 |
)
|
127 |
|
|
|
144 |
|
145 |
|
146 |
def get_class_name(image_name=None):
|
147 |
+
default = (
|
148 |
IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
|
149 |
)
|
150 |
class_name = st.text_input(
|
151 |
+
"Class to predict",
|
152 |
help="Name of the class to build the zero-shot CLIP classifier with.",
|
153 |
+
value=default,
|
154 |
disabled=st.session_state.disabled,
|
155 |
placeholder="Type class name here",
|
156 |
)
|
|
|
160 |
|
161 |
|
162 |
def get_concepts(image_name=None):
|
163 |
+
default = (
|
164 |
"\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
|
165 |
if image_name
|
166 |
else ""
|
167 |
)
|
168 |
concepts = st.text_area(
|
169 |
"Concepts to test",
|
170 |
+
help=(
|
171 |
+
"List of concepts to test the predictions of the model with. "
|
172 |
+
"Write one concept per line. Maximum 10 concepts allowed."
|
173 |
+
),
|
174 |
height=160,
|
175 |
+
value=default,
|
176 |
disabled=st.session_state.disabled,
|
177 |
placeholder="Type one concept\nper line",
|
178 |
)
|
app_lib/viz.py
CHANGED
@@ -6,6 +6,32 @@ import streamlit as st
|
|
6 |
|
7 |
|
8 |
def _viz_rank(results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
rejected = results["rejected"]
|
10 |
tau = results["tau"]
|
11 |
concepts = results["concepts"]
|
@@ -55,7 +81,11 @@ def _viz_rank(results):
|
|
55 |
)
|
56 |
fig.add_vline(significance_level, line_dash="dash", line_color="black")
|
57 |
|
58 |
-
fig.update_layout(
|
|
|
|
|
|
|
|
|
59 |
if rank_df["tau"].min() <= 0.3:
|
60 |
fig.update_layout(
|
61 |
legend=dict(
|
@@ -104,24 +134,62 @@ def viz_results():
|
|
104 |
if results is None:
|
105 |
st.info("Test concepts to show results", icon="ℹ️")
|
106 |
else:
|
107 |
-
rank_tab, wealth_tab = st.tabs(
|
|
|
|
|
108 |
|
109 |
with rank_tab:
|
110 |
-
st.subheader("Rank of
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
with st.expander("Details"):
|
112 |
-
st.
|
113 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
)
|
115 |
|
116 |
if results is not None:
|
117 |
-
|
|
|
118 |
|
119 |
with wealth_tab:
|
120 |
st.subheader("Wealth Process of Testing Procedures")
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
|
126 |
if results is not None:
|
127 |
_viz_wealth(results)
|
|
|
|
6 |
|
7 |
|
8 |
def _viz_rank(results):
|
9 |
+
tau = results["tau"]
|
10 |
+
concepts = results["concepts"]
|
11 |
+
|
12 |
+
tau_mu = tau.mean(axis=0)
|
13 |
+
|
14 |
+
sorted_idx = np.argsort(tau_mu)
|
15 |
+
sorted_tau = tau_mu[sorted_idx]
|
16 |
+
sorted_concepts = [concepts[idx] for idx in sorted_idx]
|
17 |
+
|
18 |
+
min_size, max_size = 14, 50
|
19 |
+
|
20 |
+
_, centercol, _ = st.columns(3)
|
21 |
+
with centercol:
|
22 |
+
with st.container():
|
23 |
+
for concept, tau in zip(sorted_concepts, sorted_tau):
|
24 |
+
style = (
|
25 |
+
"text-align:center;"
|
26 |
+
f"font-size:{max_size - tau * (max_size - min_size)}px"
|
27 |
+
)
|
28 |
+
st.write(
|
29 |
+
f"<p style='{style}'>{concept}</p>",
|
30 |
+
unsafe_allow_html=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def _viz_test(results):
|
35 |
rejected = results["rejected"]
|
36 |
tau = results["tau"]
|
37 |
concepts = results["concepts"]
|
|
|
81 |
)
|
82 |
fig.add_vline(significance_level, line_dash="dash", line_color="black")
|
83 |
|
84 |
+
fig.update_layout(
|
85 |
+
yaxis_title="Rank of importance",
|
86 |
+
xaxis_title="",
|
87 |
+
margin=dict(l=20, r=20, t=20, b=20),
|
88 |
+
)
|
89 |
if rank_df["tau"].min() <= 0.3:
|
90 |
fig.update_layout(
|
91 |
legend=dict(
|
|
|
134 |
if results is None:
|
135 |
st.info("Test concepts to show results", icon="ℹ️")
|
136 |
else:
|
137 |
+
rank_tab, test_tab, wealth_tab = st.tabs(
|
138 |
+
["Rank of importance", "Testing results", "Wealth process"]
|
139 |
+
)
|
140 |
|
141 |
with rank_tab:
|
142 |
+
st.subheader("Rank of Importance")
|
143 |
+
st.write(
|
144 |
+
"""
|
145 |
+
This tab visually shows the rank of importance of the specified concepts
|
146 |
+
for the prediction of the model on the input image. Larger font sizes indicate
|
147 |
+
higher importance. See the other two tabs for more details.
|
148 |
+
"""
|
149 |
+
)
|
150 |
+
|
151 |
+
if results is not None:
|
152 |
+
_viz_rank(results)
|
153 |
+
st.divider()
|
154 |
+
|
155 |
+
with test_tab:
|
156 |
+
st.subheader("Testing Results")
|
157 |
+
st.write(
|
158 |
+
"""
|
159 |
+
Importance is measured by performing sequential tests of statistical independence.
|
160 |
+
This tab shows the results of these tests and how the rank of importance is computed.
|
161 |
+
Concepts are sorted by increasing rejection time, where a shorter rejection time indicates
|
162 |
+
higher importance.
|
163 |
+
"""
|
164 |
+
)
|
165 |
with st.expander("Details"):
|
166 |
+
st.markdown(
|
167 |
+
"""
|
168 |
+
Results are averaged over multiple random draws of conditioning subsets of
|
169 |
+
concepts. The number of tests can be controlled under `Advanced settings`.
|
170 |
+
|
171 |
+
- **Rejection rate**: The average number of times the test is rejected for a concept.
|
172 |
+
- **Rejection time**: The (normalized) average number of steps before the test is
|
173 |
+
rejected for a concept.
|
174 |
+
- **Significance level**: The level at which the test is rejected for a concept.
|
175 |
+
"""
|
176 |
)
|
177 |
|
178 |
if results is not None:
|
179 |
+
_viz_test(results)
|
180 |
+
st.divider()
|
181 |
|
182 |
with wealth_tab:
|
183 |
st.subheader("Wealth Process of Testing Procedures")
|
184 |
+
st.markdown(
|
185 |
+
"""
|
186 |
+
Sequential tests instantiate a wealth process for each concept. Once the
|
187 |
+
wealth reaches a value of 1/α, the test is rejected with Type I error control at
|
188 |
+
level α. This tab shows the average wealth process of the testing procedures for
|
189 |
+
each concept.
|
190 |
+
"""
|
191 |
+
)
|
192 |
|
193 |
if results is not None:
|
194 |
_viz_wealth(results)
|
195 |
+
st.divider()
|
assets/results/ace.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c82df211515a7d8c565c6e350d4183ba7c877d23ab03e0d639319a3429d8850b
|
3 |
+
size 81407
|
assets/results/english_springer_1.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81c05c7f4cc612f2afefb06a67a970001605ccfe83ddc72e2eb46cafa27773dc
|
3 |
+
size 81414
|
assets/results/english_springer_2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2ef386512d02796f4570464e88ffaf6df8829858ad6a23836eaa7f86c6a1302
|
3 |
+
size 81416
|
assets/results/french_horn.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:467bf67d0cbf873bb8c72b0bf68906e331573920cf019b78142245d37a9fafb9
|
3 |
+
size 81412
|
assets/results/parachute.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33b001c2390ecdd8975fc8a39b257813a0ff7a05fc65c6a52a981581693b6a42
|
3 |
+
size 81415
|
header.md
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
-
Test the
|
|
|
4 |
|
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
+
Test the effect of different concepts on the predictions of a classifier. Concepts are ranked by their *importance*: how much they change the prediction. [[paper]](https://arxiv.org/pdf/2405.19146) [[code]](https://github.com/Sulam-Group/IBYDMT)
|
4 |
+
|
5 |
|
ibydmt/test.py
CHANGED
@@ -120,7 +120,6 @@ class xSKIT(SequentialTester):
|
|
120 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
121 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
122 |
) -> Tuple[Float[Array, "1"], Float[Array, "1"]]:
|
123 |
-
|
124 |
if len(self._queue) == 0:
|
125 |
Cuj = C + [j]
|
126 |
|
|
|
120 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
121 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
122 |
) -> Tuple[Float[Array, "1"], Float[Array, "1"]]:
|
|
|
123 |
if len(self._queue) == 0:
|
124 |
Cuj = C + [j]
|
125 |
|
precompute_results.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import app_lib.defaults as defaults
|
8 |
+
from app_lib.test import get_testing_config, test
|
9 |
+
|
10 |
+
assets_dir = "assets"
|
11 |
+
image_dir = os.path.join(assets_dir, "images")
|
12 |
+
results_dir = os.path.join(assets_dir, "results")
|
13 |
+
os.makedirs(results_dir, exist_ok=True)
|
14 |
+
|
15 |
+
testing_config = get_testing_config(
|
16 |
+
significance_level=defaults.SIGNIFICANCE_LEVEL_VALUE,
|
17 |
+
tau_max=defaults.TAU_MAX_VALUE,
|
18 |
+
r=defaults.R_VALUE,
|
19 |
+
)
|
20 |
+
|
21 |
+
image_presets = json.load(open(os.path.join(assets_dir, "image_presets.json")))
|
22 |
+
for _image_name, _image_presets in image_presets.items():
|
23 |
+
_image_name = f"{_image_name}.jpg"
|
24 |
+
_image_path = os.path.join(image_dir, _image_name)
|
25 |
+
|
26 |
+
_image = Image.open(_image_path)
|
27 |
+
_class_name = _image_presets["class_name"]
|
28 |
+
_concepts = _image_presets["concepts"]
|
29 |
+
_cardinality = defaults.CARDINALITY_VALUE(_concepts)
|
30 |
+
|
31 |
+
_results = test(
|
32 |
+
testing_config,
|
33 |
+
_image,
|
34 |
+
_class_name,
|
35 |
+
_concepts,
|
36 |
+
_cardinality,
|
37 |
+
defaults.DATASET_NAME,
|
38 |
+
defaults.MODEL_NAME,
|
39 |
+
with_streamlit=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
np.save(
|
43 |
+
os.path.join(results_dir, _image_name.split(".")[0]),
|
44 |
+
_results,
|
45 |
+
allow_pickle=True,
|
46 |
+
)
|