File size: 3,267 Bytes
357b0b8
 
 
 
 
 
f58917e
357b0b8
 
f9d31ee
357b0b8
 
96ac3ab
357b0b8
 
17476c1
357b0b8
 
f9d31ee
 
f58917e
357b0b8
6d88167
357b0b8
 
 
 
 
 
 
 
9cde513
357b0b8
6d88167
 
 
 
 
 
 
 
 
 
862e020
6d88167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357b0b8
6d88167
357b0b8
 
 
 
 
f58917e
 
862e020
f58917e
 
 
 
 
 
 
 
862e020
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st

from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel

import utils

BASELINE_MODEL = "openai/clip-vit-base-patch32"
MODEL_PATH = "flax-community/clip-rsicd-v2"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGES_DIR = "./images"
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json")

def app():
    filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
    model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
    image2caption = utils.load_captions(CAPTIONS_FILE)

    st.title("Retrieve Images given Text")
    st.markdown("""
        This demo shows the image to text retrieval capabilities of this model, i.e., 
        given a text query, we use our fine-tuned CLIP model to project the text query 
        to the image/caption embedding space and search for nearby images (by 
        cosine similarity) in this space.
        
        Our fine-tuned CLIP model was previously used to generate image vectors for 
        our demo, and NMSLib was used for fast vector access.

    """)
    suggested_query = [
        "ships",
        "school house",
        "military installation",
        "mountains",
        "beaches",
        "airports",
        "lakes"
    ]
    st.text("Some suggested queries to start you off with...")
    col0, col1, col2, col3, col4, col5, col6 = st.columns(7)
        # [1, 1.1, 1.3, 1.1, 1, 1, 1])
    suggest_idx = -1
    with col0:
        if st.button(suggested_query[0]):
            suggest_idx = 0
    with col1:
        if st.button(suggested_query[1]):
            suggest_idx = 1
    with col2:
        if st.button(suggested_query[2]):
            suggest_idx = 2
    with col3:
        if st.button(suggested_query[3]):
            suggest_idx = 3
    with col4:
        if st.button(suggested_query[4]):
            suggest_idx = 4
    with col5:
        if st.button(suggested_query[5]):
            suggest_idx = 5
    with col6:
        if st.button(suggested_query[6]):
            suggest_idx = 6
    query = st.text_input("OR enter a text Query:")
    query = suggested_query[suggest_idx] if suggest_idx > -1 else query

    if st.button("Query") or suggest_idx > -1:
        inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
        query_vec = model.get_text_features(**inputs)
        query_vec = np.asarray(query_vec)
        ids, distances = index.knnQuery(query_vec, k=10)
        result_filenames = [filenames[id] for id in ids]
        for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
            caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
            col1, col2, col3 = st.columns([2, 10, 10])
            col1.markdown("{:d}.".format(rank + 1))
            col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
                       caption=caption)
            caption_text = []
            for caption in image2caption[result_filename]:
                caption_text.append("* {:s}\n".format(caption))
            col3.markdown("".join(caption_text))                       
            st.markdown("---")
        suggest_idx = -1