File size: 3,012 Bytes
c46c5fe
6afb63d
 
 
 
 
9f0c540
6afb63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9333117
25ac530
 
6afb63d
 
 
 
 
 
1c96f8a
6afb63d
 
 
 
 
 
 
 
 
 
 
 
 
 
cc20aca
6afb63d
cc20aca
b572157
3ac7bc4
6afb63d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import clip
import pickle
import requests
from PIL import Image
import numpy as np
import torch


is_gpu = False 
device = CUDA(0) if is_gpu else "cpu"

from datasets import load_dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")

emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
        id2url, img_names, img_emb = pickle.load(emb)


orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)

def search(search_query):

    with torch.no_grad():
        # Encode and normalize the description using CLIP
        text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)


    # Retrieve the description vector
    text_features = text_encoded.cpu().numpy()

    # Compute the similarity between the descrption and each photo using the Cosine similarity
    similarities = (text_features @ img_emb.T).squeeze(0)

    # Sort the photos by their similarity score
    best_photos = similarities.argsort()[::-1]
    best_photos = best_photos[:15]
    #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)

    best_photo_ids = img_names[best_photos]

    imgs = []

    # Iterate over the top 5 results
    for id in best_photo_ids:

        id, _ = id.split('.')
        url = id2url.get(id, "")
        if url == "": continue

        img = url  + "?h=512"
       # r = requests.get(url + "?w=512", stream=True)
       # img = Image.open(r.raw)
        #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
        imgs.append(img)
        #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))

        if len(imgs) == 5: break

    return imgs

with gr.Blocks() as demo:
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            text = gr.Textbox(
                label="Enter your prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
            ).style(
                container=False,
            )
            search_btn = gr.Button("Search for images").style(full_width=False)

        gallery = gr.Gallery(label="Generated images", show_label=False).style(grid=[3,3,5])

    search_btn.click(search, text, gallery, api_name="images")
    #search_btn.click(search, text, temp, api_name="list")


demo.launch()