import matplotlib.pyplot as plt import nmslib import numpy as np import os import streamlit as st from PIL import Image from transformers import CLIPProcessor, FlaxCLIPModel BASELINE_MODEL = "openai/clip-vit-base-patch32" # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1" MODEL_PATH = "flax-community/clip-rsicd-v2" # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv" # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" # IMAGES_DIR = "/home/shared/data/rsicd_images" IMAGES_DIR = "./images" @st.cache(allow_output_mutation=True) def load_index(): filenames, image_vecs = [], [] fvec = open(IMAGE_VECTOR_FILE, "r") for line in fvec: cols = line.strip().split('\t') filename = cols[0] image_vec = np.array([float(x) for x in cols[1].split(',')]) filenames.append(filename) image_vecs.append(image_vec) V = np.array(image_vecs) index = nmslib.init(method='hnsw', space='cosinesimil') index.addDataPointBatch(V) index.createIndex({'post': 2}, print_progress=True) return filenames, index @st.cache(allow_output_mutation=True) def load_model(): model = FlaxCLIPModel.from_pretrained(MODEL_PATH) processor = CLIPProcessor.from_pretrained(BASELINE_MODEL) return model, processor @st.cache(allow_output_mutation=True) def load_example_images(): example_images = {} image_names = os.listdir(IMAGES_DIR) for image_name in image_names: if image_name.find("_") < 0: continue image_class = image_name.split("_")[0] if image_class in example_images.keys(): example_images[image_class].append(image_name) else: example_images[image_class] = [image_name] return example_images def app(): filenames, index = load_index() model, processor = load_model() example_images = load_example_images() example_image_list = sorted([v[np.random.randint(0, len(v))] for k, v in example_images.items()][0:10]) st.title("Image to Image Retrieval") st.markdown(""" The CLIP model from OpenAI is trained in a self-supervised manner using contrastive learning to project images and caption text onto a common embedding space. We have fine-tuned the model using the RSICD dataset (10k images and ~50k captions from the remote sensing domain). This demo shows the image to image retrieval capabilities of this model, i.e., given an image file name as a query, we use our fine-tuned CLIP model to project the query image to the image/caption embedding space and search for nearby images (by cosine similarity) in this space. Our fine-tuned CLIP model was previously used to generate image vectors for our demo, and NMSLib was used for fast vector access. Here are some randomly generated image files from our corpus. You can copy paste one of these below or use one from the results of a text to image search -- {:s} """.format(", ".join("`{:s}`".format(example) for example in example_image_list))) image_name = st.text_input("Provide an Image File Name") submit_button = st.button("Find Similar") if submit_button: image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name))) inputs = processor(images=image, return_tensors="jax", padding=True) query_vec = model.get_image_features(**inputs) query_vec = np.asarray(query_vec) ids, distances = index.knnQuery(query_vec, k=11) result_filenames = [filenames[id] for id in ids] images, captions = [], [] for result_filename, score in zip(result_filenames, distances): if result_filename == image_name: continue images.append( plt.imread(os.path.join(IMAGES_DIR, result_filename))) captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score)) images = images[0:10] captions = captions[0:10] st.image(images[0:3], caption=captions[0:3]) st.image(images[3:6], caption=captions[3:6]) st.image(images[6:9], caption=captions[6:9]) st.image(images[9:], caption=captions[9:])