jacopoteneggi commited on
Commit
8e05eba
Β·
verified Β·
1 Parent(s): 21d3461
app_lib/ckde.py CHANGED
@@ -31,11 +31,10 @@ class cKDE:
31
  def _sample(self, z, cond_idx, m):
32
  sample_idx = list(set(range(len(z))) - set(cond_idx))
33
 
34
- kde, _ = self.kde(z, cond_idx)
35
-
36
  sample_z = np.tile(z, (m, 1))
37
- sample_z[:, sample_idx] = kde.resample(m).T
38
-
 
39
  return sample_z
40
 
41
  def kde(self, z, cond_idx):
 
31
  def _sample(self, z, cond_idx, m):
32
  sample_idx = list(set(range(len(z))) - set(cond_idx))
33
 
 
 
34
  sample_z = np.tile(z, (m, 1))
35
+ if len(sample_idx) > 0:
36
+ kde, _ = self.kde(z, cond_idx)
37
+ sample_z[:, sample_idx] = kde.resample(m).T
38
  return sample_z
39
 
40
  def kde(self, z, cond_idx):
app_lib/defaults.py CHANGED
@@ -7,8 +7,8 @@ SIGNIFICANCE_LEVEL_STEP = 0.01
7
  TAU_MAX_VALUE = 200
8
  TAU_MAX_STEP = 50
9
 
10
- R_VALUE = 10
11
  R_STEP = 5
12
 
13
- CARDINALITY_VALUE = lambda concepts: int(len(concepts) / 2)
14
  CARDINALITY_STEP = 1
 
7
  TAU_MAX_VALUE = 200
8
  TAU_MAX_STEP = 50
9
 
10
+ R_VALUE = 20
11
  R_STEP = 5
12
 
13
+ CARDINALITY_VALUE = 1
14
  CARDINALITY_STEP = 1
app_lib/main.py CHANGED
@@ -16,6 +16,21 @@ def _disable():
16
  st.session_state.disabled = True
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
20
  columns = st.columns([0.40, 0.60])
21
 
@@ -26,13 +41,6 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
26
 
27
  with image_col:
28
  image_name, image = get_image()
29
- if image_name != st.session_state.image_name:
30
- st.session_state.image_name = image_name
31
- st.session_state.tested = False
32
-
33
- if image_name is not None and not st.session_state.tested:
34
- st.session_state.results = load_precomputed_results(image_name)
35
-
36
  st.image(image, use_column_width=True)
37
 
38
  change_image_button = st.button(
@@ -40,9 +48,8 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
40
  use_container_width=False,
41
  disabled=st.session_state.disabled,
42
  )
43
- if change_image_button:
44
- st.session_state.sidebar_state = "expanded"
45
- st.experimental_rerun()
46
  with concepts_col:
47
  model_name = get_model_name()
48
  class_name, class_ready, class_error = get_class_name(image_name)
@@ -74,13 +81,18 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
74
  disabled=st.session_state.disabled or not ready,
75
  )
76
 
77
- if test_button:
78
- st.session_state.results = None
79
 
80
- testing_config = get_testing_config(
81
- significance_level=significance_level, tau_max=tau_max, r=r
82
- )
83
- test(
 
 
 
 
 
84
  testing_config,
85
  image,
86
  class_name,
@@ -91,8 +103,12 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
91
  device=device,
92
  )
93
 
94
- st.session_state.tested = True
 
 
 
 
 
95
 
96
- with columns[1]:
97
- st.header("Results")
98
- viz_results()
 
16
  st.session_state.disabled = True
17
 
18
 
19
+ def _toggle_sidebar(button):
20
+ if button:
21
+ st.session_state.sidebar_state = "expanded"
22
+ st.experimental_rerun()
23
+
24
+
25
+ def _preload_results(image_name):
26
+ if image_name != st.session_state.image_name:
27
+ st.session_state.image_name = image_name
28
+ st.session_state.tested = False
29
+
30
+ if st.session_state.image_name is not None and not st.session_state.tested:
31
+ st.session_state.results = load_precomputed_results(image_name)
32
+
33
+
34
  def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
35
  columns = st.columns([0.40, 0.60])
36
 
 
41
 
42
  with image_col:
