import numpy as np from PIL import Image from tensorflow.keras.datasets import cifar10 from huggingface_hub import from_pretrained_keras import gradio as gr def prepare_output(neighbours): """Function to return the image grid based on the nearest neighbours @params neighbours: List of indices of the nearest neighbours""" anchor_near_neighbours = reversed(neighbours) img_grid ="RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2)) # Image Grid of top-10 neighbours for idx, nn_idx in enumerate(anchor_near_neighbours): img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8) img_grid.paste( Image.fromarray(img_arr, "RGB"), ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH), ) return img_grid def get_nearest_neighbours(img): """Has the inference code to get the nearest neighbours from the model @params img: Image to be fed to the model""" # Pre-process image img = np.expand_dims(img / 255, axis=0) img_x_test = np.append(x_test, img, axis=0) # Get the embeddings and check the cosine distance embeddings = model.predict(img_x_test) gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :] # Make image grid output img_grid = prepare_output(near_neighbours[-1][:-1]) return np.array(img_grid) if __name__ == "__main__": # Constants HEIGHT_WIDTH = 32 NEAR_NEIGHBOURS = 10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_test = x_test.astype("float32") / 255.0 model = from_pretrained_keras("keras-io/cifar10_metric_learning") examples = ["examples/yatch.jpeg", "examples/horse.jpeg", "examples/car.jpeg"] title = "Metric Learning for Image Similarity Search" more_text = """Embeddings for the input image are computed using the model. The nearest neighbours are then calculated using the cosine distance and these shown in the image grid.""" description = f"This space uses model trained on CIFAR10 dataset using metric learning technique.\n\n{more_text}" article = """

Keras Example given by Mat Kelcey
Space by Vrinda Prabhu

""" gr.Interface( fn=get_nearest_neighbours, inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR outputs=gr.Image(), examples=examples, article=article, allow_flagging="never", analytics_enabled=False, title=title, description=description, ).launch(enable_queue=True)