Javi commited on
Commit
af5047d
1 Parent(s): 081801e

Introduced file uploader hack

Browse files
Files changed (1) hide show
  1. streamlit_app.py +83 -22
streamlit_app.py CHANGED
@@ -1,9 +1,40 @@
1
  import random
2
  from typing import Optional, List
 
3
 
4
- import booste
5
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
7
  from session_state import SessionState, get_state
8
 
9
  # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
@@ -28,7 +59,7 @@ IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_9
28
  "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
29
  ]
30
 
31
- @st.cache
32
  def select_random_dataset():
33
  return random.sample(IMAGES_LINKS, 10)
34
 
@@ -46,9 +77,17 @@ class Sections:
46
  st.markdown(" ")
47
 
48
  @staticmethod
49
- def image_uploader(accept_multiple_files: bool) -> Optional[List[str]]:
50
- uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
51
- accept_multiple_files=accept_multiple_files)
 
 
 
 
 
 
 
 
52
 
53
  @staticmethod
54
  def image_picker(state: SessionState):
@@ -117,9 +156,12 @@ class Sections:
117
  col1.image(state.images[idx], use_column_width=True)
118
  else:
119
  col2.image(state.images[idx], use_column_width=True)
 
 
120
  else:
121
  col1.warning("Select an image")
122
 
 
123
  with col3:
124
  st.markdown("Query prompt")
125
  if state.prompts is not None:
@@ -133,10 +175,19 @@ class Sections:
133
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
134
  if st.button("Predict"):
135
  with st.spinner("Predicting..."):
136
- clip_response = booste.clip(BOOSTE_API_KEY,
137
- prompts=state.prompts,
138
- images=state.images,
139
- pretty_print=True)
 
 
 
 
 
 
 
 
 
140
  st.markdown("### Results")
141
  # st.write(clip_response)
142
  if len(state.images) == 1:
@@ -152,8 +203,13 @@ class Sections:
152
  else:
153
  st.markdown(f"### {state.prompts[0]}")
154
  assert len(state.prompts) == 1
155
- simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
156
- in list(clip_response.values())[0].items()]
 
 
 
 
 
157
  simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
158
  for image, probability in simplified_clip_results[:5]:
159
  col1, col2 = st.beta_columns([1, 3])
@@ -162,35 +218,40 @@ class Sections:
162
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
163
 
164
 
165
-
166
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
167
- session_state = get_state()
168
  if task_name == "Image classification":
 
169
  Sections.header()
170
- Sections.image_uploader(accept_multiple_files=False)
171
- st.markdown("or choose one from")
172
- Sections.image_picker(session_state)
 
173
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
174
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
175
  Sections.single_image_input_preview(session_state)
176
  Sections.classification_output(session_state)
177
  elif task_name == "Prompt ranking":
 
178
  Sections.header()
179
- Sections.image_uploader(accept_multiple_files=False)
180
- st.markdown("or choose one from")
181
- Sections.image_picker(session_state)
 
182
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
183
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
184
  Sections.prompts_input(session_state, input_label)
185
  Sections.single_image_input_preview(session_state)
186
  Sections.classification_output(session_state)
187
  elif task_name == "Image ranking":
 
188
  Sections.header()
189
- Sections.image_uploader(accept_multiple_files=True)
190
- st.markdown("or use random dataset")
191
- Sections.dataset_picker(session_state)
 
192
  Sections.prompts_input(session_state, "Enter the prompt to query the images by")
193
  Sections.multiple_images_input_preview(session_state)
194
  Sections.classification_output(session_state)
 
195
 
196
  session_state.sync()
 
1
  import random
2
  from typing import Optional, List
3
+ import uuid
4
 
 
5
  import streamlit as st
