# Import Libraries from pathlib import Path import pandas as pd import numpy as np import torch import pickle from PIL import Image from io import BytesIO import requests import gradio as gr import os #from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer import sentence_transformers from sentence_transformers import SentenceTransformer, util # check if CUDA available device = "cuda" if torch.cuda.is_available() else "cpu" # Load the openAI's CLIP model #model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") #processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") #tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # taking photo IDs #photo_ids = pd.read_csv("./photo_ids.csv") #photo_ids = list(photo_ids['photo_id']) # Photo dataset #photos = pd.read_csv("./photos.tsv000", sep="\t", header=0) # taking features vectors #photo_features = np.load("./features.npy") IMAGES_DIR = './photos/' #def show_output_image(matched_images) : #image=[] #for photo_id in matched_images: # photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280" #response = requests.get(photo_image_url, stream=True) #img = Image.open(BytesIO(response.content)) # response = requests.get(photo_image_url, stream=True).raw # img = Image.open(response) #photo = photo_id + '.jpg' #img = Image.open(response).convert("RGB") #img = Image.open(os.path.join(IMAGES_DIR, photo)) #image.append(img) #return image # Encode and normalize the search query using CLIP #def encode_search_query(search_query, model, device): # with torch.no_grad(): # inputs = tokenizer([search_query], padding=True, return_tensors="pt") #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True) # text_features = model.get_text_features(**inputs).cpu().numpy() # return text_features # Find all matched photos #def find_matches(features, photo_ids, results_count=4): # Compute the similarity between the search query and each photo using the Cosine similarity #text_features = np.array(text_features) #similarities = (photo_features @ features.T).squeeze(1) # Sort the photos by their similarity score #best_photo_idx = (-similarities).argsort() # Return the photo IDs of the best matches #matches = [photo_ids[i] for i in best_photo_idx[:results_count]] #return matches #Load CLIP model model = SentenceTransformer('clip-ViT-B-32') # pre-computed embeddings emb_filename = 'unsplash-25k-photos-embeddings.pkl' with open(emb_filename, 'rb') as fIn: img_names, img_emb = pickle.load(fIn) def display_matches(indices): best_matched_images = [Image.open(os.path.join("photos/", img_names[best_img['corpus_id']])) for best_img in indices] return best_matched_images def image_search(search_text, search_image, option): # Input Text Query #search_query = "The feeling when your program finally works" if option == "Text-To-Image" : # Extracting text features embeddings #text_features = encode_search_query(search_text, model, device) text_emb = model.encode([search_text], convert_to_tensor=True) # Find the matched Images #matched_images = find_matches(text_features, photo_features, photo_ids, 4) matched_results = util.semantic_search(text_emb, img_emb, 4)[0] # top 4 highest ranked images return display_matches(matched_results) elif option == "Image-To-Image": # Input Image for Search #search_image = Image.fromarray(search_image.astype('uint8'), 'RGB') #with torch.no_grad(): # processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"] # image_feature = model.get_image_features(processed_image.to(device)) # image_feature /= image_feature.norm(dim=-1, keepdim=True) #image_feature = image_feature.cpu().numpy() # Find the matched Images #matched_images = find_matches(image_feature, photo_ids, 4) image_emb = model.encode(Image.open(search_image), convert_to_tensor=True) # Find the matched Images #matched_images = find_matches(text_features, photo_features, photo_ids, 4) #similarity = util.cos_sim(image_emb, img_emb) matched_results = util.semantic_search(image_emb, img_emb, 4)[0] return display_matches(matched_results) gr.Interface(fn=image_search, inputs=[gr.inputs.Textbox(lines=7, label="Input Text"), gr.inputs.Image(type="pil", optional=True), gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"]) ], outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]), enable_queue=True ).launch(debug=True,share=True)