JavierFnts commited on
Commit
0454d20
β€’
1 Parent(s): 30c8dd0

prompt ranking working

Browse files
Files changed (3) hide show
  1. streamlit_app.py β†’ app.py +95 -92
  2. clip_model.py +13 -16
  3. requirements.txt +3 -2
streamlit_app.py β†’ app.py RENAMED
@@ -4,9 +4,6 @@ import requests
4
  import streamlit as st
5
  from clip_model import ClipModel
6
 
7
- from session_state import SessionState, get_state
8
- from images_mocker import ImagesMocker
9
-
10
  from PIL import Image
11
 
12
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
@@ -32,24 +29,34 @@ def load_image_from_url(url: str) -> Image.Image:
32
  def load_model() -> ClipModel:
33
  return ClipModel()
34
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def limit_number_images(state: SessionState):
37
  """When moving between tasks sometimes the state of images can have too many samples"""
38
- if state.images is not None and len(state.images) > 1:
39
- state.images = [state.images[0]]
40
 
41
 
42
- def limit_number_prompts(state: SessionState):
43
  """When moving between tasks sometimes the state of prompts can have too many samples"""
44
- if state.prompts is not None and len(state.prompts) > 1:
45
- state.prompts = [state.prompts[0]]
46
 
47
 
48
- def is_valid_prediction_state(state: SessionState) -> bool:
49
- if state.images is None or len(state.images) < 1:
50
  st.error("Choose at least one image before predicting")
51
  return False
52
- if state.prompts is None or len(state.prompts) < 1:
53
  st.error("Write at least one prompt before predicting")
54
  return False
55
  return True
@@ -97,16 +104,16 @@ class Sections:
97
  st.markdown("### Try OpenAI's CLIP model in your browser")
98
  st.markdown(" ")
99
  st.markdown(" ")
100
- with st.beta_expander("What is CLIP?"):
101
  st.markdown("CLIP is a machine learning model that computes similarity between text "
102
  "(also called prompts) and images. It has been trained on a dataset with millions of diverse"
103
  " image-prompt pairs, which allows it to generalize to unseen examples."
104
  " <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
105
  unsafe_allow_html=True)
106
- col1, col2 = st.beta_columns(2)
107
  col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
108
  col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
109
- with st.beta_expander("What can CLIP do?"):
110
  st.markdown("#### Prompt ranking")
111
  st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
112
  st.markdown("#### Image ranking")
@@ -118,7 +125,7 @@ class Sections:
118
  st.markdown(" ")
119
 
120
  @staticmethod
121
- def image_uploader(state: SessionState, accept_multiple_files: bool):
122
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
123
  accept_multiple_files=accept_multiple_files)
124
  if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
@@ -129,125 +136,123 @@ class Sections:
129
  pil_image = Image.open(uploaded_image)
130
  pil_image = preprocess_image(pil_image)
131
  images.append(pil_image)
132
- state.images = images
133
 
134
 
135
  @staticmethod
136
- def image_picker(state: SessionState, default_text_input: str):
137
- col1, col2, col3 = st.beta_columns(3)
138
  with col1:
139
  default_image_1 = load_image_from_url("https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg")
140
  st.image(default_image_1, use_column_width=True)
141
  if st.button("Select image 1"):
142
- state.images = [default_image_1]
143
- state.default_text_input = default_text_input
144
  with col2:
145
  default_image_2 = load_image_from_url("https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg")
146
  st.image(default_image_2, use_column_width=True)
147
  if st.button("Select image 2"):
148
- state.images = [default_image_2]
149
- state.default_text_input = default_text_input
150
  with col3:
151
  default_image_3 = load_image_from_url("https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg")
152
  st.image(default_image_3, use_column_width=True)
153
  if st.button("Select image 3"):
154
- state.images = [default_image_3]
155
- state.default_text_input = default_text_input
156
 
157
  @staticmethod
158
- def dataset_picker(state: SessionState):
159
- columns = st.beta_columns(5)
160
- state.dataset = load_default_dataset()
161
  image_idx = 0
