import pickle import gradio as gr from datasets import load_dataset from transformers import AutoModel # `LSH` and `Table` imports are necessary in order for the # `lsh.pickle` file to load successfully. from similarity_utils import LSH, BuildLSHTable, Table seed = 42 # Only runs once when the script is first run. with open("lsh.pickle", "rb") as handle: loaded_lsh = pickle.load(handle) # Load model for computing embeddings. model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds" model = AutoModel.from_pretrained(model_ckpt) lsh_builder = BuildLSHTable(model) lsh_builder.lsh = loaded_lsh # Candidate images. dataset = load_dataset("gjuggler/bird-data") num_samples = 10000 seed = 42 candidate_dataset= dataset["train"].shuffle(seed=seed).select(range(num_samples)) #candidate_dataset = dataset["train"].shuffle(seed=seed) def query(image, top_k): results = lsh_builder.query(image) # Should be a list of string file paths for gr.Gallery to work images = [] # List of labels for each image in the gallery labels = [] candidates = [] for idx, r in enumerate(sorted(results, key=results.get, reverse=True)): if idx == top_k: break image_id, label = r.split("_")[0], r.split("_")[1] candidates.append(candidate_dataset[int(image_id)]["image"]) labels.append(f"Label: {label}") for i, candidate in enumerate(candidates): filename = f"similar_{i}.png" candidate.save(filename) images.append(filename) # The gallery component can be a list of tuples, where the first element is a path to a file # and the second element is an optional caption for that image return list(zip(images, labels)) # You can set the type of gr.Image to be PIL, numpy or str (filepath) # Not sure what the best for this demo is. gr.Interface( query, inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)], outputs=gr.Gallery().style(grid=[3], height="auto"), examples=[["0eea9ce8757c431e99b05afbc2bfbee2.jpg", 5], ["8cedc91ec6584e7c847877d4f7ac4d65.jpg", 5], ["0b47c4950bbf4b999f0bcd3e5d61dc94.jpg", 5]], ).launch()