jusancp99 commited on
Commit
ffccc19
1 Parent(s): 7314049

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+ from transformers import AutoModel
6
+
7
+ # `LSH` and `Table` imports are necessary in order for the
8
+ # `lsh.pickle` file to load successfully.
9
+ from similarity_utils import LSH, BuildLSHTable, Table
10
+
11
+ seed = 42
12
+
13
+ # Only runs once when the script is first run.
14
+ with open("lsh.pickle", "rb") as handle:
15
+ loaded_lsh = pickle.load(handle)
16
+
17
+ # Load model for computing embeddings.
18
+ model_ckpt = "gjuggler/swin-tiny-patch4-window7-224-finetuned-birds"
19
+ model = AutoModel.from_pretrained(model_ckpt)
20
+ lsh_builder = BuildLSHTable(model)
21
+ lsh_builder.lsh = loaded_lsh
22
+
23
+ # Candidate images.
24
+ dataset = load_dataset("gjuggler/bird-data")
25
+ candidate_dataset = dataset["train"].shuffle(seed=seed)
26
+
27
+ def query(image, top_k):
28
+ results = lsh_builder.query(image)
29
+
30
+ # Should be a list of string file paths for gr.Gallery to work
31
+ images = []
32
+ # List of labels for each image in the gallery
33
+ labels = []
34
+
35
+ candidates = []
36
+
37
+ for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
38
+ if idx == top_k:
39
+ break
40
+ image_id, label = r.split("_")[0], r.split("_")[1]
41
+ candidates.append(candidate_dataset[int(image_id)]["image"])
42
+ labels.append(f"Label: {label}")
43
+
44
+ for i, candidate in enumerate(candidates):
45
+ filename = f"similar_{i}.png"
46
+ candidate.save(filename)
47
+ images.append(filename)
48
+
49
+ # The gallery component can be a list of tuples, where the first element is a path to a file
50
+ # and the second element is an optional caption for that image
51
+ return list(zip(images, labels))
52
+
53
+
54
+ title = "Fetch similar birds"
55
+
56
+ # You can set the type of gr.Image to be PIL, numpy or str (filepath)
57
+ # Not sure what the best for this demo is.
58
+ gr.Interface(
59
+ query,
60
+ inputs=[gr.Image(type="pil"), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
61
+ outputs=gr.Gallery().style(grid=[3], height="auto"),
62
+ # Filenames denote the integer labels. Know here: https://hf.co/datasets/beans
63
+ title=title,
64
+ description=description,
65
+ examples=[["0.png", 5], ["1.png", 5], ["2.png", 5]],
66
+ ).launch()
67
+