RobotJelly commited on
Commit
147b3ce
1 Parent(s): ffaf3d7
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -50,14 +50,14 @@ def encode_search_query(search_query, model, device):
50
  with torch.no_grad():
51
  inputs = tokenizer([search_query], padding=True, return_tensors="pt")
52
  #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
53
- text_features = model.get_text_features(**inputs).detach().numpy()
54
  return text_features
55
 
56
  # Find all matched photos
57
- def find_matches(text_features, photo_features, photo_ids, results_count=4):
58
  # Compute the similarity between the search query and each photo using the Cosine similarity
59
- text_features = np.array(text_features)
60
- similarities = (photo_features @ text_features.T).squeeze(1)
61
  # Sort the photos by their similarity score
62
  best_photo_idx = (-similarities).argsort()
63
  # Return the photo IDs of the best matches
@@ -85,9 +85,9 @@ def image_search(search_text, search_image, option):
85
  processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
86
  image_feature = model.get_image_features(processed_image.to(device))
87
  image_feature /= image_feature.norm(dim=-1, keepdim=True)
88
- image_feature = image_feature.detach().numpy()
89
  # Find the matched Images
90
- matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
91
  return show_output_image(matched_images)
92
 
93
  gr.Interface(fn=image_search,
50
  with torch.no_grad():
51
  inputs = tokenizer([search_query], padding=True, return_tensors="pt")
52
  #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
53
+ text_features = model.get_text_features(**inputs).cpu().numpy()
54
  return text_features
55
 
56
  # Find all matched photos
57
+ def find_matches(features, photo_ids, results_count=4):
58
  # Compute the similarity between the search query and each photo using the Cosine similarity
59
+ #text_features = np.array(text_features)
60
+ similarities = (photo_features @ features.T).squeeze(1)
61
  # Sort the photos by their similarity score
62
  best_photo_idx = (-similarities).argsort()
63
  # Return the photo IDs of the best matches
85
  processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
86
  image_feature = model.get_image_features(processed_image.to(device))
87
  image_feature /= image_feature.norm(dim=-1, keepdim=True)
88
+ image_feature = image_feature.cpu().numpy()
89
  # Find the matched Images
90
+ matched_images = find_matches(image_feature, photo_ids, 4)
91
  return show_output_image(matched_images)
92
 
93
  gr.Interface(fn=image_search,