ryaalbr commited on
Commit
6afb63d
1 Parent(s): d042d50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import clip
3
+ import pickle
4
+ import requests
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+
9
+ is_gpu = False
10
+ device = CUDA(0) if is_gpu else "cpu"
11
+
12
+ from datasets import load_dataset
13
+ dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
14
+
15
+ emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
16
+ with open(emb_filename, 'rb') as emb:
17
+ id2url, img_names, img_emb = pickle.load(emb)
18
+
19
+
20
+ orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
21
+
22
+ def search(search_query):
23
+
24
+ with torch.no_grad():
25
+ # Encode and normalize the description using CLIP
26
+ text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
27
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
28
+
29
+
30
+ # Retrieve the description vector
31
+ text_features = text_encoded.cpu().numpy()
32
+
33
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
34
+ similarities = (text_features @ img_emb.T).squeeze(0)
35
+
36
+ # Sort the photos by their similarity score
37
+ best_photos = similarities.argsort()[::-1]
38
+ best_photos = best_photos[:15]
39
+ #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
40
+
41
+ best_photo_ids = img_names[best_photos]
42
+
43
+ imgs = []
44
+
45
+ # Iterate over the top 5 results
46
+ for id in best_photo_ids:
47
+
48
+ id, _ = id.split('.')
49
+ url = id2url.get(id, "")
50
+ if url == "": continue
51
+
52
+
53
+ r = requests.get(url + "?w=512", stream=True)
54
+ img = Image.open(r.raw)
55
+ #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
56
+ imgs.append(img)
57
+ #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
58
+ print()
59
+
60
+ if len(imgs) == 5: break
61
+
62
+ return imgs
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Column(variant="panel"):
66
+ with gr.Row(variant="compact"):
67
+ text = gr.Textbox(
68
+ label="Enter your prompt",
69
+ show_label=False,
70
+ max_lines=1,
71
+ placeholder="Enter your prompt",
72
+ ).style(
73
+ container=False,
74
+ )
75
+ search_btn = gr.Button("Search for images").style(full_width=False)
76
+
77
+ gallery = gr.Gallery(
78
+ label="Generated images", show_label=False, elem_id="gallery"
79
+ ).style(grid=[3,3,5], height="auto")
80
+
81
+ search_btn.click(search, text, gallery)
82
+
83
+ demo.launch()