43
  image_name, image = get_image()
 
 
 
 
 
 
 
44
  st.image(image, use_column_width=True)
45
 
46
  change_image_button = st.button(
 
48
  use_container_width=False,
49
  disabled=st.session_state.disabled,
50
  )
51
+ _toggle_sidebar(change_image_button)
52
+
 
53
  with concepts_col:
54
  model_name = get_model_name()
55
  class_name, class_ready, class_error = get_class_name(image_name)
 
81
  disabled=st.session_state.disabled or not ready,
82
  )
83
 
84
+ if test_button:
85
+ st.session_state.results = None
86
 
87
+ with columns[1]:
88
+ viz_results()
89
+
90
+ testing_config = get_testing_config(
91
+ significance_level=significance_level, tau_max=tau_max, r=r
92
+ )
93
+
94
+ with columns[0]:
95
+ results = test(
96
  testing_config,
97
  image,
98
  class_name,
 
103
  device=device,
104
  )
105
 
106
+ st.session_state.tested = True
107
+ st.session_state.results = results
108
+ st.session_state.disabled = False
109
+ st.experimental_rerun()
110
+ else:
111
+ _preload_results(image_name)
112
 
113
+ with columns[1]:
114
+ viz_results()
 
app_lib/test.py CHANGED
@@ -255,9 +255,4 @@ def test(
255
  "wealth": wealth,
256
  }
257
 
258
- if with_streamlit:
259
- st.session_state.results = results
260
- st.session_state.disabled = False
261
- st.experimental_rerun()
262
- else:
263
- return results
 
255
  "wealth": wealth,
256
  }
257
 
258
+ return results
 
 
 
 
 
app_lib/user_input.py CHANGED
@@ -86,11 +86,11 @@ def _get_cardinality(concepts, concepts_ready):
86
  "Size of conditioning set",
87
  help=(
88
  "The number of concepts to condition model predictions on. "
89
- "Defaults to half of the number of concepts.",
90
  ),
91
  min_value=1,
92
  max_value=max(2, len(concepts) - 1),
93
- value=default(concepts),
94
  step=step,
95
  disabled=st.session_state.disabled or not concepts_ready,
96
  )
@@ -132,7 +132,7 @@ def get_image():
132
  if uploaded_file is not None:
133
  return (None, Image.open(uploaded_file))
134
  else:
135
- DEFAULT = IMAGE_NAMES.index("ace.jpg")
136
  image_idx = image_select(
137
  label="or select one",
138
  images=IMAGE_PATHS,
@@ -171,7 +171,7 @@ def get_concepts(image_name=None):
171
  "List of concepts to test the predictions of the model with. "
172
  "Write one concept per line. Maximum 10 concepts allowed."
173
  ),
174
- height=160,
175
  value=default,
176
  disabled=st.session_state.disabled,
177
  placeholder="Type one concept\nper line",
 
86
  "Size of conditioning set",
87
  help=(
88
  "The number of concepts to condition model predictions on. "
89
+ "Defaults to {default}."
90
  ),
91
  min_value=1,
92
  max_value=max(2, len(concepts) - 1),
93
+ value=default,
94
  step=step,
95
  disabled=st.session_state.disabled or not concepts_ready,
96
  )
 
132
  if uploaded_file is not None:
133
  return (None, Image.open(uploaded_file))
134
  else:
135
+ DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg")
136
  image_idx = image_select(
137
  label="or select one",
138
  images=IMAGE_PATHS,
 
171
  "List of concepts to test the predictions of the model with. "
172
  "Write one concept per line. Maximum 10 concepts allowed."
173
  ),
174
+ height=180,
175
  value=default,
176
  disabled=st.session_state.disabled,
177
  placeholder="Type one concept\nper line",
app_lib/viz.py CHANGED
@@ -15,20 +15,37 @@ def _viz_rank(results):
15
  sorted_tau = tau_mu[sorted_idx]
16
  sorted_concepts = [concepts[idx] for idx in sorted_idx]
17
 
