File size: 1,875 Bytes
3ae84a3
 
 
 
78cdedf
3ae84a3
 
 
 
 
0d69242
9b28e54
3ae84a3
 
0d69242
 
3ae84a3
 
0d69242
78cdedf
3ae84a3
 
0d69242
9b28e54
 
78cdedf
0d69242
c45624f
 
78cdedf
3ae84a3
 
22ac4b7
3ae84a3
 
 
 
 
 
0d69242
 
3ae84a3
 
 
c6bc4cd
3ae84a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import pickle

import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor


seed = 42

# Only runs once when the script is first run.
with open("index_768.pickle", "rb") as handle:
    index = pickle.load(handle)

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")

# Candidate images.
dataset = load_dataset("sasha/butterflies_10k_names_multiple")
ds = dataset["train"]


def query(image, top_k=4):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    images = ds.select(inx)["image"]
    return images


title = "Find my Butterfly 🦋"
description = "This Space demos an image similarity system. You can refer to [this notebook](TODO) to know the details of the system. You can pick any image from the available samples below. On the right hand side, you'll find the similar images returned by the system. The example images have been named with their corresponding integer class labels for easier identification. The fetched images will also have their integer labels tagged so that you can validate the correctness of the results."

# You can set the type of gr.Image to be PIL, numpy or str (filepath)
# Not sure what the best for this demo is.
gr.Interface(
    query,
    inputs=[gr.Image(type="pil")],
    outputs=gr.Gallery().style(grid=[2], height="auto"),
    # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
    title=title,
    description=description,
    #examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]],
).launch()