import numpy as np from PIL import Image from tensorflow.keras.datasets import cifar10 from huggingface_hub import from_pretrained_keras import gradio as gr def prepare_output(neighbours): """Function to return the image grid based on the nearest neighbours @params neighbours: List of indices of the nearest neighbours""" anchor_near_neighbours = reversed(neighbours) img_grid = Image.new("RGB", (HEIGHT_WIDTH * 5, HEIGHT_WIDTH * 2)) # Image Grid of top-10 neighbours for idx, nn_idx in enumerate(anchor_near_neighbours): img_arr = (np.array(x_test[nn_idx]) * 255).astype(np.uint8) img_grid.paste( Image.fromarray(img_arr, "RGB"), ((idx % 5) * HEIGHT_WIDTH, (idx // 5) * HEIGHT_WIDTH), ) return img_grid def get_nearest_neighbours(img): """Has the inference code to get the nearest neighbours from the model @params img: Image to be fed to the model""" # Pre-process image img = np.expand_dims(img / 255, axis=0) img_x_test = np.append(x_test, img, axis=0) # Get the embeddings and check the cosine distance embeddings = model.predict(img_x_test) gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) near_neighbours = np.argsort(gram_matrix.T)[:, -(NEAR_NEIGHBOURS + 1) :] # Make image grid output img_grid = prepare_output(near_neighbours[-1][:-1]) return np.array(img_grid) if __name__ == "__main__": # Constants HEIGHT_WIDTH = 32 NEAR_NEIGHBOURS = 10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_test = x_test.astype("float32") / 255.0 model = from_pretrained_keras("keras-io/cifar10_metric_learning") examples = ["/examples/boat.jpeg", "/examples/horse.jpeg", "/examples/car.jpeg"] title = "Metric Learning for Image Similarity Search" more_text = """Embeddings for the input image are xomputed using the model trained using the metric learning technique. The nearest neighbours are calculated using the cosine distance and these shown in the image grid.""" description = f"This space uses model trained on CIFAR10 dataset using metric learning.\n\n{more_text}" article = """

Keras Example by Mat Kelcey
Space by Vrinda Prabhu

""" gr.Interface( fn=get_nearest_neighbours, inputs=gr.Image(shape=(32, 32)), # Resize to CIFAR outputs=gr.Image(), examples=examples, article=article, allow_flagging="never", analytics_enabled=False, title=title, description=description, ).launch(enable_queue=True)