6
+ from mock import patch
7
+
8
+ class ImagesMocker:
9
+ """HACK ALERT: I needed a way to call the booste API without storing the images first
10
+ (as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
11
+
12
+ def __init__(self):
13
+ self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
14
+ self.path_patch = patch('os.path.exists', lambda x: True)
15
+ self.image_id2image_lookup = {}
16
+
17
+ def start_mocking(self):
18
+ self.pil_patch.start()
19
+ self.path_patch.start()
20
+
21
+ def stop_mocking(self):
22
+ self.pil_patch.stop()
23
+ self.path_patch.stop()
24
+
25
+ def image_id2image(self, image_id: str):
26
+ return self.image_id2image_lookup[image_id]
27
+
28
+ def calculate_image_id2image_lookup(self, images: List):
29
+ self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
30
+ @property
31
+ def image_ids(self):
32
+ return list(self.image_id2image_lookup.keys())
33
 
34
+ images_mocker = ImagesMocker()
35
+ import booste
36
+
37
+ from PIL import Image
38
  from session_state import SessionState, get_state
39
 
40
  # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
 
59
  "https://cdn.pixabay.com/photo/2016/11/29/04/52/architecture-1867411_960_720.jpg",
60
  ]
61
 
62
+ @st.cache # Cache this so that it doesn't change every time something changes in the page
63
  def select_random_dataset():
64
  return random.sample(IMAGES_LINKS, 10)
65
 
 
77
  st.markdown(" ")
78
 
79
  @staticmethod
80
+ def image_uploader(state: SessionState, accept_multiple_files: bool):
81
+ uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
82
+ accept_multiple_files=accept_multiple_files)
83
+ if uploaded_images is not None or (accept_multiple_files and len(uploaded_images) > 1):
84
+ images = []
85
+ if not accept_multiple_files:
86
+ uploaded_images = [uploaded_images]
87
+ for uploaded_image in uploaded_images:
88
+ images.append(Image.open(uploaded_image))
89
+ state.images = images
90
+
91
 
92
  @staticmethod
93
  def image_picker(state: SessionState):
 
156
  col1.image(state.images[idx], use_column_width=True)
157
  else:
158
  col2.image(state.images[idx], use_column_width=True)
159
+ if len(state.images) < 2:
160
+ col2.warning("At least 2 images required")
161
  else:
162
  col1.warning("Select an image")
163
 
164
+
165
  with col3:
166
  st.markdown("Query prompt")
167
  if state.prompts is not None:
 
175
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
176
  if st.button("Predict"):
177
  with st.spinner("Predicting..."):
178
+ if isinstance(state.images[0], str):
179
+ print("Regular call!")
180
+ clip_response = booste.clip(BOOSTE_API_KEY,
181
+ prompts=state.prompts,
182
+ images=state.images)
183
+ else:
184
+ print("Hacky call!")
185
+ images_mocker.calculate_image_id2image_lookup(state.images)
186
+ images_mocker.start_mocking()
187
+ clip_response = booste.clip(BOOSTE_API_KEY,
188
+ prompts=state.prompts,
189
+ images=images_mocker.image_ids)
190
+ images_mocker.stop_mocking()
191
  st.markdown("### Results")
192
  # st.write(clip_response)
193
  if len(state.images) == 1:
 
203
  else:
204
  st.markdown(f"### {state.prompts[0]}")
205
  assert len(state.prompts) == 1
206
+ if isinstance(state.images[0], str):
207
+ simplified_clip_results = [(image, results["probabilityRelativeToImages"]) for image, results
208
+ in list(clip_response.values())[0].items()]
209
+ else:
210
+ simplified_clip_results = [(images_mocker.image_id2image(image),
211
+ results["probabilityRelativeToImages"]) for image, results
212
+ in list(clip_response.values())[0].items()]
213
  simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
214
  for image, probability in simplified_clip_results[:5]:
215
  col1, col2 = st.beta_columns([1, 3])
 
218
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
219
 
220
 
 
221
  task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
 
222
  if task_name == "Image classification":
223
+ session_state = get_state()
224
  Sections.header()
225
+ Sections.image_uploader(session_state, accept_multiple_files=False)
226
+ if session_state.images is None:
227
+ st.markdown("or choose one from")
228
+ Sections.image_picker(session_state)
229
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
230
  Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
231
  Sections.single_image_input_preview(session_state)
232
  Sections.classification_output(session_state)
233
  elif task_name == "Prompt ranking":
234
+ session_state = get_state()
235
  Sections.header()
236
+ Sections.image_uploader(session_state, accept_multiple_files=False)
237
+ if session_state.images is None:
238
+ st.markdown("or choose one from")
239
+ Sections.image_picker(session_state)
240
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
241
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
242
  Sections.prompts_input(session_state, input_label)
243
  Sections.single_image_input_preview(session_state)
244
  Sections.classification_output(session_state)
245
  elif task_name == "Image ranking":
246
+ session_state = get_state()
247
  Sections.header()
248
+ Sections.image_uploader(session_state, accept_multiple_files=True)
249
+ if session_state.images is None:
250
+ st.markdown("or use this random dataset")
251
+ Sections.dataset_picker(session_state)
252
  Sections.prompts_input(session_state, "Enter the prompt to query the images by")
253
  Sections.multiple_images_input_preview(session_state)
254
  Sections.classification_output(session_state)
255
+ print(session_state.images)
256
 
257
  session_state.sync()