File size: 2,064 Bytes
ffccc19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d7f07
ffccc19
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pickle

import gradio as gr
from datasets import load_dataset
from transformers import AutoModel

# `LSH` and `Table` imports are necessary in order for the
# `lsh.pickle` file to load successfully.
from similarity_utils import LSH, BuildLSHTable, Table

seed = 42

# Only runs once when the script is first run.
with open("lsh.pickle", "rb") as handle:
    loaded_lsh = pickle.load(handle)

# Load model for computing embeddings.
model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds"
model = AutoModel.from_pretrained(model_ckpt)
lsh_builder = BuildLSHTable(model)
lsh_builder.lsh = loaded_lsh

# Candidate images.
dataset = load_dataset("gjuggler/bird-data")
candidate_dataset = dataset["train"].shuffle(seed=seed)

def query(image, top_k):
    results = lsh_builder.query(image)

    # Should be a list of string file paths for gr.Gallery to work
    images = []
    # List of labels for each image in the gallery
    labels = []

    candidates = []

    for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
        if idx == top_k:
            break
        image_id, label = r.split("_")[0], r.split("_")[1]
        candidates.append(candidate_dataset[int(image_id)]["image"])
        labels.append(f"Label: {label}")

    for i, candidate in enumerate(candidates):
        filename = f"similar_{i}.png"
        candidate.save(filename)
        images.append(filename)

    # The gallery component can be a list of tuples, where the first element is a path to a file
    # and the second element is an optional caption for that image
    return list(zip(images, labels))


# You can set the type of gr.Image to be PIL, numpy or str (filepath)
# Not sure what the best for this demo is.
gr.Interface(
    query,
    inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
    outputs=gr.Gallery().style(grid=[3], height="auto"),
    examples=[["0eea9ce8757c431e99b05afbc2bfbee2.jpg", 5], ["8cedc91ec6584e7c847877d4f7ac4d65.jpg", 5], ["0b47c4950bbf4b999f0bcd3e5d61dc94.jpg", 5]],
).launch()