File size: 2,744 Bytes
1748d76
 
 
 
 
0c81859
1748d76
 
 
 
 
 
 
 
 
 
bff355f
b1098aa
1748d76
 
 
 
 
 
 
 
 
 
 
474fa74
bff355f
 
3b85df0
1748d76
bff355f
1748d76
 
429d7b2
 
bff355f
1748d76
 
b1098aa
e5e67a3
429d7b2
5a1647d
fa69e46
1748d76
429d7b2
1748d76
 
828058c
23b9e76
265f964
23b9e76
265f964
819ea68
0572dbd
429d7b2
1748d76
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
from PIL import Image, ImageDraw
import os

    

# Load model for computing embeddings of the candidate images
print('Load model for computing embeddings of the candidate images')
model_ckpt = "google/vit-base-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size

# Load dataset
dataset_with_embeddings = load_dataset("LucyintheSky/24-1-30-ds-embeddings", split="train", token=os.environ.get('TOKEN'))
dataset_with_embeddings.add_faiss_index(column='embeddings')


def get_neighbors(query_image, top_k=8):
    qi_embedding = model(**extractor(query_image, return_tensors="pt"))
    qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
    scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
    return scores, retrieved_examples

    

def search(image_dict):

    # Open query image
    query_image = Image.open(image_dict['composite']).convert(mode='RGB')

    # Get similar image
    scores, retrieved_examples = get_neighbors(query_image)

    final_md = ""
    
    # Create result diction for gr.Gallery
    result = []
    for i in range(len(retrieved_examples["image"])):
        name = retrieved_examples["name"][i]
        result.append((retrieved_examples["image_link"][i], name))

        #final_md += """![](""" + retrieved_examples["image_link"][i] + """)\n"""
        final_md += """<a href='"""+retrieved_examples["link"][i] +"""'> <img src='"""+retrieved_examples["image_link"][i] +"""' width='200'/> </a> """
    
    return result, final_md

iface = gr.Interface(fn=search,
                     description="""
                     <center><a href="https://www.lucyinthesky.com/"><img width="500" src="https://cdn.discordapp.com/attachments/1120417968032063538/1201666647157657640/LucyITS-2022-blk.png?ex=65caa646&is=65b83146&hm=09ad6fe279edc3a32981306d563e63af815d760fc0d8d0a3fbef4e4553c0a83a&"> </a> </center>
                     <br>
                     <center> Sketch to find your favorite Lucy in the Sky dress! </center>
                     <br>
                     """,
                     inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background': './template.JPG', 'layers': None, 'composite': None}, sources=['upload'], transforms=[]), 
                     outputs=[gr.Gallery(label='Similar', object_fit='contain', height=1200), gr.Markdown()],
                     theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
iface.launch()