File size: 2,541 Bytes
48726ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset
from PIL import Image, ImageDraw
import os

    
# Load model for computing embeddings of the candidate images
print('Load model for computing embeddings of the candidate images')
model_ckpt = "google/vit-base-patch16-224"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size

# Load dataset
dataset_with_embeddings = load_dataset("tonyassi/vogue-runway-top15-512px-nobg-embeddings2", split="train")
dataset_with_embeddings.add_faiss_index(column='embeddings')


def get_neighbors(query_image, top_k=10):
    qi_embedding = model(**extractor(query_image, return_tensors="pt"))
    qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
    scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
    return scores, retrieved_examples

    

def search(image_dict):

    # Open query image
    query_image = Image.open(image_dict['composite']).convert(mode='RGB')

    # Get similar image
    scores, retrieved_examples = get_neighbors(query_image)

    #final_md = ""
    
    # Create result diction for gr.Gallery
    result = []
    for i in range(len(retrieved_examples["image"])):
        id = retrieved_examples["label"][i]
        print('id', id)
        label = dataset_with_embeddings.features["label"].names[id]
        print('label', label)
        result.append((retrieved_examples["image"][i], label))
    
    return result, query_image

iface = gr.Interface(fn=search,
                     title='Sketch to Fashion Collection',
                     description="""
                     Tony Assi
                     """,
                     inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background':'./model2.png', 'layers':None, 'composite':None}, sources=['upload'], transforms=[]), 
                     outputs=[gr.Gallery(label='Similar', object_fit='contain', height=900), gr.Image()],
                     #examples=[[{'background':'./images/goth.jpg', 'layers':None, 'composite':'./images/goth.jpg'}],[{'background':'./images/pink.jpg', 'layers':None, 'composite':'./images/pink.jpg'}], [{'background':'./images/boot.jpg', 'layers':None, 'composite':'./images/boot.jpg'}]],
                     theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
iface.launch()