piyushgrover commited on
Commit
d3992a1
·
1 Parent(s): e41b4ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -25
app.py CHANGED
@@ -33,10 +33,52 @@ else:
33
  # Print some statistics
34
  print(f"Photos loaded: {len(photo_ids)}")
35
 
 
36
 
37
- def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Encode the search query
39
- if not query_text and not query_photo_id:
40
  return []
41
 
42
  text_features = encode_search_query(model, query_text)
@@ -53,8 +95,12 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
53
  # Find the best match
54
  best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
55
 
56
- elif query_img:
57
- query_photo_features = model.encode_image(query_img)
 
 
 
 
58
  query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
59
 
60
  # Combine the test and photo queries and normalize again
@@ -66,7 +112,7 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
66
  else:
67
  # Display the results
68
  print("Test search result")
69
- best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
70
 
71
  return best_photo_ids
72
 
@@ -76,20 +122,21 @@ with gr.Blocks() as app:
76
  gr.Markdown(
77
  """
78
  # CLIP Image Search Engine!
79
- ### Enter search query or/and input image to find the similar images from the database -
80
  """)
81
 
82
  with gr.Row(visible=True):
83
  with gr.Column():
84
  with gr.Row():
85
- search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query')
86
 
87
  with gr.Row():
88
  submit_btn = gr.Button("Submit", variant='primary')
89
  clear_btn = gr.ClearButton()
90
 
91
- with gr.Column():
92
- search_image = gr.Image(label='Upload Image or Select from results')
 
93
 
94
  with gr.Row(visible=True):
95
  output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
@@ -102,44 +149,75 @@ with gr.Blocks() as app:
102
  return {
103
  search_image: None,
104
  output_images: None,
105
- search_text: None
 
 
106
  }
107
 
108
 
109
- clear_btn.click(clear_data, None, [search_image, output_images, search_text])
110
 
111
 
112
  def on_select(evt: gr.SelectData, output_image_ids):
113
  return {
114
- search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100"
 
 
115
  }
116
 
117
 
118
- output_images.select(on_select, output_image_ids, search_image)
119
 
120
 
121
- def func_search(query, img):
122
- best_photo_ids = search_by_text_and_photo(query, img)
123
- img_urls = []
124
- for p_id in best_photo_ids:
125
- url = f"https://unsplash.com/photos/{p_id}/download?w=100"
126
- img_urls.append(url)
 
 
 
127
 
128
- valid_images = filter_invalid_urls(img_urls, best_photo_ids)
 
 
 
 
 
 
 
 
 
 
129
 
130
- return {
131
- output_image_ids: valid_images['image_ids'],
132
- output_images: valid_images['image_urls']
133
- }
 
 
134
 
135
 
136
  submit_btn.click(
137
  func_search,
138
- [search_text, search_image],
139
  [output_images, output_image_ids]
140
  )
141
 
 
 
 
 
 
 
 
 
 
142
  '''
143
  Launch the app
144
  '''
145
  app.launch()
 
 
 
 
33
  # Print some statistics
34
  print(f"Photos loaded: {len(photo_ids)}")
35
 
36
+ from PIL import Image
37
 
38
+
39
+ def encode_search_query(net, search_query):
40
+ with torch.no_grad():
41
+ tokenized_query = clip.tokenize(search_query)
42
+ # print("tokenized_query: ", tokenized_query.shape)
43
+ # Encode and normalize the search query using CLIP
44
+ text_encoded = net.encode_text(tokenized_query.to(device))
45
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
46
+
47
+ # Retrieve the feature vector
48
+ # print("text_encoded: ", text_encoded.shape)
49
+ return text_encoded
50
+
51
+
52
+ def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
53
+ # Compute the similarity between the search query and each photo using the Cosine similarity
54
+ # print("text_features: ", text_features.shape)
55
+ # print("photo_features: ", photo_features.shape)
56
+ similarities = (photo_features @ text_features.T).squeeze(1)
57
+
58
+ # Sort the photos by their similarity score
59
+ best_photo_idx = (-similarities).argsort()
60
+ # print("best_photo_idx: ", best_photo_idx.shape)
61
+ # print("best_photo_idx: ", best_photo_idx[:results_count])
62
+
63
+ result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
64
+ # print("result_list: ", len(result_list))
65
+ # Return the photo IDs of the best matches
66
+ return result_list
67
+
68
+
69
+ def search_unslash(net, search_query, photo_features, photo_ids, results_count=10):
70
+ # Encode the search query
71
+ text_features = encode_search_query(net, search_query)
72
+
73
+ # Find the best matches
74
+ best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
75
+
76
+ return best_photo_ids
77
+
78
+
79
+ def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5):
80
  # Encode the search query
