SpaceVector_v0 / text_to_image.py
LayBraid
:construction: change id to index
cd52a4f
import json
import os
import numpy as np
import streamlit as st
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
import nmslib
def load_index(image_vector_file):
filenames, image_vecs = [], []
fvec = open(image_vector_file, "r")
for line in fvec:
cols = line.strip().split(' ')
filename = cols[0]
image_vec = np.array([float(x) for x in cols[1].split(',')])
filenames.append(filename)
image_vecs.append(image_vec)
V = np.array(image_vecs)
index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(V)
index.createIndex({'post': 2}, print_progress=True)
return filenames, index
def load_captions(caption_file):
image2caption = {}
with open(caption_file, "r") as fcap:
for line in fcap:
data = json.loads(line.strip())
filename = data["filename"]
captions = data["captions"]
image2caption[filename] = captions
return image2caption
def get_image(text, number):
model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv")
image2caption = load_captions("./images/test-captions.json")
inputs = processor(text=[text], images=None, return_tensors="jax", padding=True)
vector = model.get_text_features(**inputs)
vector = np.asarray(vector)
ids, distances = index.knnQuery(vector, k=number)
result_filenames = [filename[index] for index in ids]
for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
col1, col2, col3 = st.columns([2, 10, 10])
col1.markdown("{:d}.".format(rank + 1))
col2.image(Image.open(os.path.join("./images", result_filename)),
caption=caption)
# caption_text = []
# for caption in image2caption[result_filename]:
# caption_text.append("* {:s}".format(caption))
# col3.markdown("".join(caption_text))
st.markdown("---")
suggest_idx = -1
def app():
st.title("Welcome to Space Vector")
st.text("You want search an image with given text.")
text = st.text_input("Enter text: ")
number = st.number_input("Enter number of images result: ", min_value=1, max_value=10)
if st.button("Search"):
get_image(text, number)