18
- min_size, max_size = 14, 50
19
-
20
- _, centercol, _ = st.columns(3)
21
- with centercol:
22
- with st.container():
23
- for concept, tau in zip(sorted_concepts, sorted_tau):
24
- style = (
25
- "text-align:center;"
26
- f"font-size:{max_size - tau * (max_size - min_size)}px"
27
- )
28
- st.write(
29
- f"<p style='{style}'>{concept}</p>",
30
- unsafe_allow_html=True,
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def _viz_test(results):
@@ -125,46 +142,47 @@ def _viz_wealth(results):
125
  annotation_position="bottom right",
126
  )
127
  fig.update_yaxes(range=[0, 1.5 * 1 / significance_level])
 
128
  st.plotly_chart(fig, use_container_width=True)
129
 
130
 
131
  def viz_results():
132
  results = st.session_state.results
133
 
134
- if results is None:
135
- st.info("Test concepts to show results", icon="ℹ️")
136
- else:
137
- rank_tab, test_tab, wealth_tab = st.tabs(
138
- ["Rank of importance", "Testing results", "Wealth process"]
139
- )
140
 
141
- with rank_tab:
142
- st.subheader("Rank of Importance")
143
- st.write(
144
- """
145
  This tab visually shows the rank of importance of the specified concepts
146
  for the prediction of the model on the input image. Larger font sizes indicate
147
  higher importance. See the other two tabs for more details.
148
- """
149
- )
150
 
151
- if results is not None:
152
- _viz_rank(results)
153
- st.divider()
 
 
154
 
155
- with test_tab:
156
- st.subheader("Testing Results")
157
- st.write(
158
- """
159
  Importance is measured by performing sequential tests of statistical independence.
160
  This tab shows the results of these tests and how the rank of importance is computed.
161
  Concepts are sorted by increasing rejection time, where a shorter rejection time indicates
162
  higher importance.
 
 
 
 
163
  """
164
- )
165
- with st.expander("Details"):
166
- st.markdown(
167
- """
168
  Results are averaged over multiple random draws of conditioning subsets of
169
  concepts. The number of tests can be controlled under `Advanced settings`.
170
 
@@ -172,24 +190,28 @@ def viz_results():
172
  - **Rejection time**: The (normalized) average number of steps before the test is
173
  rejected for a concept.
174
  - **Significance level**: The level at which the test is rejected for a concept.
175
- """
176
- )
177
 
178
- if results is not None:
179
- _viz_test(results)
180
- st.divider()
 
 
181
 
182
- with wealth_tab:
183
- st.subheader("Wealth Process of Testing Procedures")
184
- st.markdown(
185
- """
186
  Sequential tests instantiate a wealth process for each concept. Once the
187
  wealth reaches a value of 1/Ξ±, the test is rejected with Type I error control at
188
  level Ξ±. This tab shows the average wealth process of the testing procedures for
189
  each concept.
190
- """
191
- )
192
 
193
- if results is not None:
194
- _viz_wealth(results)
195
- st.divider()
 
 
 
15
  sorted_tau = tau_mu[sorted_idx]
16
  sorted_concepts = [concepts[idx] for idx in sorted_idx]
17
 
18
+ sorted_width = 1 - sorted_tau
19
+ sorted_width /= sorted_width.max()
20
+ sorted_width *= 100
21
+
22
+ m = min(len(concepts), 7)
23
+ nrows = np.ceil(len(sorted_concepts) / m).astype(int)
24
+ with st.container():
25
+ for i in range(nrows):
26
+ cols = st.columns(m)
27
+ for j in range(m):
28
+ concept_idx = i * m + j
29
+ if concept_idx >= len(sorted_concepts):
30
+ break
31
+ concept = sorted_concepts[concept_idx]
32
+ col = cols[j]
33
+ with col:
34
+ circle_style = f"""
35
+ background: #418FDE;
36
+ border-radius: 50%;
37
+ width: {sorted_width[concept_idx]}%;
38
+ padding-bottom: {sorted_width[concept_idx]}%;
39
+ """
40
+ st.markdown(
41
+ f"""
42
+ <p id='concept'>
43
+ <strong>{concept}</strong>
44
+ </p>
45
+ <div style='{circle_style}'></div>
46
+ """,
47
+ unsafe_allow_html=True,
48
+ )
49
 
50
 
51
  def _viz_test(results):
 
142
  annotation_position="bottom right",
143
  )
