jacopoteneggi commited on
Commit
0aef92c
·
verified ·
1 Parent(s): c3af76c
app_lib/main.py CHANGED
@@ -29,7 +29,7 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
29
  image_col, concepts_col = st.columns(2)
30
 
31
  with image_col:
32
- image = get_image()
33
  st.image(image, use_column_width=True)
34
 
35
  change_image_button = st.button(
@@ -42,8 +42,8 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
42
  st.experimental_rerun()
43
  with concepts_col:
44
  model_name = get_model_name()
45
- class_name, class_ready, class_error = get_class_name()
46
- concepts, concepts_ready, concepts_error = get_concepts()
47
 
48
  ready = class_ready and concepts_ready
49
 
 
29
  image_col, concepts_col = st.columns(2)
30
 
31
  with image_col:
32
+ image_name, image = get_image()
33
  st.image(image, use_column_width=True)
34
 
35
  change_image_button = st.button(
 
42
  st.experimental_rerun()
43
  with concepts_col:
44
  model_name = get_model_name()
45
+ class_name, class_ready, class_error = get_class_name(image_name)
46
+ concepts, concepts_ready, concepts_error = get_concepts(image_name)
47
 
48
  ready = class_ready and concepts_ready
49
 
app_lib/user_input.py CHANGED
@@ -1,9 +1,17 @@
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  from streamlit_image_select import image_select
4
 
5
  from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
6
 
 
 
 
 
 
7
 
8
  def _validate_class_name(class_name):
9
  if class_name is None:
@@ -125,37 +133,49 @@ def get_model_name():
125
  def get_image():
126
  with st.sidebar:
127
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
128
- image = uploaded_file or image_select(
129
- label="or select one",
130
- images=[
131
- "assets/ace.jpg",
132
- "assets/ace.jpg",
133
- "assets/ace.jpg",
134
- "assets/ace.jpg",
135
- ],
136
- )
137
- return Image.open(image)
138
-
139
-
140
- def get_class_name():
 
 
 
 
 
141
  class_name = st.text_input(
142
  "Class to test",
143
  help="Name of the class to build the zero-shot CLIP classifier with.",
144
- value="cat",
145
  disabled=st.session_state.disabled,
 
146
  )
147
 
148
  class_ready, class_error = _validate_class_name(class_name)
149
  return class_name, class_ready, class_error
150
 
151
 
152
- def get_concepts():
 
 
 
 
 
153
  concepts = st.text_area(
154
  "Concepts to test",
155
  help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
156
  height=160,
157
- value="piano\ncute\nwhiskers\nmusic\nwild",
158
  disabled=st.session_state.disabled,
 
159
  )
160
  concepts = concepts.split("\n")
161
  concepts = [concept.strip() for concept in concepts]
 
1
+ import json
2
+ import os
3
+
4
  import streamlit as st
5
  from PIL import Image
6
  from streamlit_image_select import image_select
7
 
8
  from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
9
 
10
+ IMAGE_DIR = os.path.join("assets", "images")
11
+ IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR))))
12
+ IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES))
13
+ IMAGE_PRESETS = json.load(open("assets/image_presets.json"))
14
+
15
 
16
  def _validate_class_name(class_name):
17
  if class_name is None:
 
133
  def get_image():
134
  with st.sidebar:
135
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
136
+ if uploaded_file is not None:
137
+ return (None, Image.open(uploaded_file))
138
+ else:
139
+ DEFAULT = IMAGE_NAMES.index("ace.jpg")
140
+ image_idx = image_select(
141
+ label="or select one",
142
+ images=IMAGE_PATHS,
143
+ index=DEFAULT,
144
+ return_value="index",
145
+ )
146
+ image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx]
147
+ return (image_name, Image.open(image_path))
148
+
149
+
150
+ def get_class_name(image_name=None):
151
+ DEFAULT = (
152
+ IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
153
+ )
154
  class_name = st.text_input(
155
  "Class to test",
156
  help="Name of the class to build the zero-shot CLIP classifier with.",
157
+ value=DEFAULT,
158
  disabled=st.session_state.disabled,
159
+ placeholder="Type class name here",
160
  )
161
 
162
  class_ready, class_error = _validate_class_name(class_name)
163
  return class_name, class_ready, class_error
164
 
165
 
166
+ def get_concepts(image_name=None):
167
+ DEFAULT = (
168
+ "\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
169
+ if image_name
170
+ else ""
171
+ )
172
  concepts = st.text_area(
173
  "Concepts to test",
174
  help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
175
  height=160,
176
+ value=DEFAULT,
177
  disabled=st.session_state.disabled,
178
+ placeholder="Type one concept\nper line",
179
  )
