import gradio as gr import numpy as np from transformers import AutoFeatureExtractor, AutoModel from datasets import load_dataset from PIL import Image, ImageDraw import os # Load model for computing embeddings of the candidate images print('Load model for computing embeddings of the candidate images') model_ckpt = "google/vit-base-patch16-224" extractor = AutoFeatureExtractor.from_pretrained(model_ckpt) model = AutoModel.from_pretrained(model_ckpt) hidden_dim = model.config.hidden_size dataset_with_embeddings = load_dataset("LucyintheSky/24-1-8-ds-embeddings", split="train", token=os.environ.get('TOKEN')) dataset_with_embeddings.add_faiss_index(column='embeddings') def get_neighbors(query_image, top_k=8): qi_embedding = model(**extractor(query_image, return_tensors="pt")) qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze() scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k) return scores, retrieved_examples def search(image_dict): query_image = Image.open(image_dict['composite']).convert(mode='RGB') print('search') scores, retrieved_examples = get_neighbors(query_image) print('return example') result = [] for i in range(len(retrieved_examples["image"])): id = str(retrieved_examples["text"][i]) + ' ' + str(scores[i]) print('id', id) #label = dataset_with_embeddings.features["label"].names[id] #print('label', label) result.append((retrieved_examples["image"][i], """Visit W3Schools""")) return result iface = gr.Interface(fn=search, description="""

Sketch to find your favorite Lucy in the Sky dress!

""", inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background': './template.JPG', 'layers': None, 'composite': None}, sources=['upload'], transforms=[]), outputs=gr.Gallery(label='Similar', object_fit='contain', height=1200), theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),) iface.launch()