144
  fig.update_yaxes(range=[0, 1.5 * 1 / significance_level])
145
+ fig.update_layout(margin=dict(l=20, r=20, t=20, b=20))
146
  st.plotly_chart(fig, use_container_width=True)
147
 
148
 
149
  def viz_results():
150
  results = st.session_state.results
151
 
152
+ st.header("Results")
153
+ rank_tab, test_tab, wealth_tab = st.tabs(
154
+ ["Rank of importance", "Testing results", "Wealth process"]
155
+ )
 
 
156
 
157
+ with rank_tab:
158
+ st.subheader("Rank of Importance")
159
+ st.write(
160
+ """
161
  This tab visually shows the rank of importance of the specified concepts
162
  for the prediction of the model on the input image. Larger font sizes indicate
163
  higher importance. See the other two tabs for more details.
164
+ """
165
+ )
166
 
167
+ if results is not None:
168
+ _viz_rank(results)
169
+ st.divider()
170
+ else:
171
+ st.info("Waiting for results", icon="ℹ️")
172
 
173
+ with test_tab:
174
+ st.subheader("Testing Results")
175
+ st.write(
176
+ """
177
  Importance is measured by performing sequential tests of statistical independence.
178
  This tab shows the results of these tests and how the rank of importance is computed.
179
  Concepts are sorted by increasing rejection time, where a shorter rejection time indicates
180
  higher importance.
181
+ """
182
+ )
183
+ with st.expander("Details"):
184
+ st.markdown(
185
  """
 
 
 
 
186
  Results are averaged over multiple random draws of conditioning subsets of
187
  concepts. The number of tests can be controlled under `Advanced settings`.
188
 
 
190
  - **Rejection time**: The (normalized) average number of steps before the test is
191
  rejected for a concept.
192
  - **Significance level**: The level at which the test is rejected for a concept.
193
+ """
194
+ )
195
 
196
+ if results is not None:
197
+ _viz_test(results)
198
+ st.divider()
199
+ else:
200
+ st.info("Waiting for results", icon="ℹ️")
201
 
202
+ with wealth_tab:
203
+ st.subheader("Wealth Process of Testing Procedures")
204
+ st.markdown(
205
+ """
206
  Sequential tests instantiate a wealth process for each concept. Once the
207
  wealth reaches a value of 1/Ξ±, the test is rejected with Type I error control at
208
  level Ξ±. This tab shows the average wealth process of the testing procedures for
209
  each concept.
210
+ """
211
+ )
212
 
213
+ if results is not None:
214
+ _viz_wealth(results)
215
+ st.divider()
216
+ else:
217
+ st.info("Waiting for results", icon="ℹ️")
assets/image_presets.json CHANGED
@@ -1,52 +1,50 @@
1
  {
2
- "ace": {
3
  "class_name": "cat",
4
  "concepts": [
5
- "piano",
6
- "cute",
7
- "whiskers",
8
- "music",
9
- "wild"
 
 
10
  ]
11
  },
12
- "english_springer_1": {
13
- "class_name": "English springer",
14
- "concepts": [
15
- "spaniel",
16
- "sibling",
17
- "fluffy",
18
- "patch",
19
- "portrait"
20
- ]
21
- },
22
- "english_springer_2": {
23
- "class_name": "English springer",
24
  "concepts": [
25
- "spaniel",
26
- "fetch",
27
- "fishing",
28
- "trumpet",
29
- "cathedral"
 
 
30
  ]
31
  },
32
- "french_horn": {
33
- "class_name": "French horn",
34
  "concepts": [
35
- "trumpet",
36
- "band",
37
- "instrument",
38
- "major",
39
- "naval"
 
 
40
  ]
41
  },
42
- "parachute": {
43
- "class_name": "parachute",
44
  "concepts": [
45
- "flew",
46
- "descending",
47
- "tandem",
48
- "instrument",
49
- "band"
 
 
50
  ]
51
  }
52
  }
 
