File size: 5,443 Bytes
e547128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9195b2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python
# coding: utf-8

# Norsk (Multilingual) Image Search
#
# Based on [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search)
# by [Vladimir Haltakov](https://twitter.com/haltakov).

# In[ ]:


import clip
import gradio as gr
from multilingual_clip import pt_multilingual_clip, legacy_multilingual_clip
import numpy as np
import os
import pandas as pd
from PIL import Image
import requests
import torch
from transformers import AutoTokenizer


# In[ ]:


# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"

model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# In[ ]:


# Load the image IDs
images_info = pd.read_csv("./metadata.csv")
image_ids = list(
    open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
)

# Load the image feature vectors
image_features = np.load("./image_features.npy")

# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
    image_features = torch.from_numpy(image_features).float().to(device)
else:
    image_features = torch.from_numpy(image_features).to(device)

image_features = image_features / image_features.norm(dim=-1, keepdim=True)

# ## Define Functions
#
# Some important functions for processing the data are defined here.
#
#

# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.

# In[ ]:


def encode_search_query(search_query):
    with torch.no_grad():
        # Encode and normalize the search query using the multilingual model
        text_encoded = model.forward(search_query, tokenizer)
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

    # Retrieve the feature vector
    return text_encoded


# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.

# In[ ]:


def find_best_matches(text_features, image_features, image_ids, results_count=3):
    # Compute the similarity between the search query and each image using the Cosine similarity
    similarities = (image_features @ text_features.T).squeeze(1)

    # Sort the images by their similarity score
    best_image_idx = (-similarities).argsort()

    # Return the image IDs of the best matches
    return [
        [image_ids[i], similarities[i].item()] for i in best_image_idx[:results_count]
    ]


# In[ ]:


def clip_search(search_query):
    if len(search_query) >= 3:
        text_features = encode_search_query(search_query)

        # Compute the similarity between the descrption and each photo using the Cosine similarity
        # similarities = list((text_features @ photo_features.T).squeeze(0))

        # Sort the photos by their similarity score
        matches = find_best_matches(
            text_features, image_features, image_ids, results_count=15
        )

        images = []
        for i in range(15):
            # Retrieve the photo ID
            image_id = matches[i][0]
            image_url = images_info[images_info["filename"] == image_id][
                "image_url"
            ].values[0]

            # response = requests.get(image_url)
            # img = PIL.open(response.raw)

            images.append(
                [
                    (image_url),
                    images_info[images_info["filename"] == image_id][
                        "permalink"
                    ].values[0],
                ]
            )

        #     print(images)
        return images


css = (
    "footer {display: none !important;} .gradio-container {min-height: 0px !important;}"
)
with gr.Blocks(css=css) as gr_app:
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            search_string = gr.Textbox(
                label="Evocative Search",
                show_label=True,
                max_lines=1,
                placeholder="Type something, or click a suggested search below.",
            ).style(
                container=False,
            )
            btn = gr.Button("Search", variant="primary").style(full_width=False)
        with gr.Row(variant="compact"):
            suggest1 = gr.Button(
                "två hundar som leker i snön", variant="secondary"
            ).style(size="sm")
            suggest2 = gr.Button(
                "en fisker til sjøs i en båt", variant="secondary"
            ).style(size="sm")
            suggest3 = gr.Button(
                "cold dark alone on the street", variant="secondary"
            ).style(size="sm")
            suggest4 = gr.Button("도로 위의 자동차들", variant="secondary").style(size="sm")
        gallery = gr.Gallery(label=False, show_label=False, elem_id="gallery").style(
            grid=[6],
            height="100%",
        )

    suggest1.click(clip_search, inputs=suggest1, outputs=gallery)
    suggest2.click(clip_search, inputs=suggest2, outputs=gallery)
    suggest3.click(clip_search, inputs=suggest3, outputs=gallery)
    suggest4.click(clip_search, inputs=suggest4, outputs=gallery)
    btn.click(clip_search, inputs=search_string, outputs=gallery)
    search_string.submit(clip_search, search_string, gallery)

if __name__ == "__main__":
    gr_app.launch(share=False)