File size: 2,773 Bytes
c6d13a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from PIL import Image

from tensorflow.keras.datasets import cifar10

from huggingface_hub import from_pretrained_keras
import gradio as gr


def prepare_output(neighbours):
    """Function to return the image grid based on the nearest neighbours
    @params neighbours: List of indices of the nearest neighbours"""
    anchor_near_neighbours = reversed(neighbours)
    img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2))

    # Image Grid of top-10 neighbours
    for idx, nn_idx in enumerate(anchor_near_neighbours):
        img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8)
        img_grid.paste(
            Image.fromarray(img_arr, "RGB"),
            ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH),
        )

    return img_grid


def get_nearest_neighbours(img):
    """Has the inference code to get the nearest neighbours from the model
    @params img: Image to be fed to the model"""

    # Pre-process image
    img = np.expand_dims(img / 255, axis=0)
    img_x_test = np.append(x_test, img, axis=0)

    # Get the embeddings and check the cosine distance
    embeddings = model.predict(img_x_test)
    gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
    near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :]

    # Make image grid output
    img_grid = prepare_output(near_neighbours[-1][:-1])
    return np.array(img_grid)


if __name__ == "__main__":
    # Constants
    HEIGHT_WIDTH = 32
    NEAR_NEIGHBOURS = 10

    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_test = x_test.astype("float32") / 255.0

    model = from_pretrained_keras("keras-io/cifar10_metric_learning")

    examples = ["/examples/boat.jpeg", "/examples/horse.jpeg", "/examples/car.jpeg"]
    title = "Metric Learning for Image Similarity Search"

    more_text = """Embeddings for the input image are xomputed using the model trained using the metric learning technique. 
                   The nearest neighbours are calculated using the cosine distance and these shown in the image grid."""

    description = f"This space uses model trained on CIFAR10 dataset using metric learning.\n\n{more_text}"

    article = """
    <p style='text-align: center'>
        <a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example by Mat Kelcey</a>
        <br>
        Space by Vrinda Prabhu
        </p>
    """

    gr.Interface(
        fn=get_nearest_neighbours,
        inputs=gr.Image(shape=(32, 32)),  # Resize to CIFAR
        outputs=gr.Image(),
        examples=examples,
        article=article,
        allow_flagging="never",
        analytics_enabled=False,
        title=title,
        description=description,
    ).launch(enable_queue=True)