1
  {
2
+ "bowl_ace": {
3
  "class_name": "cat",
4
  "concepts": [
5
+ "sink",
6
+ "bathroom",
7
+ "paws",
8
+ "tabby",
9
+ "white",
10
+ "soap",
11
+ "mirror"
12
  ]
13
  },
14
+ "gardener_ace": {
15
+ "class_name": "cat",
 
 
 
 
 
 
 
 
 
 
16
  "concepts": [
17
+ "box",
18
+ "plant",
19
+ "floor",
20
+ "gardening",
21
+ "whiskers",
22
+ "fur",
23
+ "brown"
24
  ]
25
  },
26
+ "gentleman_ace": {
27
+ "class_name": "cat",
28
  "concepts": [
29
+ "lamp",
30
+ "chair",
31
+ "blue",
32
+ "velvet",
33
+ "paw",
34
+ "pointy ears",
35
+ "wooden"
36
  ]
37
  },
38
+ "mathematician_ace": {
39
+ "class_name": "cat",
40
  "concepts": [
41
+ "desk",
42
+ "book",
43
+ "whiskers",
44
+ "green eyes",
45
+ "paws",
46
+ "math",
47
+ "play"
48
  ]
49
  }
50
  }
assets/images/ace.jpg DELETED
Binary file (64 kB)
 
assets/images/bowl_ace.jpg ADDED
assets/images/english_springer_1.jpg DELETED
Binary file (33 kB)
 
assets/images/english_springer_2.jpg DELETED
Binary file (110 kB)
 
assets/images/french_horn.jpg DELETED
Binary file (136 kB)
 
assets/images/gardener_ace.jpg ADDED
assets/images/gentleman_ace.jpg ADDED
assets/images/mathematician_ace.jpg ADDED
assets/images/parachute.jpg DELETED
Binary file (121 kB)
 
assets/results/{ace.npy β†’ bowl_ace.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c82df211515a7d8c565c6e350d4183ba7c877d23ab03e0d639319a3429d8850b
3
- size 81407
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f7746293f59199f6872a9570757b2d9d827c2733935cedba2c161181a1cc19c
3
+ size 226871
assets/results/{english_springer_1.npy β†’ gardener_ace.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:81c05c7f4cc612f2afefb06a67a970001605ccfe83ddc72e2eb46cafa27773dc
3
- size 81414
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e3e4a87d960a3591c6c97e10cc40a9e0727048118cc1d5670770bdd74e457ce
3
+ size 226873
assets/results/{english_springer_2.npy β†’ gentleman_ace.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c2ef386512d02796f4570464e88ffaf6df8829858ad6a23836eaa7f86c6a1302
3
- size 81416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a319ecde9c2299323692edf385274b6939c15a9d6c296aa70629898d8798934f
3
+ size 226874
assets/results/{french_horn.npy β†’ mathematician_ace.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:467bf67d0cbf873bb8c72b0bf68906e331573920cf019b78142245d37a9fafb9
3
- size 81412
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf0f277e9f06b957f03c39620a11514a373754b72cf1b5e97088964d07ff7b4a
3
+ size 226873
assets/results/parachute.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33b001c2390ecdd8975fc8a39b257813a0ff7a05fc65c6a52a981581693b6a42
3
- size 81415
 
 
 
 
precompute_results.py CHANGED
@@ -26,7 +26,7 @@ for _image_name, _image_presets in image_presets.items():
26
  _image = Image.open(_image_path)
27
  _class_name = _image_presets["class_name"]
28
  _concepts = _image_presets["concepts"]
29
- _cardinality = defaults.CARDINALITY_VALUE(_concepts)
30
 
31
  _results = test(
32
  testing_config,
 
26
  _image = Image.open(_image_path)
27
  _class_name = _image_presets["class_name"]
28
  _concepts = _image_presets["concepts"]
29
+ _cardinality = defaults.CARDINALITY_VALUE
30
 
31
  _results = test(
32
  testing_config,
style.css CHANGED
@@ -39,6 +39,14 @@ h1 {
39
  }
40
  }
41
 
 
 
 
 
 
 
 
 
42
  .stSpinner>div {
43
  display: flex;
44
  justify-content: center;
 
39
  }
40
  }
41
 
42
+ [data-testid="stMarkdownContainer"]:has(> #concept) {
43
+ display: flex;
44
+ flex-direction: column;
45
+ align-items: center;
46
+
47
+
48
+ }
49
+
50
  .stSpinner>div {
51
  display: flex;
52
  justify-content: center;