Spaces:
Runtime error
Runtime error
File size: 1,975 Bytes
ffccc19 acce508 ffccc19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel
# `LSH` and `Table` imports are necessary in order for the
# `lsh.pickle` file to load successfully.
from similarity_utils import LSH, BuildLSHTable, Table
seed = 42
# Only runs once when the script is first run.
with open("lsh.pickle", "rb") as handle:
loaded_lsh = pickle.load(handle)
# Load model for computing embeddings.
model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds"
model = AutoModel.from_pretrained(model_ckpt)
lsh_builder = BuildLSHTable(model)
lsh_builder.lsh = loaded_lsh
# Candidate images.
dataset = load_dataset("gjuggler/bird-data")
candidate_dataset = dataset["train"].shuffle(seed=seed)
def query(image, top_k):
results = lsh_builder.query(image)
# Should be a list of string file paths for gr.Gallery to work
images = []
# List of labels for each image in the gallery
labels = []
candidates = []
for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
if idx == top_k:
break
image_id, label = r.split("_")[0], r.split("_")[1]
candidates.append(candidate_dataset[int(image_id)]["image"])
labels.append(f"Label: {label}")
for i, candidate in enumerate(candidates):
filename = f"similar_{i}.png"
candidate.save(filename)
images.append(filename)
# The gallery component can be a list of tuples, where the first element is a path to a file
# and the second element is an optional caption for that image
return list(zip(images, labels))
# You can set the type of gr.Image to be PIL, numpy or str (filepath)
# Not sure what the best for this demo is.
gr.Interface(
query,
inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
outputs=gr.Gallery().style(grid=[3], height="auto"),
examples=[["309.jpg", 5], ["81.jpg", 5], ["93.jpg", 5]],
).launch()
|