JavierFnts commited on
Commit
40497e6
β€’
1 Parent(s): 0454d20

Base functionality working again πŸŽ‰

Browse files
Files changed (2) hide show
  1. app.py +45 -47
  2. clip_model.py +16 -3
app.py CHANGED
@@ -230,8 +230,7 @@ class Sections:
230
 
231
  @staticmethod
232
  def classification_output(model: ClipModel):
233
- # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
234
- if st.button("Predict") and is_valid_prediction_state(): # PREDICT πŸš€
235
  with st.spinner("Predicting..."):
236
 
237
  st.markdown("### Results")
@@ -247,7 +246,6 @@ class Sections:
247
  st.markdown(f"### {st.session_state.prompts[0]}")
248
 
249
  scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
250
- st.json(scores)
251
  scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
252
  sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
253
 
@@ -272,47 +270,47 @@ class Sections:
272
  # " It can be whatever you can think of",
273
  # unsafe_allow_html=True)
274
 
275
-
276
- Sections.header()
277
- col1, col2 = st.columns([1, 2])
278
- col1.markdown(" "); col1.markdown(" ")
279
- col1.markdown("#### Task selection")
280
- task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
281
- st.markdown("<br>", unsafe_allow_html=True)
282
- init_state()
283
- model = load_model()
284
- if task_name == "Image classification":
285
- Sections.image_uploader(accept_multiple_files=False)
286
- if st.session_state.images is None:
287
- st.markdown("or choose one from")
288
- Sections.image_picker(default_text_input="banana; boat; bird")
289
- input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
290
- Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
291
- limit_number_images()
292
- Sections.single_image_input_preview()
293
- Sections.classification_output(model)
294
- elif task_name == "Prompt ranking":
295
- Sections.image_uploader(accept_multiple_files=False)
296
- if st.session_state.images is None:
297
- st.markdown("or choose one from")
298
- Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
299
- "A beautiful creature;"
300
- " Something that grows in tropical regions")
301
- input_label = "Enter the prompts to choose from separated by a semi-colon. " \
302
- "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
303
- Sections.prompts_input(input_label)
304
- limit_number_images()
305
- Sections.single_image_input_preview()
306
- Sections.classification_output(model)
307
- elif task_name == "Image ranking":
308
- Sections.image_uploader(accept_multiple_files=True)
309
- if st.session_state.images is None or len(st.session_state.images) < 2:
310
- st.markdown("or use this random dataset")
311
- Sections.dataset_picker()
312
- Sections.prompts_input("Enter the prompt to query the images by")
313
- limit_number_prompts()
314
- Sections.multiple_images_input_preview()
315
- Sections.classification_output(model)
316
-
317
- st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
318
- "", unsafe_allow_html=True)
 
230
 
231
  @staticmethod
232
  def classification_output(model: ClipModel):
233
+ if st.button("Predict") and is_valid_prediction_state():
 
234
  with st.spinner("Predicting..."):
235
 
236
  st.markdown("### Results")
 
246
  st.markdown(f"### {st.session_state.prompts[0]}")
247
 
248
  scores = model.compute_images_probabilities(st.session_state.images, st.session_state.prompts[0])
 
249
  scored_images = [(image, score) for image, score in zip(st.session_state.images, scores)]
250
  sorted_scored_images = sorted(scored_images, key=lambda x: x[1], reverse=True)
251
 
 
270
  # " It can be whatever you can think of",
271
  # unsafe_allow_html=True)
272
 
