File size: 6,421 Bytes
1768dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import requests
import streamlit as st

from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel

import utils

BASELINE_MODEL = "openai/clip-vit-base-patch32"
MODEL_PATH = "flax-community/clip-rsicd-v2"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGES_DIR = "./images"
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json")

@st.cache(allow_output_mutation=True)
def load_example_images():
    example_images = {}
    image_names = os.listdir(IMAGES_DIR)
    for image_name in image_names:
        if image_name.find("_") < 0:
            continue
        image_class = image_name.split("_")[0]
        if image_class in example_images.keys():
            example_images[image_class].append(image_name)
        else:
            example_images[image_class] = [image_name]
    example_image_list = sorted([v[np.random.randint(0, len(v))] 
                                for k, v in example_images.items()][0:10])
    return example_image_list


def get_image_thumbnail(image_filename):
    image = Image.open(os.path.join(IMAGES_DIR, image_filename))
    image = image.resize((100, 100))
    return image


def download_and_prepare_image(image_url):
    try:
        image_raw = requests.get(image_url, stream=True,).raw
        image = Image.open(image_raw).convert("RGB")
        width, height = image.size
        resize_mult = width / 224 if width < height else height / 224
        image = image.resize((int(width // resize_mult), 
                              int(height // resize_mult)))
        width, height = image.size
        left = int((width - 224) // 2)
        top = int((height - 224) // 2)
        right = int((width + 224) // 2)
        bottom = int((height + 224) // 2)
        image = image.crop((left, top, right, bottom))
        return image
    except Exception as e:
        return None

def app():
    filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
    model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
    image2caption = utils.load_captions(CAPTIONS_FILE)

    example_image_list = load_example_images()

    st.title("Retrieve Images given Images")
    st.markdown("""
        This demo shows the image to image retrieval capabilities of this model, i.e., 
        given an image file name as a query, we use our fine-tuned CLIP model 
        to project the query image to the image/caption embedding space and search 
        for nearby images (by cosine similarity) in this space.
        
        Our fine-tuned CLIP model was previously used to generate image vectors for 
        our demo, and NMSLib was used for fast vector access.

        Here are some randomly generated image files from our corpus, that you can 
        find similar images for by selecting the button below it. Alternatively you
        can upload your own image from the Internet.
    """)

    suggest_idx = -1
    col0, col1, col2, col3, col4 = st.columns(5)
    col0.image(get_image_thumbnail(example_image_list[0]))
    col1.image(get_image_thumbnail(example_image_list[1]))
    col2.image(get_image_thumbnail(example_image_list[2]))
    col3.image(get_image_thumbnail(example_image_list[3]))
    col4.image(get_image_thumbnail(example_image_list[4]))
    col0t, col1t, col2t, col3t, col4t = st.columns(5)
    with col0t:
        if st.button("Image-1"):
            suggest_idx = 0
    with col1t:
        if st.button("Image-2"):
            suggest_idx = 1
    with col2t:
        if st.button("Image-3"):
            suggest_idx = 2
    with col3t:
        if st.button("Image-4"):
            suggest_idx = 3
    with col4t:
        if st.button("Image-5"):
            suggest_idx = 4
    col5, col6, col7, col8, col9 = st.columns(5)
    col5.image(get_image_thumbnail(example_image_list[5]))
    col6.image(get_image_thumbnail(example_image_list[6]))
    col7.image(get_image_thumbnail(example_image_list[7]))
    col8.image(get_image_thumbnail(example_image_list[8]))
    col9.image(get_image_thumbnail(example_image_list[9]))
    col5t, col6t, col7t, col8t, col9t = st.columns(5)
    with col5t:
        if st.button("Image-6"):
            suggest_idx = 5
    with col6t:
        if st.button("Image-7"):
            suggest_idx = 6
    with col7t:
        if st.button("Image-8"):
            suggest_idx = 7
    with col8t:
        if st.button("Image-9"):
            suggest_idx = 8
    with col9t:
        if st.button("Image-10"):
            suggest_idx = 9

    image_url = st.text_input(
        "OR provide an image URL",
        value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg")
    
    submit_button = st.button("Find Similar")
    
    if submit_button or suggest_idx > -1:
        image_name = None
        if suggest_idx > -1:
            image_name = example_image_list[suggest_idx]
            image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
        else:
            image = download_and_prepare_image(image_url)
            st.image(image, caption="Input Image")
            st.markdown("---")

        if image is None:
            st.error("Image could not be downloaded, please try another one!")
        else:
            inputs = processor(images=image, return_tensors="jax", padding=True)
            query_vec = model.get_image_features(**inputs)
            query_vec = np.asarray(query_vec)
            ids, distances = index.knnQuery(query_vec, k=11)
            result_filenames = [filenames[id] for id in ids]
            rank = 0
            for result_filename, score in zip(result_filenames, distances):
                if image_name is not None and result_filename == image_name:
                    continue
                caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
                col1, col2, col3 = st.columns([2, 10, 10])
                col1.markdown("{:d}.".format(rank + 1))
                col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
                        caption=caption)
                caption_text = []
                for caption in image2caption[result_filename]:
                    caption_text.append("* {:s}\n".format(caption))
                col3.markdown("".join(caption_text))                       
                rank += 1
                st.markdown("---")
            suggest_idx = -1