RobotJelly commited on
Commit
7c286b0
1 Parent(s): 4a33ca4
Files changed (1) hide show
  1. app.py +65 -35
app.py CHANGED
@@ -3,25 +3,27 @@ from pathlib import Path
3
  import pandas as pd
4
  import numpy as np
5
  import torch
 
6
  from PIL import Image
7
  from io import BytesIO
8
  import requests
9
  import gradio as gr
10
  import os
11
  from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
12
- import urllib.request
 
13
 
14
  # check if CUDA available
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # Load the openAI's CLIP model
18
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
19
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
20
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
21
 
22
  # taking photo IDs
23
- photo_ids = pd.read_csv("./photo_ids.csv")
24
- photo_ids = list(photo_ids['photo_id'])
25
 
26
  # Photo dataset
27
  photos = pd.read_csv("./photos.tsv000", sep="\t", header=0)
@@ -31,38 +33,56 @@ photo_features = np.load("./features.npy")
31
 
32
  IMAGES_DIR = './photos'
33
 
34
- def show_output_image(matched_images) :
35
- image=[]
36
- for photo_id in matched_images:
37
- photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
 
 
38
  #response = requests.get(photo_image_url, stream=True)
39
  #img = Image.open(BytesIO(response.content))
40
- response = requests.get(photo_image_url, stream=True).raw
41
- img = Image.open(response)
42
  #photo = photo_id + '.jpg'
43
  #img = Image.open(response).convert("RGB")
44
  #img = Image.open(os.path.join(IMAGES_DIR, photo))
45
- image.append(img)
46
- return image
 
47
 
48
  # Encode and normalize the search query using CLIP
49
- def encode_search_query(search_query, model, device):
50
- with torch.no_grad():
51
- inputs = tokenizer([search_query], padding=True, return_tensors="pt")
52
  #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
53
- text_features = model.get_text_features(**inputs).cpu().numpy()
54
- return text_features
55
 
56
  # Find all matched photos
57
- def find_matches(features, photo_ids, results_count=4):
58
  # Compute the similarity between the search query and each photo using the Cosine similarity
59
  #text_features = np.array(text_features)
60
- similarities = (photo_features @ features.T).squeeze(1)
61
  # Sort the photos by their similarity score
62
- best_photo_idx = (-similarities).argsort()
63
  # Return the photo IDs of the best matches
64
- matches = [photo_ids[i] for i in best_photo_idx[:results_count]]
65
- return matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def image_search(search_text, search_image, option):
68
 
@@ -70,25 +90,35 @@ def image_search(search_text, search_image, option):
70
  #search_query = "The feeling when your program finally works"
71
 
72
  if option == "Text-To-Image" :
73
- # Extracting text features
74
- text_features = encode_search_query(search_text, model, device)
 
75
 
76
  # Find the matched Images
77
- matched_images = find_matches(text_features, photo_features, photo_ids, 4)
 
78
 
79
- return show_output_image(matched_images)
 
80
  elif option == "Image-To-Image":
81
  # Input Image for Search
82
  #search_image = Image.fromarray(search_image.astype('uint8'), 'RGB')
83
 
84
- with torch.no_grad():
85
- processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
86
- image_feature = model.get_image_features(processed_image.to(device))
87
- image_feature /= image_feature.norm(dim=-1, keepdim=True)
88
- image_feature = image_feature.cpu().numpy()
89
  # Find the matched Images
90
- matched_images = find_matches(image_feature, photo_ids, 4)
91
- return show_output_image(matched_images)
 
 
 
 
 
 
 
92
 
