import json import os import numpy as np import streamlit as st from PIL import Image from transformers import CLIPProcessor, FlaxCLIPModel import nmslib def load_index(image_vector_file): filenames, image_vecs = [], [] fvec = open(image_vector_file, "r") for line in fvec: cols = line.strip().split(' ') filename = cols[0] image_vec = np.array([float(x) for x in cols[1].split(',')]) filenames.append(filename) image_vecs.append(image_vec) V = np.array(image_vecs) index = nmslib.init(method='hnsw', space='cosinesimil') index.addDataPointBatch(V) index.createIndex({'post': 2}, print_progress=True) return filenames, index def load_captions(caption_file): image2caption = {} with open(caption_file, "r") as fcap: for line in fcap: data = json.loads(line.strip()) filename = data["filename"] captions = data["captions"] image2caption[filename] = captions return image2caption def get_image(text, number): model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv") image2caption = load_captions("./images/test-captions.json") inputs = processor(text=[text], images=None, return_tensors="jax", padding=True) vector = model.get_text_features(**inputs) vector = np.asarray(vector) ids, distances = index.knnQuery(vector, k=number) result_filenames = [filename[index] for index in ids] for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)): caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score) col1, col2, col3 = st.columns([2, 10, 10]) col1.markdown("{:d}.".format(rank + 1)) col2.image(Image.open(os.path.join("./images", result_filename)), caption=caption) # caption_text = [] # for caption in image2caption[result_filename]: # caption_text.append("* {:s}".format(caption)) # col3.markdown("".join(caption_text)) st.markdown("---") suggest_idx = -1 def app(): st.title("Welcome to Space Vector") st.text("You want search an image with given text.") text = st.text_input("Enter text: ") number = st.number_input("Enter number of images result: ", min_value=1, max_value=10) if st.button("Search"): get_image(text, number)