Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
import nmslib | |
import numpy as np | |
import os | |
import requests | |
import streamlit as st | |
from PIL import Image | |
from transformers import CLIPProcessor, FlaxCLIPModel | |
import utils | |
BASELINE_MODEL = "openai/clip-vit-base-patch32" | |
MODEL_PATH = "flax-community/clip-rsicd-v2" | |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv" | |
IMAGES_DIR = "./images" | |
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json") | |
def load_example_images(): | |
example_images = {} | |
image_names = os.listdir(IMAGES_DIR) | |
for image_name in image_names: | |
if image_name.find("_") < 0: | |
continue | |
image_class = image_name.split("_")[0] | |
if image_class in example_images.keys(): | |
example_images[image_class].append(image_name) | |
else: | |
example_images[image_class] = [image_name] | |
example_image_list = sorted([v[np.random.randint(0, len(v))] | |
for k, v in example_images.items()][0:10]) | |
return example_image_list | |
def get_image_thumbnail(image_filename): | |
image = Image.open(os.path.join(IMAGES_DIR, image_filename)) | |
image = image.resize((100, 100)) | |
return image | |
def download_and_prepare_image(image_url): | |
try: | |
image_raw = requests.get(image_url, stream=True,).raw | |
image = Image.open(image_raw).convert("RGB") | |
width, height = image.size | |
resize_mult = width / 224 if width < height else height / 224 | |
image = image.resize((int(width // resize_mult), | |
int(height // resize_mult))) | |
width, height = image.size | |
left = int((width - 224) // 2) | |
top = int((height - 224) // 2) | |
right = int((width + 224) // 2) | |
bottom = int((height + 224) // 2) | |
image = image.crop((left, top, right, bottom)) | |
return image | |
except Exception as e: | |
return None | |
def app(): | |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE) | |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL) | |
image2caption = utils.load_captions(CAPTIONS_FILE) | |
example_image_list = load_example_images() | |
st.title("Retrieve Images given Images") | |
st.markdown(""" | |
This demo shows the image to image retrieval capabilities of this model, i.e., | |
given an image file name as a query, we use our fine-tuned CLIP model | |
to project the query image to the image/caption embedding space and search | |
for nearby images (by cosine similarity) in this space. | |
Our fine-tuned CLIP model was previously used to generate image vectors for | |
our demo, and NMSLib was used for fast vector access. | |
Here are some randomly generated image files from our corpus, that you can | |
find similar images for by selecting the button below it. Alternatively you | |
can upload your own image from the Internet. | |
""") | |
suggest_idx = -1 | |
col0, col1, col2, col3, col4 = st.columns(5) | |
col0.image(get_image_thumbnail(example_image_list[0])) | |
col1.image(get_image_thumbnail(example_image_list[1])) | |
col2.image(get_image_thumbnail(example_image_list[2])) | |
col3.image(get_image_thumbnail(example_image_list[3])) | |
col4.image(get_image_thumbnail(example_image_list[4])) | |
col0t, col1t, col2t, col3t, col4t = st.columns(5) | |
with col0t: | |
if st.button("Image-1"): | |
suggest_idx = 0 | |
with col1t: | |
if st.button("Image-2"): | |
suggest_idx = 1 | |
with col2t: | |
if st.button("Image-3"): | |
suggest_idx = 2 | |
with col3t: | |
if st.button("Image-4"): | |
suggest_idx = 3 | |
with col4t: | |
if st.button("Image-5"): | |
suggest_idx = 4 | |
col5, col6, col7, col8, col9 = st.columns(5) | |
col5.image(get_image_thumbnail(example_image_list[5])) | |
col6.image(get_image_thumbnail(example_image_list[6])) | |
col7.image(get_image_thumbnail(example_image_list[7])) | |
col8.image(get_image_thumbnail(example_image_list[8])) | |
col9.image(get_image_thumbnail(example_image_list[9])) | |
col5t, col6t, col7t, col8t, col9t = st.columns(5) | |
with col5t: | |
if st.button("Image-6"): | |
suggest_idx = 5 | |
with col6t: | |
if st.button("Image-7"): | |
suggest_idx = 6 | |
with col7t: | |
if st.button("Image-8"): | |
suggest_idx = 7 | |
with col8t: | |
if st.button("Image-9"): | |
suggest_idx = 8 | |
with col9t: | |
if st.button("Image-10"): | |
suggest_idx = 9 | |
image_url = st.text_input( | |
"OR provide an image URL", | |
value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg") | |
submit_button = st.button("Find Similar") | |
if submit_button or suggest_idx > -1: | |
image_name = None | |
if suggest_idx > -1: | |
image_name = example_image_list[suggest_idx] | |
image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name))) | |
else: | |
image = download_and_prepare_image(image_url) | |
st.image(image, caption="Input Image") | |
st.markdown("---") | |
if image is None: | |
st.error("Image could not be downloaded, please try another one!") | |
else: | |
inputs = processor(images=image, return_tensors="jax", padding=True) | |
query_vec = model.get_image_features(**inputs) | |
query_vec = np.asarray(query_vec) | |
ids, distances = index.knnQuery(query_vec, k=11) | |
result_filenames = [filenames[id] for id in ids] | |
rank = 0 | |
for result_filename, score in zip(result_filenames, distances): | |
if image_name is not None and result_filename == image_name: | |
continue | |
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score) | |
col1, col2, col3 = st.columns([2, 10, 10]) | |
col1.markdown("{:d}.".format(rank + 1)) | |
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)), | |
caption=caption) | |
caption_text = [] | |
for caption in image2caption[result_filename]: | |
caption_text.append("* {:s}\n".format(caption)) | |
col3.markdown("".join(caption_text)) | |
rank += 1 | |
st.markdown("---") | |
suggest_idx = -1 | |