180
  concepts = concepts.split("\n")
181
  concepts = [concept.strip() for concept in concepts]
app_lib/viz.py CHANGED
@@ -41,23 +41,32 @@ def _viz_rank(results):
41
  y=rank_df["concept"],
42
  orientation="h",
43
  marker=dict(color="#a6cee3"),
44
- name="Normalized rejection time",
45
  )
46
  )
47
- fig.add_shape(
48
- type="line",
49
- yref="paper",
50
- line=dict(color="black", dash="dash"),
51
- x0=significance_level,
52
- x1=significance_level,
53
- y0=0,
54
- y1=1,
55
- name="significance level",
56
- showlegend=True,
57
  )
 
 
58
  fig.update_layout(yaxis_title="Rank of importance", xaxis_title="")
 
 
 
 
 
 
 
 
 
59
 
60
- _, centercol, _ = st.columns([1, 4, 1])
61
  with centercol:
62
  st.plotly_chart(fig, use_container_width=True)
63
 
@@ -86,7 +95,6 @@ def _viz_wealth(results):
86
  annotation_position="bottom right",
87
  )
88
  fig.update_yaxes(range=[0, 1.5 * 1 / significance_level])
89
- # fig.update_layout(legend=dict(orientation="h", x=0, y=1.2))
90
  st.plotly_chart(fig, use_container_width=True)
91
 
92
 
 
41
  y=rank_df["concept"],
42
  orientation="h",
43
  marker=dict(color="#a6cee3"),
44
+ name="Rejection time",
45
  )
46
  )
47
+ fig.add_trace(
48
+ go.Scatter(
49
+ x=[significance_level, significance_level],
50
+ y=[sorted_concepts[0], sorted_concepts[0]],
51
+ mode="lines",
52
+ line=dict(color="black", dash="dash"),
53
+ name="significance level",
54
+ )
 
 
55
  )
56
+ fig.add_vline(significance_level, line_dash="dash", line_color="black")
57
+
58
  fig.update_layout(yaxis_title="Rank of importance", xaxis_title="")
59
+ if rank_df["tau"].min() <= 0.3:
60
+ fig.update_layout(
61
+ legend=dict(
62
+ x=0.3,
63
+ y=1.0,
64
+ bordercolor="black",
65
+ borderwidth=1,
66
+ ),
67
+ )
68
 
69
+ _, centercol, _ = st.columns([1, 3, 1])
70
  with centercol:
71
  st.plotly_chart(fig, use_container_width=True)
72
 
 
95
  annotation_position="bottom right",
96
  )
97
  fig.update_yaxes(range=[0, 1.5 * 1 / significance_level])
 
98
  st.plotly_chart(fig, use_container_width=True)
99
 
100
 
assets/ace.jpg DELETED
Binary file (197 kB)
 
assets/image_presets.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ace": {
3
+ "class_name": "cat",
4
+ "concepts": [
5
+ "piano",
6
+ "cute",
7
+ "whiskers",
8
+ "music",
9
+ "wild"
10
+ ]
11
+ },
12
+ "english_springer_1": {
13
+ "class_name": "English springer",
14
+ "concepts": [
15
+ "spaniel",
16
+ "sibling",
17
+ "fluffy",
18
+ "patch",
19
+ "portrait"
20
+ ]
21
+ },
22
+ "english_springer_2": {
23
+ "class_name": "English springer",
24
+ "concepts": [
25
+ "spaniel",
26
+ "fetch",
27
+ "fishing",
28
+ "trumpet",
29
+ "cathedral"
30
+ ]
31
+ },
32
+ "french_horn": {
33
+ "class_name": "French horn",
34
+ "concepts": [
35
+ "trumpet",
36
+ "band",
37
+ "instrument",
38
+ "major",
39
+ "naval"
40
+ ]
41
+ },
42
+ "parachute": {
43
+ "class_name": "parachute",
44
+ "concepts": [
45
+ "flew",
46
+ "descending",
47
+ "tandem",
48
+ "instrument",
49
+ "band"
50
+ ]
51
+ }
52
+ }
assets/images/ace.jpg ADDED
assets/images/english_springer_1.jpg ADDED
assets/images/english_springer_2.jpg ADDED
assets/images/french_horn.jpg ADDED
assets/images/parachute.jpg ADDED