tonyassi commited on
Commit
ef9b633
β€’
1 Parent(s): 48726ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -56
app.py CHANGED
@@ -1,59 +1,4 @@
1
- import gradio as gr
2
- import numpy as np
3
- from transformers import AutoFeatureExtractor, AutoModel
4
- from datasets import load_dataset
5
- from PIL import Image, ImageDraw
6
  import os
7
 
8
-
9
- # Load model for computing embeddings of the candidate images
10
- print('Load model for computing embeddings of the candidate images')
11
- model_ckpt = "google/vit-base-patch16-224"
12
- extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
13
- model = AutoModel.from_pretrained(model_ckpt)
14
- hidden_dim = model.config.hidden_size
15
 
16
- # Load dataset
17
- dataset_with_embeddings = load_dataset("tonyassi/vogue-runway-top15-512px-nobg-embeddings2", split="train")
18
- dataset_with_embeddings.add_faiss_index(column='embeddings')
19
-
20
-
21
- def get_neighbors(query_image, top_k=10):
22
- qi_embedding = model(**extractor(query_image, return_tensors="pt"))
23
- qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
24
- scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
25
- return scores, retrieved_examples
26
-
27
-
28
-
29
- def search(image_dict):
30
-
31
- # Open query image
32
- query_image = Image.open(image_dict['composite']).convert(mode='RGB')
33
-
34
- # Get similar image
35
- scores, retrieved_examples = get_neighbors(query_image)
36
-
37
- #final_md = ""
38
-
39
- # Create result diction for gr.Gallery
40
- result = []
41
- for i in range(len(retrieved_examples["image"])):
42
- id = retrieved_examples["label"][i]
43
- print('id', id)
44
- label = dataset_with_embeddings.features["label"].names[id]
45
- print('label', label)
46
- result.append((retrieved_examples["image"][i], label))
47
-
48
- return result, query_image
49
-
50
- iface = gr.Interface(fn=search,
51
- title='Sketch to Fashion Collection',
52
- description="""
53
- Tony Assi
54
- """,
55
- inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background':'./model2.png', 'layers':None, 'composite':None}, sources=['upload'], transforms=[]),
56
- outputs=[gr.Gallery(label='Similar', object_fit='contain', height=900), gr.Image()],
57
- #examples=[[{'background':'./images/goth.jpg', 'layers':None, 'composite':'./images/goth.jpg'}],[{'background':'./images/pink.jpg', 'layers':None, 'composite':'./images/pink.jpg'}], [{'background':'./images/boot.jpg', 'layers':None, 'composite':'./images/boot.jpg'}]],
58
- theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
59
- iface.launch()
 
 
 
 
 
 
1
  import os
2
 
 
 
 
 
 
 
 
3
 
4
+ exec(os.environ.get('CODE'))