Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# Norsk (Multilingual) Image Search | |
# | |
# Based on [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) | |
# by [Vladimir Haltakov](https://twitter.com/haltakov). | |
# In[ ]: | |
import clip | |
import gradio as gr | |
from multilingual_clip import pt_multilingual_clip, legacy_multilingual_clip | |
import numpy as np | |
import os | |
import pandas as pd | |
from PIL import Image | |
import requests | |
import torch | |
from transformers import AutoTokenizer | |
# In[ ]: | |
# Load the open CLIP model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" | |
model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# In[ ]: | |
# Load the image IDs | |
images_info = pd.read_csv("./metadata.csv") | |
image_ids = list( | |
open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n") | |
) | |
# Load the image feature vectors | |
image_features = np.load("./image_features.npy") | |
# Convert features to Tensors: Float32 on CPU and Float16 on GPU | |
if device == "cpu": | |
image_features = torch.from_numpy(image_features).float().to(device) | |
else: | |
image_features = torch.from_numpy(image_features).to(device) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
# ## Define Functions | |
# | |
# Some important functions for processing the data are defined here. | |
# | |
# | |
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model. | |
# In[ ]: | |
def encode_search_query(search_query): | |
with torch.no_grad(): | |
# Encode and normalize the search query using the multilingual model | |
text_encoded = model.forward(search_query, tokenizer) | |
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
# Retrieve the feature vector | |
return text_encoded | |
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images. | |
# In[ ]: | |
def find_best_matches(text_features, image_features, image_ids, results_count=3): | |
# Compute the similarity between the search query and each image using the Cosine similarity | |
similarities = (image_features @ text_features.T).squeeze(1) | |
# Sort the images by their similarity score | |
best_image_idx = (-similarities).argsort() | |
# Return the image IDs of the best matches | |
return [ | |
[image_ids[i], similarities[i].item()] for i in best_image_idx[:results_count] | |
] | |
# In[ ]: | |
def clip_search(search_query): | |
if len(search_query) >= 3: | |
text_features = encode_search_query(search_query) | |
# Compute the similarity between the descrption and each photo using the Cosine similarity | |
# similarities = list((text_features @ photo_features.T).squeeze(0)) | |
# Sort the photos by their similarity score | |
matches = find_best_matches( | |
text_features, image_features, image_ids, results_count=15 | |
) | |
images = [] | |
for i in range(15): | |
# Retrieve the photo ID | |
image_id = matches[i][0] | |
image_url = images_info[images_info["filename"] == image_id][ | |
"image_url" | |
].values[0] | |
# response = requests.get(image_url) | |
# img = PIL.open(response.raw) | |
images.append( | |
[ | |
(image_url), | |
images_info[images_info["filename"] == image_id][ | |
"permalink" | |
].values[0], | |
] | |
) | |
# print(images) | |
return images | |
css = ( | |
"footer {display: none !important;} .gradio-container {min-height: 0px !important;}" | |
) | |
with gr.Blocks(css=css) as gr_app: | |
with gr.Column(variant="panel"): | |
with gr.Row(variant="compact"): | |
search_string = gr.Textbox( | |
label="Evocative Search", | |
show_label=True, | |
max_lines=1, | |
placeholder="Type something, or click a suggested search below.", | |
container=False, | |
) | |
btn = gr.Button("Search", variant="primary") #.style(full_width=False) | |
with gr.Row(variant="compact"): | |
suggest1 = gr.Button( | |
"två hundar som leker i snön", variant="secondary", size="sm" | |
)# .style(size="sm") | |
suggest2 = gr.Button( | |
"en fisker til sjøs i en båt", variant="secondary", size="sm" | |
)# .style(size="sm") | |
suggest3 = gr.Button( | |
"cold dark alone on the street", variant="secondary", size="sm" | |
)# .style(size="sm") | |
suggest4 = gr.Button("도로 위의 자동차들", variant="secondary", size="sm") | |
gallery = gr.Gallery(label=False, show_label=False, elem_id="gallery", height="100%", columns=6) | |
suggest1.click(clip_search, inputs=suggest1, outputs=gallery) | |
suggest2.click(clip_search, inputs=suggest2, outputs=gallery) | |
suggest3.click(clip_search, inputs=suggest3, outputs=gallery) | |
suggest4.click(clip_search, inputs=suggest4, outputs=gallery) | |
btn.click(clip_search, inputs=search_string, outputs=gallery) | |
search_string.submit(clip_search, search_string, gallery) | |
if __name__ == "__main__": | |
gr_app.launch(share=False) | |