81
+ if not query_text and query_photo is None and not query_photo_id:
82
  return []
83
 
84
  text_features = encode_search_query(model, query_text)
 
95
  # Find the best match
96
  best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
97
 
98
+ elif query_photo is not None:
99
+ query_photo = preprocess(query_photo)
100
+ query_photo = torch.tensor(query_photo).permute(2, 0, 1)
101
+
102
+ print(query_photo.shape)
103
+ query_photo_features = model.encode_image(query_photo)
104
  query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
105
 
106
  # Combine the test and photo queries and normalize again
 
112
  else:
113
  # Display the results
114
  print("Test search result")
115
+ best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10)
116
 
117
  return best_photo_ids
118
 
 
122
  gr.Markdown(
123
  """
124
  # CLIP Image Search Engine!
125
+ ### Enter search query or/and select image to find the similar images
126
  """)
127
 
128
  with gr.Row(visible=True):
129
  with gr.Column():
130
  with gr.Row():
131
+ search_text = gr.Textbox(value='', placeholder='Search..', label='Enter search query')
132
 
133
  with gr.Row():
134
  submit_btn = gr.Button("Submit", variant='primary')
135
  clear_btn = gr.ClearButton()
136
 
137
+ with gr.Column(visible=True) as input_image_col:
138
+ search_image = gr.Image(label='Select from results', interactive=False)
139
+ search_image_id = gr.State(None)
140
 
141
  with gr.Row(visible=True):
142
  output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
 
149
  return {
150
  search_image: None,
151
  output_images: None,
152
+ search_text: None,
153
+ search_image_id: None,
154
+ input_image_col: gr.update(visible=True)
155
  }
156
 
157
 
158
+ clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col])
159
 
160
 
161
  def on_select(evt: gr.SelectData, output_image_ids):
162
  return {
163
+ search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320",
164
+ search_image_id: output_image_ids[evt.index],
165
+ input_image_col: gr.update(visible=True)
166
  }
167
 
168
 
169
+ output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col])
170
 
171
 
172
+ def func_search(query, img, img_id):
173
+ best_photo_ids = []
174
+ if img_id:
175
+ best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id)
176
+ elif img is not None:
177
+ img = Image.open(img)
178
+ best_photo_ids = search_by_text_and_photo(query, query_photo=img)
179
+ elif query:
180
+ best_photo_ids = search_by_text_and_photo(query)
181
 
182
+ if len(best_photo_ids) == 0:
183
+ print("Invalid Search Request")
184
+ return {
185
+ output_image_ids: [],
186
+ output_images: []
187
+ }
188
+ else:
189
+ img_urls = []
190
+ for p_id in best_photo_ids:
191
+ url = f"https://unsplash.com/photos/{p_id}/download?w=20"
192
+ img_urls.append(url)
193
 
194
+ valid_images = filter_invalid_urls(img_urls, best_photo_ids)
195
+
196
+ return {
197
+ output_image_ids: valid_images['image_ids'],
198
+ output_images: valid_images['image_urls']
199
+ }
200
 
201
 
202
  submit_btn.click(
203
  func_search,
204
+ [search_text, search_image, search_image_id],
205
  [output_images, output_image_ids]
206
  )
207
 
208
+
209
+ def on_upload(evt: gr.SelectData):
210
+ return {
211
+ search_image_id: None
212
+ }
213
+
214
+
215
+ search_image.upload(on_upload, None, search_image_id)
216
+
217
  '''
218
  Launch the app
219
  '''
220
  app.launch()
221
+
222
+
223
+