Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from datasets import load_dataset
|
5 |
+
from transformers import AutoModel
|
6 |
+
|
7 |
+
# `LSH` and `Table` imports are necessary in order for the
|
8 |
+
# `lsh.pickle` file to load successfully.
|
9 |
+
from similarity_utils import LSH, BuildLSHTable, Table
|
10 |
+
|
11 |
+
seed = 42
|
12 |
+
|
13 |
+
# Only runs once when the script is first run.
|
14 |
+
with open("lsh.pickle", "rb") as handle:
|
15 |
+
loaded_lsh = pickle.load(handle)
|
16 |
+
|
17 |
+
# Load model for computing embeddings.
|
18 |
+
model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds"
|
19 |
+
model = AutoModel.from_pretrained(model_ckpt)
|
20 |
+
lsh_builder = BuildLSHTable(model)
|
21 |
+
lsh_builder.lsh = loaded_lsh
|
22 |
+
|
23 |
+
# Candidate images.
|
24 |
+
dataset = load_dataset("gjuggler/bird-data")
|
25 |
+
candidate_dataset = dataset["train"].shuffle(seed=seed)
|
26 |
+
|
27 |
+
def query(image, top_k):
|
28 |
+
results = lsh_builder.query(image)
|
29 |
+
|
30 |
+
# Should be a list of string file paths for gr.Gallery to work
|
31 |
+
images = []
|
32 |
+
# List of labels for each image in the gallery
|
33 |
+
labels = []
|
34 |
+
|
35 |
+
candidates = []
|
36 |
+
|
37 |
+
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
|
38 |
+
if idx == top_k:
|
39 |
+
break
|
40 |
+
image_id, label = r.split("_")[0], r.split("_")[1]
|
41 |
+
candidates.append(candidate_dataset[int(image_id)]["image"])
|
42 |
+
labels.append(f"Label: {label}")
|
43 |
+
|
44 |
+
for i, candidate in enumerate(candidates):
|
45 |
+
filename = f"similar_{i}.png"
|
46 |
+
candidate.save(filename)
|
47 |
+
images.append(filename)
|
48 |
+
|
49 |
+
# The gallery component can be a list of tuples, where the first element is a path to a file
|
50 |
+
# and the second element is an optional caption for that image
|
51 |
+
return list(zip(images, labels))
|
52 |
+
|
53 |
+
|
54 |
+
title = "Fetch similar birds"
|
55 |
+
|
56 |
+
# You can set the type of gr.Image to be PIL, numpy or str (filepath)
|
57 |
+
# Not sure what the best for this demo is.
|
58 |
+
gr.Interface(
|
59 |
+
query,
|
60 |
+
inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
|
61 |
+
outputs=gr.Gallery().style(grid=[3], height="auto"),
|
62 |
+
# Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
|
63 |
+
title=title,
|
64 |
+
description=description,
|
65 |
+
examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]],
|
66 |
+
).launch()
|
67 |
+
|