import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
from PIL import Image, ImageDraw
import os
# Load model for computing embeddings of the candidate images
print('Load model for computing embeddings of the candidate images')
model_ckpt = "google/vit-base-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
# Load dataset
dataset_with_embeddings = load_dataset("LucyintheSky/24-1-30-ds-embeddings", split="train", token=os.environ.get('TOKEN'))
dataset_with_embeddings.add_faiss_index(column='embeddings')
def get_neighbors(query_image, top_k=8):
qi_embedding = model(**extractor(query_image, return_tensors="pt"))
qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
return scores, retrieved_examples
def search(image_dict):
# Open query image
query_image = Image.open(image_dict['composite']).convert(mode='RGB')
# Get similar image
scores, retrieved_examples = get_neighbors(query_image)
final_md = ""
# Create result diction for gr.Gallery
result = []
for i in range(len(retrieved_examples["image"])):
name = retrieved_examples["name"][i]
result.append((retrieved_examples["image_link"][i], name))
#final_md += """![](""" + retrieved_examples["image_link"][i] + """)\n"""
final_md += """ \n"""
return result, final_md
iface = gr.Interface(fn=search,
description="""