tonyassi commited on
Commit
48726ec
β€’
1 Parent(s): 13140bc

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +59 -0
  3. model1.png +0 -0
  4. model2.png +3 -0
  5. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model2.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import AutoFeatureExtractor, AutoModel
4
+ from datasets import load_dataset
5
+ from PIL import Image, ImageDraw
6
+ import os
7
+
8
+
9
+ # Load model for computing embeddings of the candidate images
10
+ print('Load model for computing embeddings of the candidate images')
11
+ model_ckpt = "google/vit-base-patch16-224"
12
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
13
+ model = AutoModel.from_pretrained(model_ckpt)
14
+ hidden_dim = model.config.hidden_size
15
+
16
+ # Load dataset
17
+ dataset_with_embeddings = load_dataset("tonyassi/vogue-runway-top15-512px-nobg-embeddings2", split="train")
18
+ dataset_with_embeddings.add_faiss_index(column='embeddings')
19
+
20
+
21
+ def get_neighbors(query_image, top_k=10):
22
+ qi_embedding = model(**extractor(query_image, return_tensors="pt"))
23
+ qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
24
+ scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
25
+ return scores, retrieved_examples
26
+
27
+
28
+
29
+ def search(image_dict):
30
+
31
+ # Open query image
32
+ query_image = Image.open(image_dict['composite']).convert(mode='RGB')
33
+
34
+ # Get similar image
35
+ scores, retrieved_examples = get_neighbors(query_image)
36
+
37
+ #final_md = ""
38
+
39
+ # Create result diction for gr.Gallery
40
+ result = []
41
+ for i in range(len(retrieved_examples["image"])):
42
+ id = retrieved_examples["label"][i]
43
+ print('id', id)
44
+ label = dataset_with_embeddings.features["label"].names[id]
45
+ print('label', label)
46
+ result.append((retrieved_examples["image"][i], label))
47
+
48
+ return result, query_image
49
+
50
+ iface = gr.Interface(fn=search,
51
+ title='Sketch to Fashion Collection',
52
+ description="""
53
+ Tony Assi
54
+ """,
55
+ inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background':'./model2.png', 'layers':None, 'composite':None}, sources=['upload'], transforms=[]),
56
+ outputs=[gr.Gallery(label='Similar', object_fit='contain', height=900), gr.Image()],
57
+ #examples=[[{'background':'./images/goth.jpg', 'layers':None, 'composite':'./images/goth.jpg'}],[{'background':'./images/pink.jpg', 'layers':None, 'composite':'./images/pink.jpg'}], [{'background':'./images/boot.jpg', 'layers':None, 'composite':'./images/boot.jpg'}]],
58
+ theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
59
+ iface.launch()
model1.png ADDED
model2.png ADDED

Git LFS Details

  • SHA256: 7e2b75e16c06ef4635a018a567a24ccf63fff4c04951380c85f790b6b770f63f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ faiss-cpu
5
+ https://gradio-builds.s3.amazonaws.com/f3e3c5c02f1069c1180004a46909309544f42523/gradio-4.16.0-py3-none-any.whl