sketch-to-dress / app.py
LucyintheSky's picture
Update app.py
fa69e46 verified
raw
history blame contribute delete
No virus
2.74 kB
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
from PIL import Image, ImageDraw
import os
# Load model for computing embeddings of the candidate images
print('Load model for computing embeddings of the candidate images')
model_ckpt = "google/vit-base-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
# Load dataset
dataset_with_embeddings = load_dataset("LucyintheSky/24-1-30-ds-embeddings", split="train", token=os.environ.get('TOKEN'))
dataset_with_embeddings.add_faiss_index(column='embeddings')
def get_neighbors(query_image, top_k=8):
qi_embedding = model(**extractor(query_image, return_tensors="pt"))
qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
return scores, retrieved_examples
def search(image_dict):
# Open query image
query_image = Image.open(image_dict['composite']).convert(mode='RGB')
# Get similar image
scores, retrieved_examples = get_neighbors(query_image)
final_md = ""
# Create result diction for gr.Gallery
result = []
for i in range(len(retrieved_examples["image"])):
name = retrieved_examples["name"][i]
result.append((retrieved_examples["image_link"][i], name))
#final_md += """![](""" + retrieved_examples["image_link"][i] + """)\n"""
final_md += """<a href='"""+retrieved_examples["link"][i] +"""'> <img src='"""+retrieved_examples["image_link"][i] +"""' width='200'/> </a> """
return result, final_md
iface = gr.Interface(fn=search,
description="""
<center><a href="https://www.lucyinthesky.com/"><img width="500" src="https://cdn.discordapp.com/attachments/1120417968032063538/1201666647157657640/LucyITS-2022-blk.png?ex=65caa646&is=65b83146&hm=09ad6fe279edc3a32981306d563e63af815d760fc0d8d0a3fbef4e4553c0a83a&"> </a> </center>
<br>
<center> Sketch to find your favorite Lucy in the Sky dress! </center>
<br>
""",
inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background': './template.JPG', 'layers': None, 'composite': None}, sources=['upload'], transforms=[]),
outputs=[gr.Gallery(label='Similar', object_fit='contain', height=1200), gr.Markdown()],
theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
iface.launch()