import os import matplotlib.pyplot as plt import numpy as np import streamlit as st from utils import load_index, load_model def app(model_name): images_directory = "images/val2017" features_directory = f"features/val2017/{model_name}.tsv" files, index = load_index(features_directory) model, processor = load_model(f"koclip/{model_name}") st.title("Text to Image Search Engine") st.markdown( """ This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of 5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images are displayed below. KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder. Example Queries : 컴퓨터하는 고양이(Cat playing on a computer), 길 위에서 달리는 자동차(Car running on the road) """ ) query = st.text_input("한글 질문을 적어주세요 (Korean Text Query) :", value="컴퓨터하는 고양이") if st.button("질문 (Query)"): proc = processor(text=[query], images=None, return_tensors="jax", padding=True) vec = np.asarray(model.get_text_features(**proc)) ids, dists = index.knnQuery(vec, k=10) result_files = map(lambda id: files[id], ids) result_imgs, result_captions = [], [] for file, dist in zip(result_files, dists): result_imgs.append(plt.imread(os.path.join(images_directory, file))) result_captions.append("{:s} (유사도: {:.3f})".format(file, 1.0 - dist)) st.image(result_imgs[:3], caption=result_captions[:3], width=200) st.image(result_imgs[3:6], caption=result_captions[3:6], width=200) st.image(result_imgs[6:9], caption=result_captions[6:9], width=200) st.image(result_imgs[9:], caption=result_captions[9:], width=200)