File size: 2,501 Bytes
cf349fd
 
2cf3514
 
f1d50b1
 
2cf3514
f1d50b1
 
 
2cf3514
 
f1d50b1
cf349fd
2cf3514
f1d50b1
cf349fd
2cf3514
 
cf349fd
 
 
 
 
 
 
c0d4b59
 
cf349fd
938499b
2cf3514
 
f1d50b1
938499b
cf349fd
 
 
 
 
 
 
 
84c806e
cf349fd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import matplotlib.pyplot as plt
import numpy as np
import streamlit as st

from utils import load_index, load_model


def app(model_name):
    images_directory = "images/val2017"
    features_directory = f"features/val2017/{model_name}.tsv"

    files, index = load_index(features_directory)
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Text to Image Search Engine")
    st.markdown(
        """
        This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
        5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP 
        vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
        are displayed below.
        
        KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and
        Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence).
        Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder.
        Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
        
        Example Queries : ์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด(Cat playing on a computer), ๊ธธ ์œ„์—์„œ ๋‹ฌ๋ฆฌ๋Š” ์ž๋™์ฐจ(Car running on the road)
    """
    )

    query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด")
    if st.button("์งˆ๋ฌธ (Query)"):
        proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
        vec = np.asarray(model.get_text_features(**proc))
        ids, dists = index.knnQuery(vec, k=10)
        result_files = map(lambda id: files[id], ids)
        result_imgs, result_captions = [], []
        for file, dist in zip(result_files, dists):
            result_imgs.append(plt.imread(os.path.join(images_directory, file)))
            result_captions.append("Score: {:.3f}".format(1.0 - dist))

        st.image(result_imgs[:3], caption=result_captions[:3], width=200)
        st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
        st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
        st.image(result_imgs[9:], caption=result_captions[9:], width=200)