162
  for col in columns:
163
- col.image(state.dataset[image_idx])
164
  image_idx += 1
165
- col.image(state.dataset[image_idx])
166
  image_idx += 1
167
  if st.button("Select random dataset"):
168
- state.images = state.dataset
169
- state.default_text_input = "A sign that says 'SLOW DOWN'"
170
 
171
  @staticmethod
172
- def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
173
  raw_text_input = st.text_input(input_label,
174
- value=state.default_text_input if state.default_text_input is not None else "")
175
- state.is_default_text_input = raw_text_input == state.default_text_input
176
  if raw_text_input:
177
- state.prompts = [prompt_prefix + class_name for class_name in raw_text_input.split(";") if len(class_name) > 1]
178
 
179
  @staticmethod
180
- def single_image_input_preview(state: SessionState):
181
  st.markdown("### Preview")
182
- col1, col2 = st.beta_columns([1, 2])
183
  with col1:
184
  st.markdown("Image to classify")
185
- if state.images is not None:
186
- st.image(state.images[0], use_column_width=True)
187
  else:
188
  st.warning("Select an image")
189
 
190
  with col2:
191
  st.markdown("Labels to choose from")
192
- if state.prompts is not None:
193
- for prompt in state.prompts:
194
  st.markdown(f"* {prompt}")
195
- if len(state.prompts) < 2:
196
  st.warning("At least two prompts/classes are needed")
197
  else:
198
  st.warning("Enter the prompts/classes to classify from")
199
 
200
  @staticmethod
201
- def multiple_images_input_preview(state: SessionState):
202
  st.markdown("### Preview")
203
  st.markdown("Images to classify")
204
- col1, col2, col3 = st.beta_columns(3)
205
- if state.images is not None:
206
- for idx, image in enumerate(state.images):
207
- if idx < len(state.images) / 2:
208
- col1.image(state.images[idx], use_column_width=True)
209
  else:
210
- col2.image(state.images[idx], use_column_width=True)
211
- if len(state.images) < 2:
212
  col2.warning("At least 2 images required")
213
  else:
214
  col1.warning("Select an image")
215
 
216
  with col3:
217
  st.markdown("Query prompt")
218
- if state.prompts is not None:
219
- for prompt in state.prompts:
220
  st.write(prompt)
221
  else:
222
  st.warning("Enter the prompt to classify")
223
 
224
  @staticmethod
225
- def classification_output(state: SessionState, model: ClipModel):
226
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
227
- if st.button("Predict") and is_valid_prediction_state(state): # PREDICT πŸš€
228
  with st.spinner("Predicting..."):
229
 
230
  st.markdown("### Results")
231
- # st.write(clip_response)
232
- if len(state.images) == 1:
233
- scores = model.compute_prompts_probabilities(state.images[0], state.prompts)
234
- scored_prompts = [(prompt, score) for prompt, score in zip(state.prompts, scores)]
235
- st.json(scores)
236
  sorted_scored_prompts = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
237
  for prompt, probability in sorted_scored_prompts:
238
  percentage_prob = int(probability * 100)
