Spaces:
Runtime error
Runtime error
File size: 2,744 Bytes
1748d76 0c81859 1748d76 bff355f b1098aa 1748d76 474fa74 bff355f 3b85df0 1748d76 bff355f 1748d76 429d7b2 bff355f 1748d76 b1098aa e5e67a3 429d7b2 5a1647d fa69e46 1748d76 429d7b2 1748d76 828058c 23b9e76 265f964 23b9e76 265f964 819ea68 0572dbd 429d7b2 1748d76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
from PIL import Image, ImageDraw
import os
# Load model for computing embeddings of the candidate images
print('Load model for computing embeddings of the candidate images')
model_ckpt = "google/vit-base-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
# Load dataset
dataset_with_embeddings = load_dataset("LucyintheSky/24-1-30-ds-embeddings", split="train", token=os.environ.get('TOKEN'))
dataset_with_embeddings.add_faiss_index(column='embeddings')
def get_neighbors(query_image, top_k=8):
qi_embedding = model(**extractor(query_image, return_tensors="pt"))
qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
return scores, retrieved_examples
def search(image_dict):
# Open query image
query_image = Image.open(image_dict['composite']).convert(mode='RGB')
# Get similar image
scores, retrieved_examples = get_neighbors(query_image)
final_md = ""
# Create result diction for gr.Gallery
result = []
for i in range(len(retrieved_examples["image"])):
name = retrieved_examples["name"][i]
result.append((retrieved_examples["image_link"][i], name))
#final_md += """![](""" + retrieved_examples["image_link"][i] + """)\n"""
final_md += """<a href='"""+retrieved_examples["link"][i] +"""'> <img src='"""+retrieved_examples["image_link"][i] +"""' width='200'/> </a> """
return result, final_md
iface = gr.Interface(fn=search,
description="""
<center><a href="https://www.lucyinthesky.com/"><img width="500" src="https://cdn.discordapp.com/attachments/1120417968032063538/1201666647157657640/LucyITS-2022-blk.png?ex=65caa646&is=65b83146&hm=09ad6fe279edc3a32981306d563e63af815d760fc0d8d0a3fbef4e4553c0a83a&"> </a> </center>
<br>
<center> Sketch to find your favorite Lucy in the Sky dress! </center>
<br>
""",
inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background': './template.JPG', 'layers': None, 'composite': None}, sources=['upload'], transforms=[]),
outputs=[gr.Gallery(label='Similar', object_fit='contain', height=1200), gr.Markdown()],
theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
iface.launch() |