vdprabhu's picture
Update app.py
c6d13a8
raw history blame
No virus
2.77 kB
import numpy as np
from PIL import Image
from tensorflow.keras.datasets import cifar10
from huggingface_hub import from_pretrained_keras
import gradio as gr
def prepare_output(neighbours):
"""Function to return the image grid based on the nearest neighbours
@params neighbours: List of indices of the nearest neighbours"""
anchor_near_neighbours = reversed(neighbours)
img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2))
# Image Grid of top-10 neighbours
for idx, nn_idx in enumerate(anchor_near_neighbours):
img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8)
img_grid.paste(
Image.fromarray(img_arr, "RGB"),
((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH),
)
return img_grid
def get_nearest_neighbours(img):
"""Has the inference code to get the nearest neighbours from the model
@params img: Image to be fed to the model"""
# Pre-process image
img = np.expand_dims(img / 255, axis=0)
img_x_test = np.append(x_test, img, axis=0)
# Get the embeddings and check the cosine distance
embeddings = model.predict(img_x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :]
# Make image grid output
img_grid = prepare_output(near_neighbours[-1][:-1])
return np.array(img_grid)
if __name__ == "__main__":
# Constants
HEIGHT_WIDTH = 32
NEAR_NEIGHBOURS = 10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_test = x_test.astype("float32") / 255.0
model = from_pretrained_keras("keras-io/cifar10_metric_learning")
examples = ["/examples/boat.jpeg", "/examples/horse.jpeg", "/examples/car.jpeg"]
title = "Metric Learning for Image Similarity Search"
more_text = """Embeddings for the input image are xomputed using the model trained using the metric learning technique.
The nearest neighbours are calculated using the cosine distance and these shown in the image grid."""
description = f"This space uses model trained on CIFAR10 dataset using metric learning.\n\n{more_text}"
article = """
<p style='text-align: center'>
<a href='https://keras.io/examples/vision/metric_learning/' target='_blank'>Keras Example by Mat Kelcey</a>
<br>
Space by Vrinda Prabhu
</p>
"""
gr.Interface(
fn=get_nearest_neighbours,
inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR
outputs=gr.Image(),
examples=examples,
article=article,
allow_flagging="never",
analytics_enabled=False,
title=title,
description=description,
).launch(enable_queue=True)