239
  st.markdown(
240
- f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
241
- elif len(state.prompts) == 1:
242
- st.markdown(f"### {state.prompts[0]}")
243
 
244
- scores = model.compute_prompts_probabilities(state.images[0], state.prompts)
245
- scored_images = [(image, score) for image, score in zip(state.images, scores)]
246
  st.json(scores)
247
- sorted_scored_images = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
 
248
 
249
  for image, probability in sorted_scored_images[:5]:
250
- col1, col2 = st.beta_columns([1, 3])
251
  col1.image(image, use_column_width=True)
252
  percentage_prob = int(probability * 100)
253
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
@@ -269,47 +274,45 @@ class Sections:
269
 
270
 
271
  Sections.header()
272
- col1, col2 = st.beta_columns([1, 2])
273
  col1.markdown(" "); col1.markdown(" ")
274
  col1.markdown("#### Task selection")
275
  task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
276
  st.markdown("<br>", unsafe_allow_html=True)
277
-
278
  model = load_model()
279
- session_state = get_state()
280
  if task_name == "Image classification":
281
- Sections.image_uploader(session_state, accept_multiple_files=False)
282
- if session_state.images is None:
283
  st.markdown("or choose one from")
284
- Sections.image_picker(session_state, default_text_input="banana; boat; bird")
285
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
286
- Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
287
- limit_number_images(session_state)
288
- Sections.single_image_input_preview(session_state)
289
- Sections.classification_output(session_state, model)
290
  elif task_name == "Prompt ranking":
291
- Sections.image_uploader(session_state, accept_multiple_files=False)
292
- if session_state.images is None:
293
  st.markdown("or choose one from")
294
- Sections.image_picker(session_state, default_text_input="A calm afternoon in the Mediterranean; "
295
  "A beautiful creature;"
296
  " Something that grows in tropical regions")
297
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
298
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
299
- Sections.prompts_input(session_state, input_label)
300
- limit_number_images(session_state)
301
- Sections.single_image_input_preview(session_state)
302
- Sections.classification_output(session_state, model)
303
  elif task_name == "Image ranking":
304
- Sections.image_uploader(session_state, accept_multiple_files=True)
305
- if session_state.images is None or len(session_state.images) < 2:
306
  st.markdown("or use this random dataset")
307
- Sections.dataset_picker(session_state)
308
- Sections.prompts_input(session_state, "Enter the prompt to query the images by")
309
- limit_number_prompts(session_state)
310
- Sections.multiple_images_input_preview(session_state)
311
- Sections.classification_output(session_state, model)
312
 
313
  st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
314
  "", unsafe_allow_html=True)
315
- session_state.sync()
 
4
  import streamlit as st
5
  from clip_model import ClipModel
6
 
 
 
 
7
  from PIL import Image
8
 
9
  IMAGES_LINKS = ["https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg",
 
29
  def load_model() -> ClipModel:
30
  return ClipModel()
31
 
32
+ def init_state():
33
+ if "images" not in st.session_state:
34
+ st.session_state.images = None
35
+ if "prompts" not in st.session_state:
36
+ st.session_state.prompts = None
37
+ if "predictions" not in st.session_state:
38
+ st.session_state.predictions = None
39
+ if "default_text_input" not in st.session_state:
40
+ st.session_state.default_text_input = None
41
+
42
 
43
+ def limit_number_images():
44
  """When moving between tasks sometimes the state of images can have too many samples"""
45
+ if st.session_state.images is not None and len(st.session_state.images) > 1:
46
+ st.session_state.images = [st.session_state.images[0]]
47
 
48
 
49
+ def limit_number_prompts():
50
  """When moving between tasks sometimes the state of prompts can have too many samples"""
51
+ if st.session_state.prompts is not None and len(st.session_state.prompts) > 1:
52
+ st.session_state.prompts = [st.session_state.prompts[0]]
53
 
54
 
55
+ def is_valid_prediction_state() -> bool:
56
+ if st.session_state.images is None or len(st.session_state.images) < 1:
57
  st.error("Choose at least one image before predicting")
58
  return False
59
+ if st.session_state.prompts is None or len(st.session_state.prompts) < 1:
60
  st.error("Write at least one prompt before predicting")
61
  return False
62
  return True
 
104
  st.markdown("### Try OpenAI's CLIP model in your browser")
105
  st.markdown(" ")
106
  st.markdown(" ")
107
+ with st.expander("What is CLIP?"):
108
  st.markdown("CLIP is a machine learning model that computes similarity between text "
109
  "(also called prompts) and images. It has been trained on a dataset with millions of diverse"
110
  " image-prompt pairs, which allows it to generalize to unseen examples."
111
  " <br /> Check out [OpenAI's blogpost](https://openai.com/blog/clip/) for more details",
112
  unsafe_allow_html=True)
113
+ col1, col2 = st.columns(2)
114
  col1.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-a.svg")
115
  col2.image("https://openaiassets.blob.core.windows.net/$web/clip/draft/20210104b/overview-b.svg")
116
+ with st.expander("What can CLIP do?"):
117
  st.markdown("#### Prompt ranking")
118
  st.markdown("Given different prompts and an image CLIP will rank the different prompts based on how well they describe the image")
119
  st.markdown("#### Image ranking")
 
125
  st.markdown(" ")
126
 
127
  @staticmethod
128
+ def image_uploader(accept_multiple_files: bool):
129
  uploaded_images = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
130
  accept_multiple_files=accept_multiple_files)
