Spaces:
Runtime error
Runtime error
import pickle | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
from sentence_transformers.util import semantic_search | |
from transformers import AutoModel, AutoProcessor | |
st.title("VitB32 Bert Ko Small Clip Test") | |
st.markdown("Unsplash data์์ ์ ๋ ฅ ํ ์คํธ์ ๊ฐ์ฅ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ํฉ๋๋ค.") | |
def get_dual_encoder_model(): | |
with st.spinner("Loading model..."): | |
model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval() | |
processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip") | |
return model, processor | |
def get_clip_model(): | |
with st.spinner("Loading model..."): | |
model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval() | |
processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko") | |
return model, processor | |
model_type = st.radio( | |
"Select model", | |
["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"], | |
horizontal=True, | |
) | |
if model_type == "Bingsu/clip-vit-base-patch32-ko": | |
model, processor = get_clip_model() | |
else: | |
model, processor = get_dual_encoder_model() | |
info = pd.read_csv("info.csv") | |
with open("img_id.pkl", "rb") as f: | |
img_id = pickle.load(f) | |
img_emb = np.load("img_emb.npy") | |
text = st.text_input("Input Text", value="๊ฒ์ ๊ณ ์์ด") | |
tokens = processor(text=text, return_tensors="pt") | |
with torch.no_grad(): | |
text_emb = model.get_text_features(**tokens) | |
result = semantic_search(text_emb, img_emb, top_k=15)[0] | |
_result = iter(result) | |
def get_url() -> str: | |
# ๋ช๋ช ์ด๋ฏธ์ง๊ฐ info.csv ๋ฐ์ดํฐ์ ์์ต๋๋ค. | |
while True: | |
r = next(_result) | |
photo_id = img_id[r["corpus_id"]] | |
target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"] | |
if len(target_series) == 0: | |
continue | |
img_url = target_series.iloc[0] | |
return img_url | |
columns = st.columns(3) + st.columns(3) | |
for col in columns: | |
img_url = get_url() | |
col.image(img_url, use_column_width=True) | |