Spaces:
Runtime error
Runtime error
File size: 2,182 Bytes
bf3fe47 e336559 eb6e722 bf3fe47 2a88d86 bf3fe47 5834f42 bf3fe47 f0c8df9 9c8a7bc 2a88d86 f0c8df9 2a88d86 f0c8df9 2a88d86 bf3fe47 5834f42 bf3fe47 eb6e722 bf3fe47 f0c8df9 5834f42 f0c8df9 5834f42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import torch
from sentence_transformers.util import semantic_search
from transformers import AutoModel, AutoProcessor
st.title("VitB32 Bert Ko Small Clip Test")
st.markdown("Unsplash data์์ ์
๋ ฅ ํ
์คํธ์ ๊ฐ์ฅ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ๊ฒ์ํฉ๋๋ค.")
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_dual_encoder_model():
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval()
processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip")
return model, processor
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_clip_model():
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval()
processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko")
return model, processor
model_type = st.radio(
"Select model",
["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"],
horizontal=True,
)
if model_type == "Bingsu/clip-vit-base-patch32-ko":
model, processor = get_clip_model()
else:
model, processor = get_dual_encoder_model()
info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")
text = st.text_input("Input Text", value="๊ฒ์ ๊ณ ์์ด")
tokens = processor(text=text, return_tensors="pt")
with torch.no_grad():
text_emb = model.get_text_features(**tokens)
result = semantic_search(text_emb, img_emb, top_k=15)[0]
_result = iter(result)
def get_url() -> str:
# ๋ช๋ช ์ด๋ฏธ์ง๊ฐ info.csv ๋ฐ์ดํฐ์ ์์ต๋๋ค.
while True:
r = next(_result)
photo_id = img_id[r["corpus_id"]]
target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
if len(target_series) == 0:
continue
img_url = target_series.iloc[0]
return img_url
columns = st.columns(3) + st.columns(3)
for col in columns:
img_url = get_url()
col.image(img_url, use_column_width=True)
|