File size: 2,039 Bytes
cf349fd
 
2cf3514
 
f1d50b1
 
2cf3514
f1d50b1
 
 
2cf3514
 
f1d50b1
cf349fd
2cf3514
f1d50b1
cf349fd
2cf3514
 
a811816
cf349fd
a811816
2cf3514
 
f1d50b1
938499b
cf349fd
48a1fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os

import matplotlib.pyplot as plt
import numpy as np
import streamlit as st

from utils import load_index, load_model


def app(model_name):
    images_directory = "images/val2017"
    features_directory = f"features/val2017/{model_name}.tsv"

    files, index = load_index(features_directory)
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Text to Image Search Engine")
    st.markdown(
        """
        This demo explores KoCLIP's use case as a Korean image search engine. We pre-computed embeddings of 5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation using KoCLIP's ViT backbone. Then, given a text query from the user, these image embeddings are ranked based on cosine similarity. Top matches are displayed below.
        
        Example Queries: ์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด (Cat playing on a computer), ๊ธธ ์œ„์—์„œ ๋‹ฌ๋ฆฌ๋Š” ์ž๋™์ฐจ (Car on the road)
    """
    )

    query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด")
    if st.button("์งˆ๋ฌธ (Query)"):
        st.markdown("""---""")
        with st.spinner("Computing..."):
            proc = processor(
                text=[query], images=None, return_tensors="jax", padding=True
            )
            vec = np.asarray(model.get_text_features(**proc))
            ids, dists = index.knnQuery(vec, k=10)
            result_files = map(lambda id: files[id], ids)
            result_imgs, result_captions = [], []
            for file, dist in zip(result_files, dists):
                result_imgs.append(plt.imread(os.path.join(images_directory, file)))
                result_captions.append("Score: {:.3f}".format(1.0 - dist))

            st.image(result_imgs[:3], caption=result_captions[:3], width=200)
            st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
            st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
            st.image(result_imgs[9:], caption=result_captions[9:], width=200)