LucyintheSky commited on
Commit
1748d76
β€’
1 Parent(s): e46255a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import AutoFeatureExtractor, AutoModel
4
+ from datasets import load_dataset
5
+ from PIL import Image, ImageDraw
6
+
7
+
8
+
9
+
10
+ # Load model for computing embeddings of the candidate images
11
+ print('Load model for computing embeddings of the candidate images')
12
+ model_ckpt = "google/vit-base-patch16-224"
13
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
14
+ model = AutoModel.from_pretrained(model_ckpt)
15
+ hidden_dim = model.config.hidden_size
16
+
17
+
18
+ dataset_with_embeddings = load_dataset("LucyintheSky/24-1-8-ds-embeddings", split="train")
19
+ dataset_with_embeddings.add_faiss_index(column='embeddings')
20
+
21
+
22
+ def get_neighbors(query_image, top_k=8):
23
+ qi_embedding = model(**extractor(query_image, return_tensors="pt"))
24
+ qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze()
25
+ scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k)
26
+ return scores, retrieved_examples
27
+
28
+
29
+
30
+ def search(img):
31
+
32
+ query_image = Image.open(img).convert(mode='RGB')
33
+
34
+ #query_image.thumbnail((1000,1000))
35
+
36
+
37
+ #query_image = query_image.resize((160,160))
38
+
39
+ print('search')
40
+ scores, retrieved_examples = get_neighbors(query_image)
41
+
42
+ print('return example')
43
+
44
+ result = []
45
+ for i in range(len(retrieved_examples["image"])):
46
+ id = str(retrieved_examples["text"][i]) + ' ' + str(scores[i])
47
+ print('id', id)
48
+ #label = dataset_with_embeddings.features["label"].names[id]
49
+ #print('label', label)
50
+ result.append((retrieved_examples["image"][i], id))
51
+
52
+ return result, query_image
53
+
54
+ iface = gr.Interface(fn=search,
55
+ title='Celebrity Look-a-Like',
56
+ inputs=gr.Image(type='filepath', label='Your Photo'),
57
+ outputs=[gr.Gallery(label='Similar', object_fit='contain'), gr.Image(label='Face')],
58
+ #examples=[['./images/tony.jpg'],['./images/jessica.jpg'],['./images/scarlett.jpg'],['./images/christian.jpg']],
59
+ theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),)
60
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ datasets
4
+ faiss-cpu