File size: 2,769 Bytes
c6d13a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e9c9e7
c6d13a8
 
5859bbc
 
c6d13a8
5859bbc
c6d13a8
 
4dd1784
 
 
 
c6d13a8
 
 
 
 
 
 
 
 
 
 
 
 
4dd1784
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from PIL import Image

from tensorflow.keras.datasets import cifar10

from huggingface_hub import from_pretrained_keras
import gradio as gr


def prepare_output(neighbours):
    """Function to return the image grid based on the nearest neighbours
    @params neighbours: List of indices of the nearest neighbours"""
    anchor_near_neighbours = reversed(neighbours)
    img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2))

    # Image Grid of top-10 neighbours
    for idx, nn_idx in enumerate(anchor_near_neighbours):
        img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8)
        img_grid.paste(
            Image.fromarray(img_arr, "RGB"),
            ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH),
        )

    return img_grid


def get_nearest_neighbours(img):
    """Has the inference code to get the nearest neighbours from the model
    @params img: Image to be fed to the model"""

    # Pre-process image
    img = np.expand_dims(img / 255, axis=0)
    img_x_test = np.append(x_test, img, axis=0)

    # Get the embeddings and check the cosine distance
    embeddings = model.predict(img_x_test)
    gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
    near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :]

    # Make image grid output
    img_grid = prepare_output(near_neighbours[-1][:-1])
    return np.array(img_grid)


if __name__ == "__main__":
    # Constants
    HEIGHT_WIDTH = 32
    NEAR_NEIGHBOURS = 10

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_test = x_test.astype("float32") / 255.0

    model = from_pretrained_keras("keras-io/cifar10_metric_learning")

    examples = ["examples/boat.jpeg", "examples/horse.jpeg", "examples/car.jpeg"]
    title = "Metric Learning for Image Similarity Search"

    more_text = """Embeddings for the input image are computed using the model. 
                   The nearest neighbours are then calculated using the cosine distance and these shown in the image grid."""

    description = f"This space uses model trained on CIFAR10 dataset using metric learning technique.\n\n{more_text}"

    article = """
        <p style='text-align: center'>
            <a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example by Mat Kelcey</a>
            <br>
            Space by Vrinda Prabhu
        </p>
    """

    gr.Interface(
        fn=get_nearest_neighbours,
        inputs=gr.Image(shape=(32, 32)),  # Resize to CIFAR
        outputs=gr.Image(),
        examples=examples,
        article=article,
        allow_flagging="never",
        analytics_enabled=False,
        title=title,
        description=description,
    ).launch(enable_queue=True, debug=True)