273
+ if __name__ == "__main__":
274
+ Sections.header()
275
+ col1, col2 = st.columns([1, 2])
276
+ col1.markdown(" "); col1.markdown(" ")
277
+ col1.markdown("#### Task selection")
278
+ task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
279
+ st.markdown("<br>", unsafe_allow_html=True)
280
+ init_state()
281
+ model = load_model()
282
+ if task_name == "Image classification":
283
+ Sections.image_uploader(accept_multiple_files=False)
284
+ if st.session_state.images is None:
285
+ st.markdown("or choose one from")
286
+ Sections.image_picker(default_text_input="banana; boat; bird")
287
+ input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
288
+ Sections.prompts_input(input_label, prompt_prefix='A picture of a ')
289
+ limit_number_images()
290
+ Sections.single_image_input_preview()
291
+ Sections.classification_output(model)
292
+ elif task_name == "Prompt ranking":
293
+ Sections.image_uploader(accept_multiple_files=False)
294
+ if st.session_state.images is None:
295
+ st.markdown("or choose one from")
296
+ Sections.image_picker(default_text_input="A calm afternoon in the Mediterranean; "
297
+ "A beautiful creature;"
298
+ " Something that grows in tropical regions")
299
+ input_label = "Enter the prompts to choose from separated by a semi-colon. " \
300
+ "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
301
+ Sections.prompts_input(input_label)
302
+ limit_number_images()
303
+ Sections.single_image_input_preview()
304
+ Sections.classification_output(model)
305
+ elif task_name == "Image ranking":
306
+ Sections.image_uploader(accept_multiple_files=True)
307
+ if st.session_state.images is None or len(st.session_state.images) < 2:
308
+ st.markdown("or use this random dataset")
309
+ Sections.dataset_picker()
310
+ Sections.prompts_input("Enter the prompt to query the images by")
311
+ limit_number_prompts()
312
+ Sections.multiple_images_input_preview()
313
+ Sections.classification_output(model)
314
+
315
+ st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
316
+ "", unsafe_allow_html=True)
clip_model.py CHANGED
@@ -2,6 +2,8 @@ import clip
2
  from PIL.Image import Image
3
  import torch
4
 
 
 
5
  class ClipModel:
6
  def __init__(self, model_name: str = 'RN50') -> None:
7
  """
@@ -42,7 +44,7 @@ class ClipModel:
42
  preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
43
  tokenized_prompts = clip.tokenize(prompt)
44
  with torch.inference_mode():
45
- image_features = self._model.encode_image(torch.cat(preprocessed_images))
46
  text_features = self._model.encode_text(tokenized_prompts)
47
 
48
  # normalized features
@@ -51,8 +53,19 @@ class ClipModel:
51
 
52
  # cosine similarity as logits
53
  logit_scale = self._model.logit_scale.exp()
54
- logits_per_image = logit_scale * image_features @ text_features.t()
55
 
56
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
57
 
58
- return probs
 
 
 
 
 
 
 
 
 
 
 
 
2
  from PIL.Image import Image
3
  import torch
4
 
5
+
6
+
7
  class ClipModel:
8
  def __init__(self, model_name: str = 'RN50') -> None:
9
  """
 
44
  preprocessed_images = [self._img_preprocess(image).unsqueeze(0) for image in images]
45
  tokenized_prompts = clip.tokenize(prompt)
46
  with torch.inference_mode():
47
+ image_features = torch.cat([self._model.encode_image(preprocessed_image) for preprocessed_image in preprocessed_images])
48
  text_features = self._model.encode_text(tokenized_prompts)
49
 
50
  # normalized features
 
53
 
54
  # cosine similarity as logits
55
  logit_scale = self._model.logit_scale.exp()
56
+ logits_per_image = logit_scale * text_features @ image_features.t()
57
 
58
  probs = list(logits_per_image.softmax(dim=-1).cpu().numpy()[0])
59
 
60
+ return probs
61
+
62
+ if __name__ == "__main__":
63
+ from app import load_default_dataset
64
+
65
+ model = ClipModel()
66
+ images = load_default_dataset()
67
+ prompts = ['Hello', 'How are you', 'Goodbye']
68
+ prompts_scores = model.compute_prompts_probabilities(images[0], prompts)
69
+ images_scores = model.compute_images_probabilities(images, prompts[0])
70
+ print(f"Prompts scores: {prompts_scores}")
71
+ print(f"Images scores: {images_scores}")