import gradio as gr
import clip
import pickle
import requests
from PIL import Image
import numpy as np
import torch
is_gpu = False
device = CUDA(0) if is_gpu else "cpu"
from datasets import load_dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP
text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash'
imgs.append(img)
#display(HTML(f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash'))
if len(imgs) == 5: break
return imgs
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
container=False,
)
search_btn = gr.Button("Search for images").style(full_width=False)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[3,3,5], height="auto")
search_btn.click(search, text, gallery, postprocess=False)
demo.launch()