93
  gr.Interface(fn=image_search,
94
  inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
+ import pickle
7
  from PIL import Image
8
  from io import BytesIO
9
  import requests
10
  import gradio as gr
11
  import os
12
  from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
13
+ import sentence_transformers
14
+ from sentence_transformers import SentenceTransformer, util
15
 
16
  # check if CUDA available
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  # Load the openAI's CLIP model
20
+ #model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
21
+ #processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
22
+ #tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
23
 
24
  # taking photo IDs
25
+ #photo_ids = pd.read_csv("./photo_ids.csv")
26
+ #photo_ids = list(photo_ids['photo_id'])
27
 
28
  # Photo dataset
29
  photos = pd.read_csv("./photos.tsv000", sep="\t", header=0)
33
 
34
  IMAGES_DIR = './photos'
35
 
36
+
37
+
38
+ #def show_output_image(matched_images) :
39
+ #image=[]
40
+ #for photo_id in matched_images:
41
+ # photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
42
  #response = requests.get(photo_image_url, stream=True)
43
  #img = Image.open(BytesIO(response.content))
44
+ # response = requests.get(photo_image_url, stream=True).raw
45
+ # img = Image.open(response)
46
  #photo = photo_id + '.jpg'
47
  #img = Image.open(response).convert("RGB")
48
  #img = Image.open(os.path.join(IMAGES_DIR, photo))
49
+ #image.append(img)
50
+ #return image
51
+
52
 
53
  # Encode and normalize the search query using CLIP
54
+ #def encode_search_query(search_query, model, device):
55
+ # with torch.no_grad():
56
+ # inputs = tokenizer([search_query], padding=True, return_tensors="pt")
57
  #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
58
+ # text_features = model.get_text_features(**inputs).cpu().numpy()
59
+ # return text_features
60
 
61
  # Find all matched photos
62
+ #def find_matches(features, photo_ids, results_count=4):
63
  # Compute the similarity between the search query and each photo using the Cosine similarity
64
  #text_features = np.array(text_features)
65
+ #similarities = (photo_features @ features.T).squeeze(1)
66
  # Sort the photos by their similarity score
67
+ #best_photo_idx = (-similarities).argsort()
68
  # Return the photo IDs of the best matches
69
+ #matches = [photo_ids[i] for i in best_photo_idx[:results_count]]
70
+ #return matches
71
+
72
+ #Load CLIP model
73
+ model = SentenceTransformer('clip-ViT-B-32')
74
+
75
+ # pre-computed embeddings
76
+ emb_filename = 'unsplash-25k-photos-embeddings.pkl'
77
+ with open(emb_filename, 'rb') as fIn:
78
+ img_names, img_emb = pickle.load(fIn)
79
+
80
+ def display_matches(similarity):
81
+ best_matched_images = []
82
+ for best_img in torch.topk(similarity, 4, 0).indices:
83
+ img = Image.open(os.path.join('./photos', img_names[best_img]))
84
+ best_matched_images.append(img)
85
+ return best_matched_images
86
 
87
  def image_search(search_text, search_image, option):
88
 
90
  #search_query = "The feeling when your program finally works"
91
 
92
  if option == "Text-To-Image" :
93
+ # Extracting text features embeddings
94
+ #text_features = encode_search_query(search_text, model, device)
95
+ text_emb = model.encode([serach_text], convert_to_tensor=True)
96
 
97
  # Find the matched Images
98
+ #matched_images = find_matches(text_features, photo_features, photo_ids, 4)
99
+ similarity = util.cos_sim(text_emb, img_emb)
100
 
101
+ # top 4 highest ranked images
102
+ return display_matches(similarity)
103
  elif option == "Image-To-Image":
104
  # Input Image for Search
105
  #search_image = Image.fromarray(search_image.astype('uint8'), 'RGB')
106
 
107
+ #with torch.no_grad():
108
+ # processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
109
+ # image_feature = model.get_image_features(processed_image.to(device))
110
+ # image_feature /= image_feature.norm(dim=-1, keepdim=True)
111
+ #image_feature = image_feature.cpu().numpy()
112
  # Find the matched Images
113
+ #matched_images = find_matches(image_feature, photo_ids, 4)
114
+
115
+ image_emb = model.encode(Image.open(search_image), convert_to_tensor=True)
116
+
117
+ # Find the matched Images
118
+ #matched_images = find_matches(text_features, photo_features, photo_ids, 4)
119
+ similarity = util.cos_sim(image_emb, img_emb)
120
+
121
+ return display_matches(similarity)
122
 
123
  gr.Interface(fn=image_search,
124
  inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),