import gradio as gr import clip import pickle import requests from PIL import Image import numpy as np import torch is_gpu = False device = CUDA(0) if is_gpu else "cpu" from datasets import load_dataset dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl' with open(emb_filename, 'rb') as emb: id2url, img_names, img_emb = pickle.load(emb) orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False) def search(search_query): with torch.no_grad(): # Encode and normalize the description using CLIP text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query)) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the description vector text_features = text_encoded.cpu().numpy() # Compute the similarity between the descrption and each photo using the Cosine similarity similarities = (text_features @ img_emb.T).squeeze(0) # Sort the photos by their similarity score best_photos = similarities.argsort()[::-1] best_photos = best_photos[:15] #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True) best_photo_ids = img_names[best_photos] imgs = [] # Iterate over the top 5 results for id in best_photo_ids: id, _ = id.split('.') url = id2url.get(id, "") if url == "": continue img = url + "?h=512" # r = requests.get(url + "?w=512", stream=True) # img = Image.open(r.raw) #credits = f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash' imgs.append(img) #display(HTML(f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash')) if len(imgs) == 5: break return imgs with gr.Blocks() as demo: with gr.Column(variant="panel"): with gr.Row(variant="compact"): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", ).style( container=False, ) search_btn = gr.Button("Search for images").style(full_width=False) gallery = gr.Gallery(label="Generated images", show_label=False).style(grid=[3,3,5]) search_btn.click(search, text, gallery, api_name="images") #search_btn.click(search, text, temp, api_name="list") demo.launch()