from huggingface_hub import from_pretrained_keras import numpy as np import json import gradio as gr import tensorflow_text import tensorflow_addons import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.image as mpimg # load config with open("image_paths.json", 'r') as f: image_paths = json.load(f) image_embeddings = np.load("image_embeddings.npy") text_encoder = from_pretrained_keras("keras-io/dual-encoder-image-search") def find_matches(image_paths, image_embeddings, queries, k=9, normalize=True): # Get the embedding for the query. query_embedding = text_encoder(tf.convert_to_tensor(queries)) # Normalize the query and the image embeddings. if normalize: image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1) query_embedding = tf.math.l2_normalize(query_embedding, axis=1) # Compute the dot product between the query and the image embeddings. dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True) # Retrieve top k indices. results = tf.math.top_k(dot_similarity, k).indices.numpy() # Return matching image paths. return [[image_paths[idx] for idx in indices] for indices in results] def inference(query): matches = find_matches(image_paths, image_embeddings, [query], normalize=True)[0] plt.figure(figsize=(20, 20)) for i in range(9): ax = plt.subplot(3, 3, i + 1) plt.imshow(mpimg.imread(matches[i])) plt.axis("off") plt.savefig("img.png") return "img.png" examples= ['a family standing next to the ocean on a sandy beach with a surf board', 'a group of people sitting in an audience with pen and paper', 'a couple of cows that are in some grass' ] gr.Interface( fn=inference, title="Natural language image search with a Dual Encoder", description = "Implementation of a dual encoder model for retrieving images that match natural language queries (Note: for demo purposes, only 1k images were used as search space)", inputs="text", examples=examples, outputs="image", cache_examples=False, article = "Author: Vu Minh Chien. Based on the keras example from Khalid Salama", ).launch(debug=True, enable_queue=True)