clip-rsicd-demo / dashboard_text2image.py
Sujit Pal
fix: replace st.beta_column with st.column
862e020
raw history blame
No virus
3.27 kB
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
import utils
BASELINE_MODEL = "openai/clip-vit-base-patch32"
MODEL_PATH = "flax-community/clip-rsicd-v2"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGES_DIR = "./images"
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json")
def app():
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
image2caption = utils.load_captions(CAPTIONS_FILE)
st.title("Retrieve Images given Text")
st.markdown("""
This demo shows the image to text retrieval capabilities of this model, i.e.,
given a text query, we use our fine-tuned CLIP model to project the text query
to the image/caption embedding space and search for nearby images (by
cosine similarity) in this space.
Our fine-tuned CLIP model was previously used to generate image vectors for
our demo, and NMSLib was used for fast vector access.
""")
suggested_query = [
"ships",
"school house",
"military installation",
"mountains",
"beaches",
"airports",
"lakes"
]
st.text("Some suggested queries to start you off with...")
col0, col1, col2, col3, col4, col5, col6 = st.columns(7)
# [1, 1.1, 1.3, 1.1, 1, 1, 1])
suggest_idx = -1
with col0:
if st.button(suggested_query[0]):
suggest_idx = 0
with col1:
if st.button(suggested_query[1]):
suggest_idx = 1
with col2:
if st.button(suggested_query[2]):
suggest_idx = 2
with col3:
if st.button(suggested_query[3]):
suggest_idx = 3
with col4:
if st.button(suggested_query[4]):
suggest_idx = 4
with col5:
if st.button(suggested_query[5]):
suggest_idx = 5
with col6:
if st.button(suggested_query[6]):
suggest_idx = 6
query = st.text_input("OR enter a text Query:")
query = suggested_query[suggest_idx] if suggest_idx > -1 else query
if st.button("Query") or suggest_idx > -1:
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
query_vec = model.get_text_features(**inputs)
query_vec = np.asarray(query_vec)
ids, distances = index.knnQuery(query_vec, k=10)
result_filenames = [filenames[id] for id in ids]
for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
col1, col2, col3 = st.columns([2, 10, 10])
col1.markdown("{:d}.".format(rank + 1))
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
caption=caption)
caption_text = []
for caption in image2caption[result_filename]:
caption_text.append("* {:s}\n".format(caption))
col3.markdown("".join(caption_text))
st.markdown("---")
suggest_idx = -1