Ashish08 commited on
Commit
bf88188
1 Parent(s): 8c2b29c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import torch
3
  import clip
4
  from PIL import Image
 
5
  import numpy as np
6
 
7
  # Load CLIP model and preprocessing
@@ -9,17 +10,34 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model, preprocess = clip.load("ViT-B/32", device=device)
10
 
11
  # Function to predict descriptions and probabilities
12
- def predict(image, descriptions):
 
 
 
 
 
 
 
 
 
 
 
 
13
  image = preprocess(image).unsqueeze(0).to(device)
 
14
  text = clip.tokenize(descriptions).to(device)
15
 
16
  with torch.no_grad():
 
17
  image_features = model.encode_image(image)
18
  text_features = model.encode_text(text)
19
 
 
20
  logits_per_image, logits_per_text = model(image, text)
 
21
  probs = logits_per_image.softmax(dim=-1).cpu().numpy()
22
 
 
23
  return descriptions[np.argmax(probs)], np.max(probs)
24
 
25
  # Streamlit app
 
2
  import torch
3
  import clip
4
  from PIL import Image
5
+ from typing import List, Tuple
6
  import numpy as np
7
 
8
  # Load CLIP model and preprocessing
 
10
  model, preprocess = clip.load("ViT-B/32", device=device)
11
 
12
  # Function to predict descriptions and probabilities
13
+ def predict(image: Image.Image, descriptions: List[str]) -> Tuple[str, float]:
14
+ """
15
+ Predict the best matching description for the provided image based on the given descriptions.
16
+ Uses the CLIP model to compute similarities between the image and text descriptions.
17
+
18
+ Args:
19
+ image (Image.Image): The input image for which the descriptions are being evaluated.
20
+ descriptions (List[str]): A list of textual descriptions to compare against the image.
21
+
22
+ Returns:
23
+ Tuple[str, float]: A tuple containing the best-matching description and the corresponding probability.
24
+ """
25
+ # Preprocess the image and move it to the appropriate device
26
  image = preprocess(image).unsqueeze(0).to(device)
27
+ # Tokenize the descriptions and move them to the appropriate device
28
  text = clip.tokenize(descriptions).to(device)
29
 
30
  with torch.no_grad():
31
+ # Encode image and text features using the CLIP model
32
  image_features = model.encode_image(image)
33
  text_features = model.encode_text(text)
34
 
35
+ # Compute the similarity scores (logits) between image and text
36
  logits_per_image, logits_per_text = model(image, text)
37
+ # Convert logits to probabilities
38
  probs = logits_per_image.softmax(dim=-1).cpu().numpy()
39
 
40
+ # Return the description with the highest probability and the corresponding probability
41
  return descriptions[np.argmax(probs)], np.max(probs)
42
 
43
  # Streamlit app