File size: 1,251 Bytes
bf3fe47
 
 
 
 
 
 
 
 
 
5834f42
bf3fe47
5834f42
 
 
 
 
 
 
bf3fe47
 
 
 
 
 
5834f42
bf3fe47
 
5834f42
 
bf3fe47
5834f42
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import pickle

import streamlit as st
import numpy as np
import pandas as pd
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from sentence_transformers.util import semantic_search


st.title("VitB32 Bert Ko Small Clip Test")
st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")

with st.spinner("Loading model..."):
    model = VisionTextDualEncoderModel.from_pretrained(
        "Bingsu/vitB32_bert_ko_small_clip"
    )
    processor = VisionTextDualEncoderProcessor.from_pretrained(
        "Bingsu/vitB32_bert_ko_small_clip"
    )

info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
    img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")

text = st.text_input("Input Text", value="๊ฒ€์€ ๊ณ ์–‘์ด")
tokens = processor(text=text, return_tensors="pt")

with st.spinner("Predicting..."):
    text_emb = model.get_text_features(**tokens)

result = semantic_search(text_emb, img_emb, top_k=6)[0]

columns = st.columns(3) + st.columns(3)
for i, col in enumerate(columns):
    photo_id = img_id[result[i]["corpus_id"]]
    img_url = info.loc[info["photo_id"] == photo_id, "photo_image_url"].values[0]
    col.image(img_url, use_column_width=True)