ryaalbr commited on
Commit
dbf8f9d
1 Parent(s): 6218a98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -5
app.py CHANGED
@@ -4,6 +4,13 @@ import random
4
  import numpy as np
5
  from transformers import CLIPProcessor, CLIPModel
6
  from os import environ
 
 
 
 
 
 
 
7
 
8
  # Load the pre-trained model and processor
9
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -33,10 +40,61 @@ get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key
33
  def generate_text(image, model_name):
34
  return get_caption(image, model_name)
35
 
36
- get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
37
- def search_images(text):
38
- return get_images(text, api_name="images")
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  with gr.Blocks() as demo:
41
 
42
  with gr.Tab("Zero-Shot Classification"):
@@ -87,6 +145,6 @@ with gr.Blocks() as demo:
87
  desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
88
  search_btn = gr.Button("Find Images").style(full_width=False)
89
  gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
90
- search_btn.click(search_images,inputs=desc, outputs=gallery)
91
 
92
  demo.launch()
 
4
  import numpy as np
5
  from transformers import CLIPProcessor, CLIPModel
6
  from os import environ
7
+ import clip
8
+ import pickle
9
+ import requests
10
+ import torch
11
+
12
+ is_gpu = False
13
+ device = CUDA(0) if is_gpu else "cpu"
14
 
15
  # Load the pre-trained model and processor
16
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
40
  def generate_text(image, model_name):
41
  return get_caption(image, model_name)
42
 
43
+ # get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
44
+ # def search_images(text):
45
+ # return get_images(text, api_name="images")
46
+
47
+ emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
48
+ with open(emb_filename, 'rb') as emb:
49
+ id2url, img_names, img_emb = pickle.load(emb)
50
+
51
+ orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
52
+
53
+ def search(search_query):
54
+
55
+ with torch.no_grad():
56
+ # Encode and normalize the description using CLIP
57
+ text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
58
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
59
+
60
+
61
+ # Retrieve the description vector
62
+ text_features = text_encoded.cpu().numpy()
63
+
64
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
65
+ similarities = (text_features @ img_emb.T).squeeze(0)
66
+
67
+ # Sort the photos by their similarity score
68
+ best_photos = similarities.argsort()[::-1]
69
+ best_photos = best_photos[:15]
70
+ #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
71
+
72
+ best_photo_ids = img_names[best_photos]
73
+
74
+ imgs = []
75
+
76
+ # Iterate over the top 5 results
77
+ for id in best_photo_ids:
78
+
79
+ id, _ = id.split('.')
80
+ url = id2url.get(id, "")
81
+ if url == "": continue
82
+
83
+ img = url + "?h=512"
84
+ # r = requests.get(url + "?w=512", stream=True)
85
+ # img = Image.open(r.raw)
86
+ #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
87
+ imgs.append(img)
88
+ #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
89
+
90
+ if len(imgs) == 5: break
91
+
92
+ return imgs
93
+
94
+
95
+
96
+
97
+
98
  with gr.Blocks() as demo:
99
 
100
  with gr.Tab("Zero-Shot Classification"):
 
145
  desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
146
  search_btn = gr.Button("Find Images").style(full_width=False)
147
  gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
148
+ search_btn.click(search,inputs=desc, outputs=gallery)
149
 
150
  demo.launch()