131
  if (not accept_multiple_files and uploaded_images is not None) or (accept_multiple_files and len(uploaded_images) >= 1):
 
136
  pil_image = Image.open(uploaded_image)
137
  pil_image = preprocess_image(pil_image)
138
  images.append(pil_image)
139
+ st.session_state.images = images
140
 
141
 
142
  @staticmethod
143
+ def image_picker(default_text_input: str):
144
+ col1, col2, col3 = st.columns(3)
145
  with col1:
146
  default_image_1 = load_image_from_url("https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg")
147
  st.image(default_image_1, use_column_width=True)
148
  if st.button("Select image 1"):
149
+ st.session_state.images = [default_image_1]
150
+ st.session_state.default_text_input = default_text_input
151
  with col2:
152
  default_image_2 = load_image_from_url("https://cdn.pixabay.com/photo/2019/11/11/14/30/zebra-4618513_960_720.jpg")
153
  st.image(default_image_2, use_column_width=True)
154
  if st.button("Select image 2"):
155
+ st.session_state.images = [default_image_2]
156
+ st.session_state.default_text_input = default_text_input
157
  with col3:
158
  default_image_3 = load_image_from_url("https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg")
159
  st.image(default_image_3, use_column_width=True)
160
  if st.button("Select image 3"):
161
+ st.session_state.images = [default_image_3]
162
+ st.session_state.default_text_input = default_text_input
163
 
164
  @staticmethod
165
+ def dataset_picker():
166
+ columns = st.columns(5)
167
+ st.session_state.dataset = load_default_dataset()
168
  image_idx = 0
169
  for col in columns:
170
+ col.image(st.session_state.dataset[image_idx])
171
  image_idx += 1
172
+ col.image(st.session_state.dataset[image_idx])
173
  image_idx += 1
174
  if st.button("Select random dataset"):
175
+ st.session_state.images = st.session_state.dataset
176
+ st.session_state.default_text_input = "A sign that says 'SLOW DOWN'"
177
 
178
  @staticmethod
179
+ def prompts_input(input_label: str, prompt_prefix: str = ''):
180
  raw_text_input = st.text_input(input_label,
181
+ value=st.session_state.default_text_input if st.session_state.default_text_input is not None else "")
182
+ st.session_state.is_default_text_input = raw_text_input == st.session_state.default_text_input
183
  if raw_text_input:
184
+ st.session_state.prompts = [prompt_prefix + class_name for class_name in raw_text_input.split(";") if len(class_name) > 1]
185
 
186
  @staticmethod
187
+ def single_image_input_preview():
188
  st.markdown("### Preview")
189
+ col1, col2 = st.columns([1, 2])
190
  with col1:
191
  st.markdown("Image to classify")
192
+ if st.session_state.images is not None:
193
+ st.image(st.session_state.images[0], use_column_width=True)
194
  else:
195
  st.warning("Select an image")
196
 
197
  with col2:
198
  st.markdown("Labels to choose from")
199
+ if st.session_state.prompts is not None:
200
+ for prompt in st.session_state.prompts:
201
  st.markdown(f"* {prompt}")
202
+ if len(st.session_state.prompts) < 2:
203
  st.warning("At least two prompts/classes are needed")
204
  else:
205
  st.warning("Enter the prompts/classes to classify from")
206
 
207
  @staticmethod
208
+ def multiple_images_input_preview():
209
  st.markdown("### Preview")
210
  st.markdown("Images to classify")
