|
import numpy as np |
|
from PIL import Image |
|
|
|
from tensorflow.keras.datasets import cifar10 |
|
|
|
from huggingface_hub import from_pretrained_keras |
|
import gradio as gr |
|
|
|
|
|
def prepare_output(neighbours): |
|
"""Function to return the image grid based on the nearest neighbours |
|
@params neighbours: List of indices of the nearest neighbours""" |
|
anchor_near_neighbours = reversed(neighbours) |
|
img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2)) |
|
|
|
|
|
for idx, nn_idx in enumerate(anchor_near_neighbours): |
|
img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8) |
|
img_grid.paste( |
|
Image.fromarray(img_arr, "RGB"), |
|
((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH), |
|
) |
|
|
|
return img_grid |
|
|
|
|
|
def get_nearest_neighbours(img): |
|
"""Has the inference code to get the nearest neighbours from the model |
|
@params img: Image to be fed to the model""" |
|
|
|
|
|
img = np.expand_dims(img / 255, axis=0) |
|
img_x_test = np.append(x_test, img, axis=0) |
|
|
|
|
|
embeddings = model.predict(img_x_test) |
|
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) |
|
near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :] |
|
|
|
|
|
img_grid = prepare_output(near_neighbours[-1][:-1]) |
|
return np.array(img_grid) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
HEIGHT_WIDTH = 32 |
|
NEAR_NEIGHBOURS = 10 |
|
|
|
(x_train, y_train), (x_test, y_test) = cifar10.load_data() |
|
x_test = x_test.astype("float32") / 255.0 |
|
|
|
model = from_pretrained_keras("keras-io/cifar10_metric_learning") |
|
|
|
examples = ["/examples/boat.jpeg", "/examples/horse.jpeg", "/examples/car.jpeg"] |
|
title = "Metric Learning for Image Similarity Search" |
|
|
|
more_text = """Embeddings for the input image are xomputed using the model trained using the metric learning technique. |
|
The nearest neighbours are calculated using the cosine distance and these shown in the image grid.""" |
|
|
|
description = f"This space uses model trained on CIFAR10 dataset using metric learning.\n\n{more_text}" |
|
|
|
article = """ |
|
<p style='text-align: center'> |
|
<a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example by Mat Kelcey</a> |
|
<br> |
|
Space by Vrinda Prabhu |
|
</p> |
|
""" |
|
|
|
gr.Interface( |
|
fn=get_nearest_neighbours, |
|
inputs=gr.Image(shape=(32, 32)), |
|
outputs=gr.Image(), |
|
examples=examples, |
|
article=article, |
|
allow_flagging="never", |
|
analytics_enabled=False, |
|
title=title, |
|
description=description, |
|
).launch(enable_queue=True) |
|
|