ImageSearch / app.py
ryaalbr's picture
Update app.py
cc20aca
raw history blame
No virus
3.01 kB
import gradio as gr
import clip
import pickle
import requests
from PIL import Image
import numpy as np
import torch
is_gpu = False
device = CUDA(0) if is_gpu else "cpu"
from datasets import load_dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP
text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
imgs.append(img)
#display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
if len(imgs) == 5: break
return imgs
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
container=False,
)
search_btn = gr.Button("Search for images").style(full_width=False)
gallery = gr.Gallery(label="Generated images", show_label=False).style(grid=[3,3,5])
search_btn.click(search, text, gallery, api_name="images")
#search_btn.click(search, text, temp, api_name="list")
demo.launch()