211
+ col1, col2, col3 = st.columns(3)
212
+ if st.session_state.images is not None:
213
+ for idx, image in enumerate(st.session_state.images):
214
+ if idx < len(st.session_state.images) / 2:
215
+ col1.image(st.session_state.images[idx], use_column_width=True)
216
  else:
217
+ col2.image(st.session_state.images[idx], use_column_width=True)
218
+ if len(st.session_state.images) < 2:
219
  col2.warning("At least 2 images required")
220
  else:
221
  col1.warning("Select an image")
222
 
223
  with col3:
224
  st.markdown("Query prompt")
225
+ if st.session_state.prompts is not None:
226
+ for prompt in st.session_state.prompts:
227
  st.write(prompt)
228
  else:
229
  st.warning("Enter the prompt to classify")
230
 
231
  @staticmethod
232
+ def classification_output(model: ClipModel):
233
  # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
234
+ if st.button("Predict") and is_valid_prediction_state(): # PREDICT πŸš€
235
  with st.spinner("Predicting..."):
236
 
237
  st.markdown("### Results")
238
+ if len(st.session_state.images) == 1:
239
+ scores = model.compute_prompts_probabilities(st.session_state.images[0], st.session_state.prompts)
240
+ scored_prompts = [(prompt, score) for prompt, score in zip(st.session_state.prompts, scores)]
 
 
241
  sorted_scored_prompts = sorted(scored_prompts, key=lambda x: x[1], reverse=True)
242
  for prompt, probability in sorted_scored_prompts:
243
  percentage_prob = int(probability * 100)
244
  st.markdown(
245
+ f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) {prompt}")
246
+ elif len(st.session_state.prompts) == 1:
247
+ st.markdown(f"### {st.session_state.prompts[0]}")
248
 
249
+ scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
 
250
  st.json(scores)
251
+ scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
252
+ sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
253
 
254
  for image, probability in sorted_scored_images[:5]:
255
+ col1, col2 = st.columns([1, 3])
256
  col1.image(image, use_column_width=True)
257
  percentage_prob = int(probability * 100)
258
  col2.markdown(f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200)")
 
274
 
275
 
276
  Sections.header()
277
+ col1, col2 = st.columns([1, 2])
278
  col1.markdown(" "); col1.markdown(" ")
279
  col1.markdown("#### Task selection")
280
  task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
281
  st.markdown("<br>", unsafe_allow_html=True)
282
+ init_state()
283
  model = load_model()
 
284
  if task_name == "Image classification":
285
+ Sections.image_uploader(accept_multiple_files=False)
286
+ if st.session_state.images is None:
287
  st.markdown("or choose one from")
288
+ Sections.image_picker(default_text_input="banana; boat; bird")
289
  input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
290
+ Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
291
+ limit_number_images()
292
+ Sections.single_image_input_preview()
293
+ Sections.classification_output(model)
294
  elif task_name == "Prompt ranking":
295
+ Sections.image_uploader(accept_multiple_files=False)
296
+ if st.session_state.images is None:
297
  st.markdown("or choose one from")
298
+ Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
299
  "A beautiful creature;"
300
  " Something that grows in tropical regions")
301
  input_label = "Enter the prompts to choose from separated by a semi-colon. " \
302
  "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
303
+ Sections.prompts_input(input_label)
304
+ limit_number_images()
305
+ Sections.single_image_input_preview()
306
+ Sections.classification_output(model)
307
  elif task_name == "Image ranking":
308
+ Sections.image_uploader(accept_multiple_files=True)
309
+ if st.session_state.images is None or len(st.session_state.images) < 2:
310
  st.markdown("or use this random dataset")
311
+ Sections.dataset_picker()
312
+ Sections.prompts_input("Enter the prompt to query the images by")
313
+ limit_number_prompts()
314
+ Sections.multiple_images_input_preview()
315
+ Sections.classification_output(model)
316
 
317
  st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
318
  "", unsafe_allow_html=True)
 
clip_model.py CHANGED
@@ -9,7 +9,7 @@ class ClipModel:
9
  ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
10
  'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
