koclip / text2image.py
jaketae's picture
feature: add intro page, cleanup descriptions
a811816
raw history blame
No virus
2.04 kB
import os
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
from utils import load_index, load_model
def app(model_name):
images_directory = "images/val2017"
features_directory = f"features/val2017/{model_name}.tsv"
files, index = load_index(features_directory)
model, processor = load_model(f"koclip/{model_name}")
st.title("Text to Image Search Engine")
st.markdown(
"""
This demo explores KoCLIP's use case as a Korean image search engine. We pre-computed embeddings of 5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation using KoCLIP's ViT backbone. Then, given a text query from the user, these image embeddings are ranked based on cosine similarity. Top matches are displayed below.
Example Queries: ์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด (Cat playing on a computer), ๊ธธ ์œ„์—์„œ ๋‹ฌ๋ฆฌ๋Š” ์ž๋™์ฐจ (Car on the road)
"""
)
query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์ปดํ“จํ„ฐํ•˜๋Š” ๊ณ ์–‘์ด")
if st.button("์งˆ๋ฌธ (Query)"):
st.markdown("""---""")
with st.spinner("Computing..."):
proc = processor(
text=[query], images=None, return_tensors="jax", padding=True
)
vec = np.asarray(model.get_text_features(**proc))
ids, dists = index.knnQuery(vec, k=10)
result_files = map(lambda id: files[id], ids)
result_imgs, result_captions = [], []
for file, dist in zip(result_files, dists):
result_imgs.append(plt.imread(os.path.join(images_directory, file)))
result_captions.append("Score: {:.3f}".format(1.0 - dist))
st.image(result_imgs[:3], caption=result_captions[:3], width=200)
st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
st.image(result_imgs[9:], caption=result_captions[9:], width=200)