vdprabhu commited on
Commit
c6d13a8
1 Parent(s): e7ac340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -0
app.py CHANGED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ from tensorflow.keras.datasets import cifar10
5
+
6
+ from huggingface_hub import from_pretrained_keras
7
+ import gradio as gr
8
+
9
+
10
+ def prepare_output(neighbours):
11
+ """Function to return the image grid based on the nearest neighbours
12
+ @params neighbours: List of indices of the nearest neighbours"""
13
+ anchor_near_neighbours = reversed(neighbours)
14
+ img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2))
15
+
16
+ # Image Grid of top-10 neighbours
17
+ for idx, nn_idx in enumerate(anchor_near_neighbours):
18
+ img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8)
19
+ img_grid.paste(
20
+ Image.fromarray(img_arr, "RGB"),
21
+ ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH),
22
+ )
23
+
24
+ return img_grid
25
+
26
+
27
+ def get_nearest_neighbours(img):
28
+ """Has the inference code to get the nearest neighbours from the model
29
+ @params img: Image to be fed to the model"""
30
+
31
+ # Pre-process image
32
+ img = np.expand_dims(img / 255, axis=0)
33
+ img_x_test = np.append(x_test, img, axis=0)
34
+
35
+ # Get the embeddings and check the cosine distance
36
+ embeddings = model.predict(img_x_test)
37
+ gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
38
+ near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :]
39
+
40
+ # Make image grid output
41
+ img_grid = prepare_output(near_neighbours[-1][:-1])
42
+ return np.array(img_grid)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ # Constants
47
+ HEIGHT_WIDTH = 32
48
+ NEAR_NEIGHBOURS = 10
49
+
50
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
51
+ x_test = x_test.astype("float32") / 255.0
52
+
53
+ model = from_pretrained_keras("keras-io/cifar10_metric_learning")
54
+
55
+ examples = ["/examples/boat.jpeg", "/examples/horse.jpeg", "/examples/car.jpeg"]
56
+ title = "Metric Learning for Image Similarity Search"
57
+
58
+ more_text = """Embeddings for the input image are xomputed using the model trained using the metric learning technique.
59
+ The nearest neighbours are calculated using the cosine distance and these shown in the image grid."""
60
+
61
+ description = f"This space uses model trained on CIFAR10 dataset using metric learning.\n\n{more_text}"
62
+
63
+ article = """
64
+ <p style='text-align: center'>
65
+ <a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example by Mat Kelcey</a>
66
+ <br>
67
+ Space by Vrinda Prabhu
68
+ </p>
69
+ """
70
+
71
+ gr.Interface(
72
+ fn=get_nearest_neighbours,
73
+ inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR
74
+ outputs=gr.Image(),
75
+ examples=examples,
76
+ article=article,
77
+ allow_flagging="never",
78
+ analytics_enabled=False,
79
+ title=title,
80
+ description=description,
81
+ ).launch(enable_queue=True)