11
  """
12
- self.model, self.img_preprocess = clip.load(model_name)
13
 
14
  def predict(self, images: list[Image], prompts: list[str]) -> dict:
15
  if len(images) == 1:
@@ -19,43 +19,40 @@ class ClipModel:
19
  else:
20
  raise ValueError('Either images or prompts must be a single element')
21
 
22
- def compute_prompts_probabilities(self, image: Image, prompts: list[str]) -> dict[str, float]:
23
- preprocessed_image = self.img_preprocess(image).unsqueeze(0)
24
  tokenized_prompts = clip.tokenize(prompts)
25
  with torch.inference_mode():
26
- image_features = self.model.encode_image(preprocessed_image)
27
- text_features = self.model.encode_text(tokenized_prompts)
28
 
29
  # normalized features
30
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
31
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
32
 
33
  # cosine similarity as logits
34
- logit_scale = self.model.logit_scale.exp()
35
  logits_per_image = logit_scale * image_features @ text_features.t()
36
 
37
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
38
 
39
- scored_prompts = {tag: float(prob) for tag, prob in zip(prompts, probs)}
40
- return scored_prompts
41
 
42
- def compute_images_probabilities(self, images: list[Image], prompt: str) -> dict[Image, float]:
43
- raise
44
- preprocessed_images = [self.img_preprocess(image).unsqueeze(0) for image in images]
45
  tokenized_prompts = clip.tokenize(prompt)
46
  with torch.inference_mode():
47
- image_features = self.model.encode_image(preprocessed_image)
48
- text_features = self.model.encode_text(tokenized_prompts)
49
 
50
  # normalized features
51
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
52
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
53
 
54
  # cosine similarity as logits
55
- logit_scale = self.model.logit_scale.exp()
56
  logits_per_image = logit_scale * image_features @ text_features.t()
57
 
58
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
59
 
60
- scored_prompts = {tag: float(prob) for tag, prob in zip(prompts, probs)}
61
- return scored_prompts
 
9
  ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
10
  'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
11
  """
12
+ self._model, self._img_preprocess = clip.load(model_name)
13
 
14
  def predict(self, images: list[Image], prompts: list[str]) -> dict:
15
  if len(images) == 1:
 
19
  else:
20
  raise ValueError('Either images or prompts must be a single element')
21
 
22
+ def compute_prompts_probabilities(self, image: Image, prompts: list[str]) -> list[float]:
23
+ preprocessed_image = self._img_preprocess(image).unsqueeze(0)
24
  tokenized_prompts = clip.tokenize(prompts)
25
  with torch.inference_mode():
26
+ image_features = self._model.encode_image(preprocessed_image)
27
+ text_features = self._model.encode_text(tokenized_prompts)
28
 
29
  # normalized features
30
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
31
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
32
 
33
  # cosine similarity as logits
34
+ logit_scale = self._model.logit_scale.exp()
35
  logits_per_image = logit_scale * image_features @ text_features.t()
36
 
37
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
38
 
39
+ return probs
 
40
 
41
+ def compute_images_probabilities(self, images: list[Image], prompt: str) -> list[float]:
42
+ preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
 
43
  tokenized_prompts = clip.tokenize(prompt)
44
  with torch.inference_mode():
45
+ image_features = self._model.encode_image(torch.cat(preprocessed_images))
46
+ text_features = self._model.encode_text(tokenized_prompts)
47
 
48
  # normalized features
49
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
50
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
51
 
52
  # cosine similarity as logits
53
+ logit_scale = self._model.logit_scale.exp()
54
  logits_per_image = logit_scale * image_features @ text_features.t()
55
 
56
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
57
 
58
+ return probs
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- streamlit~=0.76.0
2
  git+https://github.com/openai/CLIP@b46f5ac
3
  Pillow==8.1.0
4
- mock==4.0.3
 
 
1
+ streamlit~=1.11.1
2
  git+https://github.com/openai/CLIP@b46f5ac
3
  Pillow==8.1.0
4
+ mock==4.0.3
5
+ protobuf==3.20.0 # It raises errors otherwise