vdprabhu's picture
Update app.py
94c41ad
import numpy as np
from PIL import Image
from tensorflow.keras.datasets import cifar10
from huggingface_hub import from_pretrained_keras
import gradio as gr
def prepare_output(neighbours):
"""Function to return the image grid based on the nearest neighbours
@params neighbours: List of indices of the nearest neighbours"""
anchor_near_neighbours = reversed(neighbours)
img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2))
# Image Grid of top-10 neighbours
for idx, nn_idx in enumerate(anchor_near_neighbours):
img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8)
img_grid.paste(
Image.fromarray(img_arr, "RGB"),
((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH),
)
return img_grid
def get_nearest_neighbours(img):
"""Has the inference code to get the nearest neighbours from the model
@params img: Image to be fed to the model"""
# Pre-process image
img = np.expand_dims(img / 255, axis=0)
img_x_test = np.append(x_test, img, axis=0)
# Get the embeddings and check the cosine distance
embeddings = model.predict(img_x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :]
# Make image grid output
img_grid = prepare_output(near_neighbours[-1][:-1])
return np.array(img_grid)
if __name__ == "__main__":
# Constants
HEIGHT_WIDTH = 32
NEAR_NEIGHBOURS = 10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_test = x_test.astype("float32") / 255.0
model = from_pretrained_keras("keras-io/cifar10_metric_learning")
examples = ["examples/yatch.jpeg", "examples/horse.jpeg", "examples/car.jpeg"]
title = "Metric Learning for Image Similarity Search"
more_text = """Embeddings for the input image are computed using the model. The nearest neighbours are then calculated
using cosine distance. These are shown here as an image grid."""
description = f"This space uses model trained on CIFAR10 dataset using metric learning technique.\n{more_text}\n\n"
article = """
<p style='text-align: center'>
<a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example given by Mat Kelcey</a>
<br>
Space by Vrinda Prabhu
</p>
"""
gr.Interface(
fn=get_nearest_neighbours,
inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR
outputs=gr.Image(),
examples=examples,
article=article,
allow_flagging="never",
analytics_enabled=False,
title=title,
description=description,
).launch(enable_queue=True)