File size: 2,182 Bytes
bf3fe47
 
 
 
e336559
eb6e722
bf3fe47
2a88d86
bf3fe47
 
5834f42
bf3fe47
f0c8df9
9c8a7bc
2a88d86
f0c8df9
2a88d86
 
f0c8df9
 
 
2a88d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3fe47
 
 
 
 
 
5834f42
bf3fe47
eb6e722
 
 
bf3fe47
f0c8df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5834f42
 
f0c8df9
 
5834f42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pickle

import numpy as np
import pandas as pd
import streamlit as st
import torch
from sentence_transformers.util import semantic_search
from transformers import AutoModel, AutoProcessor

st.title("VitB32 Bert Ko Small Clip Test")
st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_dual_encoder_model():
    with st.spinner("Loading model..."):
        model = AutoModel.from_pretrained("Bingsu/vitB32_bert_ko_small_clip").eval()
        processor = AutoProcessor.from_pretrained("Bingsu/vitB32_bert_ko_small_clip")
    return model, processor


@st.cache(allow_output_mutation=True, show_spinner=False)
def get_clip_model():
    with st.spinner("Loading model..."):
        model = AutoModel.from_pretrained("Bingsu/clip-vit-base-patch32-ko").eval()
        processor = AutoProcessor.from_pretrained("Bingsu/clip-vit-base-patch32-ko")
    return model, processor


model_type = st.radio(
    "Select model",
    ["Bingsu/clip-vit-base-patch32-ko", "Bingsu/vitB32_bert_ko_small_clip"],
    horizontal=True,
)

if model_type == "Bingsu/clip-vit-base-patch32-ko":
    model, processor = get_clip_model()
else:
    model, processor = get_dual_encoder_model()

info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
    img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")

text = st.text_input("Input Text", value="๊ฒ€์€ ๊ณ ์–‘์ด")
tokens = processor(text=text, return_tensors="pt")

with torch.no_grad():
    text_emb = model.get_text_features(**tokens)

result = semantic_search(text_emb, img_emb, top_k=15)[0]
_result = iter(result)


def get_url() -> str:
    # ๋ช‡๋ช‡ ์ด๋ฏธ์ง€๊ฐ€ info.csv ๋ฐ์ดํ„ฐ์— ์—†์Šต๋‹ˆ๋‹ค.
    while True:
        r = next(_result)
        photo_id = img_id[r["corpus_id"]]
        target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
        if len(target_series) == 0:
            continue
        img_url = target_series.iloc[0]
        return img_url


columns = st.columns(3) + st.columns(3)
for col in columns:
    img_url = get_url()
    col.image(img_url, use_column_width=True)