File size: 3,429 Bytes
dbb1308
83305ef
cf029f9
83305ef
77403d5
83305ef
77403d5
3bc904e
 
 
 
 
 
6c40b92
a42c923
dbb1308
 
 
d21fc46
77403d5
6c40b92
dbb1308
 
 
 
 
 
77403d5
b2f105a
a193d64
dbb1308
a193d64
83305ef
 
 
b33e6dd
cf029f9
 
83305ef
28de0db
cf029f9
83305ef
a193d64
 
 
 
 
83305ef
a193d64
83305ef
 
 
 
dbb1308
 
 
77403d5
a42c923
b2f105a
dbb1308
b2f105a
 
dbb1308
83305ef
b33e6dd
cf029f9
 
83305ef
a42c923
83305ef
28de0db
dbb1308
a193d64
 
 
 
 
 
 
 
 
 
b2f105a
dbb1308
a42c923
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from upstash_vector import AsyncIndex
from transformers import AutoFeatureExtractor, AutoModel
from datasets import load_dataset

index = AsyncIndex.from_env()

model_ckpt = "google/vit-base-patch16-224-in21k"
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
hidden_dim = model.config.hidden_size
dataset = load_dataset("BounharAbdelaziz/Face-Aging-Dataset")

MAX_K = 50

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Find Your Twins

        Upload your face and find the most similar faces from [Face Aging Dataset](https://huggingface.co/datasets/BounharAbdelaziz/Face-Aging-Dataset) using Google's [VIT](https://huggingface.co/google/vit-base-patch16-224-in21k) model. For best results please use 1x1 ratio face images, take a look at examples. Also increasing count in the advanced section results with more accurate searches. The Vector similarity search is powered by [Upstash Vector](https://upstash.com) 🚀. You can check our blog [post](https://huggingface.co/blog/omerXfaruq/serverless-image-similarity-with-upstash-vector) to learn more.
        """
    )

    with gr.Tab("Basic"):
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(type="pil")
            with gr.Column(scale=2):
                output_images = gr.Gallery()

        @input_image.change(inputs=input_image, outputs=output_images)
        async def find_similar_faces(image):
            if image is None:
                return None
            inputs = extractor(images=image, return_tensors="pt")
            outputs = model(**inputs)
            embed = outputs.last_hidden_state[0][0]
            result = await index.query(vector=embed.tolist(), top_k=4)
            return [dataset["train"][int(vector.id)]["image"] for vector in result]

        gr.Examples(
            examples=[
                dataset["train"][6]["image"],
                dataset["train"][7]["image"],
                dataset["train"][8]["image"],
            ],
            inputs=input_image,
            outputs=output_images,
            fn=find_similar_faces,
            cache_examples=False,
        )

    with gr.Tab("Advanced"):
        with gr.Row():
            with gr.Column(scale=1):
                adv_input_image = gr.Image(type="pil")
                adv_image_count = gr.Slider(1, MAX_K, 10, label="Image Count")
                adv_button = gr.Button("Submit")

            with gr.Column(scale=2):
                adv_output_image = gr.Gallery()

        async def find_similar_faces(image, count):
            inputs = extractor(images=image, return_tensors="pt")
            outputs = model(**inputs)
            embed = outputs.last_hidden_state[0][0]
            result = await index.query(
                vector=embed.tolist(), top_k=max(1, min(MAX_K, int(count)))
            )
            return [dataset["train"][int(vector.id)]["image"] for vector in result]

        adv_button.click(
            fn=find_similar_faces,
            inputs=[adv_input_image, adv_image_count],
            outputs=[adv_output_image],
        )
        adv_input_image.upload(
            fn=find_similar_faces,
            inputs=[adv_input_image, adv_image_count],
            outputs=[adv_output_image],
        )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=40)
    demo.launch()